Skip to content
Snippets Groups Projects
Commit 2c12cf2c authored by Seince, Maxime's avatar Seince, Maxime
Browse files

Upload New File

parent 7e44b0ee
No related branches found
No related tags found
No related merge requests found
import tqdm
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import torchio as tio
import monai
import utils
import losses
import data
from models import U_Net
class Supervised_Model:
def __init__(self, parameters) :
self.parameters = parameters
self.device = parameters['device']
self.loss_function = self.parameters['evaluation_loss'].to(self.device)
self.model = U_Net.UNet(self.parameters['n_channels'], self.parameters['n_classes'], self.parameters['n_features_map_supervised']).to(self.device)
if self.parameters['weights_supervised_load_path'] != None :
self.model.load_state_dict(torch.load(self.parameters['weights_supervised_load_path']))
print('Supervised Model Loaded')
def load_model(self, weights_load_path) :
if weights_load_path != None:
self.model.load_state_dict(torch.load(weights_load_path))
def save_best_model(self, validation_losses, save_path) :
if validation_losses[-1] <= np.min(validation_losses[:-1]) :
torch.save(self.model.state_dict(), save_path)
def early_stopping(self, validation_losses) :
if len(validation_losses) - validation_losses.index(np.min(validation_losses)) > 6 :
return True
else :
return False
def run_training(self, training_loader, validation_loader) :
avg_train_losses = []
avg_val_losses = []
optimizer = torch.optim.Adam(self.model.parameters(), lr=self.parameters['learning_rate_supervised'])
for epoch in range(self.parameters['num_epochs']) :
train_loss = []
val_loss = []
# Training
self.model.train()
with tqdm.tqdm(training_loader, unit = 'batch', disable = self.parameters['tqdm_disabled']) as tepoch :
for batch_index, batch in enumerate(tepoch) :
inputs = batch[0].squeeze(1).float().to(self.device)
labels = batch[1].squeeze(1).float().to(self.device)
labels = losses.transform_mask_for_dice_loss(labels, batch).to(self.device)
logits = self.model(inputs)
batch_loss_training = self.loss_function(logits, labels)
train_loss.append(batch_loss_training.item())
optimizer.zero_grad()
batch_loss_training.backward(retain_graph = True)
optimizer.step()
#Logging
tepoch.set_description(f"Epoch {epoch}")
tepoch.set_postfix(training_loss = f'{batch_loss_training.item()}')
avg_train_losses.append(np.average(train_loss))
# Validation
if epoch % self.parameters['eval_frequency'] == 0 :
self.model.eval()
with torch.no_grad() :
with tqdm.tqdm(validation_loader, unit = 'batch', disable = self.parameters['tqdm_disabled']) as tepoch :
for batch_index_val, batch_val in enumerate(tepoch) :
inputs_val = batch_val[0].squeeze(1).float().to(self.device)
labels_val = batch_val[1].squeeze(1).float().to(self.device)
labels_val = losses.transform_mask_for_dice_loss(labels_val, batch_val).to(self.device)
logits_val = self.model(inputs_val)
batch_loss_validation = self.loss_function(logits_val, labels_val)
val_loss.append(batch_loss_validation.item())
# Logging
tepoch.set_description(f"Epoch {epoch}")
tepoch.set_postfix(validation_loss = f'{batch_loss_validation.item()}')
avg_val_losses.append(np.average(val_loss))
if len(avg_val_losses) == 1 :
torch.save(self.model.state_dict(), self.parameters['save_path_supervised_model'])
else :
self.save_best_model(avg_val_losses, self.parameters['save_path_supervised_model'])
if self.early_stopping(avg_val_losses) :
print(f'Supervised Training Early Stopping : Epoch n° {epoch}')
break
return avg_train_losses, avg_val_losses
def run_test(self, testing_loader) :
test_losses = []
self.model.eval()
with torch.no_grad() :
with tqdm.tqdm(testing_loader, unit = 'batch', disable = self.parameters['tqdm_disabled']) as tepoch :
for batch_index, batch in enumerate(tepoch) :
inputs = batch[0].squeeze(1).float().to(self.device)
labels = batch[1].squeeze(1).float().to(self.device)
labels = losses.transform_mask_for_dice_loss(labels, batch).to(self.device)
logits = self.model(inputs)
batch_loss_testing = self.loss_function(logits, labels)#.mean(dim = 0)
test_losses.append(batch_loss_testing.item())#[:, 0, 0].cpu())
return test_losses
def run_test_volume(self, testing_loader_volume) :
test_losses = []
loss_function = monai.losses.DiceLoss(include_background = True,
to_onehot_y = False,
reduction = 'none',
softmax = True)
self.model.eval()
with torch.no_grad() :
with tqdm.tqdm(testing_loader_volume, unit = 'batch', disable = self.parameters['tqdm_disabled']) as tepoch :
for batch_index, batch in enumerate(tepoch) :
inputs = batch[0].squeeze(0).float().to(self.device)
labels = batch[1].permute(0, 2, 3, 4, 1)
labels = losses.transform_mask_for_dice_loss_3D(labels, batch).to(self.device)
logits = self.model(inputs)
logits = logits.permute(1, 2, 3, 0).unsqueeze(0)
batch_loss_testing = loss_function(logits, labels).mean(dim = 0)
test_losses.append(batch_loss_testing[:, 0, 0].cpu())
test_losses_detailed = torch.stack(test_losses).mean(dim = 0)
test_losses = torch.mean(test_losses_detailed)
return test_losses, test_losses_detailed
def run_detailed_test(self, testing_loader) :
test_losses = []
loss_detailed = monai.losses.DiceLoss(include_background = True,
to_onehot_y = False,
reduction = 'none',
softmax = True).to(self.parameters['device'])
self.model.eval()
self.finetuning_layer.eval()
with torch.no_grad() :
with tqdm.tqdm(testing_loader, unit = 'batch', disable = self.parameters['tqdm_disabled']) as tepoch :
for batch_index, batch in enumerate(tepoch) :
inputs = batch[0].squeeze(1).float().to(self.device)
labels = batch[1].squeeze(1).float().to(self.device)
labels = losses.transform_mask_for_dice_loss(labels, batch).to(self.device)
logits = self.finetuning_layer(self.model(inputs))
batch_loss_testing = loss_detailed(logits, labels).mean(dim = 0)
test_losses.append(batch_loss_testing[:, 0, 0].cpu())
return test_losses
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment