From 5f05e0d721a488977a940b16869c289b53cc2487 Mon Sep 17 00:00:00 2001
From: Max Ramsay King <maxramsayking@gmail.com>
Date: Mon, 4 Apr 2022 10:31:35 -0700
Subject: [PATCH] Updated the ES learner to be more in line with the random
 search funcitons

---
 MetaAugment/CP2_Max.py | 146 +++++++++++++++++++++++++----------------
 1 file changed, 89 insertions(+), 57 deletions(-)

diff --git a/MetaAugment/CP2_Max.py b/MetaAugment/CP2_Max.py
index 61cccbec..c85e9fa9 100644
--- a/MetaAugment/CP2_Max.py
+++ b/MetaAugment/CP2_Max.py
@@ -14,36 +14,56 @@ import pygad.torchga as torchga
 import random
 import copy
 
-from MetaAugment.main import *
-
-# import MetaAugment.child_networks as child_networks
 # from MetaAugment.main import *
+# import MetaAugment.child_networks as child_networks
 
 
 np.random.seed(0)
 random.seed(0)
 
 
-# augmentation_space = [
-#             # (function_name, do_we_need_to_specify_magnitude)
-#             ("ShearX", True),
-#             ("ShearY", True),
-#             ("TranslateX", True),
-#             ("TranslateY", True),
-#             ("Rotate", True),
-#             ("Brightness", True),
-#             ("Color", True),
-#             ("Contrast", True),
-#             ("Sharpness", True),
-#             ("Posterize", True),
-#             ("Solarize", True),
-#             ("AutoContrast", False),
-#             ("Equalize", False),
-#             ("Invert", False),
-#         ]
+augmentation_space = [
+            # (function_name, do_we_need_to_specify_magnitude)
+            ("ShearX", True),
+            ("ShearY", True),
+            ("TranslateX", True),
+            ("TranslateY", True),
+            ("Rotate", True),
+            ("Brightness", True),
+            ("Color", True),
+            ("Contrast", True),
+            ("Sharpness", True),
+            ("Posterize", True),
+            ("Solarize", True),
+            ("AutoContrast", False),
+            ("Equalize", False),
+            ("Invert", False),
+        ]
 
 class Learner(nn.Module):
-    def __init__(self, num_transforms = 3):
+    def __init__(self, fun_num=14, p_bins=11, m_bins=10):
+        self.fun_num = fun_num
+        self.p_bins = p_bins 
+        self.m_bins = m_bins 
+
+        self.augmentation_space = [
+            # (function_name, do_we_need_to_specify_magnitude)
+            ("ShearX", True),
+            ("ShearY", True),
+            ("TranslateX", True),
+            ("TranslateY", True),
+            ("Rotate", True),
+            ("Brightness", True),
+            ("Color", True),
+            ("Contrast", True),
+            ("Sharpness", True),
+            ("Posterize", True),
+            ("Solarize", True),
+            ("AutoContrast", False),
+            ("Equalize", False),
+            ("Invert", False),
+        ]
+
         super().__init__()
         self.conv1 = nn.Conv2d(1, 6, 5)
         self.relu1 = nn.ReLU()
@@ -55,11 +75,9 @@ class Learner(nn.Module):
         self.relu3 = nn.ReLU()
         self.fc2 = nn.Linear(120, 84)
         self.relu4 = nn.ReLU()
-        self.fc3 = nn.Linear(84, num_transforms + 21)
-        # self.sig = nn.Sigmoid()
-# Currently using discrete outputs for the probabilities 
-
+        self.fc3 = nn.Linear(84, 5 * 2 * (self.fun_num + self.p_bins + self.m_bins))
 
+# Currently using discrete outputs for the probabilities 
 
     def forward(self, x):
         y = self.conv1(x)
@@ -78,10 +96,22 @@ class Learner(nn.Module):
         return y
 
     def get_idx(self, x):
+        section = self.fun_num + self.p_bins + self.m_bins
         y = self.forward(x)
-        idx_ret = torch.argmax(y[:, 0:3].mean(dim = 0))
-        p_ret = 0.1 * torch.argmax(y[:, 3:].mean(dim = 0))
-        return (idx_ret, p_ret)
+        full_policy = []
+        for pol in range(5 * 2):
+            int_pol = []
+            idx_ret = torch.argmax(y[:, (pol * section):(pol*section) + self.fun_num].mean(dim = 0))
+
+            trans, need_mag = self.augmentation_space[idx_ret]
+
+            p_ret = 0.1 * torch.argmax(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0))
+            mag = torch.argmax(y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0)) if need_mag else 0
+            int_pol.append((trans, p_ret, mag))
+            if pol % 2 != 0:
+                full_policy.append(tuple(int_pol))
+
+        return full_policy
 
 
 class LeNet(nn.Module):
@@ -118,44 +148,27 @@ class LeNet(nn.Module):
 
 
 # code from https://github.com/ChawDoe/LeNet5-MNIST-PyTorch/blob/master/train.py
-def train_model(transform_idx, p, child_network):
+def train_model(full_policy, child_network):
     """
     Takes in the specific transformation index and probability 
     """
 
-    if transform_idx == 0:
-        transform_train = torchvision.transforms.Compose(
-           [
-            torchvision.transforms.RandomVerticalFlip(p),
-            torchvision.transforms.ToTensor(),
-            ]
-               )
-    elif transform_idx == 1:
-        transform_train = torchvision.transforms.Compose(
-           [
-            torchvision.transforms.RandomHorizontalFlip(p),
-            torchvision.transforms.ToTensor(),
-            ]
-               )
-    else:
-        transform_train = torchvision.transforms.Compose(
-           [
-            torchvision.transforms.RandomGrayscale(p),
-            torchvision.transforms.ToTensor(),
-            ]
-               )
+    # transformation = generate_policy(5, ps, mags)
+
+    train_transform = transforms.Compose([
+                                            full_policy,
+                                            transforms.ToTensor()
+                                        ])
 
     batch_size = 32
     n_samples = 0.005
 
-    train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False, transform=transform_train)
+    train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False, transform=train_transform)
     test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=False, transform=torchvision.transforms.ToTensor())
 
     train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, 0.01)
 
 
-    # child_network = child_networks.lenet()
-
     sgd = optim.SGD(child_network.parameters(), lr=1e-1)
     cost = nn.CrossEntropyLoss()
     epoch = 20
@@ -191,20 +204,37 @@ train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=600
 
 class Evolutionary_learner():
 
-    def __init__(self, network, num_solutions = 30, num_generations = 10, num_parents_mating = 15, train_loader = None, sec_model = None):
-        self.meta_rl_agent = network
+    def __init__(self, network, num_solutions = 30, num_generations = 10, num_parents_mating = 15, train_loader = None, sec_model = None, p_bins = 11, mag_bins = 10, fun_num = 14):
+        self.meta_rl_agent = Learner(fun_num, p_bins=11, m_bins=10)
         self.torch_ga = torchga.TorchGA(model=network, num_solutions=num_solutions)
         self.num_generations = num_generations
         self.num_parents_mating = num_parents_mating
         self.initial_population = self.torch_ga.population_weights
         self.train_loader = train_loader
         self.sec_model = sec_model
+        self.p_bins = p_bins 
+        self.mag_bins = mag_bins
+        self.fun_num = fun_num
 
         assert num_solutions > num_parents_mating, 'Number of solutions must be larger than the number of parents mating!'
 
         self.set_up_instance()
     
 
+    def generate_policy(self, sp_num, ps, mags):
+        policies = []
+        for subpol in range(sp_num):
+            sub = []
+            for idx in range(2):
+                transformation = augmentation_space[(2*subpol) + idx]
+                p = ps[(2*subpol) + idx]
+                mag = mags[(2*subpol) + idx]
+                sub.append((transformation, p, mag))
+            policies.append(tuple(sub))
+        
+        return policies
+
+
     def run_instance(self, return_weights = False):
         self.ga_instance.run()
         solution, solution_fitness, solution_idx = self.ga_instance.best_solution()
@@ -213,12 +243,14 @@ class Evolutionary_learner():
         else:
             return solution, solution_fitness, solution_idx
 
+
     def new_model(self):
         copy_model = copy.deepcopy(self.sec_model)
         return copy_model
 
 
     def set_up_instance(self):
+
         def fitness_func(solution, sol_idx):
             """
             Defines fitness function (accuracy of the model)
@@ -227,9 +259,9 @@ class Evolutionary_learner():
                                                             weights_vector=solution)
             self.meta_rl_agent.load_state_dict(model_weights_dict)
             for idx, (test_x, label_x) in enumerate(train_loader):
-                trans_idx, p = self.meta_rl_agent.get_idx(test_x)
+                full_policy = self.meta_rl_agent.get_idx(test_x)
             cop_mod = self.new_model()
-            fit_val = train_model(trans_idx, p, cop_mod)
+            fit_val = train_model(full_policy, cop_mod)
             cop_mod = 0
             return fit_val
 
-- 
GitLab