diff --git a/MetaAugment/autoaugment_learners/evo_learner.py b/MetaAugment/autoaugment_learners/evo_learner.py index b4c2e4be596c9ce9f9b1dec0e7772f58cce72168..92347c767098d033eb0afeecd2f640a1f016e142 100644 --- a/MetaAugment/autoaugment_learners/evo_learner.py +++ b/MetaAugment/autoaugment_learners/evo_learner.py @@ -196,15 +196,15 @@ class evo_learner(): Solution_idx -> Int """ self.num_generations = iterations - self.running_best = [0 for i in range(iterations)] - self.running_avg = [0 for i in range(iterations)] + self.history_best = [0 for i in range(iterations)] + self.history_avg = [0 for i in range(iterations)] self.gen_count = 0 self.best_model = 0 self.set_up_instance() self.ga_instance.run() - self.running_avg = self.running_avg / self.num_solutions + self.history_avg = self.history_avg / self.num_solutions solution, solution_fitness, solution_idx = self.ga_instance.best_solution() if return_weights: @@ -250,10 +250,11 @@ class evo_learner(): fit_val = ((self.test_autoaugment_policy(full_policy, train_dataset, test_dataset)[0]) / + self.test_autoaugment_policy(full_policy, train_dataset, test_dataset)[0]) / 2 - if fit_val > self.running_best[self.gen_count]: - self.running_best[self.gen_count] = fit_val + if fit_val > self.history_best[self.gen_count]: + self.history_best[self.gen_count] = fit_val + self.best_model = model_weights_dict - self.running_avg[self.gen_count] += fit_val + self.history_avg[self.gen_count] += fit_val return fit_val