From c6136f11b5871ebaf274795bceeaa85ba4649166 Mon Sep 17 00:00:00 2001
From: Sun Jin Kim <sk2521@ic.ac.uk>
Date: Mon, 25 Apr 2022 17:23:37 +0100
Subject: [PATCH] finish removing fun_num from aa_learners

also update tests accordingly
---
 MetaAugment/autoaugment_learners/evo_learner.py          | 1 -
 MetaAugment/autoaugment_learners/gru_learner.py          | 9 ++++-----
 MetaAugment/autoaugment_learners/randomsearch_learner.py | 8 +++-----
 MetaAugment/autoaugment_learners/ucb_learner.py          | 2 --
 test/MetaAugment/test_aa_learner.py                      | 9 +++------
 test/MetaAugment/test_gru_learner.py                     | 2 --
 test/MetaAugment/test_randomsearch_learner.py            | 3 +--
 7 files changed, 11 insertions(+), 23 deletions(-)

diff --git a/MetaAugment/autoaugment_learners/evo_learner.py b/MetaAugment/autoaugment_learners/evo_learner.py
index 8e1d5bc1..e9a65865 100644
--- a/MetaAugment/autoaugment_learners/evo_learner.py
+++ b/MetaAugment/autoaugment_learners/evo_learner.py
@@ -6,7 +6,6 @@ import pygad
 import pygad.torchga as torchga
 import copy
 import torch
-from MetaAugment.controller_networks.evo_controller import Evo_learner
 
 from MetaAugment.autoaugment_learners.aa_learner import aa_learner, augmentation_space
 import MetaAugment.child_networks as cn
