From 63a7084cb93c8a25a82bca967a24d39ae1cc5f4e Mon Sep 17 00:00:00 2001
From: Sun Jin Kim <sk2521@ic.ac.uk>
Date: Mon, 11 Apr 2022 18:04:43 +0900
Subject: [PATCH] John: Add EasyNet

---
 MetaAugment/Baseline_JC.ipynb | 270 +++++-------------
 MetaAugment/UCB1_JC.ipynb     | 524 ++++++++++++++++------------------
 2 files changed, 319 insertions(+), 475 deletions(-)

diff --git a/MetaAugment/Baseline_JC.ipynb b/MetaAugment/Baseline_JC.ipynb
index d979dc8a..5e0523f0 100644
--- a/MetaAugment/Baseline_JC.ipynb
+++ b/MetaAugment/Baseline_JC.ipynb
@@ -62,7 +62,33 @@
     },
     {
       "cell_type": "code",
+      "source": [
+        "\"\"\"Define internal NN module that trains on the dataset\"\"\"\n",
+        "class EasyNet(nn.Module):\n",
+        "    def __init__(self):\n",
+        "        super().__init__()\n",
+        "        self.fc1 = nn.Linear(784, 2048)\n",
+        "        self.relu1 = nn.ReLU()\n",
+        "        self.fc2 = nn.Linear(2048, 10)\n",
+        "        self.relu2 = nn.ReLU()\n",
+        "\n",
+        "    def forward(self, x):\n",
+        "        y = x.view(x.shape[0], -1)\n",
+        "        y = self.fc1(y)\n",
+        "        y = self.relu1(y)\n",
+        "        y = self.fc2(y)\n",
+        "        y = self.relu2(y)\n",
+        "        return y"
+      ],
+      "metadata": {
+        "id": "ukf2-C94UWzs"
+      },
       "execution_count": 3,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "execution_count": 4,
       "metadata": {
         "id": "xujQtvVWBgMH"
       },
@@ -93,13 +119,13 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 4,
+      "execution_count": 5,
       "metadata": {
         "id": "vu_4I4qkbx73"
       },
       "outputs": [],
       "source": [
-        "def run_baseline(batch_size=32, toy_size=0.02, max_epochs=100, early_stop_num=10, early_stop_flag=True, average_validation=[15,25]):\n",
+        "def run_baseline(batch_size=32, toy_size=0.02, max_epochs=100, early_stop_num=10, early_stop_flag=True, average_validation=[15,25], IsLeNet=True):\n",
         "\n",
         "    # create transformations using above info\n",
         "    transform = torchvision.transforms.Compose([\n",
@@ -113,7 +139,10 @@
         "    train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, toy_size)\n",
         "\n",
         "    # create model\n",
-        "    model = LeNet()\n",
+        "    if IsLeNet:\n",
+        "        model = LeNet()\n",
+        "    else:\n",
+        "        model = EasyNet()\n",
         "    sgd = optim.SGD(model.parameters(), lr=1e-1)\n",
         "    cost = nn.CrossEntropyLoss()\n",
         "\n",
@@ -171,196 +200,20 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 5,
-      "metadata": {
-        "colab": {
-          "base_uri": "https://localhost:8080/"
-        },
-        "id": "KVhYheLfBP33",
-        "outputId": "8009d87f-7e39-40e3-c6ef-8f3a12f9433f"
-      },
-      "outputs": [
-        {
-          "name": "stdout",
-          "output_type": "stream",
-          "text": [
-            "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n",
-            "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MetaAugment/train/MNIST/raw/train-images-idx3-ubyte.gz\n"
-          ]
-        },
-        {
-          "name": "stderr",
-          "output_type": "stream",
-          "text": [
-            "9913344it [00:04, 2462502.04it/s]                             \n"
-          ]
-        },
-        {
-          "name": "stdout",
-          "output_type": "stream",
-          "text": [
-            "Extracting ./MetaAugment/train/MNIST/raw/train-images-idx3-ubyte.gz to ./MetaAugment/train/MNIST/raw\n",
-            "\n",
-            "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n",
-            "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MetaAugment/train/MNIST/raw/train-labels-idx1-ubyte.gz\n"
-          ]
-        },
-        {
-          "name": "stderr",
-          "output_type": "stream",
-          "text": [
-            "29696it [00:00, 3785722.37it/s]          \n"
-          ]
-        },
-        {
-          "name": "stdout",
-          "output_type": "stream",
-          "text": [
-            "Extracting ./MetaAugment/train/MNIST/raw/train-labels-idx1-ubyte.gz to ./MetaAugment/train/MNIST/raw\n",
-            "\n",
-            "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n",
-            "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MetaAugment/train/MNIST/raw/t10k-images-idx3-ubyte.gz\n"
-          ]
-        },
-        {
-          "name": "stderr",
-          "output_type": "stream",
-          "text": [
-            "1649664it [00:00, 3348476.95it/s]                             \n"
-          ]
-        },
-        {
-          "name": "stdout",
-          "output_type": "stream",
-          "text": [
-            "Extracting ./MetaAugment/train/MNIST/raw/t10k-images-idx3-ubyte.gz to ./MetaAugment/train/MNIST/raw\n",
-            "\n",
-            "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n",
-            "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MetaAugment/train/MNIST/raw/t10k-labels-idx1-ubyte.gz\n"
-          ]
-        },
-        {
-          "name": "stderr",
-          "output_type": "stream",
-          "text": [
-            "5120it [00:00, 2935726.11it/s]          \n"
-          ]
-        },
-        {
-          "name": "stdout",
-          "output_type": "stream",
-          "text": [
-            "Extracting ./MetaAugment/train/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MetaAugment/train/MNIST/raw\n",
-            "\n",
-            "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n",
-            "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MetaAugment/test/MNIST/raw/train-images-idx3-ubyte.gz\n"
-          ]
-        },
-        {
-          "name": "stderr",
-          "output_type": "stream",
-          "text": [
-            "9913344it [00:04, 2338660.11it/s]                             \n"
-          ]
-        },
-        {
-          "name": "stdout",
-          "output_type": "stream",
-          "text": [
-            "Extracting ./MetaAugment/test/MNIST/raw/train-images-idx3-ubyte.gz to ./MetaAugment/test/MNIST/raw\n",
-            "\n",
-            "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n",
-            "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MetaAugment/test/MNIST/raw/train-labels-idx1-ubyte.gz\n"
-          ]
-        },
-        {
-          "name": "stderr",
-          "output_type": "stream",
-          "text": [
-            "29696it [00:00, 33554432.00it/s]         "
-          ]
-        },
-        {
-          "name": "stdout",
-          "output_type": "stream",
-          "text": [
-            "Extracting ./MetaAugment/test/MNIST/raw/train-labels-idx1-ubyte.gz to ./MetaAugment/test/MNIST/raw\n",
-            "\n",
-            "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n",
-            "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MetaAugment/test/MNIST/raw/t10k-images-idx3-ubyte.gz\n"
-          ]
-        },
-        {
-          "name": "stderr",
-          "output_type": "stream",
-          "text": [
-            "\n",
-            "1649664it [00:00, 2786152.46it/s]                             \n"
-          ]
-        },
-        {
-          "name": "stdout",
-          "output_type": "stream",
-          "text": [
-            "Extracting ./MetaAugment/test/MNIST/raw/t10k-images-idx3-ubyte.gz to ./MetaAugment/test/MNIST/raw\n",
-            "\n",
-            "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n",
-            "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MetaAugment/test/MNIST/raw/t10k-labels-idx1-ubyte.gz\n"
-          ]
-        },
-        {
-          "name": "stderr",
-          "output_type": "stream",
-          "text": [
-            "5120it [00:00, 4789214.20it/s]          \n"
-          ]
-        },
-        {
-          "name": "stdout",
-          "output_type": "stream",
-          "text": [
-            "Extracting ./MetaAugment/test/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MetaAugment/test/MNIST/raw\n",
-            "\n",
-            "0\tBest accuracy: 18.00%\n",
-            "10\tBest accuracy: 75.50%\n",
-            "20\tBest accuracy: 78.00%\n",
-            "30\tBest accuracy: 95.00%\n",
-            "40\tBest accuracy: 95.50%\n",
-            "50\tBest accuracy: 94.00%\n",
-            "60\tBest accuracy: 85.00%\n",
-            "70\tBest accuracy: 85.50%\n",
-            "80\tBest accuracy: 62.50%\n",
-            "90\tBest accuracy: 76.00%\n",
-            "Average best accuracy: 79.86%\n",
-            "\n",
-            "0\tAverage accuracy: 93.50%\n",
-            "10\tAverage accuracy: 93.45%\n",
-            "20\tAverage accuracy: 46.95%\n",
-            "30\tAverage accuracy: 71.41%\n",
-            "40\tAverage accuracy: 73.68%\n",
-            "50\tAverage accuracy: 64.50%\n",
-            "60\tAverage accuracy: 72.50%\n",
-            "70\tAverage accuracy: 94.36%\n",
-            "80\tAverage accuracy: 84.77%\n",
-            "90\tAverage accuracy: 92.14%\n",
-            "Average average accuracy: 80.92%\n",
-            "\n"
-          ]
-        }
-      ],
       "source": [
         "batch_size = 32               # size of batch the inner NN is trained with\n",
-        "toy_size = 0.02               # total propeortion of training and test set we use\n",
+        "toy_size = 0.05               # total propeortion of training and test set we use\n",
         "max_epochs = 100              # max number of epochs that is run if early stopping is not hit\n",
         "early_stop_num = 10           # max number of worse validation scores before early stopping is triggered\n",
         "early_stop_flag = True        # implement early stopping or not\n",
         "average_validation = [15,25]  # if not implementing early stopping, what epochs are we averaging over\n",
         "num_iterations = 100          # how many iterations are we averaging over\n",
+        "IsLeNet = True                # using LeNet or EasyNet\n",
         "\n",
         "# run using early stopping\n",
         "best_accuracies = []\n",
         "for baselines in range(num_iterations):\n",
-        "    best_acc = run_baseline(batch_size, toy_size, max_epochs, early_stop_num, early_stop_flag, average_validation)\n",
+        "    best_acc = run_baseline(batch_size, toy_size, max_epochs, early_stop_num, early_stop_flag, average_validation, IsLeNet)\n",
         "    best_accuracies.append(best_acc)\n",
         "    if baselines % 10 == 0:\n",
         "        print(\"{}\\tBest accuracy: {:.2f}%\".format(baselines, best_acc*100))\n",
@@ -370,19 +223,52 @@
         "early_stop_flag = False\n",
         "best_accuracies = []\n",
         "for baselines in range(num_iterations):\n",
-        "    best_acc = run_baseline(batch_size, toy_size, max_epochs, early_stop_num, early_stop_flag, average_validation)\n",
+        "    best_acc = run_baseline(batch_size, toy_size, max_epochs, early_stop_num, early_stop_flag, average_validation, IsLeNet)\n",
         "    best_accuracies.append(best_acc)\n",
         "    if baselines % 10 == 0:\n",
         "        print(\"{}\\tAverage accuracy: {:.2f}%\".format(baselines, best_acc*100))\n",
         "print(\"Average average accuracy: {:.2f}%\\n\".format(np.mean(best_accuracies)*100))"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "id": "KVhYheLfBP33",
+        "outputId": "39c42079-a3cb-492e-8e26-68818eeac808"
+      },
+      "execution_count": 6,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "0\tBest accuracy: 95.60%\n",
+            "10\tBest accuracy: 85.40%\n",
+            "20\tBest accuracy: 86.40%\n",
+            "30\tBest accuracy: 95.40%\n",
+            "40\tBest accuracy: 97.00%\n",
+            "50\tBest accuracy: 80.40%\n",
+            "60\tBest accuracy: 95.60%\n",
+            "70\tBest accuracy: 96.40%\n",
+            "80\tBest accuracy: 86.20%\n",
+            "90\tBest accuracy: 95.40%\n",
+            "Average best accuracy: 84.65%\n",
+            "\n",
+            "0\tAverage accuracy: 78.45%\n",
+            "10\tAverage accuracy: 58.02%\n",
+            "20\tAverage accuracy: 38.60%\n",
+            "30\tAverage accuracy: 65.15%\n",
+            "40\tAverage accuracy: 77.22%\n",
+            "50\tAverage accuracy: 79.09%\n",
+            "60\tAverage accuracy: 95.55%\n",
+            "70\tAverage accuracy: 86.33%\n",
+            "80\tAverage accuracy: 85.98%\n",
+            "90\tAverage accuracy: 78.20%\n",
+            "Average average accuracy: 83.31%\n",
+            "\n"
+          ]
+        }
       ]
-    },
-    {
-      "cell_type": "code",
-      "execution_count": null,
-      "metadata": {},
-      "outputs": [],
-      "source": []
     }
   ],
   "metadata": {
@@ -406,9 +292,9 @@
       "name": "python",
       "nbconvert_exporter": "python",
       "pygments_lexer": "ipython3",
-      "version": "3.9.7"
+      "version": "3.7.7"
     }
   },
   "nbformat": 4,
   "nbformat_minor": 0
-}
+}
\ No newline at end of file
diff --git a/MetaAugment/UCB1_JC.ipynb b/MetaAugment/UCB1_JC.ipynb
index d3bbda39..196e2ce4 100644
--- a/MetaAugment/UCB1_JC.ipynb
+++ b/MetaAugment/UCB1_JC.ipynb
@@ -1,12 +1,24 @@
 {
+  "nbformat": 4,
+  "nbformat_minor": 0,
+  "metadata": {
+    "colab": {
+      "name": "UCB1.ipynb",
+      "provenance": [],
+      "collapsed_sections": []
+    },
+    "kernelspec": {
+      "name": "python3",
+      "display_name": "Python 3"
+    },
+    "language_info": {
+      "name": "python"
+    },
+    "accelerator": "GPU"
+  },
   "cells": [
     {
       "cell_type": "code",
-      "execution_count": 1,
-      "metadata": {
-        "id": "U_ZJ2LqDiu_v"
-      },
-      "outputs": [],
       "source": [
         "import numpy as np\n",
         "import torch\n",
@@ -18,17 +30,116 @@
         "import torchvision\n",
         "import torchvision.datasets as datasets\n",
         "\n",
-        "import child_networks\n",
-        "from main import *"
-      ]
+        "from matplotlib import pyplot as plt\n",
+        "from numpy import save, load"
+      ],
+      "metadata": {
+        "id": "U_ZJ2LqDiu_v"
+      },
+      "execution_count": null,
+      "outputs": []
     },
     {
       "cell_type": "code",
-      "execution_count": 3,
+      "source": [
+        "\"\"\"Define internal NN module that trains on the dataset\"\"\"\n",
+        "class LeNet(nn.Module):\n",
+        "    def __init__(self):\n",
+        "        super().__init__()\n",
+        "        self.conv1 = nn.Conv2d(1, 6, 5)\n",
+        "        self.relu1 = nn.ReLU()\n",
+        "        self.pool1 = nn.MaxPool2d(2)\n",
+        "        self.conv2 = nn.Conv2d(6, 16, 5)\n",
+        "        self.relu2 = nn.ReLU()\n",
+        "        self.pool2 = nn.MaxPool2d(2)\n",
+        "        self.fc1 = nn.Linear(256, 120)\n",
+        "        self.relu3 = nn.ReLU()\n",
+        "        self.fc2 = nn.Linear(120, 84)\n",
+        "        self.relu4 = nn.ReLU()\n",
+        "        self.fc3 = nn.Linear(84, 10)\n",
+        "        self.relu5 = nn.ReLU()\n",
+        "\n",
+        "    def forward(self, x):\n",
+        "        y = self.conv1(x)\n",
+        "        y = self.relu1(y)\n",
+        "        y = self.pool1(y)\n",
+        "        y = self.conv2(y)\n",
+        "        y = self.relu2(y)\n",
+        "        y = self.pool2(y)\n",
+        "        y = y.view(y.shape[0], -1)\n",
+        "        y = self.fc1(y)\n",
+        "        y = self.relu3(y)\n",
+        "        y = self.fc2(y)\n",
+        "        y = self.relu4(y)\n",
+        "        y = self.fc3(y)\n",
+        "        y = self.relu5(y)\n",
+        "        return y"
+      ],
       "metadata": {
-        "id": "Iql-c88jGGWy"
+        "id": "4ksS_duLFADW"
       },
-      "outputs": [],
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "\"\"\"Define internal NN module that trains on the dataset\"\"\"\n",
+        "class EasyNet(nn.Module):\n",
+        "    def __init__(self):\n",
+        "        super().__init__()\n",
+        "        self.fc1 = nn.Linear(784, 2048)\n",
+        "        self.relu1 = nn.ReLU()\n",
+        "        self.fc2 = nn.Linear(2048, 10)\n",
+        "        self.relu2 = nn.ReLU()\n",
+        "\n",
+        "    def forward(self, x):\n",
+        "        y = x.view(x.shape[0], -1)\n",
+        "        y = self.fc1(y)\n",
+        "        y = self.relu1(y)\n",
+        "        y = self.fc2(y)\n",
+        "        y = self.relu2(y)\n",
+        "        return y"
+      ],
+      "metadata": {
+        "id": "LckxnUXGfxjW"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "\"\"\"Make toy dataset\"\"\"\n",
+        "\n",
+        "def create_toy(train_dataset, test_dataset, batch_size, n_samples):\n",
+        "    \n",
+        "    # shuffle and take first n_samples %age of training dataset\n",
+        "    shuffle_order_train = np.random.RandomState(seed=100).permutation(len(train_dataset))\n",
+        "    shuffled_train_dataset = torch.utils.data.Subset(train_dataset, shuffle_order_train)\n",
+        "    indices_train = torch.arange(int(n_samples*len(train_dataset)))\n",
+        "    reduced_train_dataset = data_utils.Subset(shuffled_train_dataset, indices_train)\n",
+        "\n",
+        "    # shuffle and take first n_samples %age of test dataset\n",
+        "    shuffle_order_test = np.random.RandomState(seed=1000).permutation(len(test_dataset))\n",
+        "    shuffled_test_dataset = torch.utils.data.Subset(test_dataset, shuffle_order_test)\n",
+        "    indices_test = torch.arange(int(n_samples*len(test_dataset)))\n",
+        "    reduced_test_dataset = data_utils.Subset(shuffled_test_dataset, indices_test)\n",
+        "\n",
+        "    # push into DataLoader\n",
+        "    train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=batch_size)\n",
+        "    test_loader = torch.utils.data.DataLoader(reduced_test_dataset, batch_size=batch_size)\n",
+        "\n",
+        "    return train_loader, test_loader"
+      ],
+      "metadata": {
+        "id": "xujQtvVWBgMH"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
       "source": [
         "\"\"\"Randomly generate 10 policies\"\"\"\n",
         "\"\"\"Each policy has 5 sub-policies\"\"\"\n",
@@ -59,15 +170,15 @@
         "                    policies[policy, sub_policy, transformation + 4] = np.random.randint(5,15)/10\n",
         "\n",
         "    return policies"
-      ]
+      ],
+      "metadata": {
+        "id": "Iql-c88jGGWy"
+      },
+      "execution_count": null,
+      "outputs": []
     },
     {
       "cell_type": "code",
-      "execution_count": 4,
-      "metadata": {
-        "id": "QE2VWI8o731X"
-      },
-      "outputs": [],
       "source": [
         "\"\"\"Pick policy and sub-policy\"\"\"\n",
         "\"\"\"Each row of data should have a different sub-policy but for now, this will do\"\"\"\n",
@@ -104,18 +215,23 @@
         "            scale = policies[policy, sub_policy][5]\n",
         "\n",
         "    return degrees, shear, scale"
-      ]
+      ],
+      "metadata": {
+        "id": "QE2VWI8o731X"
+      },
+      "execution_count": null,
+      "outputs": []
     },
     {
       "cell_type": "code",
-      "execution_count": 6,
+      "execution_count": null,
       "metadata": {
         "id": "vu_4I4qkbx73"
       },
       "outputs": [],
       "source": [
         "\"\"\"Sample policy, open and apply above transformations\"\"\"\n",
-        "def run_UCB1(policies, batch_size, toy_size, max_epochs, early_stop_num, iterations):\n",
+        "def run_UCB1(policies, batch_size, toy_size, max_epochs, early_stop_num, iterations, IsLeNet):\n",
         "\n",
         "    # get number of policies and sub-policies\n",
         "    num_policies = len(policies)\n",
@@ -127,6 +243,8 @@
         "    q_plus_cnt = [0]*num_policies\n",
         "    total_count = 0\n",
         "\n",
+        "    best_q_values = []\n",
+        "\n",
         "    for policy in range(iterations):\n",
         "\n",
         "        # get the action to try (either initially in order or using best q_plus_cnt value)\n",
@@ -144,18 +262,60 @@
         "            torchvision.transforms.ToTensor()])\n",
         "\n",
         "        # open data and apply these transformations\n",
-        "        train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=True, transform=transform)\n",
-        "        test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=True, transform=torchvision.transforms.ToTensor())\n",
+        "        train_dataset = datasets.MNIST(root='./MetaAugment/train', train=True, download=True, transform=transform)\n",
+        "        test_dataset = datasets.MNIST(root='./MetaAugment/test', train=False, download=True, transform=transform)\n",
         "\n",
         "        # create toy dataset from above uploaded data\n",
         "        train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, toy_size)\n",
         "\n",
         "        # create model\n",
-        "        child_network = child_networks.lenet()\n",
-        "        sgd = optim.SGD(child_network.parameters(), lr=1e-1)\n",
+        "        if IsLeNet:\n",
+        "            model = LeNet()\n",
+        "        else:\n",
+        "            model = EasyNet()\n",
+        "        sgd = optim.SGD(model.parameters(), lr=1e-1)\n",
         "        cost = nn.CrossEntropyLoss()\n",
         "\n",
-        "        best_acc = train_child_network(child_network, train_loader, test_loader, sgd, cost, max_epochs=100)\n",
+        "        # set variables for best validation accuracy and early stop count\n",
+        "        best_acc = 0\n",
+        "        early_stop_cnt = 0\n",
+        "\n",
+        "        # train model and check validation accuracy each epoch\n",
+        "        for _epoch in range(max_epochs):\n",
+        "\n",
+        "            # train model\n",
+        "            model.train()\n",
+        "            for idx, (train_x, train_label) in enumerate(train_loader):\n",
+        "                label_np = np.zeros((train_label.shape[0], 10))\n",
+        "                sgd.zero_grad()\n",
+        "                predict_y = model(train_x.float())\n",
+        "                loss = cost(predict_y, train_label.long())\n",
+        "                loss.backward()\n",
+        "                sgd.step()\n",
+        "\n",
+        "            # check validation accuracy on validation set\n",
+        "            correct = 0\n",
+        "            _sum = 0\n",
+        "            model.eval()\n",
+        "            for idx, (test_x, test_label) in enumerate(test_loader):\n",
+        "                predict_y = model(test_x.float()).detach()\n",
+        "                predict_ys = np.argmax(predict_y, axis=-1)\n",
+        "                label_np = test_label.numpy()\n",
+        "                _ = predict_ys == test_label\n",
+        "                correct += np.sum(_.numpy(), axis=-1)\n",
+        "                _sum += _.shape[0]\n",
+        "            \n",
+        "            # update best validation accuracy if it was higher, otherwise increase early stop count\n",
+        "            acc = correct / _sum\n",
+        "            if acc > best_acc :\n",
+        "                best_acc = acc\n",
+        "                early_stop_cnt = 0\n",
+        "            else:\n",
+        "                early_stop_cnt += 1\n",
+        "\n",
+        "            # exit if validation gets worse over 10 runs\n",
+        "            if early_stop_cnt >= early_stop_num:\n",
+        "                break\n",
         "\n",
         "        # update q_values\n",
         "        if policy < num_policies:\n",
@@ -163,7 +323,11 @@
         "        else:\n",
         "            q_values[this_policy] = (q_values[this_policy]*cnts[this_policy] + best_acc) / (cnts[this_policy] + 1)\n",
         "\n",
-        "        print(q_values)\n",
+        "        best_q_value = max(q_values)\n",
+        "        best_q_values.append(best_q_value)\n",
+        "\n",
+        "        if (policy+1) % 10 == 0:\n",
+        "            print(\"Iteration: {},\\tQ-Values: {}, Best Policy: {}\".format(policy+1, list(np.around(np.array(q_values),2)), max(list(np.around(np.array(q_values),2)))))\n",
         "\n",
         "        # update counts\n",
         "        cnts[this_policy] += 1\n",
@@ -174,277 +338,71 @@
         "            for i in range(num_policies):\n",
         "                q_plus_cnt[i] = q_values[i] + np.sqrt(2*np.log(total_count)/cnts[i])\n",
         "\n",
-        "    return q_values"
+        "    return q_values, best_q_values"
       ]
     },
     {
       "cell_type": "code",
-      "execution_count": 7,
+      "source": [
+        "batch_size = 32       # size of batch the inner NN is trained with\n",
+        "toy_size = 0.05       # total propeortion of training and test set we use\n",
+        "max_epochs = 100      # max number of epochs that is run if early stopping is not hit\n",
+        "early_stop_num = 10   # max number of worse validation scores before early stopping is triggered\n",
+        "num_policies = 5      # fix number of policies\n",
+        "num_sub_policies = 5  # fix number of sub-policies in a policy\n",
+        "iterations = 100      # total iterations, should be more than the number of policies\n",
+        "IsLeNet = True        # using LeNet or EasyNet\n",
+        "\n",
+        "# generate random policies at start\n",
+        "policies = generate_policies(num_policies, num_sub_policies)\n",
+        "\n",
+        "q_values, best_q_values = run_UCB1(policies, batch_size, toy_size, max_epochs, early_stop_num, iterations, IsLeNet)\n",
+        "\n",
+        "plt.plot(best_q_values)\n",
+        "\n",
+        "best_q_values = np.array(best_q_values)\n",
+        "save('best_q_values_LeNet_5percent.npy', best_q_values)\n",
+        "#best_q_values = load('best_q_values_LeNet_5percent.npy', allow_pickle=True)"
+      ],
       "metadata": {
         "colab": {
-          "base_uri": "https://localhost:8080/"
+          "base_uri": "https://localhost:8080/",
+          "height": 447
         },
         "id": "doHUtJ_tEiA6",
-        "outputId": "6735e812-f7be-4f8b-cec2-52a069f7731b"
+        "outputId": "3cc290fb-2d45-4fac-b0e6-6fb6c0490b78"
       },
+      "execution_count": null,
       "outputs": [
         {
-          "name": "stdout",
           "output_type": "stream",
+          "name": "stdout",
           "text": [
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
-            "main.train_child_network best accuracy:  0.5\n",
-            "main.train_child_network best accuracy:  0.5\n",
-            "main.train_child_network best accuracy:  0.5\n",
-            "main.train_child_network best accuracy:  0.5\n",
-            "main.train_child_network best accuracy:  0.5\n",
-            "main.train_child_network best accuracy:  0.5\n",
-            "main.train_child_network best accuracy:  0.5\n",
-            "main.train_child_network best accuracy:  0.5\n",
-            "main.train_child_network best accuracy:  0.5\n",
-            "main.train_child_network best accuracy:  0.5\n",
-            "[0, 0, 0, 0, 0.5, 0, 0, 0, 0, 0]\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "[0, 0, 0, 0, 0.5, 0, 0, 0, 0, 0]\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "[0, 0, 0, 0, 0.5, 0, 0, 0, 0, 0]\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "[0, 0, 0, 0, 0.5, 0, 0, 0, 0, 0]\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "[0, 0, 0, 0, 0.5, 0, 0, 0, 0, 0]\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0.5\n",
-            "main.train_child_network best accuracy:  0.5\n",
-            "main.train_child_network best accuracy:  0.5\n",
-            "main.train_child_network best accuracy:  0.5\n",
-            "main.train_child_network best accuracy:  0.5\n",
-            "main.train_child_network best accuracy:  0.5\n",
-            "main.train_child_network best accuracy:  0.5\n",
-            "main.train_child_network best accuracy:  0.5\n",
-            "main.train_child_network best accuracy:  0.5\n",
-            "main.train_child_network best accuracy:  0.5\n",
-            "[0, 0, 0, 0, 0.5, 0, 0, 0, 0, 0.5]\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "[0, 0, 0, 0, 0.25, 0, 0, 0, 0, 0.5]\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "[0, 0, 0, 0, 0.25, 0, 0, 0, 0, 0.25]\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "[0.0, 0, 0, 0, 0.25, 0, 0, 0, 0, 0.25]\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "[0.0, 0.0, 0, 0, 0.25, 0, 0, 0, 0, 0.25]\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "[0.0, 0.0, 0.0, 0, 0.25, 0, 0, 0, 0, 0.25]\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "[0.0, 0.0, 0.0, 0.0, 0.25, 0, 0, 0, 0, 0.25]\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "[0.0, 0.0, 0.0, 0.0, 0.25, 0.0, 0, 0, 0, 0.25]\n",
-            "main.train_child_network best accuracy:  0.5\n",
-            "main.train_child_network best accuracy:  0.5\n",
-            "main.train_child_network best accuracy:  0.5\n",
-            "main.train_child_network best accuracy:  0.5\n",
-            "main.train_child_network best accuracy:  0.5\n",
-            "main.train_child_network best accuracy:  0.5\n",
-            "main.train_child_network best accuracy:  0.5\n",
-            "main.train_child_network best accuracy:  0.5\n",
-            "main.train_child_network best accuracy:  0.5\n",
-            "main.train_child_network best accuracy:  0.5\n",
-            "[0.0, 0.0, 0.0, 0.0, 0.25, 0.0, 0.25, 0, 0, 0.25]\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "[0.0, 0.0, 0.0, 0.0, 0.25, 0.0, 0.25, 0.0, 0, 0.25]\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "main.train_child_network best accuracy:  0\n",
-            "[0.0, 0.0, 0.0, 0.0, 0.25, 0.0, 0.25, 0.0, 0.0, 0.25]\n",
-            "Wall time: 3.92 s\n"
+            "Iteration: 10,\tQ-Values: [0.79, 0.85, 0.87, 0.69, 0.83], Best Policy: 0.87\n",
+            "Iteration: 20,\tQ-Values: [0.82, 0.86, 0.88, 0.82, 0.86], Best Policy: 0.88\n",
+            "Iteration: 30,\tQ-Values: [0.84, 0.89, 0.77, 0.84, 0.87], Best Policy: 0.89\n",
+            "Iteration: 40,\tQ-Values: [0.83, 0.9, 0.83, 0.84, 0.89], Best Policy: 0.9\n",
+            "Iteration: 50,\tQ-Values: [0.84, 0.91, 0.82, 0.85, 0.87], Best Policy: 0.91\n",
+            "Iteration: 60,\tQ-Values: [0.83, 0.92, 0.83, 0.86, 0.88], Best Policy: 0.92\n",
+            "Iteration: 70,\tQ-Values: [0.84, 0.92, 0.83, 0.85, 0.88], Best Policy: 0.92\n",
+            "Iteration: 80,\tQ-Values: [0.85, 0.92, 0.83, 0.85, 0.87], Best Policy: 0.92\n",
+            "Iteration: 90,\tQ-Values: [0.85, 0.91, 0.81, 0.85, 0.87], Best Policy: 0.91\n",
+            "Iteration: 100,\tQ-Values: [0.84, 0.91, 0.83, 0.85, 0.83], Best Policy: 0.91\n"
           ]
+        },
+        {
+          "output_type": "display_data",
+          "data": {
+            "text/plain": [
+              "<Figure size 432x288 with 1 Axes>"
+            ],
+            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3deXxc1X338c9Po82LvAsb7zbIgM1iG2NIWJISFkMWCE1TQ7O0oSFJgTYpbR/SUJqSJ2mePt2SloeUJIRAQigBmhjqpw5hKUnAYBvjHS8YsCVv8iIjWdaMZubXP+aONNqsGS2WffR9v156aebeczXncs1XR+eee465OyIiEq6iga6AiIj0LwW9iEjgFPQiIoFT0IuIBE5BLyISuOKBrkB748aN8+nTpw90NURETiqrVq3a7+6Vne074YJ++vTprFy5cqCrISJyUjGzd7rap64bEZHAKehFRAKnoBcRCZyCXkQkcAp6EZHAKehFRAKnoBcRCdygDvrl2w+wesehga6GiEi/GrRB39Sc4vM/WsVtj6wmldac/CISrkEb9E+v3U1dYzM1dUd5/o19A10dEZF+M2iD/uGX3+b0U4YzfkQZDy3v8slhEZGT3qAM+jU761hTfZhPvWcaNy2cxotbanl7/5GBrpaISL8YlEH/0MvvMKw0xkfnTeLGhVMoLjJ+pFa9iARq0AX9oSMJnlq7i4/On0RFeQmnjCjn6rMn8NjKnRxNpAa6eiIifS6voDezRWa22cy2mdmdneyfZmbPmtlaM3vBzCbn7EuZ2evR15K+rHxPPLZyJ4lkmk9eNL1l26cumsa7TUmWrKkZuIqJiPSTboPezGLAvcA1wGzgRjOb3a7Y3wMPufu5wD3A3+bsO+ruc6Ovj/RRvXvE3fnJqztYOH0MZ0yoaNm+cMYYzhhfwSOv7BjA2omI9I98WvQLgW3uvt3dE8CjwHXtyswGnoteP9/J/hPCmurDvH2gkY8tmNxmu5nxsfMns6b6MG/ppqyIBCafoJ8E7Mx5Xx1ty7UGuCF6/VGgwszGRu/LzWylmS03s+s7+wAzuyUqs7K2traA6hdmyeu7KI0VcfWcCR32ffDcUwF4as2ufvt8EZGB0Fc3Y/8MeJ+ZrQbeB9QA2Tub09x9AXAT8M9mdlr7g939fndf4O4LKis7XfKw11Jp5+m1u3jfGZWMHFLSYf/EUUNYOH0MS9bswl1PyopIOPIJ+hpgSs77ydG2Fu6+y91vcPd5wFeibXXR95ro+3bgBWBe76tduFfeOsC++jgfOW9il2U+PHci2/Y1sGl3/XGsmYhI/8on6FcAVWY2w8xKgcVAm9EzZjbOzLI/68vAA9H20WZWli0DXAxs7KvKF+KpNbsYWhrjirPGd1nm2rMnECsylqj7RkQC0m3Qu3sSuA1YBmwCHnP3DWZ2j5llR9G8H9hsZluA8cDXo+1nASvNbA2Zm7TfdPfjHvSJZJql6/Zw5ezxDCmNdVlu7PAyLjl9HE+p+0ZEAlKcTyF3Xwosbbft7pzXjwOPd3LcS8A5vaxjr/1qay2HjzYfs9sm6yPnTeSOn67htR2HOH/amONQOxGR/jUonoxdsmYXI4eUcGlV9zd6r5oznrLiIpa8ru4bEQnDoAj6F7fUcsVZ4ykt7v50K8pLuLRqHC9u3X8caiYi0v+CD3p35/DRZiaMLMv7mFNHDuFQY6IfayUicvwEH/RNzWnSDsPK8rodAUBFeTH1TUndkBWRIAQf9A3xJADDCwr6ElJpp6k53V/VEhE5boIP+sZEJuiHlRbWogeob2rulzqJiBxPwQd9tkU/rKzr8fPtZYP+3aZkv9RJROR4Cj7oj8QzU+4U2kcPatGLSBgGQdBnW/SF9dED1KtFLyIBCD/oe9VHr6AXkZNf+EHfoz76TIu+Ia6uGxE5+QUf9A1RH31hwyvVoheRcOSffiepxqhFP7SArptsN49G3chA21ffxJY9DQUdM2ZYKbMnjuiX+qyvOcyfP76WTbvf7XT/2ZNG8L1PXcCEkeX98vnSM8EHfUMiSWmsKK95brJiRcbwsmKNupEB0xBP8m///Sbf/dX2Hj24d+c1Z/L593VYzK3HmlNp7n1+G//63DbGDi/l9stPp8isTZlkOs0PX3qH377vJR66eSGnVQ7P++evrznM0nW7aU7lf65FZnz8gikFfc5gFXzQH4knC+qfz8pOgyDSGw3xJA+9/DYHG/KfOymZdp5eu5v9DXE+fN5Eblw4hZJY/g2Vh15+h2/+/zdojCf50pWzsHaB3J102vnPdbu574U32VffBEA8maa+Kcn1cyfyNx85m5FDOy7HCXDN2afy6Qde5Xe+8zL/8PHzGF9x7Jb9joONPPjSWyzffpDiIiuoQXa0OUV13VHuvWl+/ic3SA2CoE8VNLQyKxP0atFLz724pZYvP7mOmrqjDDvGgjedmTNpJN/91PnMmzq64M+dP3U0Q0qK+PZz2zjYmOA9M8flfWxDvJkf/OZt3thTz6zxw7l6zoSWfe+bVclVOe87c/akkTz+hffyye+/wh/8YEVenzlxZDl/ee2Z/O4FUztdz7krd/1sHU+sqqExkSyoa3YwCv6/zpF4sqChlVkV5SUtT9WKZO053MSPlr9DopsuhupDjSxdt4fTKofxxBfec1wXsYkVGd+84VyGlhbz4Etv86PlOwo6fsa4YXxr8Vw+dO5EYkWF/TWQPf6p2y7h1bcP0t28gMPKYlw0c2xBf7Fkfejcifxo+Q6e3bSPD+exqNBgFn7QJ3redXPwiKYqllY1dUe58f7lVB9qpKz42P+mSmLGrb91GrdfXkV5SeH//nqrqMj46kfm8JmLZ9CUTOV/nMH0scMo7kHw5ho9rLTNXwP94YLpYzilooyn1uxS0Hcj+KBviKcYUV74aQ4vK+adA439UCM5Ge082MiN313O4aPNPPlHFzN3yqiBrlJepo4dOtBV6DexIuPac07lkVd3UN/U3PL8i3QUfNA3xpOcOqLwoV4V5SXqow/UK9sP8Js3DxR0zBOrqqlvaubHf3gh504+OUJ+MPjweRN58KW3eWbjXm6YP3mgq3PCCj7oM6NuCj/NEeXFGkcfoGc37eVzD68imS5sUZkJI8p55LMXcfakkf1UM+mJ+VNHMWnUEJ5eu/uEDvpddUfZ3xDvttzQ0hinn1LR558ffNA3xJMM72EffSKZJp5MddsfKyeHX2/dzxd+/BpnnTqCH3/2QkboT/2TnpnxwXNP5Qe/eYu6xgSjhpYOdJXaWFd9mP/3wjb+a8Oebm9MA8ydMoqf3Xpxn9cj6KB3dxoTPR1eGc1305SkbLiC/mSxbV8Db+8/0mH7wcYEf/3zDcwcN4yHPrNQIR+QD517Kve/uJ2v/Md6po/L/57EpFFDuWH+pF7dLF++/QD3v7i90xF6jYkk62vepaK8mC+87zTOn9b9UNkRBQwvLUTQQR9Ppkmmvcfj6CEz383Y4fkvLC4D5/WddfzOd16iOdV502lm5TAevvlCRg87sVp90jvnTBrJ+dNGs2zDnryPcSCVdr797FZu/a3T+PgFUwr6y31X3VG+sXQTT6/dzfgRZcwYN6xDmRHlJdx5zZn83oVTB/xGcdBB3zJzZYEPq0DrJGh6OvbkUNeY4NYfv8YpFeX8y03zKCnqODywavzwARnqKP3LzHjiC+8t+LiX3tzPP/5iC3/18w187elNBT0zEE+mKIkV8cUrqvjcZacxpAcZczwFHvSFry6V1br4iEbenOjcnT/76Vr21Tfx08+/96QZ+igD672njeM9nx/Lr7ft59db91PI7fmy4iJ+94IpTB59cgxfDTvoE4WvLpWldWNPPG/vP8LBxo4Psb2wuZZfbtrL3R+arZCXgpgZl1ZVcmlV5UBXpV+FHfQ9WEYwa4Ra9CeUZRv28LmHV3W5/+o54/mDi6cfvwqJnETySkAzWwR8C4gB33P3b7bbPw14AKgEDgKfcPfqaN+ngbuiov/b3X/YR3XvVvZOeE+HV+b+DBk4Bxri/OWT65gzcQR/fvUZHfaXxIq4YPqYgmdpFBksug16M4sB9wJXAtXACjNb4u4bc4r9PfCQu//QzC4H/hb4pJmNAf4aWEDmRveq6NhDfX0inWlM9LyPfrhWmTohuDt/9fP11DcleeSzczljQt8/TCISunxmLloIbHP37e6eAB4FrmtXZjbwXPT6+Zz9VwPPuPvBKNyfARb1vtr5aYgXvjB4VkmsiPKSInXdDLCn1u5m6bo9fPHKKoW8SA/lk4CTgJ0576uBC9uVWQPcQKZ756NAhZmN7eLYSe0/wMxuAW4BmDp1ar5171Zv+ughO9+NWvR9Zfn2Azz88jt4AeMbfr11P3OnjOKWS2f2Y81EwtZXN2P/DPhXM/t94EWgBsh7blR3vx+4H2DBggWFTUJyDK1dNz0b46pVpvpOMpXmL59cR21DnAkFTDI3a3wFf/exc3s9ba7IYJZP0NcAU3LeT462tXD3XWRa9JjZcOC33b3OzGqA97c79oVe1LcgDfFkZnmyHoZERXkJ9YPoZuyvt+6n7mjH4YsxMy6dVdnyEFlP/Pz1XWzff4TvfOJ8Fp3dv/OUi0hb+fyfuwKoMrMZZAJ+MXBTbgEzGwccdPc08GUyI3AAlgHfMLPsJA9XRfuPi+zMlT0djTFiEC0nuGZnHZ/4/itd7j9/2mge+eyFPZrgLZlK8+3ntjJn4giunjO+N9UUkR7oNujdPWlmt5EJ7RjwgLtvMLN7gJXuvoRMq/1vzczJdN3cGh170My+RuaXBcA97n6wH86jU5mZK3veCq0oL2b34aY+rNGJa9mGPcSKjP/4o/cypN00Aa/tOMT/emIdd/3Hev7uY+cW/IvzydU1vHOgke99aoGGQIoMgLxS0N2XAkvbbbs75/XjwONdHPsArS3846oxnupx/zxk5rsZLC36X2zcy0Uzx3S6qEbV+Apq6pr49rNbOWNCBX9YwI3R5lSabz+7lXMnj+QDZ53Sl1UWkTyF/WRsL1eHHyyjbt6sbWDbvgY+edG0Lst88QNVbN1bzzeWbmLjrnfzXsy5tiFO9aGjfO36s9WaFxkgQQd9X3TdNCZSJFPpoEd9PLNxLwBXzu66/7yoyPiHj5/H0R+neKnAZfiuPWcC758V9lwiIieyoIO+MZ7ilIqezyWfncHySDzFyKHhBv0vNuzh3MkjmThqyDHLDS0t5sE/WHicaiUifSXooG+IJ3v0VGxW6wyWzYwcmt/CAYlkmsdW7uRo4tiPERQVGdfNnci4AV7UZN+7TazeWccdV84a0HqISP8JOuiPJHq2MHjWiB7Md/PC5n3c9bP1eZXd924TX772rB7Vra88s2kv7nDVHI1tFwlV0EGfGXXT81McXlb4VMVv7KkHYMVXrjjmqjM3P7iCF7fuP+ZDBbc+8hq/6GJ5tCvOGs99nzi/zbbHVu7krp+tx7tZhbi8JMbvnD+FWy6byS827GX62KFUnTL8mMeIyMkr2KBPJNMkUukeTVGcVdGDFv3mvfVMHTOUym7uDVw2q5L/u2wztfXxTsvGkyme2bCXeVNHs6DdosKr3jnELzftpak51WZpvGXr9zBqSAkfO3/yMT+7+tBRfvjy2/xo+Tuk3Ln5khkaESMSsGCDPjuhWe+GV0ZBH8+/Rb95Tz2zxnc/y+JlVZmg/822/Vw/r8M8b2zc9S6JVJrPXDydRWef2mbfsg17eOXhg2zYdZjzp40BMtP5rt5ZxwfOPIW/WHRmt5//51efwXf++02WbdjL9XM7fr6IhCPYoSSti470bhw9QEOeLfp4MsVb+49wZh7T6c6ZOILRQ0t4cWttp/tf21EHwLypozvsmzc181DT6qgMwDsHGjl4JNFp+c5MGTOUr3/0HFbedQWzJ47I6xgROTkFG/S9WXQkq9B1Y9/cd4RU2pmVR9AXFRkXnz4usyhxJ33qq3ccYtKoIYzvZKbHUyrKmTx6SJugX70zs5bL/GlaM1VE2go26LMt+qG96KMvKy6iJGZ599Fv2Zu5EXtGHl03AJdWjWNffZwtexs67Fu9o465U7sO7XlTR7N6x6E25YeVxqg6RYtziEhbwQb9kT7oujGzaBqE/Pro39hTT0nMmDFuWF7lL4lWnv9Vu+6bfe82UVN3lHlTjhH0U0ax63ATe6JJ11bvqOO8KaOIFemmqoi0FWzQNyZ6voxgrkIWH9myt56Z44ZTWpzff9ZJo4Yws3IYv9q6v8321Tu77p/Pau2nP8TRRIpNu99t2SYikivYoG+I9251qayKAuak37ynvuB1TS+rquSVtw7Q1Nz6JO3qHXWUxIw5x7hJOnviCEpjRazeWcf6XYdJpp15U/K7ESsig0uwQd/b9WKzKspKWvr7j6W+qZmauqMFB/2lVeNoak6z6p3c/vZDzJ44ss0Y+fbKimPMmTSC1TsOtfTVH6tPX0QGr3CDPtH7PnrIv+sme0M13xuxWRfOHEtJzPjJqzuAzGpMa6sPMz+P0J4/dTRrqw/z6luHmDZ26IDPmyMiJ6Zwgz6eJFZklOXZX96V4eXFHDySYPOe+g5f+xviLeU2R1MfFNqiH15WzK2/dTpPr93NE6uqeWNPPUebU3mNh583dRTxZJrnN+875o1bERncAn4yNsXQ0livH+2vHF7Gvvo4V//zix32lcaK+PfPXcS8qaPZsreeYaUxJnUz1W9nbr+8ipffPMBf/Xw9v3tBZh32fII7+8sglfa8H5QSkcEn2KDv7aIjWX/0/tOZO2UU7R9pcodvLN3Enz62hv/840t4Y8+7VI2voKgHwxtjRca3Fs/jmm+9yA9+8zbjhpcxeXT3vzAmjiznlIrMLyKNuBGRrgQb9I29nKI4a+TQEq4559RO940ZVspN31vON5ZuYvOeeq6a3fOpfieMLOcfPn4en3lwJfOmjsrrLxEzY97UUbywuZYzJ2gaAxHpXLBB3xBPMewY0wT3hfecNpY/vGQG3/3VW0Dh/fPtXX7meP7lxnmcVpn/lMF3XHUGN8yfnPfYfREZfIIN+iPxvmnRd+eOq87gxS372by38DH0nfnweRMLKj9rfEVes2WKyOAVbDPweAV9eUmMf7lpHtfNnch83RAVkRNQuC36RN/cjM3HrPEVfGvxvOPyWSIihQq4RZ8ZXikiMtgFHPTHr0UvInIiCzLoU2knnkwfc3FuEZHBIsigTyTTQGbiLxGRwS6voDezRWa22cy2mdmdneyfambPm9lqM1trZtdG26eb2VEzez36+k5fn0BnskGvseUiInmMujGzGHAvcCVQDawwsyXuvjGn2F3AY+5+n5nNBpYC06N9b7r73L6t9rElUlHQx7TakohIPk3ehcA2d9/u7gngUeC6dmUcyD6DPxLY1XdVLFxL0KtFLyKSV9BPAnbmvK+OtuX6KvAJM6sm05q/PWffjKhL57/N7NLOPsDMbjGzlWa2sra2trMiBWmOum5KYgp6EZG+SsIbgQfdfTJwLfCwmRUBu4Gp7j4P+FPgETPrMPuWu9/v7gvcfUFlZWWvK6MWvYhIq3ySsAaYkvN+crQt183AYwDu/jJQDoxz97i7H4i2rwLeBGb1ttLdabkZqxa9iEheQb8CqDKzGWZWCiwGlrQrswP4AICZnUUm6GvNrDK6mYuZzQSqgO19VfmuZFv0JWrRi4h0P+rG3ZNmdhuwDIgBD7j7BjO7B1jp7kuAO4DvmtmXyNyY/X13dzO7DLjHzJqBNPB5dz/Yb2cTaRlHrxa9iEh+k5q5+1IyN1lzt92d83ojcHEnxz0BPNHLOhasWS16EZEWQSah+uhFRFoFmYTNGnUjItIiyCSMaxy9iEiLIJOwdVKzIE9PRKQgQSZhc8oBtehFRCDQoE8kU4D66EVEINCgb23Ra/ZKEZEgg15z3YiItAoyCeMaRy8i0iLIJGxOpSmJGWbquhERCTLoE8m0WvMiIpEg07A5ldY8NyIikSDTUC16EZFWQaZhIpnWiBsRkUiQaZhIqUUvIpIVZBqqRS8i0irINMwMrwzy1EREChZkGiZSatGLiGQFmYbNSVcfvYhIJMg0jGscvYhIiyDTUOPoRURaBZmGzak0pcWa50ZEBAINerXoRURaBZmGGl4pItIqyDTUA1MiIq2CTEMFvYhIqyDTUHPdiIi0Ci4N3V1PxoqI5MgrDc1skZltNrNtZnZnJ/unmtnzZrbazNaa2bU5+74cHbfZzK7uy8p3JpV23NHNWBGRSHF3BcwsBtwLXAlUAyvMbIm7b8wpdhfwmLvfZ2azgaXA9Oj1YmAOMBH4pZnNcvdUX59IViIVLQyuFr2ICJBfi34hsM3dt7t7AngUuK5dGQdGRK9HArui19cBj7p73N3fArZFP6/fNCcdQH30IiKRfNJwErAz5311tC3XV4FPmFk1mdb87QUci5ndYmYrzWxlbW1tnlXvXDyV+WNBc92IiGT0VRreCDzo7pOBa4GHzSzvn+3u97v7AndfUFlZ2auKJJKZrpsytehFRIA8+uiBGmBKzvvJ0bZcNwOLANz9ZTMrB8bleWyfak5lum5KNNeNiAiQX4t+BVBlZjPMrJTMzdUl7crsAD4AYGZnAeVAbVRusZmVmdkMoAp4ta8q35lsi740FuvPjxEROWl026J396SZ3QYsA2LAA+6+wczuAVa6+xLgDuC7ZvYlMjdmf9/dHdhgZo8BG4EkcGt/jriBzDw3oFE3IiJZ+XTd4O5Lydxkzd12d87rjcDFXRz7deDrvahjQeJRi74kpq4bEREI8MnYlq4btehFRIAAg76l60ajbkREgACDXi16EZG2gkvDbItec92IiGQEl4aa60ZEpK3g0jCeVB+9iEiu4NJQ4+hFRNoKLg0TatGLiLQRXBq23IxVi15EBAgw6NWiFxFpK7g0TGRnr9QUCCIiQIhBn0xTGivCTEEvIgKhBr3650VEWgSXiM2ptLptRERyBBf0atGLiLQVXCI2pxT0IiK5gkvEeCqtCc1ERHIEl4jZUTciIpIRXCKq60ZEpK3gElEtehGRtoJLxGb10YuItBFcImp4pYhIW8ElYlxBLyLSRnCJ2JxSH72ISK7gEjGhUTciIm0El4jNSddcNyIiOYILerXoRUTayisRzWyRmW02s21mdmcn+//JzF6PvraYWV3OvlTOviV9WfnONCfTlMZi/f0xIiInjeLuCphZDLgXuBKoBlaY2RJ335gt4+5fyil/OzAv50ccdfe5fVflY4un0pQUq+tGRCQrnxb9QmCbu2939wTwKHDdMcrfCPykLypXKHcnkUxTplE3IiIt8knEScDOnPfV0bYOzGwaMAN4LmdzuZmtNLPlZnZ9j2uah2Q6u16sgl5EJKvbrpsCLQYed/dUzrZp7l5jZjOB58xsnbu/mXuQmd0C3AIwderUHn94IpkG0M1YEZEc+SRiDTAl5/3kaFtnFtOu28bda6Lv24EXaNt/ny1zv7svcPcFlZWVeVSpc82pTNCrRS8i0iqfRFwBVJnZDDMrJRPmHUbPmNmZwGjg5Zxto82sLHo9DrgY2Nj+2L6iFr2ISEfddt24e9LMbgOWATHgAXffYGb3ACvdPRv6i4FH3d1zDj8L+DczS5P5pfLN3NE6fS2uoBcR6SCvPnp3Xwosbbft7nbvv9rJcS8B5/SifgXJdt1orhsRkVZBJWIipRa9iEh7QSVic1LDK0VE2gsqEROpzKhOtehFRFoFlYiJqEWvPnoRkVZBJWJrH73muhERyQor6LPDKzV7pYhIi6CCvuXJWLXoRURaBBX0rS36oE5LRKRXgkpEjaMXEekoqERUi15EpKOgElGTmomIdBRUImqaYhGRjoJKRLXoRUQ6CioRsy364iINrxQRyQoq6OOpNKXFRZgp6EVEsoIK+kQyTZn650VE2ggqFZtTaUrUPy8i0kZQqZhIpjWGXkSknaBSsTnlmudGRKSdoIJeLXoRkY6CSsVEKk1psaYoFhHJFVbQJ9OUxtR1IyKSK7yg16gbEZE2gkrF5lRa89yIiLQTVComUmrRi4i0F1QqatSNiEhHQaViQk/Gioh0EFQqaq4bEZGO8kpFM1tkZpvNbJuZ3dnJ/n8ys9ejry1mVpez79NmtjX6+nRfVr493YwVEemouLsCZhYD7gWuBKqBFWa2xN03Zsu4+5dyyt8OzItejwH+GlgAOLAqOvZQn55FRMMrRUQ6yicVFwLb3H27uyeAR4HrjlH+RuAn0eurgWfc/WAU7s8Ai3pT4WNpTrla9CIi7eSTipOAnTnvq6NtHZjZNGAG8Fwhx5rZLWa20sxW1tbW5lPvTqlFLyLSUV+n4mLgcXdPFXKQu9/v7gvcfUFlZWWPPtjdNY5eRKQT+aRiDTAl5/3kaFtnFtPabVPosb3SnHIAzXUjItJOPkG/AqgysxlmVkomzJe0L2RmZwKjgZdzNi8DrjKz0WY2Grgq2tbnEtHC4GrRi4i01e2oG3dPmtltZAI6Bjzg7hvM7B5gpbtnQ38x8Ki7e86xB83sa2R+WQDc4+4H+/YUMpqTmaDXzVgRkba6DXoAd18KLG237e5277/axbEPAA/0sH55KyoyPnjuqcysHN7fHyUiclLJK+hPBiOHlHDvTfMHuhoiIicc9XOIiAROQS8iEjgFvYhI4BT0IiKBU9CLiAROQS8iEjgFvYhI4BT0IiKBs5wZC04IZlYLvNOLHzEO2N9H1TlZDMZzhsF53oPxnGFwnneh5zzN3Tud/veEC/reMrOV7r5goOtxPA3Gc4bBed6D8ZxhcJ53X56zum5ERAKnoBcRCVyIQX//QFdgAAzGc4bBed6D8ZxhcJ53n51zcH30IiLSVogtehERyaGgFxEJXDBBb2aLzGyzmW0zszsHuj79xcymmNnzZrbRzDaY2Z9E28eY2TNmtjX6Pnqg69rXzCxmZqvN7Ono/QwzeyW65v8erWkcFDMbZWaPm9kbZrbJzN4T+rU2sy9F/7bXm9lPzKw8xGttZg+Y2T4zW5+zrdNraxnfjs5/rZkVtMpSEEFvZjHgXuAaYDZwo5nNHtha9ZskcIe7zwYuAm6NzvVO4Fl3rwKejd6H5k+ATTnv/w/wT+5+OnAIuHlAatW/vgX8l7ufCU/7w1MAAAKxSURBVJxH5vyDvdZmNgn4Y2CBu59NZp3qxYR5rR8EFrXb1tW1vQaoir5uAe4r5IOCCHpgIbDN3be7ewJ4FLhugOvUL9x9t7u/Fr2uJ/M//iQy5/vDqNgPgesHpob9w8wmAx8Evhe9N+By4PGoSIjnPBK4DPg+gLsn3L2OwK81mSVOh5hZMTAU2E2A19rdXwQOttvc1bW9DnjIM5YDo8zs1Hw/K5SgnwTszHlfHW0LmplNB+YBrwDj3X13tGsPMH6AqtVf/hn4CyAdvR8L1Ll7Mnof4jWfAdQCP4i6rL5nZsMI+Fq7ew3w98AOMgF/GFhF+Nc6q6tr26uMCyXoBx0zGw48AXzR3d/N3eeZMbPBjJs1sw8B+9x91UDX5TgrBuYD97n7POAI7bppArzWo8m0XmcAE4FhdOzeGBT68tqGEvQ1wJSc95OjbUEysxIyIf9jd38y2rw3+6dc9H3fQNWvH1wMfMTM3ibTLXc5mb7rUdGf9xDmNa8Gqt39lej942SCP+RrfQXwlrvXunsz8CSZ6x/6tc7q6tr2KuNCCfoVQFV0Z76UzM2bJQNcp34R9U1/H9jk7v+Ys2sJ8Ono9aeBnx/vuvUXd/+yu0929+lkru1z7v57wPPAx6JiQZ0zgLvvAXaa2RnRpg8AGwn4WpPpsrnIzIZG/9az5xz0tc7R1bVdAnwqGn1zEXA4p4une+4exBdwLbAFeBP4ykDXpx/P8xIyf86tBV6Pvq4l02f9LLAV+CUwZqDr2k/n/37g6ej1TOBVYBvwU6BsoOvXD+c7F1gZXe+fAaNDv9bA3wBvAOuBh4GyEK818BMy9yGayfz1dnNX1xYwMiML3wTWkRmVlPdnaQoEEZHAhdJ1IyIiXVDQi4gETkEvIhI4Bb2ISOAU9CIigVPQi4gETkEvIhK4/wG/OKSaw1gzfwAAAABJRU5ErkJggg==\n"
+          },
+          "metadata": {
+            "needs_background": "light"
+          }
         }
-      ],
-      "source": [
-        "%%time\n",
-        "\n",
-        "batch_size = 32       # size of batch inner NN is trained with\n",
-        "toy_size = 0.0002       # total propeortion of training and test set we use\n",
-        "max_epochs = 100      # max number of epochs that is run if early stopping is not hit\n",
-        "early_stop_num = 10   # max number of worse validation scores before early stopping\n",
-        "iterations = 20       # total iterations, should be more than the number of policies\n",
-        "\n",
-        "# generate policies and sub-policies\n",
-        "num_policies = 10\n",
-        "num_sub_policies = 5\n",
-        "policies = generate_policies(num_policies, num_sub_policies)\n",
-        "\n",
-        "q_values = run_UCB1(policies, batch_size, toy_size, max_epochs, early_stop_num, iterations)\n",
-        "#print(q_values)"
       ]
     }
-  ],
-  "metadata": {
-    "colab": {
-      "collapsed_sections": [],
-      "name": "UCB1.ipynb",
-      "provenance": []
-    },
-    "kernelspec": {
-      "display_name": "Python 3",
-      "name": "python3"
-    },
-    "language_info": {
-      "codemirror_mode": {
-        "name": "ipython",
-        "version": 3
-      },
-      "file_extension": ".py",
-      "mimetype": "text/x-python",
-      "name": "python",
-      "nbconvert_exporter": "python",
-      "pygments_lexer": "ipython3",
-      "version": "3.8.8"
-    }
-  },
-  "nbformat": 4,
-  "nbformat_minor": 0
-}
+  ]
+}
\ No newline at end of file
-- 
GitLab