From 8c16a607402bc45d460f2c9077046e33053b5696 Mon Sep 17 00:00:00 2001
From: Sun Jin Kim <sk2521@ic.ac.uk>
Date: Fri, 22 Apr 2022 19:07:10 +0100
Subject: [PATCH] add benchmark code

---
 benchmark/scripts/04_22_gru_rs.py | 137 ++++++++++++++++++++++++++++++
 1 file changed, 137 insertions(+)
 create mode 100644 benchmark/scripts/04_22_gru_rs.py

diff --git a/benchmark/scripts/04_22_gru_rs.py b/benchmark/scripts/04_22_gru_rs.py
new file mode 100644
index 00000000..c3606347
--- /dev/null
+++ b/benchmark/scripts/04_22_gru_rs.py
@@ -0,0 +1,137 @@
+import torchvision.datasets as datasets
+import torchvision
+import torch
+
+import MetaAugment.child_networks as cn
+import MetaAugment.autoaugment_learners as aal
+
+from pathlib import Path
+
+"""
+testing gru_learner and randomsearch_learner on
+
+  fashionmnist with simple net
+
+ and 
+
+  cifar10 with lenet
+
+"""
+if torch.cuda.is_available():
+    device = torch.device('cuda')
+else:
+    device = torch.device('cpu')
+
+
+def run_benchmark(
+    save_file,
+    total_iter,
+    train_dataset,
+    test_dataset,
+    child_network_architecture,
+    agent_arch,
+    config,
+    ):
+    try:
+        # try to load agent
+        with open(save_file, 'rb') as f:
+            agent = torch.load(f, map_location=device)
+    except FileNotFoundError:
+        # if agent hasn't been saved yet, initialize the agent
+        agent = agent_arch(**config)
+
+
+    # if history is not length total_iter yet(if total_iter
+    # different policies haven't been tested yet), keep running
+    while len(agent.history)<total_iter:
+        print(f'{len(agent.history)} / {total_iter}')
+        # run 1 iteration (test one new policy and update the GRU)
+        agent.learn(
+                    train_dataset=train_dataset,
+                    test_dataset=test_dataset,
+                    child_network_architecture=child_network_architecture,
+                    iterations=1
+                    )
+        # save agent every iteration
+        with open(save_file, 'wb+') as f:
+            torch.save(agent, f)
+
+    print('run_benchmark closing')
+
+
+# aa_learner config
+config = {
+        'sp_num' : 3,
+        'learning_rate' : 1e-1,
+        'toy_flag' : False,
+#         'toy_flag' : True,
+#         'toy_size' : 0.001,
+        'batch_size' : 32,
+        'max_epochs' : 100,
+        'early_stop_num' : 10,
+        }
+total_iter=150
+
+
+# FashionMNIST with SimpleNet
+train_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/train',
+                            train=True, download=True, transform=None)
+test_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/test', 
+                        train=False, download=True,
+                        transform=torchvision.transforms.ToTensor())
+child_network_architecture = cn.SimpleNet
+
+
+# gru
+run_benchmark(
+    save_file='./benchmark/pickles/04_22_fm_sn_gru.pkl',
+    total_iter=total_iter,
+    train_dataset=train_dataset,
+    test_dataset=test_dataset,
+    child_network_architecture=child_network_architecture,
+    agent_arch=aal.gru_learner,
+    config=config,
+    )
+
+# rs
+run_benchmark(
+    save_file='./benchmark/pickles/04_22_fm_sn_rs.pkl',
+    total_iter=total_iter,
+    train_dataset=train_dataset,
+    test_dataset=test_dataset,
+    child_network_architecture=child_network_architecture,
+    agent_arch=aal.randomsearch_learner,
+    config=config,
+    )
+
+
+# CIFAR10 with LeNet
+train_dataset = datasets.CIFAR10(root='./datasets/cifar10/train',
+                        train=True, download=True, transform=None)
+test_dataset = datasets.CIFAR10(root='./datasets/cifar10/train',
+                        train=False, download=True, 
+                        transform=torchvision.transforms.ToTensor())
+child_network_architecture = cn.SimpleNet
+
+
+# gru
+run_benchmark(
+    save_file='./benchmark/pickles/04_22_cf_ln_gru',
+    total_iter=total_iter,
+    train_dataset=train_dataset,
+    test_dataset=test_dataset,
+    child_network_architecture=child_network_architecture,
+    agent_arch=aal.gru_learner,
+    config=config,
+    )
+
+# rs
+run_benchmark(
+    save_file='./benchmark/pickles/04_22_cf_ln_rs',
+    total_iter=total_iter,
+    train_dataset=train_dataset,
+    test_dataset=test_dataset,
+    child_network_architecture=child_network_architecture,
+    agent_arch=aal.randomsearch_learner,
+    config=config,
+    )
-- 
GitLab