diff --git a/MetaAugment/autoaugment_learners/gru_learner.py b/MetaAugment/autoaugment_learners/gru_learner.py index 377064a2c107c573bf8ae9a89630b22d8ee51d6c..e23178e917b47a8c3f5e477603b3c9eb4673808e 100644 --- a/MetaAugment/autoaugment_learners/gru_learner.py +++ b/MetaAugment/autoaugment_learners/gru_learner.py @@ -156,10 +156,11 @@ class gru_learner(aa_learner): cont_optim.step() # save the history every 1 epochs as a pickle - if _%1==1: - with open('gru_logs.pkl', 'wb') as file: - pickle.dump(self.history, file) - + with open('gru_logs.pkl', 'wb') as file: + pickle.dump(self.history, file) + with open('gru_learner.pkl', 'wb') as file: + pickle.dump(self, file) + @@ -182,4 +183,4 @@ if __name__=='__main__': learner = gru_learner(discrete_p_m=False) newpol = learner.generate_new_policy() learner.learn(train_dataset, test_dataset, child_network, toy_flag=True) - pprint(learner.history) \ No newline at end of file + pprint(learner.history)