diff --git a/MetaAugment/autoaugment_learners/evo_learner.py b/MetaAugment/autoaugment_learners/evo_learner.py index fe7a60e5de409cf04d48528ec0a5b3e2ec33da27..92091124a83dc08baa68c08a594a68e04dc2e47e 100644 --- a/MetaAugment/autoaugment_learners/evo_learner.py +++ b/MetaAugment/autoaugment_learners/evo_learner.py @@ -221,22 +221,21 @@ class evo_learner(aa_learner): self.train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size) for idx, (test_x, label_x) in enumerate(self.train_loader): - full_policy = self._get_single_policy_cov(test_x) + sub_pol = self._get_single_policy_cov(test_x) - while self._in_pol_dict(full_policy): - full_policy = self._get_single_policy_cov(test_x)[0] + while self._in_pol_dict(sub_pol): + sub_pol = self._get_single_policy_cov(test_x)[0] - fit_val = self._test_autoaugment_policy(full_policy,child_network_architecture,train_dataset,test_dataset) + fit_val = self._test_autoaugment_policy(sub_pol,child_network_architecture,train_dataset,test_dataset) - self.history.append((full_policy, fit_val)) - self.running_policy.append((full_policy, fit_val)) + self.history.append((sub_pol, fit_val)) + self.running_policy.append((sub_pol, fit_val)) if len(self.running_policy) > self.sp_num: self.running_policy = sorted(self.running_policy, key=lambda x: x[1], reverse=True) self.running_policy = self.running_policy[:self.sp_num] - print("appended policy: ", self.running_policy) if len(self.history_best) == 0: