Commit a8edd375 authored by cc215's avatar cc215 💬
Browse files

install

parent 05aa3706
......@@ -149,7 +149,6 @@ class ComposeAdversarialTransform(object):
dist = self.loss_fn(pred = adv_output, reference = init_output.detach(), mask=None)
mask= torch.ones_like(adv_output)
dist = 1/len(chain_of_transforms)*dist
# model.train()
return dist,adv_data,adv_output,warped_back_adv_output
......@@ -243,7 +242,7 @@ class ComposeAdversarialTransform(object):
def optimizing_transform(self, model,data,init_output,power_iterations, n_iter=1):
## optimize each transform with one forward pass.
set_grad(model, requires_grad=False)
# model.eval()
model.eval()
for i in range(n_iter):
self.make_learnable_transformation(power_iterations=power_iterations,chain_of_transforms=self.chain_of_transforms)
augmented_data = self.forward(data)
......@@ -253,14 +252,14 @@ class ComposeAdversarialTransform(object):
if self.require_bi_loss:
warped_back_prediction = self.backward(perturbed_output)
mask = torch.ones_like(perturbed_output,device=augmented_data.device)
mask.requires_grad=False
mask.requires_grad=True
with torch.no_grad():
forward_mask=self.predict_forward(mask)
backward_mask =self.predict_backward(forward_mask)
forward_reference=self.predict_forward(init_output.detach())
dist = 0.5*(self.loss_fn(pred = warped_back_prediction, reference =init_output.detach(),mask=backward_mask.detach())
+self.loss_fn(pred=perturbed_output,reference=forward_reference,mask=forward_mask.detach()))
dist = 0.5*(self.loss_fn(pred = warped_back_prediction, reference =init_output.detach(),mask=backward_mask)
+self.loss_fn(pred=perturbed_output,reference=forward_reference,mask=forward_mask))
else:
print ('here')
dist = self.loss_fn(pred = perturbed_output, reference =init_output.detach(),mask=None)
......@@ -280,7 +279,7 @@ class ComposeAdversarialTransform(object):
transform.rescale_parameters(power_iteration =power_iteration)
transform.eval()
transforms.append(transform)
# model.train()
model.train()
set_grad(model, requires_grad=True)
return transforms
......@@ -288,7 +287,7 @@ class ComposeAdversarialTransform(object):
def optimizing_transform_independent(self,data,model,init_output,power_iterations,lazy_load=False,n_iter=1):
## optimize each transform individually.
# model.eval()
model.eval()
set_grad(model, requires_grad=False)
new_transforms = []
......@@ -299,16 +298,16 @@ class ComposeAdversarialTransform(object):
augmented_data = transform.forward(data)
perturbed_output = model(augmented_data)
if transform.is_geometric()>0:
warped_back_prediction = self.backward(perturbed_output)
warped_back_prediction = transform.backward(perturbed_output)
mask = torch.ones_like(perturbed_output,device=augmented_data.device)
mask.requires_grad=False
mask.requires_grad=True
with torch.no_grad():
forward_mask=self.predict_forward(mask)
backward_mask =self.predict_backward(forward_mask)
forward_reference=self.predict_forward(init_output.detach())
dist = 0.5*(self.loss_fn(pred = warped_back_prediction, reference =init_output.detach(),mask=backward_mask.detach())
+self.loss_fn(pred=perturbed_output,reference=forward_reference,mask=forward_mask.detach()))
forward_mask=transform.predict_forward(mask)
backward_mask =transform.predict_backward(forward_mask)
forward_reference=transform.predict_forward(init_output.detach())
backward_forward_reference=transform.predict_backward(forward_reference)
dist = 0.5*(self.loss_fn(pred = warped_back_prediction, reference =backward_forward_reference,mask=backward_mask)
+self.loss_fn(pred=perturbed_output,reference=forward_reference,mask=forward_mask))
else:
dist = self.loss_fn(pred = perturbed_output, reference =init_output.detach(),mask=None)
# print ('{} dist {} '.format(str(i),dist.item()))
......@@ -318,7 +317,7 @@ class ComposeAdversarialTransform(object):
transform.rescale_parameters(power_iteration=power_iteration)
transform.eval()
new_transforms.append(transform)
#model.train()
model.train()
set_grad(model, requires_grad=True)
return new_transforms
......
......@@ -266,7 +266,7 @@ if __name__ == "__main__":
{'epsilon':1.5,
'xi':0.5,
'data_size':[10,1,128,128],
'vector_size':[4,4],
'vector_size':[128//8,128//8],
'interpolator_mode':'bilinear'
},
......
......@@ -114,13 +114,16 @@ class SegmentationModel(nn.Module):
def forward(self, input):
pred = self.model.forward(input)
return pred
def eval(self):
self.model.eval()
if self.use_ema:
# First save original parameters before replacing with EMA version
self.ema.store(self.model.parameters())
# Copy EMA parameters to model
self.ema.copy_to(self.model.parameters())
self.model.eval()
def get_loss(self, pred, targets=None,loss_type='cross_entropy'):
if not targets is None:
......@@ -134,7 +137,7 @@ class SegmentationModel(nn.Module):
if not if_testing:
self.model.train()
if self.use_ema:
self.ema.restore(model.parameters())
self.ema.restore(self.model.parameters())
else:
self.eval()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment