Skip to content
Snippets Groups Projects
Commit 6e49635f authored by Sun Jin Kim's avatar Sun Jin Kim
Browse files

cleaned up imports in /autoaugment_learners

parent 555ef1b4
No related branches found
No related tags found
No related merge requests found
# The parent class for all other autoaugment learners``
# The parent class for all other autoaugment learners
import torch
import numpy as np
from MetaAugment.main import *
import MetaAugment.child_networks as cn
import torchvision.transforms as transforms
from MetaAugment.autoaugment_learners.autoaugment import *
import torch.nn as nn
import torch.optim as optim
from MetaAugment.main import train_child_network, create_toy
from MetaAugment.autoaugment_learners.autoaugment import AutoAugment
import torchvision.transforms.autoaugment as torchaa
from torchvision.transforms import functional as F, InterpolationMode
import torchvision.transforms as transforms
from pprint import pprint
......
import torch
import numpy as np
import torchvision.transforms as transforms
import torchvision.transforms.autoaugment as torchaa
from torchvision.transforms import functional as F, InterpolationMode
from MetaAugment.main import *
import MetaAugment.child_networks as cn
from MetaAugment.autoaugment_learners.autoaugment import *
from MetaAugment.autoaugment_learners.aa_learner import *
from MetaAugment.autoaugment_learners.aa_learner import aa_learner
from MetaAugment.controller_networks.rnn_controller import RNNModel
from pprint import pprint
......@@ -33,6 +28,7 @@ augmentation_space = [
("Invert", False),
]
class gru_learner(aa_learner):
# Uses a GRU controller which is updated via Proximal Polixy Optimization
# It is the same model use in
......@@ -50,8 +46,9 @@ class gru_learner(aa_learner):
'''
super().__init__(sp_num, fun_num, p_bins, m_bins, discrete_p_m)
# TODO: We should probably use a different way to store results than self.history
self.history = []
# input_size of the RNNModel can be chosen arbitrarily as we don't put any inputs in it.
self.controller = RNNModel(mode='GRU', input_size=1, hidden_size=40, num_layers=1,
bias=True, output_size=fun_num+p_bins+m_bins)
def generate_new_policy(self):
......@@ -68,25 +65,7 @@ class gru_learner(aa_learner):
(("ShearY", 0.5, 8), ("Invert", 0.7, None)),
]
'''
new_policy = []
for _ in range(self.sp_num): # generate sp_num subpolicies for each policy
ops = []
# generate 2 operations for each subpolicy
for i in range(2):
# if our agent uses discrete representations of probability and magnitude
if self.discrete_p_m:
new_op = self.generate_new_discrete_operation()
else:
new_op = self.generate_new_continuous_operation()
new_op = self.translate_operation_tensor(new_op)
ops.append(new_op)
new_subpolicy = tuple(ops)
new_policy.append(new_subpolicy)
return new_policy
new_policy = self.controller(input=torch.rand(1))
def learn(self, train_dataset, test_dataset, child_network_architecture, toy_flag):
......@@ -114,13 +93,15 @@ if __name__=='__main__':
# We can initialize the train_dataset with its transform as None.
# Later on, we will change this object's transform attribute to the policy
# that we want to test
train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False,
train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=True,
transform=None)
test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=False,
test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=True,
transform=torchvision.transforms.ToTensor())
child_network = cn.lenet
rs_learner = randomsearch_learner(discrete_p_m=False)
rs_learner.learn(train_dataset, test_dataset, child_network, toy_flag=True)
pprint(rs_learner.history)
\ No newline at end of file
learner = gru_learner(discrete_p_m=False)
print(learner.generate_new_policy())
breakpoint()
learner.learn(train_dataset, test_dataset, child_network, toy_flag=True)
pprint(learner.history)
\ No newline at end of file
import torch
import numpy as np
import torchvision.transforms as transforms
import torchvision.transforms.autoaugment as torchaa
from torchvision.transforms import functional as F, InterpolationMode
from MetaAugment.main import *
import MetaAugment.child_networks as cn
from MetaAugment.autoaugment_learners.autoaugment import *
from MetaAugment.autoaugment_learners.aa_learner import *
from MetaAugment.autoaugment_learners.aa_learner import aa_learner
from pprint import pprint
......@@ -43,9 +38,7 @@ class randomsearch_learner(aa_learner):
m_bins: number of bins we divide the magnitude space
'''
super().__init__(sp_num, fun_num, p_bins, m_bins, discrete_p_m)
# TODO: We should probably use a different way to store results than self.history
self.history = []
def generate_new_discrete_operation(self):
'''
......
......@@ -2,11 +2,9 @@ import numpy as np
import torch
torch.manual_seed(0)
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms.autoaugment as autoaugment
#import MetaAugment.AutoAugmentDemo.ops as ops #
# code from https://github.com/ChawDoe/LeNet5-MNIST-PyTorch/blob/master/train.py
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment