From 563fec380efa544bf4e4c7cc9e039479e243f595 Mon Sep 17 00:00:00 2001 From: Sun Jin Kim <sk2521@ic.ac.uk> Date: Wed, 27 Apr 2022 16:05:29 +0100 Subject: [PATCH] write ucb_learner.get_mega_policy() and its test --- .../autoaugment_learners/ucb_learner.py | 26 +++++++++++++++++++ test/MetaAugment/test_ucb_learner.py | 3 +++ 2 files changed, 29 insertions(+) diff --git a/MetaAugment/autoaugment_learners/ucb_learner.py b/MetaAugment/autoaugment_learners/ucb_learner.py index 6ed010fc..053492b8 100644 --- a/MetaAugment/autoaugment_learners/ucb_learner.py +++ b/MetaAugment/autoaugment_learners/ucb_learner.py @@ -147,7 +147,33 @@ class ucb_learner(randomsearch_learner): print(self.cnts) + def get_mega_policy(self, number_policies): + """ + Produces a mega policy, based on the n best subpolicies (evo learner)/policies + (other learners) + + + Args: + number_policies -> int: Number of (sub)policies to be included in the mega + policy + + Returns: + megapolicy -> [subpolicy, subpolicy, ...] + """ + + temp_avg_accs = [x if x is not None else 0 for x in self.avg_accs] + + temp_history = list(zip(self.policies, temp_avg_accs)) + + number_policies = max(number_policies, len(temp_history)) + + inter_pol = sorted(temp_history, key=lambda x: x[1], reverse = True)[:number_policies] + + megapol = [] + for pol in inter_pol: + megapol += pol[0] + return megapol diff --git a/test/MetaAugment/test_ucb_learner.py b/test/MetaAugment/test_ucb_learner.py index 3f37f3e5..fc2807aa 100644 --- a/test/MetaAugment/test_ucb_learner.py +++ b/test/MetaAugment/test_ucb_learner.py @@ -52,5 +52,8 @@ def test_ucb_learner(): iterations=7 ) + print(learner.get_mega_policy(number_policies=50)) + print(learner.get_mega_policy(number_policies=3)) + if __name__=="__main__": test_ucb_learner() -- GitLab