Skip to content
Snippets Groups Projects
Commit e1828d43 authored by Sun Jin Kim's avatar Sun Jin Kim
Browse files

Merge branch 'master' of gitlab.doc.ic.ac.uk:yw21218/metarl

parents 7bffb088 c361532f
No related branches found
No related tags found
No related merge requests found
Pipeline #272640 passed
......@@ -171,38 +171,36 @@ class Genetic_learner(AaLearner):
def generate_children(self):
parent_acc = sorted(self.history, key = lambda x: x[1], reverse=True)[:self.sp_num]
parent_acc = sorted(self.history, key = lambda x: x[1], reverse=True)
parents = [x[0] for x in parent_acc]
parents_weights = [x[1] for x in parent_acc]
new_pols = []
for _ in range(self.num_offspring):
parent1, parent2 = self.choose_parents(parents, parents_weights)
cross_over = random.randrange(1, len(parent2), 1)
cross_over = random.randrange(1, int(len(parent2)/2), 1)
cross_over2 = random.randrange(int(len(parent2)/2), int(len(parent2)), 1)
child = parent1[:cross_over]
child += parent2[cross_over:]
child += parent2[cross_over:int(len(parent2)/2)]
child += parent1[int(len(parent2)/2):int(len(parent2)/2)+cross_over2]
child += parent2[int(len(parent2)/2)+cross_over2:]
new_pols.append(child)
return new_pols
def learn(self, train_dataset, test_dataset, child_network_architecture, iterations = 10):
def learn(self, train_dataset, test_dataset, child_network_architecture, iterations = 100):
for idx in range(iterations):
print("iteration: ", idx)
if len(self.history) < self.sp_num:
print("ITERATION: ", idx)
if len(self.history) < self.num_offspring:
policy = [self.gen_random_subpol()]
else:
policy = self.bin_to_subpol(random.choice(self.generate_children()))
print("policyu: ", policy)
reward = self._test_autoaugment_policy(policy,
child_network_architecture,
train_dataset,
test_dataset)
print("reward: ", reward)
print("new len hsitory: ", len(self.history))
print("hsitory: ", self.history)
print("reward: ", reward)
......
......@@ -6,7 +6,7 @@ import torchvision
import autoaug.child_networks as cn
from autoaug.autoaugment_learners.AaLearner import AaLearner
from autoaug.autoaugment_learners.gen_learner import Genetic_learner
from autoaug.autoaugment_learners.GenLearner import Genetic_learner
import random
......@@ -29,14 +29,15 @@ agent = Genetic_learner(
learning_rate=0.05,
max_epochs=float('inf'),
early_stop_num=10,
num_offspring=10
)
agent.learn(train_dataset,
test_dataset,
child_network_architecture=child_network_architecture,
iterations=10)
iterations=100)
# with open('randomsearch_logs.pkl', 'wb') as file:
# pickle.dump(self.history, file)
print(agent.history)
\ No newline at end of file
# with open('genetic_logs.pkl', 'wb') as file:
# pickle.dump(agent.history, file)
print(sorted(agent.history, key = lambda x: x[1]))
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment