From 839fb15f31e70e2bc3f38332a47b8f5e6991e576 Mon Sep 17 00:00:00 2001
From: Sun Jin Kim <sk2521@ic.ac.uk>
Date: Mon, 25 Apr 2022 17:09:24 +0100
Subject: [PATCH] FINISH REFACTORING UCB_LEARNER

---
 .../autoaugment_learners/aa_learner.py        |   5 +-
 .../autoaugment_learners/ucb_learner.py       | 101 ++++++++++--------
 temp_util/wapp_util.py                        |   3 +
 test/MetaAugment/test_ucb_learner.py          |  45 ++++++--
 4 files changed, 104 insertions(+), 50 deletions(-)

diff --git a/MetaAugment/autoaugment_learners/aa_learner.py b/MetaAugment/autoaugment_learners/aa_learner.py
index 48d4f051..0eb38d59 100644
--- a/MetaAugment/autoaugment_learners/aa_learner.py
+++ b/MetaAugment/autoaugment_learners/aa_learner.py
@@ -309,7 +309,8 @@ class aa_learner:
                                 child_network_architecture,
                                 train_dataset,
                                 test_dataset,
-                                logging=False):
+                                logging=False,
+                                print_every_epoch=True):
         """
         Given a policy (using AutoAugment paper terminology), we train a child network
         using the policy and return the accuracy (how good the policy is for the dataset and 
@@ -384,7 +385,7 @@ class aa_learner:
                                     max_epochs = self.max_epochs, 
                                     early_stop_num = self.early_stop_num, 
                                     logging = logging,
-                                    print_every_epoch=True)
+                                    print_every_epoch=print_every_epoch)
         
         # if logging is true, 'accuracy' is actually a tuple: (accuracy, accuracy_log)
         return accuracy
diff --git a/MetaAugment/autoaugment_learners/ucb_learner.py b/MetaAugment/autoaugment_learners/ucb_learner.py
index 1a4ddf3a..e22f32ff 100644
--- a/MetaAugment/autoaugment_learners/ucb_learner.py
+++ b/MetaAugment/autoaugment_learners/ucb_learner.py
@@ -1,9 +1,3 @@
-#!/usr/bin/env python
-# coding: utf-8
-
-# In[1]:
-
-
 import numpy as np
 import torch
 import torch.nn as nn
@@ -53,23 +47,24 @@ class ucb_learner(randomsearch_learner):
                         max_epochs=max_epochs,
                         early_stop_num=early_stop_num,)
         
-        self.num_policies = num_policies
 
-        # When this learner is initialized we generate `num_policies` number
-        # of random policies. 
-        # generate_new_policy is inherited from the randomsearch_learner class
-        self.policies = []
-        self.make_more_policies()
+        
 
         # attributes used in the UCB1 algorithm
-        self.q_values = [0]*self.num_policies
-        self.best_q_values = []
+        self.num_policies = num_policies
+
+        self.policies = [self.generate_new_policy() for _ in range(num_policies)]
+
+        self.avg_accs = [None]*self.num_policies
+        self.best_avg_accs = []
+
         self.cnts = [0]*self.num_policies
         self.q_plus_cnt = [0]*self.num_policies
         self.total_count = 0
 
 
 
+
     def make_more_policies(self, n):
         """generates n more random policies and adds it to self.policies
 
@@ -78,50 +73,71 @@ class ucb_learner(randomsearch_learner):
                     and add to our list of policies
         """
 
-        self.policies.append([self.generate_new_policy() for _ in n])
+        self.policies += [self.generate_new_policy() for _ in range(n)]
+
+        # all the below need to be lengthened to store information for the 
+        # new policies
+        self.avg_accs += [None for _ in range(n)]
+        self.cnts += [0 for _ in range(n)]
+        self.q_plus_cnt += [None for _ in range(n)]
+        self.num_policies += n
+
 
 
     def learn(self, 
             train_dataset, 
             test_dataset, 
             child_network_architecture, 
-            iterations=15):
+            iterations=15,
+            print_every_epoch=False):
+        """continue the UCB algorithm for `iterations` number of turns
 
