Skip to content
Snippets Groups Projects
Commit 93995f95 authored by Max Ramsay King's avatar Max Ramsay King
Browse files

stuff gen again

parent 0ff5e107
No related branches found
No related tags found
No related merge requests found
......@@ -171,38 +171,36 @@ class Genetic_learner(AaLearner):
def generate_children(self):
parent_acc = sorted(self.history, key = lambda x: x[1], reverse=True)[:self.sp_num]
parent_acc = sorted(self.history, key = lambda x: x[1], reverse=True)
parents = [x[0] for x in parent_acc]
parents_weights = [x[1] for x in parent_acc]
new_pols = []
for _ in range(self.num_offspring):
parent1, parent2 = self.choose_parents(parents, parents_weights)
cross_over = random.randrange(1, len(parent2), 1)
cross_over = random.randrange(1, int(len(parent2)/2), 1)
cross_over2 = random.randrange(int(len(parent2)/2), int(len(parent2)), 1)
child = parent1[:cross_over]
child += parent2[cross_over:]
child += parent2[cross_over:int(len(parent2)/2)]
child += parent1[int(len(parent2)/2):int(len(parent2)/2)+cross_over2]
child += parent2[int(len(parent2)/2)+cross_over2:]
new_pols.append(child)
return new_pols
def learn(self, train_dataset, test_dataset, child_network_architecture, iterations = 10):
def learn(self, train_dataset, test_dataset, child_network_architecture, iterations = 100):
for idx in range(iterations):
print("iteration: ", idx)
if len(self.history) < self.sp_num:
print("ITERATION: ", idx)
if len(self.history) < self.num_offspring:
policy = [self.gen_random_subpol()]
else:
policy = self.bin_to_subpol(random.choice(self.generate_children()))
print("policyu: ", policy)
reward = self._test_autoaugment_policy(policy,
child_network_architecture,
train_dataset,
test_dataset)
print("reward: ", reward)
print("new len hsitory: ", len(self.history))
print("hsitory: ", self.history)
print("reward: ", reward)
......
......@@ -6,7 +6,7 @@ import torchvision
import MetaAugment.child_networks as cn
from MetaAugment.autoaugment_learners.AaLearner import AaLearner
from MetaAugment.autoaugment_learners.gen_learner import Genetic_learner
from MetaAugment.autoaugment_learners.GenLearner import Genetic_learner
import random
......@@ -29,14 +29,15 @@ agent = Genetic_learner(
learning_rate=0.05,
max_epochs=float('inf'),
early_stop_num=10,
num_offspring=10
)
agent.learn(train_dataset,
test_dataset,
child_network_architecture=child_network_architecture,
iterations=10)
iterations=100)
# with open('randomsearch_logs.pkl', 'wb') as file:
# pickle.dump(self.history, file)
print(agent.history)
\ No newline at end of file
# with open('genetic_logs.pkl', 'wb') as file:
# pickle.dump(agent.history, file)
print(sorted(agent.history, key = lambda x: x[1]))
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment