From f5a10e7ecc2e4ff39de96cc171f94f81ad670a93 Mon Sep 17 00:00:00 2001 From: Max Ramsay King <maxramsayking@gmail.com> Date: Mon, 4 Apr 2022 13:01:12 -0700 Subject: [PATCH] fixed bug --- MetaAugment/CP2_Max.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/MetaAugment/CP2_Max.py b/MetaAugment/CP2_Max.py index 04c8ffcf..255a203f 100644 --- a/MetaAugment/CP2_Max.py +++ b/MetaAugment/CP2_Max.py @@ -215,6 +215,7 @@ class Evolutionary_learner(): for _ in range(2): idx_ret = torch.argmax(y[:, (pol * section):(pol*section) + self.fun_num].mean(dim = 0)) + trans, need_mag = self.augmentation_space[idx_ret] p_ret = 0.1 * torch.argmax(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0)) @@ -256,7 +257,7 @@ class Evolutionary_learner(): weights_vector=solution) self.meta_rl_agent.load_state_dict(model_weights_dict) for idx, (test_x, label_x) in enumerate(train_loader): - full_policy = self.meta_rl_agent.get_full_policy(test_x) + full_policy = self.get_full_policy(test_x) cop_mod = self.new_model() fit_val = train_model(full_policy, cop_mod) cop_mod = 0 @@ -279,7 +280,7 @@ class Evolutionary_learner(): meta_rl_agent = Learner() -ev_learner = Evolutionary_learner(meta_rl_agent, train_loader=train_loader, sec_model=LeNet()) +ev_learner = Evolutionary_learner(meta_rl_agent, train_loader=train_loader, sec_model=LeNet(), augmentation_space=augmentation_space) ev_learner.run_instance() -- GitLab