diff --git a/MetaAugment/autoaugment_learners/evo_learner.py b/MetaAugment/autoaugment_learners/evo_learner.py index 9c249c1b5a599699a23983e9b4f8b72c0931c1ca..62ad82b86a9f2b30f0003d8a198370bb1df82354 100644 --- a/MetaAugment/autoaugment_learners/evo_learner.py +++ b/MetaAugment/autoaugment_learners/evo_learner.py @@ -17,7 +17,6 @@ class Evolutionary_learner(): def __init__(self, sp_num=1, num_solutions = 10, - num_generations = 5, num_parents_mating = 5, learning_rate = 1e-1, max_epochs=float('inf'), @@ -50,7 +49,6 @@ class Evolutionary_learner(): self.auto_aug_agent = Evo_learner(fun_num=fun_num, p_bins=p_bins, m_bins=m_bins, sub_num_pol=sub_num_pol) self.torch_ga = torchga.TorchGA(model=self.auto_aug_agent, num_solutions=num_solutions) - self.num_generations = num_generations self.num_parents_mating = num_parents_mating self.initial_population = self.torch_ga.population_weights self.train_loader = train_loader @@ -65,7 +63,6 @@ class Evolutionary_learner(): assert num_solutions > num_parents_mating, 'Number of solutions must be larger than the number of parents mating!' - self.set_up_instance() def get_full_policy(self, x): @@ -178,7 +175,7 @@ class Evolutionary_learner(): return [(self.augmentation_space[max_idx[0]][0], prob1, mag1), (self.augmentation_space[max_idx[1]][0], prob2, mag2)] - def run_instance(self, return_weights = False): + def learn(self, iterations = 15, return_weights = False): """ Runs the GA instance and returns the model weights as a dictionary @@ -199,6 +196,9 @@ class Evolutionary_learner(): Solution_idx -> Int """ + self.num_generations = iterations + self.set_up_instance() + self.ga_instance.run() solution, solution_fitness, solution_idx = self.ga_instance.best_solution() if return_weights: