From 563fec380efa544bf4e4c7cc9e039479e243f595 Mon Sep 17 00:00:00 2001
From: Sun Jin Kim <sk2521@ic.ac.uk>
Date: Wed, 27 Apr 2022 16:05:29 +0100
Subject: [PATCH] write ucb_learner.get_mega_policy() and its test

---
 .../autoaugment_learners/ucb_learner.py       | 26 +++++++++++++++++++
 test/MetaAugment/test_ucb_learner.py          |  3 +++
 2 files changed, 29 insertions(+)

diff --git a/MetaAugment/autoaugment_learners/ucb_learner.py b/MetaAugment/autoaugment_learners/ucb_learner.py
index 6ed010fc..053492b8 100644
--- a/MetaAugment/autoaugment_learners/ucb_learner.py
+++ b/MetaAugment/autoaugment_learners/ucb_learner.py
@@ -147,7 +147,33 @@ class ucb_learner(randomsearch_learner):
             print(self.cnts)
 
             
+    def get_mega_policy(self, number_policies):
+        """
+        Produces a mega policy, based on the n best subpolicies (evo learner)/policies
+        (other learners)
+
+        
+        Args: 
+            number_policies -> int: Number of (sub)policies to be included in the mega
+            policy
+
+        Returns:
+            megapolicy -> [subpolicy, subpolicy, ...]
+        """
+
+        temp_avg_accs = [x if x is not None  else 0 for x in self.avg_accs]
+
+        temp_history = list(zip(self.policies, temp_avg_accs))
+
+        number_policies = max(number_policies, len(temp_history))
+
+        inter_pol = sorted(temp_history, key=lambda x: x[1], reverse = True)[:number_policies]
+
+        megapol = []
+        for pol in inter_pol:
+            megapol += pol[0]
 
+        return megapol
 
        
 
diff --git a/test/MetaAugment/test_ucb_learner.py b/test/MetaAugment/test_ucb_learner.py
index 3f37f3e5..fc2807aa 100644
--- a/test/MetaAugment/test_ucb_learner.py
+++ b/test/MetaAugment/test_ucb_learner.py
@@ -52,5 +52,8 @@ def test_ucb_learner():
         iterations=7
         )
 
+    print(learner.get_mega_policy(number_policies=50))
+    print(learner.get_mega_policy(number_policies=3))
+
 if __name__=="__main__":
     test_ucb_learner()
-- 
GitLab