Skip to content
Snippets Groups Projects
Commit ec024880 authored by Sun Jin Kim's avatar Sun Jin Kim
Browse files

Merge branch 'master' of gitlab.doc.ic.ac.uk:yw21218/metarl

parents e0be4165 3f0557f2
No related branches found
No related tags found
No related merge requests found
File added
<meta HTTP-EQUIV="REFRESH" content="0; url=http://www.cs.toronto.edu/~kriz/cifar.html">
File added
<meta HTTP-EQUIV="REFRESH" content="0; url=http://www.cs.toronto.edu/~kriz/cifar.html">
This diff is collapsed.
......@@ -29,7 +29,6 @@ class evo_learner():
batch_size=8,
toy_flag=False,
toy_size=0.1,
sub_num_pol=5,
fun_num = 14,
exclude_method=[],
):
......@@ -46,15 +45,15 @@ class evo_learner():
max_epochs=max_epochs,
early_stop_num=early_stop_num,)
self.auto_aug_agent = Evo_learner(fun_num=fun_num, p_bins=p_bins, m_bins=m_bins, sub_num_pol=sub_num_pol)
self.num_solutions = num_solutions
self.auto_aug_agent = Evo_learner(fun_num=fun_num, p_bins=p_bins, m_bins=m_bins, sub_num_pol=sp_num)
self.torch_ga = torchga.TorchGA(model=self.auto_aug_agent, num_solutions=num_solutions)
self.num_parents_mating = num_parents_mating
self.initial_population = self.torch_ga.population_weights
self.train_loader = train_loader
self.child_network = child_network
self.p_bins = p_bins
self.sub_num_pol = sub_num_pol
self.sub_num_pol = sp_num
self.m_bins = m_bins
self.fun_num = fun_num
self.augmentation_space = [x for x in augmentation_space if x[0] not in exclude_method]
......@@ -121,15 +120,15 @@ class evo_learner():
"""
section = self.auto_aug_agent.fun_num + self.auto_aug_agent.p_bins + self.auto_aug_agent.m_bins
y = self.auto_aug_agent.forward(x) # 1000 x 32
y = self.auto_aug_agent.forward(x)
y_1 = torch.softmax(y[:,:self.auto_aug_agent.fun_num], dim = 1) # 1000 x 14
y_1 = torch.softmax(y[:,:self.auto_aug_agent.fun_num], dim = 1)
y[:,:self.auto_aug_agent.fun_num] = y_1
y_2 = torch.softmax(y[:,section:section+self.auto_aug_agent.fun_num], dim = 1)
y[:,section:section+self.auto_aug_agent.fun_num] = y_2
concat = torch.cat((y_1, y_2), dim = 1)
cov_mat = torch.cov(concat.T)#[:self.auto_aug_agent.fun_num, self.auto_aug_agent.fun_num:]
cov_mat = torch.cov(concat.T)
cov_mat = cov_mat[:self.auto_aug_agent.fun_num, self.auto_aug_agent.fun_num:]
shape_store = cov_mat.shape
......@@ -197,9 +196,16 @@ class evo_learner():
Solution_idx -> Int
"""
self.num_generations = iterations
self.history_best = [0 for i in range(iterations)]
self.history_avg = [0 for i in range(iterations)]
self.gen_count = 0
self.best_model = 0
self.set_up_instance()
self.ga_instance.run()
self.history_avg = self.history_avg / self.num_solutions
solution, solution_fitness, solution_idx = self.ga_instance.best_solution()
if return_weights:
return torchga.model_weights_as_dict(model=self.auto_aug_agent, weights_vector=solution)
......@@ -207,14 +213,6 @@ class evo_learner():
return solution, solution_fitness, solution_idx
def new_model(self):
"""
Simple function to create a copy of the secondary model (used for classification)
"""
copy_model = copy.deepcopy(self.child_network)
return copy_model
def set_up_instance(self, train_dataset, test_dataset):
"""
Initialises GA instance, as well as fitness and on_generation functions
......@@ -249,9 +247,16 @@ class evo_learner():
full_policy = self.get_full_policy(test_x)
fit_val = ((self.test_autoaugment_policy(full_policy, train_dataset, test_dataset)[0])/
fit_val = ((self.test_autoaugment_policy(full_policy, train_dataset, test_dataset)[0]) /
+ self.test_autoaugment_policy(full_policy, train_dataset, test_dataset)[0]) / 2
if fit_val > self.history_best[self.gen_count]:
self.history_best[self.gen_count] = fit_val
self.best_model = model_weights_dict
self.history_avg[self.gen_count] += fit_val
return fit_val
def on_generation(ga_instance):
......@@ -267,6 +272,7 @@ class evo_learner():
None
"""
print("Generation = {generation}".format(generation=ga_instance.generations_completed))
self.gen_count += 1
print("Fitness = {fitness}".format(fitness=ga_instance.best_solution()[1]))
return
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment