From f5a10e7ecc2e4ff39de96cc171f94f81ad670a93 Mon Sep 17 00:00:00 2001
From: Max Ramsay King <maxramsayking@gmail.com>
Date: Mon, 4 Apr 2022 13:01:12 -0700
Subject: [PATCH] fixed bug

---
 MetaAugment/CP2_Max.py | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/MetaAugment/CP2_Max.py b/MetaAugment/CP2_Max.py
index 04c8ffcf..255a203f 100644
--- a/MetaAugment/CP2_Max.py
+++ b/MetaAugment/CP2_Max.py
@@ -215,6 +215,7 @@ class Evolutionary_learner():
             for _ in range(2):
                 idx_ret = torch.argmax(y[:, (pol * section):(pol*section) + self.fun_num].mean(dim = 0))
 
+
                 trans, need_mag = self.augmentation_space[idx_ret]
 
                 p_ret = 0.1 * torch.argmax(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0))
@@ -256,7 +257,7 @@ class Evolutionary_learner():
                                                             weights_vector=solution)
             self.meta_rl_agent.load_state_dict(model_weights_dict)
             for idx, (test_x, label_x) in enumerate(train_loader):
-                full_policy = self.meta_rl_agent.get_full_policy(test_x)
+                full_policy = self.get_full_policy(test_x)
             cop_mod = self.new_model()
             fit_val = train_model(full_policy, cop_mod)
             cop_mod = 0
@@ -279,7 +280,7 @@ class Evolutionary_learner():
 
 
 meta_rl_agent = Learner()
-ev_learner = Evolutionary_learner(meta_rl_agent, train_loader=train_loader, sec_model=LeNet())
+ev_learner = Evolutionary_learner(meta_rl_agent, train_loader=train_loader, sec_model=LeNet(), augmentation_space=augmentation_space)
 ev_learner.run_instance()
 
 
-- 
GitLab