Skip to content
Snippets Groups Projects
Commit 563fec38 authored by Sun Jin Kim's avatar Sun Jin Kim
Browse files

write ucb_learner.get_mega_policy() and its test

parent 53d0d751
No related branches found
No related tags found
No related merge requests found
Pipeline #272327 passed
......@@ -147,7 +147,33 @@ class ucb_learner(randomsearch_learner):
print(self.cnts)
def get_mega_policy(self, number_policies):
"""
Produces a mega policy, based on the n best subpolicies (evo learner)/policies
(other learners)
Args:
number_policies -> int: Number of (sub)policies to be included in the mega
policy
Returns:
megapolicy -> [subpolicy, subpolicy, ...]
"""
temp_avg_accs = [x if x is not None else 0 for x in self.avg_accs]
temp_history = list(zip(self.policies, temp_avg_accs))
number_policies = max(number_policies, len(temp_history))
inter_pol = sorted(temp_history, key=lambda x: x[1], reverse = True)[:number_policies]
megapol = []
for pol in inter_pol:
megapol += pol[0]
return megapol
......
......@@ -52,5 +52,8 @@ def test_ucb_learner():
iterations=7
)
print(learner.get_mega_policy(number_policies=50))
print(learner.get_mega_policy(number_policies=3))
if __name__=="__main__":
test_ucb_learner()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment