From 5c3220051e377196f5a81f1b47aba96c4ba46edc Mon Sep 17 00:00:00 2001
From: Max Ramsay King <maxramsayking@gmail.com>
Date: Tue, 26 Apr 2022 21:24:34 +0100
Subject: [PATCH] fixed evo learner

---
 .../autoaugment_learners/evo_learner.py       | 65 +++++++++----------
 1 file changed, 32 insertions(+), 33 deletions(-)

diff --git a/MetaAugment/autoaugment_learners/evo_learner.py b/MetaAugment/autoaugment_learners/evo_learner.py
index 34cc2d44..a1f21a1a 100644
--- a/MetaAugment/autoaugment_learners/evo_learner.py
+++ b/MetaAugment/autoaugment_learners/evo_learner.py
@@ -1,3 +1,4 @@
+from cgi import test
 import torch
 torch.manual_seed(0)
 import torch.nn as nn
@@ -5,43 +6,39 @@ import pygad
 import pygad.torchga as torchga
 import copy
 import torch
+from meta_augment.controller_networks.evo_controller import evo_controller
 
-from MetaAugment.autoaugment_learners.aa_learner import aa_learner
+from meta_augment.autoaugment_learners.aa_learner import aa_learner
+import meta_augment.child_networks as cn
 
 
 class evo_learner(aa_learner):
 
     def __init__(self, 
-                # search space settings
-                sp_num=5,
-                p_bins=10, 
-                m_bins=10, 
-                discrete_p_m=False,
-                exclude_method=[],
-                # child network settings
-                learning_rate=1e-1, 
+                sp_num=1,
+                num_solutions = 5, 
+                num_parents_mating = 3,
+                learning_rate = 1e-1, 
                 max_epochs=float('inf'),
                 early_stop_num=20,
+                p_bins = 1, 
+                m_bins = 1, 
                 batch_size=8,
-                toy_size=1,
-                # evolutionary learner specific settings
-                num_solutions=5,
-                num_parents_mating=3,
-                controller=None
+                toy_size=0.1,
+                fun_num = 14,
+                exclude_method=[],
+                controller = None
                 ):
 
-        super().__init__(
-                    sp_num=sp_num, 
-                    p_bins=p_bins, 
-                    m_bins=m_bins, 
-                    discrete_p_m=discrete_p_m, 
-                    batch_size=batch_size, 
-                    toy_size=toy_size, 
-                    learning_rate=learning_rate,
-                    max_epochs=max_epochs,
-                    early_stop_num=early_stop_num,
-                    exclude_method=exclude_method
-                    )
+        super().__init__(sp_num, 
+            fun_num, 
+            p_bins, 
+            m_bins, 
+            batch_size=batch_size, 
+            toy_size=toy_size, 
+            learning_rate=learning_rate,
+            max_epochs=max_epochs,
+            early_stop_num=early_stop_num,)
 
         self.num_solutions = num_solutions
         self.controller = controller
@@ -51,6 +48,8 @@ class evo_learner(aa_learner):
         self.p_bins = p_bins 
         self.sub_num_pol = sp_num
         self.m_bins = m_bins
+        self.fun_num = fun_num
+        self.augmentation_space = [x for x in self.augmentation_space if x[0] not in exclude_method]
         self.policy_dict = {}
         self.policy_result = []
 
@@ -58,6 +57,7 @@ class evo_learner(aa_learner):
         assert num_solutions > num_parents_mating, 'Number of solutions must be larger than the number of parents mating!'
 
 
+
     def get_full_policy(self, x):
         """
         Generates the full policy (self.num_sub_pol subpolicies). Network architecture requires
@@ -167,10 +167,10 @@ class evo_learner(aa_learner):
                 prob2 += torch.sigmoid(y[idx, section+self.fun_num]).item()
                 if mag1 is not None:
                     # mag1 += min(max(0, (y[idx, self.auto_aug_agent.fun_num+1]).item()), 8)
-                    mag1 += 10 * torch.sigmoid(y[idx, self.fun_num+1]).item()
+                    mag1 += min(9, 10 * torch.sigmoid(y[idx, self.fun_num+1]).item())
                 if mag2 is not None:
                     # mag2 += min(max(0, y[idx, section+self.auto_aug_agent.fun_num+1].item()), 8)
-                    mag2 += 10 * torch.sigmoid(y[idx, self.fun_num+1]).item()
+                    mag2 += min(9, 10 * torch.sigmoid(y[idx, self.fun_num+1]).item())
 
                 counter += 1
 
@@ -240,7 +240,7 @@ class evo_learner(aa_learner):
                 self.policy_dict[trans1][trans2].append(new_set)
                 return False 
             else:
-                self.policy_dict[trans1][trans2] = [new_set]
+                self.policy_dict[trans1] = {trans2: [new_set]}
         if trans2 in self.policy_dict:
             if trans1 in self.policy_dict[trans2]:
                 for test_pol in self.policy_dict[trans2][trans1]:
@@ -249,7 +249,7 @@ class evo_learner(aa_learner):
                 self.policy_dict[trans2][trans1].append(new_set)
                 return False 
             else:
-                self.policy_dict[trans2][trans1] = [new_set]
+                self.policy_dict[trans2] = {trans1: [new_set]}
 
 
     def set_up_instance(self, train_dataset, test_dataset, child_network_architecture):
@@ -298,11 +298,11 @@ class evo_learner(aa_learner):
             if len(self.policy_result) > self.sp_num:
                 self.policy_result = sorted(self.policy_result, key=lambda x: x[1], reverse=True)
                 self.policy_result = self.policy_result[:self.sp_num]
-                print("Appended policy: ", self.policy_result)
+                print("appended policy: ", self.policy_result)
 
 
             if fit_val > self.history_best[self.gen_count]:
-                print("Best policy: ", full_policy)
+                print("best policy: ", full_policy)
                 self.history_best[self.gen_count] = fit_val 
                 self.best_model = model_weights_dict
             
@@ -335,4 +335,3 @@ class evo_learner(aa_learner):
             mutation_percent_genes = 0.1,
             fitness_func=fitness_func,
             on_generation = on_generation)
-
-- 
GitLab