From 082b72b3127798bf95a02a8c3eceeee076a93d5b Mon Sep 17 00:00:00 2001
From: Sun Jin Kim <sk2521@ic.ac.uk>
Date: Wed, 27 Apr 2022 23:16:23 +0100
Subject: [PATCH] capitalize aalearners and controllers

---
 .../{aa_learner.py => AaLearner.py}            | 10 +++++-----
 .../{evo_learner.py => EvoLearner.py}          | 10 +++++-----
 .../{gru_learner.py => GruLearner.py}          | 12 ++++++------
 MetaAugment/autoaugment_learners/README.md     |  2 +-
 .../{randomsearch_learner.py => RsLearner.py}  |  8 ++++----
 .../{ucb_learner.py => UcbLearner.py}          |  8 ++++----
 MetaAugment/autoaugment_learners/__init__.py   | 10 +++++-----
 .../rand_augment_learner.py                    |  4 ++--
 .../{evo_controller.py => EvoController.py}    |  2 +-
 .../{rnn_controller.py => RnnController.py}    |  0
 MetaAugment/controller_networks/__init__.py    |  2 +-
 benchmark/scripts/04_22_ci_gru.py              |  4 ++--
 benchmark/scripts/04_22_ci_rs.py               |  4 ++--
 benchmark/scripts/04_22_fm_gru.py              |  4 ++--
 benchmark/scripts/04_22_fm_rs.py               |  4 ++--
 benchmark/scripts/util_04_22.py                |  4 ++--
 .../autoaugment_learners/aa_learners.rst       | 10 +++++-----
 ...Augment.autoaugment_learners.aa_learner.rst |  6 +++---
 ...ugment.autoaugment_learners.evo_learner.rst | 14 +++++++-------
 ...ugment.autoaugment_learners.gru_learner.rst |  6 +++---
 ...toaugment_learners.randomsearch_learner.rst |  6 +++---
 ...ugment.autoaugment_learners.ucb_learner.rst |  8 ++++----
 docs/source/usage/autoaugment_helperclass.rst  |  2 +-
 docs/source/usage/tutorial_for_team.rst        | 18 +++++++++---------
 temp_util/wapp_util.py                         | 14 +++++++-------
 test/MetaAugment/test_aa_learner.py            | 14 +++++++-------
 test/MetaAugment/test_evo_learner.py           |  2 +-
 test/MetaAugment/test_gru_learner.py           |  8 ++++----
 test/MetaAugment/test_randomsearch_learner.py  |  8 ++++----
 test/MetaAugment/test_ucb_learner.py           |  4 ++--
 30 files changed, 104 insertions(+), 104 deletions(-)
 rename MetaAugment/autoaugment_learners/{aa_learner.py => AaLearner.py} (99%)
 rename MetaAugment/autoaugment_learners/{evo_learner.py => EvoLearner.py} (97%)
 rename MetaAugment/autoaugment_learners/{gru_learner.py => GruLearner.py} (96%)
 rename MetaAugment/autoaugment_learners/{randomsearch_learner.py => RsLearner.py} (96%)
 rename MetaAugment/autoaugment_learners/{ucb_learner.py => UcbLearner.py} (97%)
 rename MetaAugment/controller_networks/{evo_controller.py => EvoController.py} (97%)
 rename MetaAugment/controller_networks/{rnn_controller.py => RnnController.py} (100%)

diff --git a/MetaAugment/autoaugment_learners/aa_learner.py b/MetaAugment/autoaugment_learners/AaLearner.py
similarity index 99%
rename from MetaAugment/autoaugment_learners/aa_learner.py
rename to MetaAugment/autoaugment_learners/AaLearner.py
index f67e99c7..30fb1820 100644
--- a/MetaAugment/autoaugment_learners/aa_learner.py
+++ b/MetaAugment/autoaugment_learners/AaLearner.py
@@ -12,9 +12,9 @@ import types
 
 
 
-class aa_learner:
+class AaLearner:
     """
-    The parent class for all aa_learner's
+    The parent class for all AaLearner's
     
     Attributes:
         op_tensor_length (int): what is the dimension of the tensor that represents
@@ -125,7 +125,7 @@ class aa_learner:
             return_log_prob (boolesn): 
                                 When this is on, we return which indices (of fun, prob, mag) were
                                 chosen (either randomly or deterministically, depending on argmax).
-                                This is used, for example, in the gru_learner to calculate the
+                                This is used, for example, in the GruLearner to calculate the
                                 probability of the actions were chosen, which is then logged, then
                                 differentiated.
 
@@ -139,7 +139,7 @@ class aa_learner:
                                 AutoAugment object.
             log_prob (float):
                             Used in reinforcement learning updates, such as proximal policy update
-                            in the gru_learner.
+                            in the GruLearner.
                             Can only be used when self.discrete_p_m.
                             We add the logged values of the indices of the image_function,
                             probability, and magnitude chosen.
@@ -247,7 +247,7 @@ class aa_learner:
                         by calling: AutoAugment.subpolicies = policy
         """
 
-        raise NotImplementedError('_generate_new_policy not implemented in aa_learner')
+        raise NotImplementedError('_generate_new_policy not implemented in AaLearner')
 
 
     def learn(self, train_dataset, test_dataset, child_network_architecture, iterations=15):
diff --git a/MetaAugment/autoaugment_learners/evo_learner.py b/MetaAugment/autoaugment_learners/EvoLearner.py
similarity index 97%
rename from MetaAugment/autoaugment_learners/evo_learner.py
rename to MetaAugment/autoaugment_learners/EvoLearner.py
index f0547b91..26b4fd86 100644
--- a/MetaAugment/autoaugment_learners/evo_learner.py
+++ b/MetaAugment/autoaugment_learners/EvoLearner.py
@@ -5,11 +5,11 @@ import pygad.torchga as torchga
 import torchvision
 import torch
 
-from MetaAugment.autoaugment_learners.aa_learner import aa_learner
+from MetaAugment.autoaugment_learners.AaLearner import AaLearner
 import MetaAugment.controller_networks as cont_n
 
 
-class evo_learner(aa_learner):
+class EvoLearner(AaLearner):
 
     def __init__(self, 
                 # search space settings
@@ -27,7 +27,7 @@ class evo_learner(aa_learner):
                 # evolutionary learner specific settings
                 num_solutions=5,
                 num_parents_mating=3,
-                controller=cont_n.evo_controller
+                controller=cont_n.EvoController
                 ):
 
         super().__init__(
@@ -273,5 +273,5 @@ class evo_learner(aa_learner):
             num_parents_mating=self.num_parents_mating, 
             initial_population=self.initial_population,
             mutation_percent_genes = 0.1,
-            _fitness_func=_fitness_func,
-            _on_generation = _on_generation)
+            fitness_func=_fitness_func,
+            on_generation = _on_generation)
diff --git a/MetaAugment/autoaugment_learners/gru_learner.py b/MetaAugment/autoaugment_learners/GruLearner.py
similarity index 96%
rename from MetaAugment/autoaugment_learners/gru_learner.py
rename to MetaAugment/autoaugment_learners/GruLearner.py
index d71dbeb9..5b47da52 100644
--- a/MetaAugment/autoaugment_learners/gru_learner.py
+++ b/MetaAugment/autoaugment_learners/GruLearner.py
@@ -1,8 +1,8 @@
 import torch
 
 import MetaAugment.child_networks as cn
-from MetaAugment.autoaugment_learners.aa_learner import aa_learner
-from MetaAugment.controller_networks.rnn_controller import RNNModel
+from MetaAugment.autoaugment_learners.AaLearner import AaLearner
+from MetaAugment.controller_networks.RnnController import RNNModel
 
 from pprint import pprint
 import pickle
@@ -11,7 +11,7 @@ import pickle
 
 
 
-class gru_learner(aa_learner):
+class GruLearner(AaLearner):
     """
     An AutoAugment learner with a GRU controller 
 
@@ -55,7 +55,7 @@ class gru_learner(aa_learner):
                     the controller. Defaults to 
         """
         if discrete_p_m==True:
-            print('Warning: Incompatible discrete_p_m=True input into gru_learner. \
+            print('Warning: Incompatible discrete_p_m=True input into GruLearner. \
                 discrete_p_m=False will be used')
         
         super().__init__(
@@ -71,7 +71,7 @@ class gru_learner(aa_learner):
                 exclude_method=exclude_method,
                 )
 
-        # GRU-specific attributes that aren't in general aa_learner's
+        # GRU-specific attributes that aren't in general AaLearner's
         self.alpha = alpha
         self.cont_mb_size = cont_mb_size
         self.b = 0.5 # b is the running exponential mean of the rewards, used for training stability
@@ -222,7 +222,7 @@ if __name__=='__main__':
     child_network_architecture = cn.lenet
     # child_network_architecture = cn.lenet()
 
-    agent = gru_learner(
+    agent = GruLearner(
                         sp_num=7,
                         toy_size=0.01,
                         batch_size=32,
diff --git a/MetaAugment/autoaugment_learners/README.md b/MetaAugment/autoaugment_learners/README.md
index 4650f6d6..a9e5a631 100644
--- a/MetaAugment/autoaugment_learners/README.md
+++ b/MetaAugment/autoaugment_learners/README.md
@@ -1,3 +1,3 @@
 write `import MetaAugment.autoaugment_learners as aa`
-and `aa_learner = aa.randomsearch_learner()`
+and `AaLearner = aa.RsLearner()`
 to use
\ No newline at end of file
diff --git a/MetaAugment/autoaugment_learners/randomsearch_learner.py b/MetaAugment/autoaugment_learners/RsLearner.py
similarity index 96%
rename from MetaAugment/autoaugment_learners/randomsearch_learner.py
rename to MetaAugment/autoaugment_learners/RsLearner.py
index c3c43ec0..fac9e229 100644
--- a/MetaAugment/autoaugment_learners/randomsearch_learner.py
+++ b/MetaAugment/autoaugment_learners/RsLearner.py
@@ -2,7 +2,7 @@ import torch
 import numpy as np
 
 import MetaAugment.child_networks as cn
-from MetaAugment.autoaugment_learners.aa_learner import aa_learner
+from MetaAugment.autoaugment_learners.AaLearner import AaLearner
 
 from pprint import pprint
 import matplotlib.pyplot as plt
@@ -12,10 +12,10 @@ import pickle
 
 
 
-class randomsearch_learner(aa_learner):
+class RsLearner(AaLearner):
     """
     Tests randomly sampled policies from the search space specified by the AutoAugment
-    paper. Acts as a baseline for other aa_learner's.
+    paper. Acts as a baseline for other AaLearner's.
     """
     def __init__(self,
                 # parameters that define the search space
@@ -165,7 +165,7 @@ if __name__=='__main__':
     child_network_architecture = cn.lenet
     # child_network_architecture = cn.lenet()
 
-    agent = randomsearch_learner(
+    agent = RsLearner(
                                 sp_num=7,
                                 toy_size=0.01,
                                 batch_size=4,
diff --git a/MetaAugment/autoaugment_learners/ucb_learner.py b/MetaAugment/autoaugment_learners/UcbLearner.py
similarity index 97%
rename from MetaAugment/autoaugment_learners/ucb_learner.py
rename to MetaAugment/autoaugment_learners/UcbLearner.py
index 5d9a32e7..0c84e664 100644
--- a/MetaAugment/autoaugment_learners/ucb_learner.py
+++ b/MetaAugment/autoaugment_learners/UcbLearner.py
@@ -3,13 +3,13 @@ import numpy as np
 from tqdm import trange
 
 from ..child_networks import *
-from .randomsearch_learner import randomsearch_learner
+from .RsLearner import RsLearner
 
 
-class ucb_learner(randomsearch_learner):
+class UcbLearner(RsLearner):
     """
     Tests randomly sampled policies from the search space specified by the AutoAugment
-    paper. Acts as a baseline for other aa_learner's.
+    paper. Acts as a baseline for other AaLearner's.
     """
     def __init__(self,
                 # parameters that define the search space
@@ -24,7 +24,7 @@ class ucb_learner(randomsearch_learner):
                 learning_rate=1e-1,
                 max_epochs=float('inf'),
                 early_stop_num=30,
-                # ucb_learner specific hyperparameter
+                # UcbLearner specific hyperparameter
                 num_policies=100
                 ):
         
diff --git a/MetaAugment/autoaugment_learners/__init__.py b/MetaAugment/autoaugment_learners/__init__.py
index 700f7359..1bd747a8 100644
--- a/MetaAugment/autoaugment_learners/__init__.py
+++ b/MetaAugment/autoaugment_learners/__init__.py
@@ -1,5 +1,5 @@
-from .aa_learner import *
-from .randomsearch_learner import *
-from .gru_learner import *
-from .evo_learner import *
-from .ucb_learner import *
\ No newline at end of file
+from .AaLearner import *
+from .RsLearner import *
+from .GruLearner import *
+from .EvoLearner import *
+from .UcbLearner import *
\ No newline at end of file
diff --git a/MetaAugment/autoaugment_learners/rand_augment_learner.py b/MetaAugment/autoaugment_learners/rand_augment_learner.py
index b6974bef..ab5380d0 100644
--- a/MetaAugment/autoaugment_learners/rand_augment_learner.py
+++ b/MetaAugment/autoaugment_learners/rand_augment_learner.py
@@ -1,8 +1,8 @@
 import torch 
 import numpy as np
-from MetaAugment.autoaugment_learners.randomsearch_learner import randomsearch_learner
+from MetaAugment.autoaugment_learners.RsLearner import RsLearner
 
-class rand_augment_learner(randomsearch_learner):
+class rand_augment_learner(RsLearner):
 
     def __init__(self):
         pass
\ No newline at end of file
diff --git a/MetaAugment/controller_networks/evo_controller.py b/MetaAugment/controller_networks/EvoController.py
similarity index 97%
rename from MetaAugment/controller_networks/evo_controller.py
rename to MetaAugment/controller_networks/EvoController.py
index 55dafc05..33a9a606 100644
--- a/MetaAugment/controller_networks/evo_controller.py
+++ b/MetaAugment/controller_networks/EvoController.py
@@ -2,7 +2,7 @@ import torch
 import torch.nn as nn
 import math
 
-class evo_controller(nn.Module):
+class EvoController(nn.Module):
     def __init__(self, fun_num=14, p_bins=11, m_bins=10, sub_num_pol=5):
         self.fun_num = fun_num
         self.p_bins = p_bins 
diff --git a/MetaAugment/controller_networks/rnn_controller.py b/MetaAugment/controller_networks/RnnController.py
similarity index 100%
rename from MetaAugment/controller_networks/rnn_controller.py
rename to MetaAugment/controller_networks/RnnController.py
diff --git a/MetaAugment/controller_networks/__init__.py b/MetaAugment/controller_networks/__init__.py
index 6182b736..02b56826 100644
--- a/MetaAugment/controller_networks/__init__.py
+++ b/MetaAugment/controller_networks/__init__.py
@@ -1 +1 @@
-from .evo_controller import evo_controller
\ No newline at end of file
+from .EvoController import EvoController
\ No newline at end of file
diff --git a/benchmark/scripts/04_22_ci_gru.py b/benchmark/scripts/04_22_ci_gru.py
index 155b7c92..e09a4796 100644
--- a/benchmark/scripts/04_22_ci_gru.py
+++ b/benchmark/scripts/04_22_ci_gru.py
@@ -8,7 +8,7 @@ import MetaAugment.autoaugment_learners as aal
 from .util_04_22 import *
 
 
-# aa_learner config
+# AaLearner config
 config = {
         'sp_num' : 3,
         'learning_rate' : 1e-1,
@@ -41,7 +41,7 @@ run_benchmark(
     train_dataset=train_dataset,
     test_dataset=test_dataset,
     child_network_architecture=child_network_architecture,
-    agent_arch=aal.gru_learner,
+    agent_arch=aal.GruLearner,
     config=config,
     )
 
diff --git a/benchmark/scripts/04_22_ci_rs.py b/benchmark/scripts/04_22_ci_rs.py
index e1279b1e..b3feb10e 100644
--- a/benchmark/scripts/04_22_ci_rs.py
+++ b/benchmark/scripts/04_22_ci_rs.py
@@ -8,7 +8,7 @@ import MetaAugment.autoaugment_learners as aal
 from .util_04_22 import *
 
 
-# aa_learner config
+# AaLearner config
 config = {
         'sp_num' : 3,
         'learning_rate' : 1e-1,
@@ -40,7 +40,7 @@ run_benchmark(
     train_dataset=train_dataset,
     test_dataset=test_dataset,
     child_network_architecture=child_network_architecture,
-    agent_arch=aal.randomsearch_learner,
+    agent_arch=aal.RsLearner,
     config=config,
     )
 
diff --git a/benchmark/scripts/04_22_fm_gru.py b/benchmark/scripts/04_22_fm_gru.py
index 807d0177..d3a5c038 100644
--- a/benchmark/scripts/04_22_fm_gru.py
+++ b/benchmark/scripts/04_22_fm_gru.py
@@ -8,7 +8,7 @@ import MetaAugment.autoaugment_learners as aal
 from .util_04_22 import *
 
 
-# aa_learner config
+# AaLearner config
 config = {
         'sp_num' : 3,
         'learning_rate' : 1e-1,
@@ -36,7 +36,7 @@ run_benchmark(
     train_dataset=train_dataset,
     test_dataset=test_dataset,
     child_network_architecture=child_network_architecture,
-    agent_arch=aal.gru_learner,
+    agent_arch=aal.GruLearner,
     config=config,
     total_iter=144
     )
diff --git a/benchmark/scripts/04_22_fm_rs.py b/benchmark/scripts/04_22_fm_rs.py
index dfe71958..3b88e333 100644
--- a/benchmark/scripts/04_22_fm_rs.py
+++ b/benchmark/scripts/04_22_fm_rs.py
@@ -8,7 +8,7 @@ import MetaAugment.autoaugment_learners as aal
 from .util_04_22 import *
 
 
-# aa_learner config
+# AaLearner config
 config = {
         'sp_num' : 3,
         'learning_rate' : 1e-1,
@@ -36,7 +36,7 @@ run_benchmark(
     train_dataset=train_dataset,
     test_dataset=test_dataset,
     child_network_architecture=child_network_architecture,
-    agent_arch=aal.randomsearch_learner,
+    agent_arch=aal.RsLearner,
     config=config,
     )
 
diff --git a/benchmark/scripts/util_04_22.py b/benchmark/scripts/util_04_22.py
index 8d7aa6b1..0c572a1d 100644
--- a/benchmark/scripts/util_04_22.py
+++ b/benchmark/scripts/util_04_22.py
@@ -5,7 +5,7 @@ import MetaAugment.autoaugment_learners as aal
 import pprint
 
 """
-testing gru_learner and randomsearch_learner on
+testing GruLearner and RsLearner on
 
   fashionmnist with simple net
 
@@ -109,7 +109,7 @@ def rerun_best_policy(
     accs=[]
     for _ in range(repeat_num):
         print(f'{_}/{repeat_num}')
-        temp_agent = aal.aa_learner(**config)
+        temp_agent = aal.AaLearner(**config)
         accs.append(
                 temp_agent._test_autoaugment_policy(megapol,
                                     child_network_architecture,
diff --git a/docs/source/MetaAugment_library/autoaugment_learners/aa_learners.rst b/docs/source/MetaAugment_library/autoaugment_learners/aa_learners.rst
index db9d7b9f..7c504484 100644
--- a/docs/source/MetaAugment_library/autoaugment_learners/aa_learners.rst
+++ b/docs/source/MetaAugment_library/autoaugment_learners/aa_learners.rst
@@ -5,8 +5,8 @@ AutoAugment learners
 .. autosummary::
    :toctree: generated
 
-   MetaAugment.autoaugment_learners.aa_learner
-   MetaAugment.autoaugment_learners.evo_learner
-   MetaAugment.autoaugment_learners.gru_learner
-   MetaAugment.autoaugment_learners.randomsearch_learner
-   MetaAugment.autoaugment_learners.ucb_learner
\ No newline at end of file
+   MetaAugment.autoaugment_learners.AaLearner
+   MetaAugment.autoaugment_learners.EvoLearner
+   MetaAugment.autoaugment_learners.GruLearner
+   MetaAugment.autoaugment_learners.RsLearner
+   MetaAugment.autoaugment_learners.UcbLearner
\ No newline at end of file
diff --git a/docs/source/MetaAugment_library/autoaugment_learners/generated/MetaAugment.autoaugment_learners.aa_learner.rst b/docs/source/MetaAugment_library/autoaugment_learners/generated/MetaAugment.autoaugment_learners.aa_learner.rst
index 85be0d01..6bf6d010 100644
--- a/docs/source/MetaAugment_library/autoaugment_learners/generated/MetaAugment.autoaugment_learners.aa_learner.rst
+++ b/docs/source/MetaAugment_library/autoaugment_learners/generated/MetaAugment.autoaugment_learners.aa_learner.rst
@@ -3,7 +3,7 @@
 
 .. currentmodule:: MetaAugment.autoaugment_learners
 
-.. autoclass:: aa_learner
+.. autoclass:: AaLearner
 
    
    .. automethod:: __init__
@@ -13,8 +13,8 @@
 
    .. autosummary::
    
-      ~aa_learner.__init__
-      ~aa_learner.learn
+      ~AaLearner.__init__
+      ~AaLearner.learn
    
    
 
diff --git a/docs/source/MetaAugment_library/autoaugment_learners/generated/MetaAugment.autoaugment_learners.evo_learner.rst b/docs/source/MetaAugment_library/autoaugment_learners/generated/MetaAugment.autoaugment_learners.evo_learner.rst
index 37f06b00..2f77ae8b 100644
--- a/docs/source/MetaAugment_library/autoaugment_learners/generated/MetaAugment.autoaugment_learners.evo_learner.rst
+++ b/docs/source/MetaAugment_library/autoaugment_learners/generated/MetaAugment.autoaugment_learners.evo_learner.rst
@@ -3,7 +3,7 @@
 
 .. currentmodule:: MetaAugment.autoaugment_learners
 
-.. autoclass:: evo_learner
+.. autoclass:: EvoLearner
 
    
    .. automethod:: __init__
@@ -13,12 +13,12 @@
 
    .. autosummary::
    
-      ~evo_learner.__init__
-      ~evo_learner.get_full_policy
-      ~evo_learner.get_single_policy_cov
-      ~evo_learner.in_pol_dict
-      ~evo_learner.learn
-      ~evo_learner.set_up_instance
+      ~EvoLearner.__init__
+      ~EvoLearner.get_full_policy
+      ~EvoLearner.get_single_policy_cov
+      ~EvoLearner.in_pol_dict
+      ~EvoLearner.learn
+      ~EvoLearner.set_up_instance
    
    
 
diff --git a/docs/source/MetaAugment_library/autoaugment_learners/generated/MetaAugment.autoaugment_learners.gru_learner.rst b/docs/source/MetaAugment_library/autoaugment_learners/generated/MetaAugment.autoaugment_learners.gru_learner.rst
index 23eb306c..ab5fda8a 100644
--- a/docs/source/MetaAugment_library/autoaugment_learners/generated/MetaAugment.autoaugment_learners.gru_learner.rst
+++ b/docs/source/MetaAugment_library/autoaugment_learners/generated/MetaAugment.autoaugment_learners.gru_learner.rst
@@ -3,7 +3,7 @@
 
 .. currentmodule:: MetaAugment.autoaugment_learners
 
-.. autoclass:: gru_learner
+.. autoclass:: GruLearner
 
    
    .. automethod:: __init__
@@ -13,8 +13,8 @@
 
    .. autosummary::
    
-      ~gru_learner.__init__
-      ~gru_learner.learn
+      ~GruLearner.__init__
+      ~GruLearner.learn
    
    
 
diff --git a/docs/source/MetaAugment_library/autoaugment_learners/generated/MetaAugment.autoaugment_learners.randomsearch_learner.rst b/docs/source/MetaAugment_library/autoaugment_learners/generated/MetaAugment.autoaugment_learners.randomsearch_learner.rst
index 72903e47..cce666fa 100644
--- a/docs/source/MetaAugment_library/autoaugment_learners/generated/MetaAugment.autoaugment_learners.randomsearch_learner.rst
+++ b/docs/source/MetaAugment_library/autoaugment_learners/generated/MetaAugment.autoaugment_learners.randomsearch_learner.rst
@@ -3,7 +3,7 @@
 
 .. currentmodule:: MetaAugment.autoaugment_learners
 
-.. autoclass:: randomsearch_learner
+.. autoclass:: RsLearner
 
    
    .. automethod:: __init__
@@ -13,8 +13,8 @@
 
    .. autosummary::
    
-      ~randomsearch_learner.__init__
-      ~randomsearch_learner.learn
+      ~RsLearner.__init__
+      ~RsLearner.learn
    
    
 
diff --git a/docs/source/MetaAugment_library/autoaugment_learners/generated/MetaAugment.autoaugment_learners.ucb_learner.rst b/docs/source/MetaAugment_library/autoaugment_learners/generated/MetaAugment.autoaugment_learners.ucb_learner.rst
index 83f80f48..718ea1de 100644
--- a/docs/source/MetaAugment_library/autoaugment_learners/generated/MetaAugment.autoaugment_learners.ucb_learner.rst
+++ b/docs/source/MetaAugment_library/autoaugment_learners/generated/MetaAugment.autoaugment_learners.ucb_learner.rst
@@ -3,7 +3,7 @@
 
 .. currentmodule:: MetaAugment.autoaugment_learners
 
-.. autoclass:: ucb_learner
+.. autoclass:: UcbLearner
 
    
    .. automethod:: __init__
@@ -13,9 +13,9 @@
 
    .. autosummary::
    
-      ~ucb_learner.__init__
-      ~ucb_learner.learn
-      ~ucb_learner.make_more_policies
+      ~UcbLearner.__init__
+      ~UcbLearner.learn
+      ~UcbLearner.make_more_policies
    
    
 
diff --git a/docs/source/usage/autoaugment_helperclass.rst b/docs/source/usage/autoaugment_helperclass.rst
index cc361da4..0dca40b3 100644
--- a/docs/source/usage/autoaugment_helperclass.rst
+++ b/docs/source/usage/autoaugment_helperclass.rst
@@ -11,7 +11,7 @@ we use as a helper class to help us apply AutoAugment policies to datasets.
 This is a tutorial (in the sense describe in https://documentation.divio.com/structure/).
 
 For an example of how the material is used in our library, see the source code of
-:meth:`aa_learner._test_autoaugment_policy <MetaAugment.autoaugment_learners.aa_learner>`.
+:meth:`AaLearner._test_autoaugment_policy <MetaAugment.autoaugment_learners.AaLearner>`.
 
 Let's say we have a policy within the search space specified by the original 
 AutoAugment paper:
diff --git a/docs/source/usage/tutorial_for_team.rst b/docs/source/usage/tutorial_for_team.rst
index 1c81cd7c..dde32e15 100644
--- a/docs/source/usage/tutorial_for_team.rst
+++ b/docs/source/usage/tutorial_for_team.rst
@@ -1,12 +1,12 @@
-aa_learner object and its children
+AaLearner object and its children
 ------------------------------------------------------------------------------------------------
 
-This is a page dedicated to demonstrating functionalities of :class:`aa_learner`.
+This is a page dedicated to demonstrating functionalities of :class:`AaLearner`.
 
 This is a how-to guide (in the sense describe in https://documentation.divio.com/structure/).
 
 ######################################################################################################
-How to use the ``aa_learner`` class to find an optimal policy for a dataset-child_network pair
+How to use the ``AaLearner`` class to find an optimal policy for a dataset-child_network pair
 ######################################################################################################
 
 This section can also be read as a ``.py`` file in ``./tutorials/how_use_aalearner.py``.
@@ -37,7 +37,7 @@ Defining the problem setting:
     
     In earlier versions, we had to write ``child_network_architecture=cn.LeNet`` 
     and not ``child_network_architecture=cn.LeNet()``. But now we can do both. 
-    Both types of objects can be input into ``aa_learner.learn()``.
+    Both types of objects can be input into ``AaLearner.learn()``.
 
     More precisely, the ``child_network_architecture`` parameter has to be either
     as ``nn.Module``, a ``function`` which returns a ``nn.Module``, or a ``type`` 
@@ -51,11 +51,11 @@ can use any other learner in place of random search learner as well)
 
 .. code-block::
 
-    # aa_agent = aal.gru_learner()
-    # aa_agent = aal.evo_learner()
-    # aa_agent = aal.ucb_learner()
+    # aa_agent = aal.GruLearner()
+    # aa_agent = aal.EvoLearner()
+    # aa_agent = aal.UcbLearner()
     # aa_agent = aal.ac_learner()
-    aa_agent = aal.randomsearch_learner(
+    aa_agent = aal.RsLearner(
                                     sp_num=7,
                                     toy_size=0.01,
                                     batch_size=4,
@@ -68,7 +68,7 @@ can use any other learner in place of random search learner as well)
                 child_network_architecture=child_network_architecture,
                 iterations=15000)
 
-You can set further hyperparameters when defining a aa_learner. 
+You can set further hyperparameters when defining a AaLearner. 
 
 Also, depending on what learner you are using, there might be unique hyperparameters.
 For example, in the GRU learner you can tune the exploration parameter ``alpha``.
diff --git a/temp_util/wapp_util.py b/temp_util/wapp_util.py
index bb101133..14675ed4 100644
--- a/temp_util/wapp_util.py
+++ b/temp_util/wapp_util.py
@@ -43,12 +43,12 @@ def parse_users_learner_spec(
                                                     IsLeNet
                                                     )
     """
-    The website receives user inputs on what they want the aa_learner
-    to be. We take those hyperparameters and return an aa_learner
+    The website receives user inputs on what they want the AaLearner
+    to be. We take those hyperparameters and return an AaLearner
 
     """
     if auto_aug_learner == 'UCB':
-        learner = aal.ucb_learner(
+        learner = aal.UcbLearner(
                         # parameters that define the search space
                         sp_num=num_sub_policies,
                         p_bins=11,
@@ -61,11 +61,11 @@ def parse_users_learner_spec(
                         learning_rate=learning_rate,
                         max_epochs=max_epochs,
                         early_stop_num=early_stop_num,
-                        # ucb_learner specific hyperparameter
+                        # UcbLearner specific hyperparameter
                         num_policies=num_policies
                         )
     elif auto_aug_learner == 'Evolutionary Learner':
-        learner = aal.evo_learner(
+        learner = aal.EvoLearner(
                         # parameters that define the search space
                         sp_num=num_sub_policies,
                         p_bins=11,
@@ -81,7 +81,7 @@ def parse_users_learner_spec(
                         )
         learner.run_instance()
     elif auto_aug_learner == 'Random Searcher':
-        agent = aal.randomsearch_learner(
+        agent = aal.RsLearner(
                         # parameters that define the search space
                         sp_num=num_sub_policies,
                         p_bins=11,
@@ -96,7 +96,7 @@ def parse_users_learner_spec(
                         early_stop_num=early_stop_num,
                         )
     elif auto_aug_learner == 'GRU Learner':
-        agent = aal.gru_learner(
+        agent = aal.GruLearner(
                         # parameters that define the search space
                         sp_num=num_sub_policies,
                         p_bins=11,
diff --git a/test/MetaAugment/test_aa_learner.py b/test/MetaAugment/test_aa_learner.py
index cbb8d952..b86730d7 100644
--- a/test/MetaAugment/test_aa_learner.py
+++ b/test/MetaAugment/test_aa_learner.py
@@ -9,7 +9,7 @@ import random
 
 def test__translate_operation_tensor():
     """
-    See if aa_learner class's _translate_operation_tensor works
+    See if AaLearner class's _translate_operation_tensor works
     by feeding many (valid) inputs in it.
 
     We make a lot of (fun_num+p_bins_m_bins,) size tensors, softmax 
@@ -29,7 +29,7 @@ def test__translate_operation_tensor():
         p_bins = random.randint(2, 15)
         m_bins = random.randint(2, 15)
         
-        agent = aal.aa_learner(
+        agent = aal.AaLearner(
                 sp_num=5,
                 p_bins=p_bins,
                 m_bins=m_bins,
@@ -57,7 +57,7 @@ def test__translate_operation_tensor():
         p_bins = random.randint(1, 15)
         m_bins = random.randint(1, 15)
 
-        agent = aal.aa_learner(
+        agent = aal.AaLearner(
                 sp_num=5,
                 p_bins=p_bins,
                 m_bins=m_bins,
@@ -77,7 +77,7 @@ def test__translate_operation_tensor():
 
 
 def test__test_autoaugment_policy():
-    agent = aal.aa_learner(
+    agent = aal.AaLearner(
                 sp_num=5,
                 p_bins=11,
                 m_bins=10,
@@ -130,7 +130,7 @@ def test_exclude_method():
                     'Brightness', 
                     'Contrast'
                     ]
-    agent = aal.gru_learner(
+    agent = aal.GruLearner(
         exclude_method=exclude_method
     )
     for _ in range(200):
@@ -142,7 +142,7 @@ def test_exclude_method():
             assert image_function_1 not in exclude_method
             assert image_function_2 not in exclude_method
     
-    agent = aal.randomsearch_learner(
+    agent = aal.RsLearner(
         exclude_method=exclude_method
     )
     for _ in range(200):
@@ -157,7 +157,7 @@ def test_exclude_method():
 
 def test_get_mega_policy():
 
-    agent = aal.randomsearch_learner(
+    agent = aal.RsLearner(
                 sp_num=5,
                 p_bins=11,
                 m_bins=10,
diff --git a/test/MetaAugment/test_evo_learner.py b/test/MetaAugment/test_evo_learner.py
index b917fb39..4ec17ca5 100644
--- a/test/MetaAugment/test_evo_learner.py
+++ b/test/MetaAugment/test_evo_learner.py
@@ -13,7 +13,7 @@ def test_evo_learner():
                             transform=torchvision.transforms.ToTensor())
 
 
-    learner = aal.evo_learner(
+    learner = aal.EvoLearner(
         # parameters that define the search space
                 sp_num=5,
                 p_bins=11,
diff --git a/test/MetaAugment/test_gru_learner.py b/test/MetaAugment/test_gru_learner.py
index b2ea8930..236a17d3 100644
--- a/test/MetaAugment/test_gru_learner.py
+++ b/test/MetaAugment/test_gru_learner.py
@@ -8,7 +8,7 @@ import random
 
 def test__generate_new_policy():
     """
-    make sure gru_learner._generate_new_policy() is robust
+    make sure GruLearner._generate_new_policy() is robust
     with respect to different values of sp_num, fun_num, 
     p_bins, and m_bins
     """
@@ -17,7 +17,7 @@ def test__generate_new_policy():
         p_bins = random.randint(2, 15)
         m_bins = random.randint(2, 15)
 
-        agent = aal.gru_learner(
+        agent = aal.GruLearner(
             sp_num=sp_num,
             p_bins=p_bins,
             m_bins=m_bins,
@@ -30,7 +30,7 @@ def test__generate_new_policy():
 
 def test_learn():
     """
-    tests the gru_learner.learn() method
+    tests the GruLearner.learn() method
     """
     train_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/train',
                             train=True, download=True, transform=None)
@@ -40,7 +40,7 @@ def test_learn():
     child_network_architecture = cn.lenet
     # child_network_architecture = cn.lenet()
 
-    agent = aal.gru_learner(
+    agent = aal.GruLearner(
                         sp_num=7,
                         toy_size=0.001,
                         batch_size=32,
diff --git a/test/MetaAugment/test_randomsearch_learner.py b/test/MetaAugment/test_randomsearch_learner.py
index 6c5a9350..6bfd5441 100644
--- a/test/MetaAugment/test_randomsearch_learner.py
+++ b/test/MetaAugment/test_randomsearch_learner.py
@@ -8,7 +8,7 @@ import random
 
 def test__generate_new_policy():
     """
-    make sure randomsearch_learner._generate_new_policy() is robust
+    make sure RsLearner._generate_new_policy() is robust
     with respect to different values of sp_num, fun_num, 
     p_bins, and m_bins
     """
@@ -20,7 +20,7 @@ def test__generate_new_policy():
             p_bins = random.randint(2, 15)
             m_bins = random.randint(2, 15)
 
-            agent = aal.randomsearch_learner(
+            agent = aal.RsLearner(
                 sp_num=sp_num,
                 p_bins=p_bins,
                 m_bins=m_bins,
@@ -39,7 +39,7 @@ def test__generate_new_policy():
 
 def test_learn():
     """
-    tests the randomsearch_learner.learn() method
+    tests the RsLearner.learn() method
     """
     train_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/train',
                             train=True, download=True, transform=None)
@@ -49,7 +49,7 @@ def test_learn():
     child_network_architecture = cn.lenet
     # child_network_architecture = cn.lenet()
 
-    agent = aal.randomsearch_learner(
+    agent = aal.RsLearner(
                         sp_num=7,
                         toy_size=0.001,
                         batch_size=32,
diff --git a/test/MetaAugment/test_ucb_learner.py b/test/MetaAugment/test_ucb_learner.py
index fc2807aa..6a1cc0fe 100644
--- a/test/MetaAugment/test_ucb_learner.py
+++ b/test/MetaAugment/test_ucb_learner.py
@@ -13,7 +13,7 @@ def test_ucb_learner():
                             transform=torchvision.transforms.ToTensor())
 
 
-    learner = aal.ucb_learner(
+    learner = aal.UcbLearner(
         # parameters that define the search space
                 sp_num=5,
                 p_bins=11,
@@ -25,7 +25,7 @@ def test_ucb_learner():
                 learning_rate=1e-1,
                 max_epochs=float('inf'),
                 early_stop_num=30,
-                # ucb_learner specific hyperparameter
+                # UcbLearner specific hyperparameter
                 num_policies=3
     )
     pprint(learner.policies)
-- 
GitLab