Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
cc215
Cardiac_Multi_view_segmentation
Commits
a8edd375
Commit
a8edd375
authored
Jul 28, 2021
by
cc215
💬
Browse files
install
parent
05aa3706
Changes
3
Hide whitespace changes
Inline
Side-by-side
image_transformer/adv_compose_transform.py
View file @
a8edd375
...
...
@@ -149,7 +149,6 @@ class ComposeAdversarialTransform(object):
dist
=
self
.
loss_fn
(
pred
=
adv_output
,
reference
=
init_output
.
detach
(),
mask
=
None
)
mask
=
torch
.
ones_like
(
adv_output
)
dist
=
1
/
len
(
chain_of_transforms
)
*
dist
# model.train()
return
dist
,
adv_data
,
adv_output
,
warped_back_adv_output
...
...
@@ -243,7 +242,7 @@ class ComposeAdversarialTransform(object):
def
optimizing_transform
(
self
,
model
,
data
,
init_output
,
power_iterations
,
n_iter
=
1
):
## optimize each transform with one forward pass.
set_grad
(
model
,
requires_grad
=
False
)
#
model.eval()
model
.
eval
()
for
i
in
range
(
n_iter
):
self
.
make_learnable_transformation
(
power_iterations
=
power_iterations
,
chain_of_transforms
=
self
.
chain_of_transforms
)
augmented_data
=
self
.
forward
(
data
)
...
...
@@ -253,14 +252,14 @@ class ComposeAdversarialTransform(object):
if
self
.
require_bi_loss
:
warped_back_prediction
=
self
.
backward
(
perturbed_output
)
mask
=
torch
.
ones_like
(
perturbed_output
,
device
=
augmented_data
.
device
)
mask
.
requires_grad
=
Fals
e
mask
.
requires_grad
=
Tru
e
with
torch
.
no_grad
():
forward_mask
=
self
.
predict_forward
(
mask
)
backward_mask
=
self
.
predict_backward
(
forward_mask
)
forward_reference
=
self
.
predict_forward
(
init_output
.
detach
())
dist
=
0.5
*
(
self
.
loss_fn
(
pred
=
warped_back_prediction
,
reference
=
init_output
.
detach
(),
mask
=
backward_mask
.
detach
()
)
+
self
.
loss_fn
(
pred
=
perturbed_output
,
reference
=
forward_reference
,
mask
=
forward_mask
.
detach
()
))
dist
=
0.5
*
(
self
.
loss_fn
(
pred
=
warped_back_prediction
,
reference
=
init_output
.
detach
(),
mask
=
backward_mask
)
+
self
.
loss_fn
(
pred
=
perturbed_output
,
reference
=
forward_reference
,
mask
=
forward_mask
))
else
:
print
(
'here'
)
dist
=
self
.
loss_fn
(
pred
=
perturbed_output
,
reference
=
init_output
.
detach
(),
mask
=
None
)
...
...
@@ -280,7 +279,7 @@ class ComposeAdversarialTransform(object):
transform
.
rescale_parameters
(
power_iteration
=
power_iteration
)
transform
.
eval
()
transforms
.
append
(
transform
)
#
model.train()
model
.
train
()
set_grad
(
model
,
requires_grad
=
True
)
return
transforms
...
...
@@ -288,7 +287,7 @@ class ComposeAdversarialTransform(object):
def
optimizing_transform_independent
(
self
,
data
,
model
,
init_output
,
power_iterations
,
lazy_load
=
False
,
n_iter
=
1
):
## optimize each transform individually.
#
model.eval()
model
.
eval
()
set_grad
(
model
,
requires_grad
=
False
)
new_transforms
=
[]
...
...
@@ -299,16 +298,16 @@ class ComposeAdversarialTransform(object):
augmented_data
=
transform
.
forward
(
data
)
perturbed_output
=
model
(
augmented_data
)
if
transform
.
is_geometric
()
>
0
:
warped_back_prediction
=
self
.
backward
(
perturbed_output
)
warped_back_prediction
=
transform
.
backward
(
perturbed_output
)
mask
=
torch
.
ones_like
(
perturbed_output
,
device
=
augmented_data
.
device
)
mask
.
requires_grad
=
Fals
e
mask
.
requires_grad
=
Tru
e
with
torch
.
no_grad
():
forward_mask
=
self
.
predict_forward
(
mask
)
backward_mask
=
self
.
predict_backward
(
forward_mask
)
forward_reference
=
self
.
predict_forward
(
init_output
.
detach
())
dist
=
0.5
*
(
self
.
loss_fn
(
pred
=
warped_back_prediction
,
reference
=
init_output
.
detach
()
,
mask
=
backward_mask
.
detach
()
)
+
self
.
loss_fn
(
pred
=
perturbed_output
,
reference
=
forward_reference
,
mask
=
forward_mask
.
detach
()
))
forward_mask
=
transform
.
predict_forward
(
mask
)
backward_mask
=
transform
.
predict_backward
(
forward_mask
)
forward_reference
=
transform
.
predict_forward
(
init_output
.
detach
())
backward_forward_reference
=
transform
.
predict_backward
(
forward_reference
)
dist
=
0.5
*
(
self
.
loss_fn
(
pred
=
warped_back_prediction
,
reference
=
backward_forward_reference
,
mask
=
backward_mask
)
+
self
.
loss_fn
(
pred
=
perturbed_output
,
reference
=
forward_reference
,
mask
=
forward_mask
))
else
:
dist
=
self
.
loss_fn
(
pred
=
perturbed_output
,
reference
=
init_output
.
detach
(),
mask
=
None
)
# print ('{} dist {} '.format(str(i),dist.item()))
...
...
@@ -318,7 +317,7 @@ class ComposeAdversarialTransform(object):
transform
.
rescale_parameters
(
power_iteration
=
power_iteration
)
transform
.
eval
()
new_transforms
.
append
(
transform
)
#
model.train()
model
.
train
()
set_grad
(
model
,
requires_grad
=
True
)
return
new_transforms
...
...
image_transformer/adv_morph.py
View file @
a8edd375
...
...
@@ -266,7 +266,7 @@ if __name__ == "__main__":
{
'epsilon'
:
1.5
,
'xi'
:
0.5
,
'data_size'
:[
10
,
1
,
128
,
128
],
'vector_size'
:[
4
,
4
],
'vector_size'
:[
128
//
8
,
128
//
8
],
'interpolator_mode'
:
'bilinear'
},
...
...
model/base_segmentation_model.py
View file @
a8edd375
...
...
@@ -114,13 +114,16 @@ class SegmentationModel(nn.Module):
def
forward
(
self
,
input
):
pred
=
self
.
model
.
forward
(
input
)
return
pred
def
eval
(
self
):
self
.
model
.
eval
()
if
self
.
use_ema
:
# First save original parameters before replacing with EMA version
self
.
ema
.
store
(
self
.
model
.
parameters
())
# Copy EMA parameters to model
self
.
ema
.
copy_to
(
self
.
model
.
parameters
())
self
.
model
.
eval
()
def
get_loss
(
self
,
pred
,
targets
=
None
,
loss_type
=
'cross_entropy'
):
if
not
targets
is
None
:
...
...
@@ -134,7 +137,7 @@ class SegmentationModel(nn.Module):
if
not
if_testing
:
self
.
model
.
train
()
if
self
.
use_ema
:
self
.
ema
.
restore
(
model
.
parameters
())
self
.
ema
.
restore
(
self
.
model
.
parameters
())
else
:
self
.
eval
()
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment