Skip to content
Snippets Groups Projects
Commit c8a9d95c authored by Seince, Maxime's avatar Seince, Maxime
Browse files

Upload New File

parent bb639d86
No related branches found
No related tags found
No related merge requests found
import matplotlib.pyplot as plt
def plot_learning_curves(avg_train_losses, avg_val_losses, parameters, is_save = False, save_path = None) :
# Plot the learning curves
fig = plt.figure(figsize=(12,6))
plt.plot(range(1,len(avg_train_losses)+1), avg_train_losses, label='Training Loss', color='black')
plt.plot([i for i in range(0, len(avg_train_losses), parameters['eval_frequency'])], avg_val_losses,label = 'Validation Loss', color='red')
# Plot at which epoch early stopping started
min_index = (len(avg_val_losses) - 6) * parameters['eval_frequency']
plt.axvline(min_index, linestyle = '--', label = 'Beginning of Early Stopping', color = 'green')
# Customizing the plot
plt.title('Learning curves')
plt.xlabel('Epochs')
plt.ylabel('Losses')
plt.legend(loc='upper right')
plt.grid(True)
plt.show()
if is_save and (save_path != None):
fig.savefig(save_path) #'/media/data/MSc_students_accounts/maxime/figures/Learning_Curves_Unet_baseline.png'
def plot_model_features(model, training_loader, device):
fig = plt.tight_layout()
activation = {}
def get_activation(name):
def hook(model, input, output):
activation[name] = output.detach()
return hook
vis_labels = ['inc', 'down1', 'down2', 'down3', 'down4', 'up1', 'up2', 'up3', 'up4'] #'layer4', 'layer5', 'layer6']
for l in vis_labels:
getattr(model, l).register_forward_hook(get_activation(l))
one_batch = next(iter(training_loader))
data = one_batch[0][0, ...]#['mri_slice']['data']
data = data.to(device = device)#, dtype = dtype)
output = model(data)
for idx, l in enumerate(vis_labels):
act = activation[l].squeeze()
# only showing the first 16 channels
ncols, nrows = 8, 2
fig, axarr = plt.subplots(nrows, ncols, figsize=(15,5))
fig.suptitle(l)
count = 0
for i in range(nrows):
for j in range(ncols):
axarr[i, j].imshow(act[count].cpu())
axarr[i, j].axis('off')
count += 1
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment