diff --git a/MetaAugment/autoaugment_learners/aa_learner.py b/MetaAugment/autoaugment_learners/aa_learner.py index 48d4f051410ce67e1593167c239284061e48953b..e4460cbfdca799022cd2f1d1ff950cd780355fa1 100644 --- a/MetaAugment/autoaugment_learners/aa_learner.py +++ b/MetaAugment/autoaugment_learners/aa_learner.py @@ -46,7 +46,6 @@ class aa_learner: def __init__(self, # parameters that define the search space sp_num=5, - fun_num=14, p_bins=11, m_bins=10, discrete_p_m=False, @@ -57,6 +56,7 @@ class aa_learner: learning_rate=1e-1, max_epochs=float('inf'), early_stop_num=20, + exclude_method = [], ): """ Args: @@ -84,11 +84,9 @@ class aa_learner: """ # related to defining the search space self.sp_num = sp_num - self.fun_num = fun_num self.p_bins = p_bins self.m_bins = m_bins self.discrete_p_m = discrete_p_m - self.op_tensor_length = fun_num+p_bins+m_bins if discrete_p_m else fun_num+2 # related to training of the child_network self.batch_size = batch_size @@ -101,6 +99,9 @@ class aa_learner: # TODO: We should probably use a different way to store results than self.history self.history = [] + self.augmentation_space = [x for x in augmentation_space if x not in exclude_method] + self.fun_num = len(augmentation_space) + self.op_tensor_length = self.fun_num + p_bins + m_bins if discrete_p_m else self.fun_num +2 def translate_operation_tensor(self, operation_tensor, return_log_prob=False, argmax=False): @@ -309,7 +310,8 @@ class aa_learner: child_network_architecture, train_dataset, test_dataset, - logging=False): + logging=False, + print_every_epoch=True): """ Given a policy (using AutoAugment paper terminology), we train a child network using the policy and return the accuracy (how good the policy is for the dataset and @@ -384,7 +386,7 @@ class aa_learner: max_epochs = self.max_epochs, early_stop_num = self.early_stop_num, logging = logging, - print_every_epoch=True) + print_every_epoch=print_every_epoch) # if logging is true, 'accuracy' is actually a tuple: (accuracy, accuracy_log) return accuracy diff --git a/MetaAugment/autoaugment_learners/autoaugment.py b/MetaAugment/autoaugment_learners/autoaugment.py index 8e10c74547f7230a0eeecf11356804413721f7c1..5a8ecbcf6f0b8c6212a8c034a70d61476f4870f6 100644 --- a/MetaAugment/autoaugment_learners/autoaugment.py +++ b/MetaAugment/autoaugment_learners/autoaugment.py @@ -238,6 +238,8 @@ class AutoAugment(torch.nn.Module): if probs[i] <= p: op_meta = self._augmentation_space(10, F.get_image_size(img)) magnitudes, signed = op_meta[op_name] + print("magnitude_id: ", magnitude_id) + print("magnitudes[magnitude_id]: ", magnitudes[magnitude_id]) magnitude = float(magnitudes[magnitude_id].item()) if magnitude_id is not None else 0.0 if signed and signs[i] == 0: magnitude *= -1.0 diff --git a/MetaAugment/autoaugment_learners/evo_learner.py b/MetaAugment/autoaugment_learners/evo_learner.py index 18ecf751e614585c7db86902eb3cce927dd696f5..e9a65865c46b786005c01b4ff9d19d418baaa988 100644 --- a/MetaAugment/autoaugment_learners/evo_learner.py +++ b/MetaAugment/autoaugment_learners/evo_learner.py @@ -6,34 +6,31 @@ import pygad import pygad.torchga as torchga import copy import torch -from MetaAugment.controller_networks.evo_controller import evo_controller + +from MetaAugment.autoaugment_learners.aa_learner import aa_learner, augmentation_space import MetaAugment.child_networks as cn -from .aa_learner import aa_learner, augmentation_space class evo_learner(aa_learner): def __init__(self, sp_num=1, - num_solutions = 10, - num_parents_mating = 5, + num_solutions = 5, + num_parents_mating = 3, learning_rate = 1e-1, max_epochs=float('inf'), early_stop_num=20, - train_loader = None, - child_network = None, p_bins = 1, m_bins = 1, discrete_p_m=False, batch_size=8, toy_flag=False, toy_size=0.1, - fun_num = 14, exclude_method=[], + controller = None ): super().__init__(sp_num, - fun_num, p_bins, m_bins, discrete_p_m=discrete_p_m, @@ -42,27 +39,24 @@ class evo_learner(aa_learner): toy_size=toy_size, learning_rate=learning_rate, max_epochs=max_epochs, - early_stop_num=early_stop_num,) + early_stop_num=early_stop_num, + exclude_method=exclude_method) self.num_solutions = num_solutions - self.auto_aug_agent = evo_controller(fun_num=fun_num, p_bins=p_bins, m_bins=m_bins, sub_num_pol=sp_num) - self.torch_ga = torchga.TorchGA(model=self.auto_aug_agent, num_solutions=num_solutions) + self.controller = controller + self.torch_ga = torchga.TorchGA(model=self.controller, num_solutions=num_solutions) self.num_parents_mating = num_parents_mating self.initial_population = self.torch_ga.population_weights - self.train_loader = train_loader - self.child_network = child_network self.p_bins = p_bins self.sub_num_pol = sp_num self.m_bins = m_bins - self.fun_num = fun_num - self.augmentation_space = [x for x in augmentation_space if x[0] not in exclude_method] - + self.policy_dict = {} + self.policy_result = [] assert num_solutions > num_parents_mating, 'Number of solutions must be larger than the number of parents mating!' - def get_full_policy(self, x): """ Generates the full policy (self.num_sub_pol subpolicies). Network architecture requires @@ -79,8 +73,8 @@ class evo_learner(aa_learner): Full policy consisting of tuples of subpolicies. Each subpolicy consisting of two transformations, with a probability and magnitude float for each """ - section = self.auto_aug_agent.fun_num + self.auto_aug_agent.p_bins + self.auto_aug_agent.m_bins - y = self.auto_aug_agent.forward(x) + section = self.fun_num + self.p_bins + self.m_bins + y = self.controller.forward(x) full_policy = [] for pol in range(self.sub_num_pol): int_pol = [] @@ -89,8 +83,22 @@ class evo_learner(aa_learner): trans, need_mag = self.augmentation_space[idx_ret] - p_ret = (1/(self.p_bins-1)) * torch.argmax(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0)) - mag = torch.argmax(y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0)) if need_mag else None + if self.p_bins == 1: + p_ret = min(1, max(0, (y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0).item()))) + # p_ret = torch.sigmoid(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0)) + else: + p_ret = torch.argmax(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0).item()) * 0.1 + + + if need_mag: + # print("original mag", y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0)) + if self.m_bins == 1: + mag = min(9, max(0, (y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0).item()))) + else: + mag = torch.argmax(y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0).item()) + mag = int(mag) + else: + mag = None int_pol.append((trans, p_ret, mag)) full_policy.append(tuple(int_pol)) @@ -117,18 +125,18 @@ class evo_learner(aa_learner): Subpolicy consisting of two tuples of policies, each with a string associated to a transformation, a float for a probability, and a float for a magnittude """ - section = self.auto_aug_agent.fun_num + self.auto_aug_agent.p_bins + self.auto_aug_agent.m_bins + section = self.fun_num + self.p_bins + self.m_bins - y = self.auto_aug_agent.forward(x) + y = self.controller.forward(x) - y_1 = torch.softmax(y[:,:self.auto_aug_agent.fun_num], dim = 1) - y[:,:self.auto_aug_agent.fun_num] = y_1 - y_2 = torch.softmax(y[:,section:section+self.auto_aug_agent.fun_num], dim = 1) - y[:,section:section+self.auto_aug_agent.fun_num] = y_2 + y_1 = torch.softmax(y[:,:self.fun_num], dim = 1) + y[:,:self.fun_num] = y_1 + y_2 = torch.softmax(y[:,section:section+self.fun_num], dim = 1) + y[:,section:section+self.fun_num] = y_2 concat = torch.cat((y_1, y_2), dim = 1) cov_mat = torch.cov(concat.T) - cov_mat = cov_mat[:self.auto_aug_agent.fun_num, self.auto_aug_agent.fun_num:] + cov_mat = cov_mat[:self.fun_num, self.fun_num:] shape_store = cov_mat.shape counter, prob1, prob2, mag1, mag2 = (0, 0, 0, 0, 0) @@ -154,26 +162,29 @@ class evo_learner(aa_learner): for idx in range(y.shape[0]): if (torch.argmax(y_1[idx]) == max_idx[0]) and (torch.argmax(y_2[idx]) == max_idx[1]): - prob1 += torch.sigmoid(y[idx, self.auto_aug_agent.fun_num]).item() - prob2 += torch.sigmoid(y[idx, section+self.auto_aug_agent.fun_num]).item() + prob1 += torch.sigmoid(y[idx, self.fun_num]).item() + prob2 += torch.sigmoid(y[idx, section+self.fun_num]).item() if mag1 is not None: - mag1 += min(max(0, (y[idx, self.auto_aug_agent.fun_num+1]).item()), 8) + # mag1 += min(max(0, (y[idx, self.auto_aug_agent.fun_num+1]).item()), 8) + mag1 += 10 * torch.sigmoid(y[idx, self.fun_num+1]).item() if mag2 is not None: - mag2 += min(max(0, y[idx, section+self.auto_aug_agent.fun_num+1].item()), 8) + # mag2 += min(max(0, y[idx, section+self.auto_aug_agent.fun_num+1].item()), 8) + mag2 += 10 * torch.sigmoid(y[idx, self.fun_num+1]).item() + counter += 1 - prob1 = prob1/counter if counter != 0 else 0 - prob2 = prob2/counter if counter != 0 else 0 + prob1 = round(prob1/counter, 1) if counter != 0 else 0 + prob2 = round(prob2/counter, 1) if counter != 0 else 0 if mag1 is not None: - mag1 = mag1/counter + mag1 = int(mag1/counter) if mag2 is not None: - mag2 = mag2/counter + mag2 = int(mag2/counter) - return [(self.augmentation_space[max_idx[0]][0], prob1, mag1), (self.augmentation_space[max_idx[1]][0], prob2, mag2)] + return [((self.augmentation_space[max_idx[0]][0], prob1, mag1), (self.augmentation_space[max_idx[1]][0], prob2, mag2))] - def learn(self, iterations = 15, return_weights = False): + def learn(self, train_dataset, test_dataset, child_network_architecture, iterations = 15, return_weights = False): """ Runs the GA instance and returns the model weights as a dictionary @@ -195,24 +206,52 @@ class evo_learner(aa_learner): Solution_idx -> Int """ self.num_generations = iterations - self.history_best = [0 for i in range(iterations)] - self.history_avg = [0 for i in range(iterations)] + self.history_best = [0 for i in range(iterations+1)] + print("itations: ", iterations) + + self.history_avg = [0 for i in range(iterations+1)] self.gen_count = 0 self.best_model = 0 - self.set_up_instance() + self.set_up_instance(train_dataset, test_dataset, child_network_architecture) + print("train_dataset: ", train_dataset) self.ga_instance.run() - self.history_avg = self.history_avg / self.num_solutions + self.history_avg = [x / self.num_solutions for x in self.history_avg] + print("-----------------------------------------------------------------------------------------------------") solution, solution_fitness, solution_idx = self.ga_instance.best_solution() if return_weights: - return torchga.model_weights_as_dict(model=self.auto_aug_agent, weights_vector=solution) + return torchga.model_weights_as_dict(model=self.controller, weights_vector=solution) else: return solution, solution_fitness, solution_idx - def set_up_instance(self, train_dataset, test_dataset): + def in_pol_dict(self, new_policy): + new_policy = new_policy[0] + trans1, trans2 = new_policy[0][0], new_policy[1][0] + new_set = {new_policy[0][1], new_policy[0][2], new_policy[1][1], new_policy[1][2]} + if trans1 in self.policy_dict: + if trans2 in self.policy_dict[trans1]: + for test_pol in self.policy_dict[trans1][trans2]: + if new_set == test_pol: + return True + self.policy_dict[trans1][trans2].append(new_set) + return False + else: + self.policy_dict[trans1][trans2] = [new_set] + if trans2 in self.policy_dict: + if trans1 in self.policy_dict[trans2]: + for test_pol in self.policy_dict[trans2][trans1]: + if new_set == test_pol: + return True + self.policy_dict[trans2][trans1].append(new_set) + return False + else: + self.policy_dict[trans2][trans1] = [new_set] + + + def set_up_instance(self, train_dataset, test_dataset, child_network_architecture): """ Initialises GA instance, as well as fitness and on_generation functions @@ -233,24 +272,36 @@ class evo_learner(aa_learner): fit_val -> float """ - model_weights_dict = torchga.model_weights_as_dict(model=self.auto_aug_agent, + model_weights_dict = torchga.model_weights_as_dict(model=self.controller, weights_vector=solution) - self.auto_aug_agent.load_state_dict(model_weights_dict) + self.controller.load_state_dict(model_weights_dict) self.train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size) for idx, (test_x, label_x) in enumerate(self.train_loader): - if self.sp_num == 1: - full_policy = self.get_single_policy_cov(test_x) - else: - full_policy = self.get_full_policy(test_x) + # if self.sp_num == 1: + full_policy = self.get_single_policy_cov(test_x) + + + # else: + # full_policy = self.get_full_policy(test_x) + while self.in_pol_dict(full_policy): + full_policy = self.get_single_policy_cov(test_x)[0] -# Checkpoint -> save learner as a pickle - fit_val = ((self.test_autoaugment_policy(full_policy, train_dataset, test_dataset)[0]) / - + self.test_autoaugment_policy(full_policy, train_dataset, test_dataset)[0]) / 2 + fit_val = self.test_autoaugment_policy(full_policy,child_network_architecture,train_dataset,test_dataset) #) / + # + self.test_autoaugment_policy(full_policy, train_dataset, test_dataset)) / 2 + + self.policy_result.append([full_policy, fit_val]) + + if len(self.policy_result) > self.sp_num: + self.policy_result = sorted(self.policy_result, key=lambda x: x[1], reverse=True) + self.policy_result = self.policy_result[:self.sp_num] + print("Appended policy: ", self.policy_result) + if fit_val > self.history_best[self.gen_count]: + print("Best policy: ", full_policy) self.history_best[self.gen_count] = fit_val self.best_model = model_weights_dict @@ -284,6 +335,3 @@ class evo_learner(aa_learner): fitness_func=fitness_func, on_generation = on_generation) - - - diff --git a/MetaAugment/autoaugment_learners/gru_learner.py b/MetaAugment/autoaugment_learners/gru_learner.py index c06edec316eed6982272abc685d6e02735e92adf..db8205d5f335f056f82b0e40557a73031ad72b1a 100644 --- a/MetaAugment/autoaugment_learners/gru_learner.py +++ b/MetaAugment/autoaugment_learners/gru_learner.py @@ -47,7 +47,6 @@ class gru_learner(aa_learner): def __init__(self, # parameters that define the search space sp_num=5, - fun_num=14, p_bins=11, m_bins=10, discrete_p_m=False, @@ -78,10 +77,10 @@ class gru_learner(aa_learner): print('Warning: Incompatible discrete_p_m=True input into gru_learner. \ discrete_p_m=False will be used') - super().__init__(sp_num, - fun_num, - p_bins, - m_bins, + super().__init__( + sp_num=sp_num, + p_bins=p_bins, + m_bins=m_bins, discrete_p_m=True, batch_size=batch_size, toy_flag=toy_flag, diff --git a/MetaAugment/autoaugment_learners/randomsearch_learner.py b/MetaAugment/autoaugment_learners/randomsearch_learner.py index 6541cd3f54980254d0001c969bf2eb90d57b0ad2..09f6626f8a42a35e5006c79188fef3d2947c6418 100644 --- a/MetaAugment/autoaugment_learners/randomsearch_learner.py +++ b/MetaAugment/autoaugment_learners/randomsearch_learner.py @@ -38,7 +38,6 @@ class randomsearch_learner(aa_learner): def __init__(self, # parameters that define the search space sp_num=5, - fun_num=14, p_bins=11, m_bins=10, discrete_p_m=True, @@ -51,10 +50,9 @@ class randomsearch_learner(aa_learner): early_stop_num=30, ): - super().__init__(sp_num, - fun_num, - p_bins, - m_bins, + super().__init__(sp_num=sp_num, + p_bins=p_bins, + m_bins=m_bins, discrete_p_m=discrete_p_m, batch_size=batch_size, toy_flag=toy_flag, diff --git a/MetaAugment/autoaugment_learners/ucb_learner.py b/MetaAugment/autoaugment_learners/ucb_learner.py index 1a4ddf3a0d7d218ac768645d64154f90cd07d134..dc82c2ee75d22dd503f46212dd7251c79bb271db 100644 --- a/MetaAugment/autoaugment_learners/ucb_learner.py +++ b/MetaAugment/autoaugment_learners/ucb_learner.py @@ -1,9 +1,3 @@ -#!/usr/bin/env python -# coding: utf-8 - -# In[1]: - - import numpy as np import torch import torch.nn as nn @@ -26,7 +20,6 @@ class ucb_learner(randomsearch_learner): def __init__(self, # parameters that define the search space sp_num=5, - fun_num=14, p_bins=11, m_bins=10, discrete_p_m=True, @@ -42,7 +35,6 @@ class ucb_learner(randomsearch_learner): ): super().__init__(sp_num=sp_num, - fun_num=14, p_bins=p_bins, m_bins=m_bins, discrete_p_m=discrete_p_m, @@ -53,23 +45,24 @@ class ucb_learner(randomsearch_learner): max_epochs=max_epochs, early_stop_num=early_stop_num,) - self.num_policies = num_policies - # When this learner is initialized we generate `num_policies` number - # of random policies. - # generate_new_policy is inherited from the randomsearch_learner class - self.policies = [] - self.make_more_policies() + # attributes used in the UCB1 algorithm - self.q_values = [0]*self.num_policies - self.best_q_values = [] + self.num_policies = num_policies + + self.policies = [self.generate_new_policy() for _ in range(num_policies)] + + self.avg_accs = [None]*self.num_policies + self.best_avg_accs = [] + self.cnts = [0]*self.num_policies self.q_plus_cnt = [0]*self.num_policies self.total_count = 0 + def make_more_policies(self, n): """generates n more random policies and adds it to self.policies @@ -78,50 +71,71 @@ class ucb_learner(randomsearch_learner): and add to our list of policies """ - self.policies.append([self.generate_new_policy() for _ in n]) + self.policies += [self.generate_new_policy() for _ in range(n)] + + # all the below need to be lengthened to store information for the + # new policies + self.avg_accs += [None for _ in range(n)] + self.cnts += [0 for _ in range(n)] + self.q_plus_cnt += [None for _ in range(n)] + self.num_policies += n + def learn(self, train_dataset, test_dataset, child_network_architecture, - iterations=15): + iterations=15, + print_every_epoch=False): + """continue the UCB algorithm for `iterations` number of turns + """ for this_iter in trange(iterations): - # get the action to try (either initially in order or using best q_plus_cnt value) - # TODO: change this if statemetn - if this_iter >= self.num_policies: - this_policy_idx = np.argmax(self.q_plus_cnt) + # choose which policy we want to test + if None in self.avg_accs: + # if there is a policy we haven't tested yet, we + # test that one + this_policy_idx = self.avg_accs.index(None) this_policy = self.policies[this_policy_idx] - else: - this_policy = this_iter - - - best_acc = self.test_autoaugment_policy( + acc = self.test_autoaugment_policy( this_policy, child_network_architecture, train_dataset, test_dataset, - logging=False + logging=False, + print_every_epoch=print_every_epoch ) - - # update q_values - # TODO: change this if statemetn - if this_iter < self.num_policies: - self.q_values[this_policy_idx] += best_acc + # update q_values (average accuracy) + self.avg_accs[this_policy_idx] = acc else: - self.q_values[this_policy_idx] = (self.q_values[this_policy_idx]*self.cnts[this_policy_idx] + best_acc) / (self.cnts[this_policy_idx] + 1) - - best_q_value = max(self.q_values) - self.best_q_values.append(best_q_value) - + # if we have tested all policies before, we test the + # one with the best q_plus_cnt value + this_policy_idx = np.argmax(self.q_plus_cnt) + this_policy = self.policies[this_policy_idx] + acc = self.test_autoaugment_policy( + this_policy, + child_network_architecture, + train_dataset, + test_dataset, + logging=False, + print_every_epoch=print_every_epoch + ) + # update q_values (average accuracy) + self.avg_accs[this_policy_idx] = (self.avg_accs[this_policy_idx]*self.cnts[this_policy_idx] + acc) / (self.cnts[this_policy_idx] + 1) + + # logging the best avg acc up to now + best_avg_acc = max([x for x in self.avg_accs if x is not None]) + self.best_avg_accs.append(best_avg_acc) + + # print progress for user if (this_iter+1) % 5 == 0: print("Iteration: {},\tQ-Values: {}, Best this_iter: {}".format( this_iter+1, - list(np.around(np.array(self.q_values),2)), - max(list(np.around(np.array(self.q_values),2))) + list(np.around(np.array(self.avg_accs),2)), + max(list(np.around(np.array(self.avg_accs),2))) ) ) @@ -130,10 +144,11 @@ class ucb_learner(randomsearch_learner): self.total_count += 1 # update q_plus_cnt values every turn after the initial sweep through - # TODO: change this if statemetn - if this_iter >= self.num_policies - 1: - for i in range(self.num_policies): - self.q_plus_cnt[i] = self.q_values[i] + np.sqrt(2*np.log(self.total_count)/self.cnts[i]) + for i in range(self.num_policies): + if self.avg_accs[i] is not None: + self.q_plus_cnt[i] = self.avg_accs[i] + np.sqrt(2*np.log(self.total_count)/self.cnts[i]) + + print(self.cnts) diff --git a/temp_util/wapp_util.py b/temp_util/wapp_util.py index 78be118ae9f3143d907cb8b0940bc6283a3e82ac..e48d1c31c44d2e5d6af548eeb15b957094abac17 100644 --- a/temp_util/wapp_util.py +++ b/temp_util/wapp_util.py @@ -17,13 +17,16 @@ from MetaAugment.main import create_toy import pickle def parse_users_learner_spec( + # aalearner type auto_aug_learner, + # search space settings ds, ds_name, exclude_method, num_funcs, num_policies, num_sub_policies, + # child network settings toy_size, IsLeNet, batch_size, diff --git a/test/MetaAugment/test_aa_learner.py b/test/MetaAugment/test_aa_learner.py index 3e2808702a04746e625acd5b463cfe01f56687bd..29af4f6da149a9619bafe30ba03cabe6b77064a7 100644 --- a/test/MetaAugment/test_aa_learner.py +++ b/test/MetaAugment/test_aa_learner.py @@ -25,13 +25,12 @@ def test_translate_operation_tensor(): softmax = torch.nn.Softmax(dim=0) - fun_num = random.randint(1, 14) + fun_num=14 p_bins = random.randint(2, 15) m_bins = random.randint(2, 15) - + agent = aal.aa_learner( sp_num=5, - fun_num=fun_num, p_bins=p_bins, m_bins=m_bins, discrete_p_m=True @@ -54,13 +53,12 @@ def test_translate_operation_tensor(): for i in range(2000): - fun_num = random.randint(1, 14) + fun_num = 14 p_bins = random.randint(1, 15) m_bins = random.randint(1, 15) agent = aal.aa_learner( sp_num=5, - fun_num=fun_num, p_bins=p_bins, m_bins=m_bins, discrete_p_m=False @@ -81,7 +79,6 @@ def test_translate_operation_tensor(): def test_test_autoaugment_policy(): agent = aal.aa_learner( sp_num=5, - fun_num=14, p_bins=11, m_bins=10, discrete_p_m=True, diff --git a/test/MetaAugment/test_gru_learner.py b/test/MetaAugment/test_gru_learner.py index 6ad8204f9b8473482f00d5c5d6a9d1e391cf9e0b..b5c695cfdf2d988408d70d1379af4fbf7738ae15 100644 --- a/test/MetaAugment/test_gru_learner.py +++ b/test/MetaAugment/test_gru_learner.py @@ -14,13 +14,11 @@ def test_generate_new_policy(): """ for _ in range(40): sp_num = random.randint(1,20) - fun_num = random.randint(1, 14) p_bins = random.randint(2, 15) m_bins = random.randint(2, 15) agent = aal.gru_learner( sp_num=sp_num, - fun_num=fun_num, p_bins=p_bins, m_bins=m_bins, cont_mb_size=2 diff --git a/test/MetaAugment/test_randomsearch_learner.py b/test/MetaAugment/test_randomsearch_learner.py index 5b67d98e1f2e40d56b3aac2445f041f1372bbe9f..29cd812b1d428441d405556b53db3e65e2ab7bc6 100644 --- a/test/MetaAugment/test_randomsearch_learner.py +++ b/test/MetaAugment/test_randomsearch_learner.py @@ -16,13 +16,12 @@ def test_generate_new_policy(): def my_test(discrete_p_m): for _ in range(40): sp_num = random.randint(1,20) - fun_num = random.randint(1, 14) + p_bins = random.randint(2, 15) m_bins = random.randint(2, 15) agent = aal.randomsearch_learner( sp_num=sp_num, - fun_num=fun_num, p_bins=p_bins, m_bins=m_bins, discrete_p_m=discrete_p_m diff --git a/test/MetaAugment/test_ucb_learner.py b/test/MetaAugment/test_ucb_learner.py index 564ac80dff999467f6bf91fbc4c55019e0b86d98..7c6635ffe467e9a6cd4beb3b596380c76446b750 100644 --- a/test/MetaAugment/test_ucb_learner.py +++ b/test/MetaAugment/test_ucb_learner.py @@ -1,7 +1,18 @@ import MetaAugment.autoaugment_learners as aal - +import MetaAugment.child_networks as cn +import torchvision +import torchvision.datasets as datasets +from pprint import pprint def test_ucb_learner(): + child_network_architecture = cn.SimpleNet + train_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/train', + train=True, download=True, transform=None) + test_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/test', + train=False, download=True, + transform=torchvision.transforms.ToTensor()) + + learner = aal.ucb_learner( # parameters that define the search space sp_num=5, @@ -10,15 +21,37 @@ def test_ucb_learner(): discrete_p_m=True, # hyperparameters for when training the child_network batch_size=8, - toy_flag=False, - toy_size=0.1, + toy_flag=True, + toy_size=0.001, learning_rate=1e-1, max_epochs=float('inf'), early_stop_num=30, # ucb_learner specific hyperparameter - num_policies=100 + num_policies=3 ) - print(learner.policies) + pprint(learner.policies) + assert len(learner.policies)==len(learner.avg_accs), \ + (len(learner.policies), (len(learner.avg_accs))) + + # learn on the 3 policies we generated + learner.learn( + train_dataset=train_dataset, + test_dataset=test_dataset, + child_network_architecture=child_network_architecture, + iterations=5 + ) + + # let's say we want to explore more policies: + # we generate more new policies + learner.make_more_policies(n=4) + + # and let's explore how good those are as well + learner.learn( + train_dataset=train_dataset, + test_dataset=test_dataset, + child_network_architecture=child_network_architecture, + iterations=7 + ) if __name__=="__main__": - test_ucb_learner() \ No newline at end of file + test_ucb_learner()