diff --git a/MetaAugment/autoaugment_learners/gru_learner.py b/MetaAugment/autoaugment_learners/gru_learner.py index 5f8e19bef5308a40fa8e51fef2afb7b382f305d8..c06edec316eed6982272abc685d6e02735e92adf 100644 --- a/MetaAugment/autoaugment_learners/gru_learner.py +++ b/MetaAugment/autoaugment_learners/gru_learner.py @@ -61,7 +61,7 @@ class gru_learner(aa_learner): # GRU-specific attributes that aren't in all other aa_learners's alpha=0.2, cont_mb_size=4, - cont_lr=0.1): + cont_lr=0.03): """ Args: alpha (float, optional): Exploration parameter. It is multiplied to diff --git a/MetaAugment/main.py b/MetaAugment/main.py index 61dec50cc98134d6931d721c79b2d53fd11c750f..5796c94cb349808ad92f028e2c9318ac3fb27316 100644 --- a/MetaAugment/main.py +++ b/MetaAugment/main.py @@ -51,6 +51,7 @@ def train_child_network(child_network, device = torch.device('cpu') child_network = child_network.to(device=device) + total_val=torch.tensor([0.0]) best_acc=torch.tensor([0.0]) early_stop_cnt = 0