From 839fb15f31e70e2bc3f38332a47b8f5e6991e576 Mon Sep 17 00:00:00 2001 From: Sun Jin Kim <sk2521@ic.ac.uk> Date: Mon, 25 Apr 2022 17:09:24 +0100 Subject: [PATCH] FINISH REFACTORING UCB_LEARNER --- .../autoaugment_learners/aa_learner.py | 5 +- .../autoaugment_learners/ucb_learner.py | 101 ++++++++++-------- temp_util/wapp_util.py | 3 + test/MetaAugment/test_ucb_learner.py | 45 ++++++-- 4 files changed, 104 insertions(+), 50 deletions(-) diff --git a/MetaAugment/autoaugment_learners/aa_learner.py b/MetaAugment/autoaugment_learners/aa_learner.py index 48d4f051..0eb38d59 100644 --- a/MetaAugment/autoaugment_learners/aa_learner.py +++ b/MetaAugment/autoaugment_learners/aa_learner.py @@ -309,7 +309,8 @@ class aa_learner: child_network_architecture, train_dataset, test_dataset, - logging=False): + logging=False, + print_every_epoch=True): """ Given a policy (using AutoAugment paper terminology), we train a child network using the policy and return the accuracy (how good the policy is for the dataset and @@ -384,7 +385,7 @@ class aa_learner: max_epochs = self.max_epochs, early_stop_num = self.early_stop_num, logging = logging, - print_every_epoch=True) + print_every_epoch=print_every_epoch) # if logging is true, 'accuracy' is actually a tuple: (accuracy, accuracy_log) return accuracy diff --git a/MetaAugment/autoaugment_learners/ucb_learner.py b/MetaAugment/autoaugment_learners/ucb_learner.py index 1a4ddf3a..e22f32ff 100644 --- a/MetaAugment/autoaugment_learners/ucb_learner.py +++ b/MetaAugment/autoaugment_learners/ucb_learner.py @@ -1,9 +1,3 @@ -#!/usr/bin/env python -# coding: utf-8 - -# In[1]: - - import numpy as np import torch import torch.nn as nn @@ -53,23 +47,24 @@ class ucb_learner(randomsearch_learner): max_epochs=max_epochs, early_stop_num=early_stop_num,) - self.num_policies = num_policies - # When this learner is initialized we generate `num_policies` number - # of random policies. - # generate_new_policy is inherited from the randomsearch_learner class - self.policies = [] - self.make_more_policies() + # attributes used in the UCB1 algorithm - self.q_values = [0]*self.num_policies - self.best_q_values = [] + self.num_policies = num_policies + + self.policies = [self.generate_new_policy() for _ in range(num_policies)] + + self.avg_accs = [None]*self.num_policies + self.best_avg_accs = [] + self.cnts = [0]*self.num_policies self.q_plus_cnt = [0]*self.num_policies self.total_count = 0 + def make_more_policies(self, n): """generates n more random policies and adds it to self.policies @@ -78,50 +73,71 @@ class ucb_learner(randomsearch_learner): and add to our list of policies """ - self.policies.append([self.generate_new_policy() for _ in n]) + self.policies += [self.generate_new_policy() for _ in range(n)] + + # all the below need to be lengthened to store information for the + # new policies + self.avg_accs += [None for _ in range(n)] + self.cnts += [0 for _ in range(n)] + self.q_plus_cnt += [None for _ in range(n)] + self.num_policies += n + def learn(self, train_dataset, test_dataset, child_network_architecture, - iterations=15): + iterations=15, + print_every_epoch=False): + """continue the UCB algorithm for `iterations` number of turns + """ for this_iter in trange(iterations): - # get the action to try (either initially in order or using best q_plus_cnt value) - # TODO: change this if statemetn - if this_iter >= self.num_policies: - this_policy_idx = np.argmax(self.q_plus_cnt) + # choose which policy we want to test + if None in self.avg_accs: + # if there is a policy we haven't tested yet, we + # test that one + this_policy_idx = self.avg_accs.index(None) this_policy = self.policies[this_policy_idx] - else: - this_policy = this_iter - - - best_acc = self.test_autoaugment_policy( + acc = self.test_autoaugment_policy( this_policy, child_network_architecture, train_dataset, test_dataset, - logging=False + logging=False, + print_every_epoch=print_every_epoch ) - - # update q_values - # TODO: change this if statemetn - if this_iter < self.num_policies: - self.q_values[this_policy_idx] += best_acc + # update q_values (average accuracy) + self.avg_accs[this_policy_idx] = acc else: - self.q_values[this_policy_idx] = (self.q_values[this_policy_idx]*self.cnts[this_policy_idx] + best_acc) / (self.cnts[this_policy_idx] + 1) - - best_q_value = max(self.q_values) - self.best_q_values.append(best_q_value) - + # if we have tested all policies before, we test the + # one with the best q_plus_cnt value + this_policy_idx = np.argmax(self.q_plus_cnt) + this_policy = self.policies[this_policy_idx] + acc = self.test_autoaugment_policy( + this_policy, + child_network_architecture, + train_dataset, + test_dataset, + logging=False, + print_every_epoch=print_every_epoch + ) + # update q_values (average accuracy) + self.avg_accs[this_policy_idx] = (self.avg_accs[this_policy_idx]*self.cnts[this_policy_idx] + acc) / (self.cnts[this_policy_idx] + 1) + + # logging the best avg acc up to now + best_avg_acc = max([x for x in self.avg_accs if x is not None]) + self.best_avg_accs.append(best_avg_acc) + + # print progress for user if (this_iter+1) % 5 == 0: print("Iteration: {},\tQ-Values: {}, Best this_iter: {}".format( this_iter+1, - list(np.around(np.array(self.q_values),2)), - max(list(np.around(np.array(self.q_values),2))) + list(np.around(np.array(self.avg_accs),2)), + max(list(np.around(np.array(self.avg_accs),2))) ) ) @@ -130,10 +146,11 @@ class ucb_learner(randomsearch_learner): self.total_count += 1 # update q_plus_cnt values every turn after the initial sweep through - # TODO: change this if statemetn - if this_iter >= self.num_policies - 1: - for i in range(self.num_policies): - self.q_plus_cnt[i] = self.q_values[i] + np.sqrt(2*np.log(self.total_count)/self.cnts[i]) + for i in range(self.num_policies): + if self.avg_accs[i] is not None: + self.q_plus_cnt[i] = self.avg_accs[i] + np.sqrt(2*np.log(self.total_count)/self.cnts[i]) + + print(self.cnts) diff --git a/temp_util/wapp_util.py b/temp_util/wapp_util.py index 78be118a..e48d1c31 100644 --- a/temp_util/wapp_util.py +++ b/temp_util/wapp_util.py @@ -17,13 +17,16 @@ from MetaAugment.main import create_toy import pickle def parse_users_learner_spec( + # aalearner type auto_aug_learner, + # search space settings ds, ds_name, exclude_method, num_funcs, num_policies, num_sub_policies, + # child network settings toy_size, IsLeNet, batch_size, diff --git a/test/MetaAugment/test_ucb_learner.py b/test/MetaAugment/test_ucb_learner.py index 564ac80d..7c6635ff 100644 --- a/test/MetaAugment/test_ucb_learner.py +++ b/test/MetaAugment/test_ucb_learner.py @@ -1,7 +1,18 @@ import MetaAugment.autoaugment_learners as aal - +import MetaAugment.child_networks as cn +import torchvision +import torchvision.datasets as datasets +from pprint import pprint def test_ucb_learner(): + child_network_architecture = cn.SimpleNet + train_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/train', + train=True, download=True, transform=None) + test_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/test', + train=False, download=True, + transform=torchvision.transforms.ToTensor()) + + learner = aal.ucb_learner( # parameters that define the search space sp_num=5, @@ -10,15 +21,37 @@ def test_ucb_learner(): discrete_p_m=True, # hyperparameters for when training the child_network batch_size=8, - toy_flag=False, - toy_size=0.1, + toy_flag=True, + toy_size=0.001, learning_rate=1e-1, max_epochs=float('inf'), early_stop_num=30, # ucb_learner specific hyperparameter - num_policies=100 + num_policies=3 ) - print(learner.policies) + pprint(learner.policies) + assert len(learner.policies)==len(learner.avg_accs), \ + (len(learner.policies), (len(learner.avg_accs))) + + # learn on the 3 policies we generated + learner.learn( + train_dataset=train_dataset, + test_dataset=test_dataset, + child_network_architecture=child_network_architecture, + iterations=5 + ) + + # let's say we want to explore more policies: + # we generate more new policies + learner.make_more_policies(n=4) + + # and let's explore how good those are as well + learner.learn( + train_dataset=train_dataset, + test_dataset=test_dataset, + child_network_architecture=child_network_architecture, + iterations=7 + ) if __name__=="__main__": - test_ucb_learner() \ No newline at end of file + test_ucb_learner() -- GitLab