diff --git a/MetaAugment/autoaugment_learners/aa_learner.py b/MetaAugment/autoaugment_learners/aa_learner.py
index 950b7e11b8223b22208414594176b7a2911501dd..f67e99c78c6e9c963aadae9ba4b5e4ccf8961ea6 100644
--- a/MetaAugment/autoaugment_learners/aa_learner.py
+++ b/MetaAugment/autoaugment_learners/aa_learner.py
@@ -406,7 +406,7 @@ class aa_learner:
         return accuracy
     
 
-    def get_mega_policy(self, number_policies):
+    def get_mega_policy(self, number_policies=5):
         """
         Produces a mega policy, based on the n best subpolicies (evo learner)/policies
         (other learners)
@@ -419,6 +419,9 @@ class aa_learner:
         Returns:
             megapolicy -> [subpolicy, subpolicy, ...]
         """
+
+        number_policies = max(number_policies, len(self.history))
+
         inter_pol = sorted(self.history, key=lambda x: x[1], reverse = True)[:number_policies]
 
         megapol = []
diff --git a/MetaAugment/autoaugment_learners/ucb_learner.py b/MetaAugment/autoaugment_learners/ucb_learner.py
index 053492b87626b13efe723e17b7cf63e3655c8967..5d9a32e7182851de1b03b670569894c722f4178a 100644
--- a/MetaAugment/autoaugment_learners/ucb_learner.py
+++ b/MetaAugment/autoaugment_learners/ucb_learner.py
@@ -147,7 +147,7 @@ class ucb_learner(randomsearch_learner):
             print(self.cnts)
 
             
-    def get_mega_policy(self, number_policies):
+    def get_mega_policy(self, number_policies=5):
         """
         Produces a mega policy, based on the n best subpolicies (evo learner)/policies
         (other learners)
diff --git a/test/MetaAugment/test_aa_learner.py b/test/MetaAugment/test_aa_learner.py
index 96322b6d4ddf805a3182d8b3d9eedcc349245e95..cbb8d952b4f68899bf58557731f7bb7f023bfdd1 100644
--- a/test/MetaAugment/test_aa_learner.py
+++ b/test/MetaAugment/test_aa_learner.py
@@ -82,7 +82,7 @@ def test__test_autoaugment_policy():
                 p_bins=11,
                 m_bins=10,
                 discrete_p_m=True,
-                toy_size=0.004,
+                toy_size=0.002,
                 max_epochs=20,
                 early_stop_num=10
                 )
@@ -157,7 +157,15 @@ def test_exclude_method():
 
 def test_get_mega_policy():
 
-    agent = aal.randomsearch_learner()
+    agent = aal.randomsearch_learner(
+                sp_num=5,
+                p_bins=11,
+                m_bins=10,
+                discrete_p_m=True,
+                toy_size=0.002,
+                max_epochs=20,
+                early_stop_num=10
+                )
 
     child_network_architecture = cn.SimpleNet
     train_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/train',
@@ -167,7 +175,9 @@ def test_get_mega_policy():
                             transform=torchvision.transforms.ToTensor())
 
     agent.learn(train_dataset, test_dataset, child_network_architecture, 10)
-    mega_pol = agent.get_mega_policy()
+    mega_pol = agent.get_mega_policy(number_policies=30)
+    mega_pol = agent.get_mega_policy(number_policies=3)
+    mega_pol = agent.get_mega_policy(number_policies=1)
     print("megapol: ", mega_pol)