diff --git a/MetaAugment/autoaugment_learners/gru_learner.py b/MetaAugment/autoaugment_learners/gru_learner.py
index c06edec3..db8205d5 100644
--- a/MetaAugment/autoaugment_learners/gru_learner.py
+++ b/MetaAugment/autoaugment_learners/gru_learner.py
@@ -47,7 +47,6 @@ class gru_learner(aa_learner):
     def __init__(self,
                 # parameters that define the search space
                 sp_num=5,
-                fun_num=14,
                 p_bins=11,
                 m_bins=10,
                 discrete_p_m=False,
@@ -78,10 +77,10 @@ class gru_learner(aa_learner):
             print('Warning: Incompatible discrete_p_m=True input into gru_learner. \
                 discrete_p_m=False will be used')
         
-        super().__init__(sp_num, 
-                fun_num, 
-                p_bins, 
-                m_bins, 
+        super().__init__(
+                sp_num=sp_num, 
+                p_bins=p_bins, 
+                m_bins=m_bins, 
                 discrete_p_m=True, 
                 batch_size=batch_size, 
                 toy_flag=toy_flag, 
diff --git a/MetaAugment/autoaugment_learners/randomsearch_learner.py b/MetaAugment/autoaugment_learners/randomsearch_learner.py
index 6541cd3f..09f6626f 100644
--- a/MetaAugment/autoaugment_learners/randomsearch_learner.py
+++ b/MetaAugment/autoaugment_learners/randomsearch_learner.py
@@ -38,7 +38,6 @@ class randomsearch_learner(aa_learner):
     def __init__(self,
                 # parameters that define the search space
                 sp_num=5,
-                fun_num=14,
                 p_bins=11,
                 m_bins=10,
                 discrete_p_m=True,
@@ -51,10 +50,9 @@ class randomsearch_learner(aa_learner):
                 early_stop_num=30,
                 ):
         
-        super().__init__(sp_num, 
-                fun_num, 
-                p_bins, 
-                m_bins, 
+        super().__init__(sp_num=sp_num, 
+                p_bins=p_bins, 
+                m_bins=m_bins, 
                 discrete_p_m=discrete_p_m,
                 batch_size=batch_size,
                 toy_flag=toy_flag,
diff --git a/MetaAugment/autoaugment_learners/ucb_learner.py b/MetaAugment/autoaugment_learners/ucb_learner.py
index e22f32ff..dc82c2ee 100644
--- a/MetaAugment/autoaugment_learners/ucb_learner.py
+++ b/MetaAugment/autoaugment_learners/ucb_learner.py
@@ -20,7 +20,6 @@ class ucb_learner(randomsearch_learner):
     def __init__(self,
                 # parameters that define the search space
                 sp_num=5,
-                fun_num=14,
                 p_bins=11,
                 m_bins=10,
                 discrete_p_m=True,
@@ -36,7 +35,6 @@ class ucb_learner(randomsearch_learner):
                 ):
         
         super().__init__(sp_num=sp_num, 
-                        fun_num=14,
                         p_bins=p_bins, 
                         m_bins=m_bins, 
                         discrete_p_m=discrete_p_m,
diff --git a/test/MetaAugment/test_aa_learner.py b/test/MetaAugment/test_aa_learner.py
index 3e280870..29af4f6d 100644
--- a/test/MetaAugment/test_aa_learner.py
+++ b/test/MetaAugment/test_aa_learner.py
@@ -25,13 +25,12 @@ def test_translate_operation_tensor():
 
         softmax = torch.nn.Softmax(dim=0)
 
-        fun_num = random.randint(1, 14)
+        fun_num=14
         p_bins = random.randint(2, 15)
         m_bins = random.randint(2, 15)
-
+        
         agent = aal.aa_learner(
                 sp_num=5,
-                fun_num=fun_num,
                 p_bins=p_bins,
                 m_bins=m_bins,
                 discrete_p_m=True
@@ -54,13 +53,12 @@ def test_translate_operation_tensor():
     for i in range(2000):
         
 
-        fun_num = random.randint(1, 14)
+        fun_num = 14
         p_bins = random.randint(1, 15)
         m_bins = random.randint(1, 15)
 
         agent = aal.aa_learner(
                 sp_num=5,
-                fun_num=fun_num,
                 p_bins=p_bins,
                 m_bins=m_bins,
                 discrete_p_m=False
@@ -81,7 +79,6 @@ def test_translate_operation_tensor():
 def test_test_autoaugment_policy():
     agent = aal.aa_learner(
                 sp_num=5,
-                fun_num=14,
                 p_bins=11,
                 m_bins=10,
                 discrete_p_m=True,
diff --git a/test/MetaAugment/test_gru_learner.py b/test/MetaAugment/test_gru_learner.py
index 6ad8204f..b5c695cf 100644
--- a/test/MetaAugment/test_gru_learner.py
+++ b/test/MetaAugment/test_gru_learner.py
@@ -14,13 +14,11 @@ def test_generate_new_policy():
     """
     for _ in range(40):
         sp_num = random.randint(1,20)
-        fun_num = random.randint(1, 14)
         p_bins = random.randint(2, 15)
         m_bins = random.randint(2, 15)
 
         agent = aal.gru_learner(
             sp_num=sp_num,
-            fun_num=fun_num,
             p_bins=p_bins,
             m_bins=m_bins,
             cont_mb_size=2
diff --git a/test/MetaAugment/test_randomsearch_learner.py b/test/MetaAugment/test_randomsearch_learner.py
index 5b67d98e..29cd812b 100644
--- a/test/MetaAugment/test_randomsearch_learner.py
+++ b/test/MetaAugment/test_randomsearch_learner.py
@@ -16,13 +16,12 @@ def test_generate_new_policy():
     def my_test(discrete_p_m):
         for _ in range(40):
             sp_num = random.randint(1,20)
-            fun_num = random.randint(1, 14)
+            
             p_bins = random.randint(2, 15)
             m_bins = random.randint(2, 15)
 
             agent = aal.randomsearch_learner(
                 sp_num=sp_num,
-                fun_num=fun_num,
                 p_bins=p_bins,
                 m_bins=m_bins,
                 discrete_p_m=discrete_p_m
-- 
GitLab