From 2a450edb0f5f2ec6cb44fb134acffb6b0c6f8851 Mon Sep 17 00:00:00 2001
From: Max Ramsay King <maxramsayking@gmail.com>
Date: Tue, 5 Apr 2022 10:59:19 -0700
Subject: [PATCH] Using test_autoaugment_policy to calcualte the accuracy of
 subpolicies. ES learner now outputs similar list to randomsearcher

---
 MetaAugment/CP2_Max.py | 68 ++++++++++++++++++++++++++----------------
 1 file changed, 43 insertions(+), 25 deletions(-)

diff --git a/MetaAugment/CP2_Max.py b/MetaAugment/CP2_Max.py
index 255a203f..af6be7e4 100644
--- a/MetaAugment/CP2_Max.py
+++ b/MetaAugment/CP2_Max.py
@@ -13,6 +13,10 @@ import pygad
 import pygad.torchga as torchga
 import random
 import copy
+from torchvision.transforms import functional as F, InterpolationMode
+from typing import List, Tuple, Optional, Dict
+
+
 
 # from MetaAugment.main import *
 # import MetaAugment.child_networks as child_networks
@@ -112,36 +116,36 @@ class LeNet(nn.Module):
 
 
 # code from https://github.com/ChawDoe/LeNet5-MNIST-PyTorch/blob/master/train.py
-def train_model(full_policy, child_network):
-    """
-    Takes in the specific transformation index and probability 
-    """
+# def train_model(full_policy, child_network):
+#     """
+#     Takes in the specific transformation index and probability 
+#     """
 
-    # transformation = generate_policy(5, ps, mags)
+#     # transformation = generate_policy(5, ps, mags)
 
-    train_transform = transforms.Compose([
-                                            full_policy,
-                                            transforms.ToTensor()
-                                        ])
+#     train_transform = transforms.Compose([
+#                                             full_policy,
+#                                             transforms.ToTensor()
+#                                         ])
 
-    batch_size = 32
-    n_samples = 0.005
+#     batch_size = 32
+#     n_samples = 0.005
 
-    train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False, transform=train_transform)
-    test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=False, transform=torchvision.transforms.ToTensor())
+#     train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False, transform=train_transform)
+#     test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=False, transform=torchvision.transforms.ToTensor())
 
-    train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, 0.01)
+#     train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, 0.01)
 
 
-    sgd = optim.SGD(child_network.parameters(), lr=1e-1)
-    cost = nn.CrossEntropyLoss()
-    epoch = 20
+#     sgd = optim.SGD(child_network.parameters(), lr=1e-1)
+#     cost = nn.CrossEntropyLoss()
+#     epoch = 20
 
 
-    best_acc = train_child_network(child_network, train_loader, test_loader,
-                                     sgd, cost, max_epochs=100, print_every_epoch=False)
+#     best_acc = train_child_network(child_network, train_loader, test_loader,
+#                                      sgd, cost, max_epochs=100, print_every_epoch=False)
 
-    return best_acc
+#     return best_acc
 
 
 
@@ -168,7 +172,7 @@ train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=600
 
 class Evolutionary_learner():
 
-    def __init__(self, network, num_solutions = 30, num_generations = 10, num_parents_mating = 15, train_loader = None, sec_model = None, p_bins = 11, mag_bins = 10, fun_num = 14, augmentation_space = None):
+    def __init__(self, network, num_solutions = 30, num_generations = 10, num_parents_mating = 15, train_loader = None, sec_model = None, p_bins = 11, mag_bins = 10, fun_num = 14, augmentation_space = None, train_dataset = None, test_dataset = None):
         self.meta_rl_agent = Learner(fun_num, p_bins=11, m_bins=10)
         self.torch_ga = torchga.TorchGA(model=network, num_solutions=num_solutions)
         self.num_generations = num_generations
@@ -180,6 +184,8 @@ class Evolutionary_learner():
         self.mag_bins = mag_bins
         self.fun_num = fun_num
         self.augmentation_space = augmentation_space
+        self.train_dataset = train_dataset
+        self.test_dataset = test_dataset
 
         assert num_solutions > num_parents_mating, 'Number of solutions must be larger than the number of parents mating!'
 
@@ -219,7 +225,7 @@ class Evolutionary_learner():
                 trans, need_mag = self.augmentation_space[idx_ret]
 
                 p_ret = 0.1 * torch.argmax(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0))
-                mag = torch.argmax(y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0)) if need_mag else 0
+                mag = torch.argmax(y[:, (pol * section)+self.fun_num+self.p_bins:((pol+1)*section)].mean(dim = 0)) if need_mag else None
                 int_pol.append((trans, p_ret, mag))
 
             full_policy.append(tuple(int_pol))
@@ -253,22 +259,29 @@ class Evolutionary_learner():
             """
             Defines fitness function (accuracy of the model)
             """
+
             model_weights_dict = torchga.model_weights_as_dict(model=self.meta_rl_agent,
                                                             weights_vector=solution)
+
             self.meta_rl_agent.load_state_dict(model_weights_dict)
+
             for idx, (test_x, label_x) in enumerate(train_loader):
                 full_policy = self.get_full_policy(test_x)
+
+
             cop_mod = self.new_model()
-            fit_val = train_model(full_policy, cop_mod)
+
+            fit_val = test_autoaugment_policy(full_policy, self.train_dataset, self.test_dataset)[0]
             cop_mod = 0
+
             return fit_val
 
         def on_generation(ga_instance):
             """
             Just prints stuff while running
             """
-            print("Generation = {generation}".format(generation=self.ga_instance.generations_completed))
-            print("Fitness    = {fitness}".format(fitness=self.ga_instance.best_solution()[1]))
+            print("Generation = {generation}".format(generation=ga_instance.generations_completed))
+            print("Fitness    = {fitness}".format(fitness=ga_instance.best_solution()[1]))
             return
 
 
@@ -279,6 +292,11 @@ class Evolutionary_learner():
             on_generation = on_generation)
 
 
+
+
+
+
+
 meta_rl_agent = Learner()
 ev_learner = Evolutionary_learner(meta_rl_agent, train_loader=train_loader, sec_model=LeNet(), augmentation_space=augmentation_space)
 ev_learner.run_instance()
-- 
GitLab