Skip to content
Snippets Groups Projects
Commit 083be641 authored by Sun Jin Kim's avatar Sun Jin Kim
Browse files

move all datasets to /datasets/mnist folder. MAKE SURE TO USE PARAMETER...

move all datasets to /datasets/mnist folder. MAKE SURE TO USE PARAMETER DOWNLOAD=TRUE TO DOWNLOAD MNIST DATA AGAIN
parent c5af94a8
No related branches found
No related tags found
No related merge requests found
...@@ -90,8 +90,8 @@ def train_model(transform_idx, p): ...@@ -90,8 +90,8 @@ def train_model(transform_idx, p):
batch_size = 32 batch_size = 32
n_samples = 0.005 n_samples = 0.005
train_dataset = datasets.MNIST(root='./MetaAugment/train', train=True, download=False, transform=transform_train) train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False, transform=transform_train)
test_dataset = datasets.MNIST(root='./MetaAugment/test', train=False, download=False, transform=torchvision.transforms.ToTensor()) test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=False, transform=torchvision.transforms.ToTensor())
# create toy dataset from above uploaded data # create toy dataset from above uploaded data
train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, 0.01) train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, 0.01)
...@@ -142,8 +142,8 @@ def callback_generation(ga_instance): ...@@ -142,8 +142,8 @@ def callback_generation(ga_instance):
# ORGANISING DATA # ORGANISING DATA
# transforms = ['RandomResizedCrop', 'RandomHorizontalFlip', 'RandomVerticalCrop', 'RandomRotation'] # transforms = ['RandomResizedCrop', 'RandomHorizontalFlip', 'RandomVerticalCrop', 'RandomRotation']
train_dataset = datasets.MNIST(root='./MetaAugment/train', train=True, download=True, transform=torchvision.transforms.ToTensor()) train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=True, transform=torchvision.transforms.ToTensor())
test_dataset = datasets.MNIST(root='./MetaAugment/test', train=False, download=True, transform=torchvision.transforms.ToTensor()) test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=True, transform=torchvision.transforms.ToTensor())
n_samples = 0.02 n_samples = 0.02
# shuffle and take first n_samples %age of training dataset # shuffle and take first n_samples %age of training dataset
shuffled_train_dataset = torch.utils.data.Subset(train_dataset, torch.randperm(len(train_dataset)).tolist()) shuffled_train_dataset = torch.utils.data.Subset(train_dataset, torch.randperm(len(train_dataset)).tolist())
......
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` ```
import numpy as np import numpy as np
import torch import torch
torch.manual_seed(0) torch.manual_seed(0)
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
import torch.utils.data as data_utils import torch.utils.data as data_utils
import torchvision import torchvision
import torchvision.datasets as datasets import torchvision.datasets as datasets
import child_networks import child_networks
from main import * from main import *
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` ```
"""Randomly generate 10 policies""" """Randomly generate 10 policies"""
"""Each policy has 5 sub-policies""" """Each policy has 5 sub-policies"""
"""For each sub-policy, pick 2 transformations, 2 probabilities and 2 magnitudes""" """For each sub-policy, pick 2 transformations, 2 probabilities and 2 magnitudes"""
def generate_policies(num_policies, num_sub_policies): def generate_policies(num_policies, num_sub_policies):
policies = np.zeros([num_policies,num_sub_policies,6]) policies = np.zeros([num_policies,num_sub_policies,6])
# Policies array will be 10x5x6 # Policies array will be 10x5x6
for policy in range(num_policies): for policy in range(num_policies):
for sub_policy in range(num_sub_policies): for sub_policy in range(num_sub_policies):
# pick two sub_policy transformations (0=rotate, 1=shear, 2=scale) # pick two sub_policy transformations (0=rotate, 1=shear, 2=scale)
policies[policy, sub_policy, 0] = np.random.randint(0,3) policies[policy, sub_policy, 0] = np.random.randint(0,3)
policies[policy, sub_policy, 1] = np.random.randint(0,3) policies[policy, sub_policy, 1] = np.random.randint(0,3)
while policies[policy, sub_policy, 0] == policies[policy, sub_policy, 1]: while policies[policy, sub_policy, 0] == policies[policy, sub_policy, 1]:
policies[policy, sub_policy, 1] = np.random.randint(0,3) policies[policy, sub_policy, 1] = np.random.randint(0,3)
# pick probabilities # pick probabilities
policies[policy, sub_policy, 2] = np.random.randint(0,11) / 10 policies[policy, sub_policy, 2] = np.random.randint(0,11) / 10
policies[policy, sub_policy, 3] = np.random.randint(0,11) / 10 policies[policy, sub_policy, 3] = np.random.randint(0,11) / 10
# pick magnitudes # pick magnitudes
for transformation in range(2): for transformation in range(2):
if policies[policy, sub_policy, transformation] <= 1: if policies[policy, sub_policy, transformation] <= 1:
policies[policy, sub_policy, transformation + 4] = np.random.randint(-4,5)*5 policies[policy, sub_policy, transformation + 4] = np.random.randint(-4,5)*5
elif policies[policy, sub_policy, transformation] == 2: elif policies[policy, sub_policy, transformation] == 2:
policies[policy, sub_policy, transformation + 4] = np.random.randint(5,15)/10 policies[policy, sub_policy, transformation + 4] = np.random.randint(5,15)/10
return policies return policies
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` ```
"""Pick policy and sub-policy""" """Pick policy and sub-policy"""
"""Each row of data should have a different sub-policy but for now, this will do""" """Each row of data should have a different sub-policy but for now, this will do"""
def sample_sub_policy(policies, policy, num_sub_policies): def sample_sub_policy(policies, policy, num_sub_policies):
sub_policy = np.random.randint(0,num_sub_policies) sub_policy = np.random.randint(0,num_sub_policies)
degrees = 0 degrees = 0
shear = 0 shear = 0
scale = 1 scale = 1
# check for rotations # check for rotations
if policies[policy, sub_policy][0] == 0: if policies[policy, sub_policy][0] == 0:
if np.random.uniform() < policies[policy, sub_policy][2]: if np.random.uniform() < policies[policy, sub_policy][2]:
degrees = policies[policy, sub_policy][4] degrees = policies[policy, sub_policy][4]
elif policies[policy, sub_policy][1] == 0: elif policies[policy, sub_policy][1] == 0:
if np.random.uniform() < policies[policy, sub_policy][3]: if np.random.uniform() < policies[policy, sub_policy][3]:
degrees = policies[policy, sub_policy][5] degrees = policies[policy, sub_policy][5]
# check for shears # check for shears
if policies[policy, sub_policy][0] == 1: if policies[policy, sub_policy][0] == 1:
if np.random.uniform() < policies[policy, sub_policy][2]: if np.random.uniform() < policies[policy, sub_policy][2]:
shear = policies[policy, sub_policy][4] shear = policies[policy, sub_policy][4]
elif policies[policy, sub_policy][1] == 1: elif policies[policy, sub_policy][1] == 1:
if np.random.uniform() < policies[policy, sub_policy][3]: if np.random.uniform() < policies[policy, sub_policy][3]:
shear = policies[policy, sub_policy][5] shear = policies[policy, sub_policy][5]
# check for scales # check for scales
if policies[policy, sub_policy][0] == 2: if policies[policy, sub_policy][0] == 2:
if np.random.uniform() < policies[policy, sub_policy][2]: if np.random.uniform() < policies[policy, sub_policy][2]:
scale = policies[policy, sub_policy][4] scale = policies[policy, sub_policy][4]
elif policies[policy, sub_policy][1] == 2: elif policies[policy, sub_policy][1] == 2:
if np.random.uniform() < policies[policy, sub_policy][3]: if np.random.uniform() < policies[policy, sub_policy][3]:
scale = policies[policy, sub_policy][5] scale = policies[policy, sub_policy][5]
return degrees, shear, scale return degrees, shear, scale
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` ```
"""Sample policy, open and apply above transformations""" """Sample policy, open and apply above transformations"""
def run_UCB1(policies, batch_size, toy_size, max_epochs, early_stop_num, iterations): def run_UCB1(policies, batch_size, toy_size, max_epochs, early_stop_num, iterations):
# get number of policies and sub-policies # get number of policies and sub-policies
num_policies = len(policies) num_policies = len(policies)
num_sub_policies = len(policies[0]) num_sub_policies = len(policies[0])
#Initialize vector weights, counts and regret #Initialize vector weights, counts and regret
q_values = [0]*num_policies q_values = [0]*num_policies
cnts = [0]*num_policies cnts = [0]*num_policies
q_plus_cnt = [0]*num_policies q_plus_cnt = [0]*num_policies
total_count = 0 total_count = 0
for policy in range(iterations): for policy in range(iterations):
# get the action to try (either initially in order or using best q_plus_cnt value) # get the action to try (either initially in order or using best q_plus_cnt value)
if policy >= num_policies: if policy >= num_policies:
this_policy = np.argmax(q_plus_cnt) this_policy = np.argmax(q_plus_cnt)
else: else:
this_policy = policy this_policy = policy
# get info of transformation for this sub-policy # get info of transformation for this sub-policy
degrees, shear, scale = sample_sub_policy(policies, this_policy, num_sub_policies) degrees, shear, scale = sample_sub_policy(policies, this_policy, num_sub_policies)
# create transformations using above info # create transformations using above info
transform = torchvision.transforms.Compose( transform = torchvision.transforms.Compose(
[torchvision.transforms.RandomAffine(degrees=(degrees,degrees), shear=(shear,shear), scale=(scale,scale)), [torchvision.transforms.RandomAffine(degrees=(degrees,degrees), shear=(shear,shear), scale=(scale,scale)),
torchvision.transforms.ToTensor()]) torchvision.transforms.ToTensor()])
# open data and apply these transformations # open data and apply these transformations
train_dataset = datasets.MNIST(root='./MetaAugment/train', train=True, download=True, transform=transform) train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./MetaAugment/test', train=False, download=True, transform=torchvision.transforms.ToTensor()) test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=True, transform=torchvision.transforms.ToTensor())
# create toy dataset from above uploaded data # create toy dataset from above uploaded data
train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, toy_size) train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, toy_size)
# create model # create model
child_network = child_networks.lenet() child_network = child_networks.lenet()
sgd = optim.SGD(child_network.parameters(), lr=1e-1) sgd = optim.SGD(child_network.parameters(), lr=1e-1)
cost = nn.CrossEntropyLoss() cost = nn.CrossEntropyLoss()
best_acc = train_child_network(child_network, train_loader, test_loader, sgd, cost, max_epochs=100) best_acc = train_child_network(child_network, train_loader, test_loader, sgd, cost, max_epochs=100)
# update q_values # update q_values
if policy < num_policies: if policy < num_policies:
q_values[this_policy] += best_acc q_values[this_policy] += best_acc
else: else:
q_values[this_policy] = (q_values[this_policy]*cnts[this_policy] + best_acc) / (cnts[this_policy] + 1) q_values[this_policy] = (q_values[this_policy]*cnts[this_policy] + best_acc) / (cnts[this_policy] + 1)
print(q_values) print(q_values)
# update counts # update counts
cnts[this_policy] += 1 cnts[this_policy] += 1
total_count += 1 total_count += 1
# update q_plus_cnt values every turn after the initial sweep through # update q_plus_cnt values every turn after the initial sweep through
if policy >= num_policies - 1: if policy >= num_policies - 1:
for i in range(num_policies): for i in range(num_policies):
q_plus_cnt[i] = q_values[i] + np.sqrt(2*np.log(total_count)/cnts[i]) q_plus_cnt[i] = q_values[i] + np.sqrt(2*np.log(total_count)/cnts[i])
return q_values return q_values
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` ```
%%time %%time
batch_size = 32 # size of batch inner NN is trained with batch_size = 32 # size of batch inner NN is trained with
toy_size = 0.0002 # total propeortion of training and test set we use toy_size = 0.0002 # total propeortion of training and test set we use
max_epochs = 100 # max number of epochs that is run if early stopping is not hit max_epochs = 100 # max number of epochs that is run if early stopping is not hit
early_stop_num = 10 # max number of worse validation scores before early stopping early_stop_num = 10 # max number of worse validation scores before early stopping
iterations = 20 # total iterations, should be more than the number of policies iterations = 20 # total iterations, should be more than the number of policies
# generate policies and sub-policies # generate policies and sub-policies
num_policies = 10 num_policies = 10
num_sub_policies = 5 num_sub_policies = 5
policies = generate_policies(num_policies, num_sub_policies) policies = generate_policies(num_policies, num_sub_policies)
q_values = run_UCB1(policies, batch_size, toy_size, max_epochs, early_stop_num, iterations) q_values = run_UCB1(policies, batch_size, toy_size, max_epochs, early_stop_num, iterations)
#print(q_values) #print(q_values)
``` ```
%% Output %% Output
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0] [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0] [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0] [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0] [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
main.train_child_network best accuracy: 0.5 main.train_child_network best accuracy: 0.5
main.train_child_network best accuracy: 0.5 main.train_child_network best accuracy: 0.5
main.train_child_network best accuracy: 0.5 main.train_child_network best accuracy: 0.5
main.train_child_network best accuracy: 0.5 main.train_child_network best accuracy: 0.5
main.train_child_network best accuracy: 0.5 main.train_child_network best accuracy: 0.5
main.train_child_network best accuracy: 0.5 main.train_child_network best accuracy: 0.5
main.train_child_network best accuracy: 0.5 main.train_child_network best accuracy: 0.5
main.train_child_network best accuracy: 0.5 main.train_child_network best accuracy: 0.5
main.train_child_network best accuracy: 0.5 main.train_child_network best accuracy: 0.5
main.train_child_network best accuracy: 0.5 main.train_child_network best accuracy: 0.5
[0, 0, 0, 0, 0.5, 0, 0, 0, 0, 0] [0, 0, 0, 0, 0.5, 0, 0, 0, 0, 0]
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
[0, 0, 0, 0, 0.5, 0, 0, 0, 0, 0] [0, 0, 0, 0, 0.5, 0, 0, 0, 0, 0]
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
[0, 0, 0, 0, 0.5, 0, 0, 0, 0, 0] [0, 0, 0, 0, 0.5, 0, 0, 0, 0, 0]
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
[0, 0, 0, 0, 0.5, 0, 0, 0, 0, 0] [0, 0, 0, 0, 0.5, 0, 0, 0, 0, 0]
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
[0, 0, 0, 0, 0.5, 0, 0, 0, 0, 0] [0, 0, 0, 0, 0.5, 0, 0, 0, 0, 0]
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0.5 main.train_child_network best accuracy: 0.5
main.train_child_network best accuracy: 0.5 main.train_child_network best accuracy: 0.5
main.train_child_network best accuracy: 0.5 main.train_child_network best accuracy: 0.5
main.train_child_network best accuracy: 0.5 main.train_child_network best accuracy: 0.5
main.train_child_network best accuracy: 0.5 main.train_child_network best accuracy: 0.5
main.train_child_network best accuracy: 0.5 main.train_child_network best accuracy: 0.5
main.train_child_network best accuracy: 0.5 main.train_child_network best accuracy: 0.5
main.train_child_network best accuracy: 0.5 main.train_child_network best accuracy: 0.5
main.train_child_network best accuracy: 0.5 main.train_child_network best accuracy: 0.5
main.train_child_network best accuracy: 0.5 main.train_child_network best accuracy: 0.5
[0, 0, 0, 0, 0.5, 0, 0, 0, 0, 0.5] [0, 0, 0, 0, 0.5, 0, 0, 0, 0, 0.5]
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
[0, 0, 0, 0, 0.25, 0, 0, 0, 0, 0.5] [0, 0, 0, 0, 0.25, 0, 0, 0, 0, 0.5]
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
[0, 0, 0, 0, 0.25, 0, 0, 0, 0, 0.25] [0, 0, 0, 0, 0.25, 0, 0, 0, 0, 0.25]
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
[0.0, 0, 0, 0, 0.25, 0, 0, 0, 0, 0.25] [0.0, 0, 0, 0, 0.25, 0, 0, 0, 0, 0.25]
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
[0.0, 0.0, 0, 0, 0.25, 0, 0, 0, 0, 0.25] [0.0, 0.0, 0, 0, 0.25, 0, 0, 0, 0, 0.25]
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
[0.0, 0.0, 0.0, 0, 0.25, 0, 0, 0, 0, 0.25] [0.0, 0.0, 0.0, 0, 0.25, 0, 0, 0, 0, 0.25]
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
[0.0, 0.0, 0.0, 0.0, 0.25, 0, 0, 0, 0, 0.25] [0.0, 0.0, 0.0, 0.0, 0.25, 0, 0, 0, 0, 0.25]
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
[0.0, 0.0, 0.0, 0.0, 0.25, 0.0, 0, 0, 0, 0.25] [0.0, 0.0, 0.0, 0.0, 0.25, 0.0, 0, 0, 0, 0.25]
main.train_child_network best accuracy: 0.5 main.train_child_network best accuracy: 0.5
main.train_child_network best accuracy: 0.5 main.train_child_network best accuracy: 0.5
main.train_child_network best accuracy: 0.5 main.train_child_network best accuracy: 0.5
main.train_child_network best accuracy: 0.5 main.train_child_network best accuracy: 0.5
main.train_child_network best accuracy: 0.5 main.train_child_network best accuracy: 0.5
main.train_child_network best accuracy: 0.5 main.train_child_network best accuracy: 0.5
main.train_child_network best accuracy: 0.5 main.train_child_network best accuracy: 0.5
main.train_child_network best accuracy: 0.5 main.train_child_network best accuracy: 0.5
main.train_child_network best accuracy: 0.5 main.train_child_network best accuracy: 0.5
main.train_child_network best accuracy: 0.5 main.train_child_network best accuracy: 0.5
[0.0, 0.0, 0.0, 0.0, 0.25, 0.0, 0.25, 0, 0, 0.25] [0.0, 0.0, 0.0, 0.0, 0.25, 0.0, 0.25, 0, 0, 0.25]
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
[0.0, 0.0, 0.0, 0.0, 0.25, 0.0, 0.25, 0.0, 0, 0.25] [0.0, 0.0, 0.0, 0.0, 0.25, 0.0, 0.25, 0.0, 0, 0.25]
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
main.train_child_network best accuracy: 0 main.train_child_network best accuracy: 0
[0.0, 0.0, 0.0, 0.0, 0.25, 0.0, 0.25, 0.0, 0.0, 0.25] [0.0, 0.0, 0.0, 0.0, 0.25, 0.0, 0.25, 0.0, 0.0, 0.25]
Wall time: 3.92 s Wall time: 3.92 s
......
No preview for this file type
...@@ -82,7 +82,7 @@ class AA_Learner: ...@@ -82,7 +82,7 @@ class AA_Learner:
def __init__(self, controller): def __init__(self, controller):
self.controller = controller self.controller = controller
def learn(self, dataset, child_network, toy_flag): def learn(self, train_dataset, test_dataset, child_network, toy_flag):
''' '''
Deos what is seen in Figure 1 in the AutoAugment paper. Deos what is seen in Figure 1 in the AutoAugment paper.
...@@ -94,9 +94,10 @@ class AA_Learner: ...@@ -94,9 +94,10 @@ class AA_Learner:
while not good_policy_found: while not good_policy_found:
policy = self.controller.pop_policy() policy = self.controller.pop_policy()
train_loader, test_loader = prepare_dataset(dataset, policy, toy_flag) train_loader, test_loader = create_toy(train_dataset, test_dataset,
batch_size=32, n_samples=0.005)
reward = train_model(child_network, train_loader, test_loader, sgd, cost, epoch) reward = train_child_network(child_network, train_loader, test_loader, sgd, cost, epoch)
self.controller.update(reward, policy) self.controller.update(reward, policy)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment