diff --git a/MetaAugment/CP2_Max.py b/MetaAugment/CP2_Max.py index 04c8ffcfe871f3fdc5c570c522e92e81e52228aa..255a203f477a9fb7549f5a4e0ba7ff0bbaa65f42 100644 --- a/MetaAugment/CP2_Max.py +++ b/MetaAugment/CP2_Max.py @@ -215,6 +215,7 @@ class Evolutionary_learner(): for _ in range(2): idx_ret = torch.argmax(y[:, (pol * section):(pol*section) + self.fun_num].mean(dim = 0)) + trans, need_mag = self.augmentation_space[idx_ret] p_ret = 0.1 * torch.argmax(y[:, (pol * section)+self.fun_num:(pol*section)+self.fun_num+self.p_bins].mean(dim = 0)) @@ -256,7 +257,7 @@ class Evolutionary_learner(): weights_vector=solution) self.meta_rl_agent.load_state_dict(model_weights_dict) for idx, (test_x, label_x) in enumerate(train_loader): - full_policy = self.meta_rl_agent.get_full_policy(test_x) + full_policy = self.get_full_policy(test_x) cop_mod = self.new_model() fit_val = train_model(full_policy, cop_mod) cop_mod = 0 @@ -279,7 +280,7 @@ class Evolutionary_learner(): meta_rl_agent = Learner() -ev_learner = Evolutionary_learner(meta_rl_agent, train_loader=train_loader, sec_model=LeNet()) +ev_learner = Evolutionary_learner(meta_rl_agent, train_loader=train_loader, sec_model=LeNet(), augmentation_space=augmentation_space) ev_learner.run_instance()