Skip to content
Snippets Groups Projects
Commit 0db87838 authored by Mia Wang's avatar Mia Wang
Browse files

Merge branch 'master' of gitlab.doc.ic.ac.uk:yw21218/metarl

parents b6beb405 5d89c89e
No related branches found
No related tags found
No related merge requests found
Pipeline #272853 failed
...@@ -76,7 +76,7 @@ def run_benchmark( ...@@ -76,7 +76,7 @@ def run_benchmark(
train_dataset=train_dataset, train_dataset=train_dataset,
test_dataset=test_dataset, test_dataset=test_dataset,
child_network_architecture=child_network_architecture, child_network_architecture=child_network_architecture,
iterations=1 iterations=5
) )
# save agent every iteration # save agent every iteration
with open(save_file, 'wb+') as f: with open(save_file, 'wb+') as f:
......
...@@ -83,6 +83,13 @@ class Genetic_learner(AaLearner): ...@@ -83,6 +83,13 @@ class Genetic_learner(AaLearner):
def gen_random_subpol(self): def gen_random_subpol(self):
"""
Generates a random subpolicy using the reduced augmentation_space
Returns
--------
subpolicy -> ((transformation, probability, magnitude), (trans., prob., mag.))
"""
choose_items = [x[0] for x in self.augmentation_space] choose_items = [x[0] for x in self.augmentation_space]
trans1 = str(random.choice(choose_items)) trans1 = str(random.choice(choose_items))
trans2 = str(random.choice(choose_items)) trans2 = str(random.choice(choose_items))
...@@ -104,30 +111,50 @@ class Genetic_learner(AaLearner): ...@@ -104,30 +111,50 @@ class Genetic_learner(AaLearner):
def gen_random_policy(self): def gen_random_policy(self):
"""
Generates a random policy, consisting of sp_num subpolicies
Returns
------------
policy -> [subpolicy, subpolicy, ...]
"""
pol = [] pol = []
for _ in range(self.sp_num): for _ in range(self.sp_num):
pol.append(self.gen_random_subpol()) pol.append(self.gen_random_subpol())
return pol return pol
def bin_to_subpol(self, subpol): def bin_to_subpol(self, subpol_bin):
"""
Converts a binary string representation of a subpolicy to a subpolicy
Parameters
------------
subpol_bin -> str
Binary representation of a subpolicy
Returns
-----------
policy -> [(subpolicy)]
"""
pol = [] pol = []
for idx in range(2): for idx in range(2):
if subpol[idx*12:(idx*12)+4] in self.bin_to_aug: if subpol_bin[idx*12:(idx*12)+4] in self.bin_to_aug:
trans = self.bin_to_aug[subpol[idx*12:(idx*12)+4]] trans = self.bin_to_aug[subpol_bin[idx*12:(idx*12)+4]]
else: else:
trans = random.choice(self.just_augs) trans = random.choice(self.just_augs)
mag_is_none = not self.aug_space_dict[trans] mag_is_none = not self.aug_space_dict[trans]
if subpol[(idx*12)+4: (idx*12)+8] in self.bin_to_prob: if subpol_bin[(idx*12)+4: (idx*12)+8] in self.bin_to_prob:
prob = float(self.bin_to_prob[subpol[(idx*12)+4: (idx*12)+8]]) prob = float(self.bin_to_prob[subpol_bin[(idx*12)+4: (idx*12)+8]])
else: else:
prob = float(random.randrange(0, 11, 1) / 10) prob = float(random.randrange(0, 11, 1) / 10)
if subpol[(idx*12)+8:(idx*12)+12] in self.bin_to_mag: if subpol_bin[(idx*12)+8:(idx*12)+12] in self.bin_to_mag:
mag = int(self.bin_to_mag[subpol[(idx*12)+8:(idx*12)+12]]) mag = int(self.bin_to_mag[subpol_bin[(idx*12)+8:(idx*12)+12]])
else: else:
mag = int(random.randrange(0, 10, 1)) mag = int(random.randrange(0, 10, 1))
...@@ -139,28 +166,54 @@ class Genetic_learner(AaLearner): ...@@ -139,28 +166,54 @@ class Genetic_learner(AaLearner):
def subpol_to_bin(self, subpol): def subpol_to_bin(self, subpol):
pol = '' """
Converts a subpolicy to its binary representation
Parameters
------------
subpol -> ((transforamtion, probability, magnitude), (trans., prob., mag.))
Returns
------------
bin_pol -> str
Binary representation of the subpolicy
"""
bin_pol = ''
trans1, prob1, mag1 = subpol[0] trans1, prob1, mag1 = subpol[0]
trans2, prob2, mag2 = subpol[1] trans2, prob2, mag2 = subpol[1]
pol += self.aug_to_bin[trans1] + self.prob_to_bin[str(prob1)] bin_pol += self.aug_to_bin[trans1] + self.prob_to_bin[str(prob1)]
if mag1 == None: if mag1 == None:
pol += '1111' bin_pol += '1111'
else: else:
pol += self.mag_to_bin[str(mag1)] bin_pol += self.mag_to_bin[str(mag1)]
pol += self.aug_to_bin[trans2] + self.prob_to_bin[str(prob2)] bin_pol += self.aug_to_bin[trans2] + self.prob_to_bin[str(prob2)]
if mag2 == None: if mag2 == None:
pol += '1111' bin_pol += '1111'
else: else:
pol += self.mag_to_bin[str(mag2)] bin_pol += self.mag_to_bin[str(mag2)]
return pol return bin_pol
def choose_parents(self, parents, parents_weights): def choose_parents(self, parents, parents_weights):
"""
Chooses parents from which the next policy will be generated from
Parameters
------------
parents -> [policy, policy, ...]
parents_weights -> [float, float, ...]
Returns
------------
(parent1, parent2) -> (policy, policy)
"""
parent1 = random.choices(parents, parents_weights, k=1)[0][0] parent1 = random.choices(parents, parents_weights, k=1)[0][0]
parent2 = random.choices(parents, parents_weights, k=1)[0][0] parent2 = random.choices(parents, parents_weights, k=1)[0][0]
while parent2 == parent1: while parent2 == parent1:
...@@ -171,6 +224,13 @@ class Genetic_learner(AaLearner): ...@@ -171,6 +224,13 @@ class Genetic_learner(AaLearner):
def generate_children(self): def generate_children(self):
"""
Generates children via the random crossover method
Returns
------------
new_pols -> [child_policy, child_policy, ...]
"""
parent_acc = sorted(self.history, key = lambda x: x[1], reverse=True) parent_acc = sorted(self.history, key = lambda x: x[1], reverse=True)
parents = [x[0] for x in parent_acc] parents = [x[0] for x in parent_acc]
parents_weights = [x[1] for x in parent_acc] parents_weights = [x[1] for x in parent_acc]
...@@ -188,6 +248,20 @@ class Genetic_learner(AaLearner): ...@@ -188,6 +248,20 @@ class Genetic_learner(AaLearner):
def learn(self, train_dataset, test_dataset, child_network_architecture, iterations = 100): def learn(self, train_dataset, test_dataset, child_network_architecture, iterations = 100):
"""
Generates policies through a genetic algorithm.
Parameters
------------
train_dataset -> torchvision.dataset
test_dataset -> torchvision.dataset
child_network_architecture ->
iterations -> int
number of iterations to run the instance for
"""
for idx in range(iterations): for idx in range(iterations):
print("ITERATION: ", idx) print("ITERATION: ", idx)
......
File added
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment