Skip to content
Snippets Groups Projects
Commit 71673f0c authored by Max Ramsay King's avatar Max Ramsay King
Browse files

doc strings

parent 1ca017c3
No related branches found
No related tags found
No related merge requests found
...@@ -76,7 +76,7 @@ def run_benchmark( ...@@ -76,7 +76,7 @@ def run_benchmark(
train_dataset=train_dataset, train_dataset=train_dataset,
test_dataset=test_dataset, test_dataset=test_dataset,
child_network_architecture=child_network_architecture, child_network_architecture=child_network_architecture,
iterations=1 iterations=5
) )
# save agent every iteration # save agent every iteration
with open(save_file, 'wb+') as f: with open(save_file, 'wb+') as f:
......
...@@ -83,6 +83,13 @@ class Genetic_learner(AaLearner): ...@@ -83,6 +83,13 @@ class Genetic_learner(AaLearner):
def gen_random_subpol(self): def gen_random_subpol(self):
"""
Generates a random subpolicy using the reduced augmentation_space
Returns
--------
subpolicy -> ((transformation, probability, magnitude), (trans., prob., mag.))
"""
choose_items = [x[0] for x in self.augmentation_space] choose_items = [x[0] for x in self.augmentation_space]
trans1 = str(random.choice(choose_items)) trans1 = str(random.choice(choose_items))
trans2 = str(random.choice(choose_items)) trans2 = str(random.choice(choose_items))
...@@ -104,30 +111,50 @@ class Genetic_learner(AaLearner): ...@@ -104,30 +111,50 @@ class Genetic_learner(AaLearner):
def gen_random_policy(self): def gen_random_policy(self):
"""
Generates a random policy, consisting of sp_num subpolicies
Returns
------------
policy -> [subpolicy, subpolicy, ...]
"""
pol = [] pol = []
for _ in range(self.sp_num): for _ in range(self.sp_num):
pol.append(self.gen_random_subpol()) pol.append(self.gen_random_subpol())
return pol return pol
def bin_to_subpol(self, subpol): def bin_to_subpol(self, subpol_bin):
"""
Converts a binary string representation of a subpolicy to a subpolicy
Parameters
------------
subpol_bin -> str
Binary representation of a subpolicy
Returns
-----------
policy -> [(subpolicy)]
"""
pol = [] pol = []
for idx in range(2): for idx in range(2):
if subpol[idx*12:(idx*12)+4] in self.bin_to_aug: if subpol_bin[idx*12:(idx*12)+4] in self.bin_to_aug:
trans = self.bin_to_aug[subpol[idx*12:(idx*12)+4]] trans = self.bin_to_aug[subpol_bin[idx*12:(idx*12)+4]]
else: else:
trans = random.choice(self.just_augs) trans = random.choice(self.just_augs)
mag_is_none = not self.aug_space_dict[trans] mag_is_none = not self.aug_space_dict[trans]
if subpol[(idx*12)+4: (idx*12)+8] in self.bin_to_prob: if subpol_bin[(idx*12)+4: (idx*12)+8] in self.bin_to_prob:
prob = float(self.bin_to_prob[subpol[(idx*12)+4: (idx*12)+8]]) prob = float(self.bin_to_prob[subpol_bin[(idx*12)+4: (idx*12)+8]])
else: else:
prob = float(random.randrange(0, 11, 1) / 10) prob = float(random.randrange(0, 11, 1) / 10)
if subpol[(idx*12)+8:(idx*12)+12] in self.bin_to_mag: if subpol_bin[(idx*12)+8:(idx*12)+12] in self.bin_to_mag:
mag = int(self.bin_to_mag[subpol[(idx*12)+8:(idx*12)+12]]) mag = int(self.bin_to_mag[subpol_bin[(idx*12)+8:(idx*12)+12]])
else: else:
mag = int(random.randrange(0, 10, 1)) mag = int(random.randrange(0, 10, 1))
...@@ -139,28 +166,54 @@ class Genetic_learner(AaLearner): ...@@ -139,28 +166,54 @@ class Genetic_learner(AaLearner):
def subpol_to_bin(self, subpol): def subpol_to_bin(self, subpol):
pol = '' """
Converts a subpolicy to its binary representation
Parameters
------------
subpol -> ((transforamtion, probability, magnitude), (trans., prob., mag.))
Returns
------------
bin_pol -> str
Binary representation of the subpolicy
"""
bin_pol = ''
trans1, prob1, mag1 = subpol[0] trans1, prob1, mag1 = subpol[0]
trans2, prob2, mag2 = subpol[1] trans2, prob2, mag2 = subpol[1]
pol += self.aug_to_bin[trans1] + self.prob_to_bin[str(prob1)] bin_pol += self.aug_to_bin[trans1] + self.prob_to_bin[str(prob1)]
if mag1 == None: if mag1 == None:
pol += '1111' bin_pol += '1111'
else: else:
pol += self.mag_to_bin[str(mag1)] bin_pol += self.mag_to_bin[str(mag1)]
pol += self.aug_to_bin[trans2] + self.prob_to_bin[str(prob2)] bin_pol += self.aug_to_bin[trans2] + self.prob_to_bin[str(prob2)]
if mag2 == None: if mag2 == None:
pol += '1111' bin_pol += '1111'
else: else:
pol += self.mag_to_bin[str(mag2)] bin_pol += self.mag_to_bin[str(mag2)]
return pol return bin_pol
def choose_parents(self, parents, parents_weights): def choose_parents(self, parents, parents_weights):
"""
Chooses parents from which the next policy will be generated from
Parameters
------------
parents -> [policy, policy, ...]
parents_weights -> [float, float, ...]
Returns
------------
(parent1, parent2) -> (policy, policy)
"""
parent1 = random.choices(parents, parents_weights, k=1)[0][0] parent1 = random.choices(parents, parents_weights, k=1)[0][0]
parent2 = random.choices(parents, parents_weights, k=1)[0][0] parent2 = random.choices(parents, parents_weights, k=1)[0][0]
while parent2 == parent1: while parent2 == parent1:
...@@ -171,6 +224,13 @@ class Genetic_learner(AaLearner): ...@@ -171,6 +224,13 @@ class Genetic_learner(AaLearner):
def generate_children(self): def generate_children(self):
"""
Generates children via the random crossover method
Returns
------------
new_pols -> [child_policy, child_policy, ...]
"""
parent_acc = sorted(self.history, key = lambda x: x[1], reverse=True) parent_acc = sorted(self.history, key = lambda x: x[1], reverse=True)
parents = [x[0] for x in parent_acc] parents = [x[0] for x in parent_acc]
parents_weights = [x[1] for x in parent_acc] parents_weights = [x[1] for x in parent_acc]
...@@ -188,6 +248,20 @@ class Genetic_learner(AaLearner): ...@@ -188,6 +248,20 @@ class Genetic_learner(AaLearner):
def learn(self, train_dataset, test_dataset, child_network_architecture, iterations = 100): def learn(self, train_dataset, test_dataset, child_network_architecture, iterations = 100):
"""
Generates policies through a genetic algorithm.
Parameters
------------
train_dataset -> torchvision.dataset
test_dataset -> torchvision.dataset
child_network_architecture ->
iterations -> int
number of iterations to run the instance for
"""
for idx in range(iterations): for idx in range(iterations):
print("ITERATION: ", idx) print("ITERATION: ", idx)
......
File added
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment