From 567998d8f780bdac3412d729a6d183a99832f7fb Mon Sep 17 00:00:00 2001 From: Sun Jin Kim <sk2521@ic.ac.uk> Date: Mon, 25 Apr 2022 18:48:50 +0100 Subject: [PATCH] fully implement `exclude_method` in aa_learners --- .../autoaugment_learners/aa_learner.py | 54 +++++++++---------- .../autoaugment_learners/evo_learner.py | 30 +++++------ .../autoaugment_learners/gru_learner.py | 24 ++------- .../randomsearch_learner.py | 42 +++++---------- .../autoaugment_learners/ucb_learner.py | 14 +++-- 5 files changed, 66 insertions(+), 98 deletions(-) diff --git a/MetaAugment/autoaugment_learners/aa_learner.py b/MetaAugment/autoaugment_learners/aa_learner.py index 44a86aa3..561222a5 100644 --- a/MetaAugment/autoaugment_learners/aa_learner.py +++ b/MetaAugment/autoaugment_learners/aa_learner.py @@ -1,4 +1,3 @@ -from numpy import isin import torch import torch.nn as nn import torch.optim as optim @@ -7,31 +6,10 @@ from MetaAugment.autoaugment_learners.autoaugment import AutoAugment import torchvision.transforms as transforms -from pprint import pprint -import matplotlib.pyplot as plt import copy import types -# We will use this augmentation_space temporarily. Later on we will need to -# make sure we are able to add other image functions if the users want. -augmentation_space = [ - # (function_name, do_we_need_to_specify_magnitude) - ("ShearX", True), - ("ShearY", True), - ("TranslateX", True), - ("TranslateY", True), - ("Rotate", True), - ("Brightness", True), - ("Color", True), - ("Contrast", True), - ("Sharpness", True), - ("Posterize", True), - ("Solarize", True), - ("AutoContrast", False), - ("Equalize", False), - ("Invert", False), - ] class aa_learner: @@ -96,8 +74,30 @@ class aa_learner: # TODO: We should probably use a different way to store results than self.history self.history = [] - self.augmentation_space = [x for x in augmentation_space if x not in exclude_method] - self.fun_num = len(augmentation_space) + + # this is the full augmentation space. We take out some image functions + # if the user specifies so in the exclude_method parameter + augmentation_space = [ + # (function_name, do_we_need_to_specify_magnitude) + ("ShearX", True), + ("ShearY", True), + ("TranslateX", True), + ("TranslateY", True), + ("Rotate", True), + ("Brightness", True), + ("Color", True), + ("Contrast", True), + ("Sharpness", True), + ("Posterize", True), + ("Solarize", True), + ("AutoContrast", False), + ("Equalize", False), + ("Invert", False), + ] + self.exclude_method = exclude_method + self.augmentation_space = [x for x in augmentation_space if x[0] not in exclude_method] + + self.fun_num = len(self.augmentation_space) self.op_tensor_length = self.fun_num + p_bins + m_bins if discrete_p_m else self.fun_num +2 @@ -174,7 +174,7 @@ class aa_learner: prob_idx = torch.multinomial(prob_t, 1).item() # 0 <= p <= 10 mag = torch.multinomial(mag_t, 1).item() # 0 <= m <= 9 - function = augmentation_space[fun_idx][0] + function = self.augmentation_space[fun_idx][0] prob = prob_idx/(self.p_bins-1) indices = (fun_idx, prob_idx, mag) @@ -203,13 +203,13 @@ class aa_learner: prob = round(prob, 1) # round to nearest first decimal digit mag = round(mag) # round to nearest integer - function = augmentation_space[fun_idx][0] + function = self.augmentation_space[fun_idx][0] assert 0 <= prob <= 1, prob assert 0 <= mag <= self.m_bins-1, (mag, self.m_bins) # if the image function does not require a magnitude, we set the magnitude to None - if augmentation_space[fun_idx][1] == True: # if the image function has a magnitude + if self.augmentation_space[fun_idx][1] == True: # if the image function has a magnitude operation = (function, prob, mag) else: operation = (function, prob, None) diff --git a/MetaAugment/autoaugment_learners/evo_learner.py b/MetaAugment/autoaugment_learners/evo_learner.py index 682061ef..34cc2d44 100644 --- a/MetaAugment/autoaugment_learners/evo_learner.py +++ b/MetaAugment/autoaugment_learners/evo_learner.py @@ -1,4 +1,3 @@ -from cgi import test import torch torch.manual_seed(0) import torch.nn as nn @@ -7,19 +6,18 @@ import pygad.torchga as torchga import copy import torch -from MetaAugment.autoaugment_learners.aa_learner import aa_learner, augmentation_space -import MetaAugment.child_networks as cn +from MetaAugment.autoaugment_learners.aa_learner import aa_learner class evo_learner(aa_learner): def __init__(self, # search space settings - discrete_p_m=False, - exclude_method=[], sp_num=5, p_bins=10, m_bins=10, + discrete_p_m=False, + exclude_method=[], # child network settings learning_rate=1e-1, max_epochs=float('inf'), @@ -32,16 +30,18 @@ class evo_learner(aa_learner): controller=None ): - super().__init__(sp_num=sp_num, - p_bins=p_bins, - m_bins=m_bins, - discrete_p_m=discrete_p_m, - batch_size=batch_size, - toy_size=toy_size, - learning_rate=learning_rate, - max_epochs=max_epochs, - early_stop_num=early_stop_num, - exclude_method=exclude_method) + super().__init__( + sp_num=sp_num, + p_bins=p_bins, + m_bins=m_bins, + discrete_p_m=discrete_p_m, + batch_size=batch_size, + toy_size=toy_size, + learning_rate=learning_rate, + max_epochs=max_epochs, + early_stop_num=early_stop_num, + exclude_method=exclude_method + ) self.num_solutions = num_solutions self.controller = controller diff --git a/MetaAugment/autoaugment_learners/gru_learner.py b/MetaAugment/autoaugment_learners/gru_learner.py index 6955257f..5c15a4a4 100644 --- a/MetaAugment/autoaugment_learners/gru_learner.py +++ b/MetaAugment/autoaugment_learners/gru_learner.py @@ -9,25 +9,6 @@ import pickle -# We will use this augmentation_space temporarily. Later on we will need to -# make sure we are able to add other image functions if the users want. -augmentation_space = [ - # (function_name, do_we_need_to_specify_magnitude) - ("ShearX", True), - ("ShearY", True), - ("TranslateX", True), - ("TranslateY", True), - ("Rotate", True), - ("Brightness", True), - ("Color", True), - ("Contrast", True), - ("Sharpness", True), - ("Posterize", True), - ("Solarize", True), - ("AutoContrast", False), - ("Equalize", False), - ("Invert", False), - ] class gru_learner(aa_learner): @@ -50,6 +31,7 @@ class gru_learner(aa_learner): p_bins=11, m_bins=10, discrete_p_m=False, + exclude_method=[], # hyperparameters for when training the child_network batch_size=8, toy_size=1, @@ -85,7 +67,9 @@ class gru_learner(aa_learner): toy_size=toy_size, learning_rate=learning_rate, max_epochs=max_epochs, - early_stop_num=early_stop_num,) + early_stop_num=early_stop_num, + exclude_method=exclude_method, + ) # GRU-specific attributes that aren't in general aa_learner's self.alpha = alpha diff --git a/MetaAugment/autoaugment_learners/randomsearch_learner.py b/MetaAugment/autoaugment_learners/randomsearch_learner.py index 71d8bc1a..2c35fb80 100644 --- a/MetaAugment/autoaugment_learners/randomsearch_learner.py +++ b/MetaAugment/autoaugment_learners/randomsearch_learner.py @@ -10,25 +10,7 @@ import pickle -# We will use this augmentation_space temporarily. Later on we will need to -# make sure we are able to add other image functions if the users want. -augmentation_space = [ - # (function_name, do_we_need_to_specify_magnitude) - ("ShearX", True), - ("ShearY", True), - ("TranslateX", True), - ("TranslateY", True), - ("Rotate", True), - ("Brightness", True), - ("Color", True), - ("Contrast", True), - ("Sharpness", True), - ("Posterize", True), - ("Solarize", True), - ("AutoContrast", False), - ("Equalize", False), - ("Invert", False), - ] + class randomsearch_learner(aa_learner): """ @@ -41,6 +23,7 @@ class randomsearch_learner(aa_learner): p_bins=11, m_bins=10, discrete_p_m=True, + exclude_method=[], # hyperparameters for when training the child_network batch_size=8, toy_size=1, @@ -49,15 +32,18 @@ class randomsearch_learner(aa_learner): early_stop_num=30, ): - super().__init__(sp_num=sp_num, - p_bins=p_bins, - m_bins=m_bins, - discrete_p_m=discrete_p_m, - batch_size=batch_size, - toy_size=toy_size, - learning_rate=learning_rate, - max_epochs=max_epochs, - early_stop_num=early_stop_num,) + super().__init__( + sp_num=sp_num, + p_bins=p_bins, + m_bins=m_bins, + discrete_p_m=discrete_p_m, + batch_size=batch_size, + toy_size=toy_size, + learning_rate=learning_rate, + max_epochs=max_epochs, + early_stop_num=early_stop_num, + exclude_method=exclude_method + ) def generate_new_discrete_operation(self): diff --git a/MetaAugment/autoaugment_learners/ucb_learner.py b/MetaAugment/autoaugment_learners/ucb_learner.py index aa6cd685..fdf735be 100644 --- a/MetaAugment/autoaugment_learners/ucb_learner.py +++ b/MetaAugment/autoaugment_learners/ucb_learner.py @@ -1,15 +1,9 @@ import numpy as np -import torch -import torch.nn as nn -import torch.optim as optim -import torchvision from tqdm import trange from ..child_networks import * -from ..main import train_child_network from .randomsearch_learner import randomsearch_learner -from .aa_learner import augmentation_space class ucb_learner(randomsearch_learner): @@ -23,6 +17,7 @@ class ucb_learner(randomsearch_learner): p_bins=11, m_bins=10, discrete_p_m=True, + exclude_method=[], # hyperparameters for when training the child_network batch_size=8, toy_size=1, @@ -33,7 +28,8 @@ class ucb_learner(randomsearch_learner): num_policies=100 ): - super().__init__(sp_num=sp_num, + super().__init__( + sp_num=sp_num, p_bins=p_bins, m_bins=m_bins, discrete_p_m=discrete_p_m, @@ -41,7 +37,9 @@ class ucb_learner(randomsearch_learner): toy_size=toy_size, learning_rate=learning_rate, max_epochs=max_epochs, - early_stop_num=early_stop_num,) + early_stop_num=early_stop_num, + exclude_method=exclude_method, + ) -- GitLab