From 6e49635fa37c1d537ee9be7edc33390560519d3d Mon Sep 17 00:00:00 2001 From: Sun Jin Kim <sk2521@ic.ac.uk> Date: Thu, 7 Apr 2022 22:52:41 +0900 Subject: [PATCH] cleaned up imports in /autoaugment_learners --- .../autoaugment_learners/aa_learner.py | 14 +++--- .../autoaugment_learners/gru_learner.py | 47 ++++++------------- .../randomsearch_learner.py | 11 +---- MetaAugment/main.py | 2 - 4 files changed, 22 insertions(+), 52 deletions(-) diff --git a/MetaAugment/autoaugment_learners/aa_learner.py b/MetaAugment/autoaugment_learners/aa_learner.py index 4b7cd86a..60d463c2 100644 --- a/MetaAugment/autoaugment_learners/aa_learner.py +++ b/MetaAugment/autoaugment_learners/aa_learner.py @@ -1,14 +1,12 @@ -# The parent class for all other autoaugment learners`` +# The parent class for all other autoaugment learners import torch -import numpy as np -from MetaAugment.main import * -import MetaAugment.child_networks as cn -import torchvision.transforms as transforms -from MetaAugment.autoaugment_learners.autoaugment import * +import torch.nn as nn +import torch.optim as optim +from MetaAugment.main import train_child_network, create_toy +from MetaAugment.autoaugment_learners.autoaugment import AutoAugment -import torchvision.transforms.autoaugment as torchaa -from torchvision.transforms import functional as F, InterpolationMode +import torchvision.transforms as transforms from pprint import pprint diff --git a/MetaAugment/autoaugment_learners/gru_learner.py b/MetaAugment/autoaugment_learners/gru_learner.py index 58970e68..709fdabc 100644 --- a/MetaAugment/autoaugment_learners/gru_learner.py +++ b/MetaAugment/autoaugment_learners/gru_learner.py @@ -1,13 +1,8 @@ import torch -import numpy as np -import torchvision.transforms as transforms -import torchvision.transforms.autoaugment as torchaa -from torchvision.transforms import functional as F, InterpolationMode -from MetaAugment.main import * import MetaAugment.child_networks as cn -from MetaAugment.autoaugment_learners.autoaugment import * -from MetaAugment.autoaugment_learners.aa_learner import * +from MetaAugment.autoaugment_learners.aa_learner import aa_learner +from MetaAugment.controller_networks.rnn_controller import RNNModel from pprint import pprint @@ -33,6 +28,7 @@ augmentation_space = [ ("Invert", False), ] + class gru_learner(aa_learner): # Uses a GRU controller which is updated via Proximal Polixy Optimization # It is the same model use in @@ -50,8 +46,9 @@ class gru_learner(aa_learner): ''' super().__init__(sp_num, fun_num, p_bins, m_bins, discrete_p_m) - # TODO: We should probably use a different way to store results than self.history - self.history = [] + # input_size of the RNNModel can be chosen arbitrarily as we don't put any inputs in it. + self.controller = RNNModel(mode='GRU', input_size=1, hidden_size=40, num_layers=1, + bias=True, output_size=fun_num+p_bins+m_bins) def generate_new_policy(self): @@ -68,25 +65,7 @@ class gru_learner(aa_learner): (("ShearY", 0.5, 8), ("Invert", 0.7, None)), ] ''' - new_policy = [] - - for _ in range(self.sp_num): # generate sp_num subpolicies for each policy - ops = [] - # generate 2 operations for each subpolicy - for i in range(2): - # if our agent uses discrete representations of probability and magnitude - if self.discrete_p_m: - new_op = self.generate_new_discrete_operation() - else: - new_op = self.generate_new_continuous_operation() - new_op = self.translate_operation_tensor(new_op) - ops.append(new_op) - - new_subpolicy = tuple(ops) - - new_policy.append(new_subpolicy) - - return new_policy + new_policy = self.controller(input=torch.rand(1)) def learn(self, train_dataset, test_dataset, child_network_architecture, toy_flag): @@ -114,13 +93,15 @@ if __name__=='__main__': # We can initialize the train_dataset with its transform as None. # Later on, we will change this object's transform attribute to the policy # that we want to test - train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False, + train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=True, transform=None) - test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=False, + test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=True, transform=torchvision.transforms.ToTensor()) child_network = cn.lenet - rs_learner = randomsearch_learner(discrete_p_m=False) - rs_learner.learn(train_dataset, test_dataset, child_network, toy_flag=True) - pprint(rs_learner.history) \ No newline at end of file + learner = gru_learner(discrete_p_m=False) + print(learner.generate_new_policy()) + breakpoint() + learner.learn(train_dataset, test_dataset, child_network, toy_flag=True) + pprint(learner.history) \ No newline at end of file diff --git a/MetaAugment/autoaugment_learners/randomsearch_learner.py b/MetaAugment/autoaugment_learners/randomsearch_learner.py index 3657224f..02798927 100644 --- a/MetaAugment/autoaugment_learners/randomsearch_learner.py +++ b/MetaAugment/autoaugment_learners/randomsearch_learner.py @@ -1,13 +1,8 @@ import torch import numpy as np -import torchvision.transforms as transforms -import torchvision.transforms.autoaugment as torchaa -from torchvision.transforms import functional as F, InterpolationMode -from MetaAugment.main import * import MetaAugment.child_networks as cn -from MetaAugment.autoaugment_learners.autoaugment import * -from MetaAugment.autoaugment_learners.aa_learner import * +from MetaAugment.autoaugment_learners.aa_learner import aa_learner from pprint import pprint @@ -43,9 +38,7 @@ class randomsearch_learner(aa_learner): m_bins: number of bins we divide the magnitude space ''' super().__init__(sp_num, fun_num, p_bins, m_bins, discrete_p_m) - - # TODO: We should probably use a different way to store results than self.history - self.history = [] + def generate_new_discrete_operation(self): ''' diff --git a/MetaAugment/main.py b/MetaAugment/main.py index 0fd76bcf..5b0e04e4 100644 --- a/MetaAugment/main.py +++ b/MetaAugment/main.py @@ -2,11 +2,9 @@ import numpy as np import torch torch.manual_seed(0) import torch.nn as nn -import torch.nn.functional as F import torch.optim as optim import torchvision import torchvision.datasets as datasets -import torchvision.transforms.autoaugment as autoaugment #import MetaAugment.AutoAugmentDemo.ops as ops # # code from https://github.com/ChawDoe/LeNet5-MNIST-PyTorch/blob/master/train.py -- GitLab