+        """
 
         for this_iter in trange(iterations):
 
-            # get the action to try (either initially in order or using best q_plus_cnt value)
-            # TODO: change this if statemetn
-            if this_iter >= self.num_policies:
-                this_policy_idx = np.argmax(self.q_plus_cnt)
+            # choose which policy we want to test
+            if None in self.avg_accs:
+                # if there is a policy we haven't tested yet, we 
+                # test that one
+                this_policy_idx = self.avg_accs.index(None)
                 this_policy = self.policies[this_policy_idx]
-            else:
-                this_policy = this_iter
-
-
-            best_acc = self.test_autoaugment_policy(
+                acc = self.test_autoaugment_policy(
                                 this_policy,
                                 child_network_architecture,
                                 train_dataset,
                                 test_dataset,
-                                logging=False
+                                logging=False,
+                                print_every_epoch=print_every_epoch
                                 )
-
-            # update q_values
-            # TODO: change this if statemetn
-            if this_iter < self.num_policies:
-                self.q_values[this_policy_idx] += best_acc
+                # update q_values (average accuracy)
+                self.avg_accs[this_policy_idx] = acc
             else:
-                self.q_values[this_policy_idx] = (self.q_values[this_policy_idx]*self.cnts[this_policy_idx] + best_acc) / (self.cnts[this_policy_idx] + 1)
-
-            best_q_value = max(self.q_values)
-            self.best_q_values.append(best_q_value)
-
+                # if we have tested all policies before, we test the
+                # one with the best q_plus_cnt value
+                this_policy_idx = np.argmax(self.q_plus_cnt)
+                this_policy = self.policies[this_policy_idx]
+                acc = self.test_autoaugment_policy(
+                                this_policy,
+                                child_network_architecture,
+                                train_dataset,
+                                test_dataset,
+                                logging=False,
+                                print_every_epoch=print_every_epoch
+                                )
+                # update q_values (average accuracy)
+                self.avg_accs[this_policy_idx] = (self.avg_accs[this_policy_idx]*self.cnts[this_policy_idx] + acc) / (self.cnts[this_policy_idx] + 1)
+    
+            # logging the best avg acc up to now
+            best_avg_acc = max([x for x in self.avg_accs if x is not None])
+            self.best_avg_accs.append(best_avg_acc)
+
+            # print progress for user
             if (this_iter+1) % 5 == 0:
                 print("Iteration: {},\tQ-Values: {}, Best this_iter: {}".format(
                                 this_iter+1, 
-                                list(np.around(np.array(self.q_values),2)), 
-                                max(list(np.around(np.array(self.q_values),2)))
+                                list(np.around(np.array(self.avg_accs),2)), 
+                                max(list(np.around(np.array(self.avg_accs),2)))
                                 )
                     )
 
@@ -130,10 +146,11 @@ class ucb_learner(randomsearch_learner):
             self.total_count += 1
 
             # update q_plus_cnt values every turn after the initial sweep through
-            # TODO: change this if statemetn
-            if this_iter >= self.num_policies - 1:
-                for i in range(self.num_policies):
-                    self.q_plus_cnt[i] = self.q_values[i] + np.sqrt(2*np.log(self.total_count)/self.cnts[i])
+            for i in range(self.num_policies):
+                if self.avg_accs[i] is not None:
+                    self.q_plus_cnt[i] = self.avg_accs[i] + np.sqrt(2*np.log(self.total_count)/self.cnts[i])
+            
+            print(self.cnts)
 
             
 
diff --git a/temp_util/wapp_util.py b/temp_util/wapp_util.py
index 78be118a..e48d1c31 100644
--- a/temp_util/wapp_util.py
+++ b/temp_util/wapp_util.py
@@ -17,13 +17,16 @@ from MetaAugment.main import create_toy
 import pickle
 
 def parse_users_learner_spec(
+            # aalearner type
             auto_aug_learner, 
+            # search space settings
             ds, 
             ds_name, 
             exclude_method, 
             num_funcs, 
             num_policies, 
             num_sub_policies, 
+            # child network settings
             toy_size, 
             IsLeNet, 
             batch_size, 
diff --git a/test/MetaAugment/test_ucb_learner.py b/test/MetaAugment/test_ucb_learner.py
index 564ac80d..7c6635ff 100644
--- a/test/MetaAugment/test_ucb_learner.py
+++ b/test/MetaAugment/test_ucb_learner.py
@@ -1,7 +1,18 @@
 import MetaAugment.autoaugment_learners as aal
-
+import MetaAugment.child_networks as cn
+import torchvision
+import torchvision.datasets as datasets
+from pprint import pprint
 
 def test_ucb_learner():
+    child_network_architecture = cn.SimpleNet
+    train_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/train',
+                            train=True, download=True, transform=None)
+    test_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/test', 
+                            train=False, download=True,
+                            transform=torchvision.transforms.ToTensor())
+
+
     learner = aal.ucb_learner(
         # parameters that define the search space
                 sp_num=5,
@@ -10,15 +21,37 @@ def test_ucb_learner():
                 discrete_p_m=True,
                 # hyperparameters for when training the child_network
                 batch_size=8,
-                toy_flag=False,
-                toy_size=0.1,
+                toy_flag=True,
+                toy_size=0.001,
                 learning_rate=1e-1,
                 max_epochs=float('inf'),
                 early_stop_num=30,
                 # ucb_learner specific hyperparameter
-                num_policies=100
+                num_policies=3
     )
-    print(learner.policies)
+    pprint(learner.policies)
+    assert len(learner.policies)==len(learner.avg_accs), \
+                (len(learner.policies), (len(learner.avg_accs)))
+
+    # learn on the 3 policies we generated
+    learner.learn(
+        train_dataset=train_dataset,
+        test_dataset=test_dataset,
+        child_network_architecture=child_network_architecture,
+        iterations=5
+        )
+    
+    # let's say we want to explore more policies:
+    # we generate more new policies
+    learner.make_more_policies(n=4)
+
+    # and let's explore how good those are as well
+    learner.learn(
+        train_dataset=train_dataset,
+        test_dataset=test_dataset,
+        child_network_architecture=child_network_architecture,
+        iterations=7
+        )
 
 if __name__=="__main__":
-    test_ucb_learner()
\ No newline at end of file
+    test_ucb_learner()
-- 
GitLab