diff --git a/MetaAugment/UCB1_JC.ipynb b/MetaAugment/UCB1_JC.ipynb
index 5710a128c11db4392adb0d1372caa890c01e344e..2ebf29b8cbb7ed6c7cf1bbaffd9b55881910bfb4 100644
--- a/MetaAugment/UCB1_JC.ipynb
+++ b/MetaAugment/UCB1_JC.ipynb
@@ -1,485 +1,495 @@
 {
-  "nbformat": 4,
-  "nbformat_minor": 0,
-  "metadata": {
-    "colab": {
-      "name": "UCB1.ipynb",
-      "provenance": [],
-      "collapsed_sections": []
-    },
-    "kernelspec": {
-      "name": "python3",
-      "display_name": "Python 3"
-    },
-    "language_info": {
-      "name": "python"
-    },
-    "accelerator": "GPU"
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {
+    "id": "U_ZJ2LqDiu_v"
+   },
+   "outputs": [],
+   "source": [
+    "import numpy as np\n",
+    "import torch\n",
+    "torch.manual_seed(0)\n",
+    "import torch.nn as nn\n",
+    "import torch.nn.functional as F\n",
+    "import torch.optim as optim\n",
+    "import torch.utils.data as data_utils\n",
+    "import torchvision\n",
+    "import torchvision.datasets as datasets\n",
+    "\n",
+    "from matplotlib import pyplot as plt\n",
+    "from numpy import save, load\n",
+    "from tqdm import trange"
+   ]
   },
-  "cells": [
-    {
-      "cell_type": "code",
-      "source": [
-        "import numpy as np\n",
-        "import torch\n",
-        "torch.manual_seed(0)\n",
-        "import torch.nn as nn\n",
-        "import torch.nn.functional as F\n",
-        "import torch.optim as optim\n",
-        "import torch.utils.data as data_utils\n",
-        "import torchvision\n",
-        "import torchvision.datasets as datasets\n",
-        "\n",
-        "from matplotlib import pyplot as plt\n",
-        "from numpy import save, load\n",
-        "from tqdm import trange"
-      ],
-      "metadata": {
-        "id": "U_ZJ2LqDiu_v"
-      },
-      "execution_count": 1,
-      "outputs": []
-    },
-    {
-      "cell_type": "code",
-      "source": [
-        "\"\"\"Define internal NN module that trains on the dataset\"\"\"\n",
-        "class LeNet(nn.Module):\n",
-        "    def __init__(self, img_height, img_width, num_labels, img_channels):\n",
-        "        super().__init__()\n",
-        "        self.conv1 = nn.Conv2d(img_channels, 6, 5)\n",
-        "        self.relu1 = nn.ReLU()\n",
-        "        self.pool1 = nn.MaxPool2d(2)\n",
-        "        self.conv2 = nn.Conv2d(6, 16, 5)\n",
-        "        self.relu2 = nn.ReLU()\n",
-        "        self.pool2 = nn.MaxPool2d(2)\n",
-        "        self.fc1 = nn.Linear(int((((img_height-4)/2-4)/2)*(((img_width-4)/2-4)/2)*16), 120)\n",
-        "        self.relu3 = nn.ReLU()\n",
-        "        self.fc2 = nn.Linear(120, 84)\n",
-        "        self.relu4 = nn.ReLU()\n",
-        "        self.fc3 = nn.Linear(84, num_labels)\n",
-        "        self.relu5 = nn.ReLU()\n",
-        "\n",
-        "    def forward(self, x):\n",
-        "        y = self.conv1(x)\n",
-        "        y = self.relu1(y)\n",
-        "        y = self.pool1(y)\n",
-        "        y = self.conv2(y)\n",
-        "        y = self.relu2(y)\n",
-        "        y = self.pool2(y)\n",
-        "        y = y.view(y.shape[0], -1)\n",
-        "        y = self.fc1(y)\n",
-        "        y = self.relu3(y)\n",
-        "        y = self.fc2(y)\n",
-        "        y = self.relu4(y)\n",
-        "        y = self.fc3(y)\n",
-        "        y = self.relu5(y)\n",
-        "        return y"
-      ],
-      "metadata": {
-        "id": "4ksS_duLFADW"
-      },
-      "execution_count": 2,
-      "outputs": []
-    },
-    {
-      "cell_type": "code",
-      "source": [
-        "\"\"\"Define internal NN module that trains on the dataset\"\"\"\n",
-        "class EasyNet(nn.Module):\n",
-        "    def __init__(self, img_height, img_width, num_labels, img_channels):\n",
-        "        super().__init__()\n",
-        "        self.fc1 = nn.Linear(img_height*img_width*img_channels, 2048)\n",
-        "        self.relu1 = nn.ReLU()\n",
-        "        self.fc2 = nn.Linear(2048, num_labels)\n",
-        "        self.relu2 = nn.ReLU()\n",
-        "\n",
-        "    def forward(self, x):\n",
-        "        y = x.view(x.shape[0], -1)\n",
-        "        y = self.fc1(y)\n",
-        "        y = self.relu1(y)\n",
-        "        y = self.fc2(y)\n",
-        "        y = self.relu2(y)\n",
-        "        return y"
-      ],
-      "metadata": {
-        "id": "LckxnUXGfxjW"
-      },
-      "execution_count": 3,
-      "outputs": []
-    },
-    {
-      "cell_type": "code",
-      "source": [
-        "\"\"\"Define internal NN module that trains on the dataset\"\"\"\n",
-        "class SimpleNet(nn.Module):\n",
-        "    def __init__(self, img_height, img_width, num_labels, img_channels):\n",
-        "        super().__init__()\n",
-        "        self.fc1 = nn.Linear(img_height*img_width*img_channels, num_labels)\n",
-        "        self.relu1 = nn.ReLU()\n",
-        "\n",
-        "    def forward(self, x):\n",
-        "        y = x.view(x.shape[0], -1)\n",
-        "        y = self.fc1(y)\n",
-        "        y = self.relu1(y)\n",
-        "        return y"
-      ],
-      "metadata": {
-        "id": "enaD2xbw5hew"
-      },
-      "execution_count": 4,
-      "outputs": []
-    },
-    {
-      "cell_type": "code",
-      "source": [
-        "\"\"\"Make toy dataset\"\"\"\n",
-        "\n",
-        "def create_toy(train_dataset, test_dataset, batch_size, n_samples):\n",
-        "    \n",
-        "    # shuffle and take first n_samples %age of training dataset\n",
-        "    shuffle_order_train = np.random.RandomState(seed=100).permutation(len(train_dataset))\n",
-        "    shuffled_train_dataset = torch.utils.data.Subset(train_dataset, shuffle_order_train)\n",
-        "    indices_train = torch.arange(int(n_samples*len(train_dataset)))\n",
-        "    reduced_train_dataset = data_utils.Subset(shuffled_train_dataset, indices_train)\n",
-        "\n",
-        "    # shuffle and take first n_samples %age of test dataset\n",
-        "    shuffle_order_test = np.random.RandomState(seed=1000).permutation(len(test_dataset))\n",
-        "    shuffled_test_dataset = torch.utils.data.Subset(test_dataset, shuffle_order_test)\n",
-        "    indices_test = torch.arange(int(n_samples*len(test_dataset)))\n",
-        "    reduced_test_dataset = data_utils.Subset(shuffled_test_dataset, indices_test)\n",
-        "\n",
-        "    # push into DataLoader\n",
-        "    train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=batch_size)\n",
-        "    test_loader = torch.utils.data.DataLoader(reduced_test_dataset, batch_size=batch_size)\n",
-        "\n",
-        "    return train_loader, test_loader"
-      ],
-      "metadata": {
-        "id": "xujQtvVWBgMH"
-      },
-      "execution_count": 5,
-      "outputs": []
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {
+    "id": "4ksS_duLFADW"
+   },
+   "outputs": [],
+   "source": [
+    "\"\"\"Define internal NN module that trains on the dataset\"\"\"\n",
+    "class LeNet(nn.Module):\n",
+    "    def __init__(self, img_height, img_width, num_labels, img_channels):\n",
+    "        super().__init__()\n",
+    "        self.conv1 = nn.Conv2d(img_channels, 6, 5)\n",
+    "        self.relu1 = nn.ReLU()\n",
+    "        self.pool1 = nn.MaxPool2d(2)\n",
+    "        self.conv2 = nn.Conv2d(6, 16, 5)\n",
+    "        self.relu2 = nn.ReLU()\n",
+    "        self.pool2 = nn.MaxPool2d(2)\n",
+    "        self.fc1 = nn.Linear(int((((img_height-4)/2-4)/2)*(((img_width-4)/2-4)/2)*16), 120)\n",
+    "        self.relu3 = nn.ReLU()\n",
+    "        self.fc2 = nn.Linear(120, 84)\n",
+    "        self.relu4 = nn.ReLU()\n",
+    "        self.fc3 = nn.Linear(84, num_labels)\n",
+    "        self.relu5 = nn.ReLU()\n",
+    "\n",
+    "    def forward(self, x):\n",
+    "        y = self.conv1(x)\n",
+    "        y = self.relu1(y)\n",
+    "        y = self.pool1(y)\n",
+    "        y = self.conv2(y)\n",
+    "        y = self.relu2(y)\n",
+    "        y = self.pool2(y)\n",
+    "        y = y.view(y.shape[0], -1)\n",
+    "        y = self.fc1(y)\n",
+    "        y = self.relu3(y)\n",
+    "        y = self.fc2(y)\n",
+    "        y = self.relu4(y)\n",
+    "        y = self.fc3(y)\n",
+    "        y = self.relu5(y)\n",
+    "        return y"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {
+    "id": "LckxnUXGfxjW"
+   },
+   "outputs": [],
+   "source": [
+    "\"\"\"Define internal NN module that trains on the dataset\"\"\"\n",
+    "class EasyNet(nn.Module):\n",
+    "    def __init__(self, img_height, img_width, num_labels, img_channels):\n",
+    "        super().__init__()\n",
+    "        self.fc1 = nn.Linear(img_height*img_width*img_channels, 2048)\n",
+    "        self.relu1 = nn.ReLU()\n",
+    "        self.fc2 = nn.Linear(2048, num_labels)\n",
+    "        self.relu2 = nn.ReLU()\n",
+    "\n",
+    "    def forward(self, x):\n",
+    "        y = x.view(x.shape[0], -1)\n",
+    "        y = self.fc1(y)\n",
+    "        y = self.relu1(y)\n",
+    "        y = self.fc2(y)\n",
+    "        y = self.relu2(y)\n",
+    "        return y"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {
+    "id": "enaD2xbw5hew"
+   },
+   "outputs": [],
+   "source": [
+    "\"\"\"Define internal NN module that trains on the dataset\"\"\"\n",
+    "class SimpleNet(nn.Module):\n",
+    "    def __init__(self, img_height, img_width, num_labels, img_channels):\n",
+    "        super().__init__()\n",
+    "        self.fc1 = nn.Linear(img_height*img_width*img_channels, num_labels)\n",
+    "        self.relu1 = nn.ReLU()\n",
+    "\n",
+    "    def forward(self, x):\n",
+    "        y = x.view(x.shape[0], -1)\n",
+    "        y = self.fc1(y)\n",
+    "        y = self.relu1(y)\n",
+    "        return y"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {
+    "id": "xujQtvVWBgMH"
+   },
+   "outputs": [],
+   "source": [
+    "\"\"\"Make toy dataset\"\"\"\n",
+    "\n",
+    "def create_toy(train_dataset, test_dataset, batch_size, n_samples):\n",
+    "    \n",
+    "    # shuffle and take first n_samples %age of training dataset\n",
+    "    shuffle_order_train = np.random.RandomState(seed=100).permutation(len(train_dataset))\n",
+    "    shuffled_train_dataset = torch.utils.data.Subset(train_dataset, shuffle_order_train)\n",
+    "    indices_train = torch.arange(int(n_samples*len(train_dataset)))\n",
+    "    reduced_train_dataset = data_utils.Subset(shuffled_train_dataset, indices_train)\n",
+    "\n",
+    "    # shuffle and take first n_samples %age of test dataset\n",
+    "    shuffle_order_test = np.random.RandomState(seed=1000).permutation(len(test_dataset))\n",
+    "    shuffled_test_dataset = torch.utils.data.Subset(test_dataset, shuffle_order_test)\n",
+    "    indices_test = torch.arange(int(n_samples*len(test_dataset)))\n",
+    "    reduced_test_dataset = data_utils.Subset(shuffled_test_dataset, indices_test)\n",
+    "\n",
+    "    # push into DataLoader\n",
+    "    train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=batch_size)\n",
+    "    test_loader = torch.utils.data.DataLoader(reduced_test_dataset, batch_size=batch_size)\n",
+    "\n",
+    "    return train_loader, test_loader"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {
+    "id": "Iql-c88jGGWy"
+   },
+   "outputs": [],
+   "source": [
+    "\"\"\"Randomly generate 10 policies\"\"\"\n",
+    "\"\"\"Each policy has 5 sub-policies\"\"\"\n",
+    "\"\"\"For each sub-policy, pick 2 transformations, 2 probabilities and 2 magnitudes\"\"\"\n",
+    "\n",
+    "def generate_policies(num_policies, num_sub_policies):\n",
+    "    \n",
+    "    policies = np.zeros([num_policies,num_sub_policies,6])\n",
+    "\n",
+    "    # Policies array will be 10x5x6\n",
+    "    for policy in range(num_policies):\n",
+    "        for sub_policy in range(num_sub_policies):\n",
+    "            # pick two sub_policy transformations (0=rotate, 1=shear, 2=scale)\n",
+    "            policies[policy, sub_policy, 0] = np.random.randint(0,3)\n",
+    "            policies[policy, sub_policy, 1] = np.random.randint(0,3)\n",
+    "            while policies[policy, sub_policy, 0] == policies[policy, sub_policy, 1]:\n",
+    "                policies[policy, sub_policy, 1] = np.random.randint(0,3)\n",
+    "\n",
+    "            # pick probabilities\n",
+    "            policies[policy, sub_policy, 2] = np.random.randint(0,11) / 10\n",
+    "            policies[policy, sub_policy, 3] = np.random.randint(0,11) / 10\n",
+    "\n",
+    "            # pick magnitudes\n",
+    "            for transformation in range(2):\n",
+    "                if policies[policy, sub_policy, transformation] <= 1:\n",
+    "                    policies[policy, sub_policy, transformation + 4] = np.random.randint(-4,5)*5\n",
+    "                elif policies[policy, sub_policy, transformation] == 2:\n",
+    "                    policies[policy, sub_policy, transformation + 4] = np.random.randint(5,15)/10\n",
+    "\n",
+    "    return policies"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {
+    "id": "QE2VWI8o731X"
+   },
+   "outputs": [],
+   "source": [
+    "\"\"\"Pick policy and sub-policy\"\"\"\n",
+    "\"\"\"Each row of data should have a different sub-policy but for now, this will do\"\"\"\n",
+    "\n",
+    "def sample_sub_policy(policies, policy, num_sub_policies):\n",
+    "    sub_policy = np.random.randint(0,num_sub_policies)\n",
+    "\n",
+    "    degrees = 0\n",
+    "    shear = 0\n",
+    "    scale = 1\n",
+    "\n",
+    "    # check for rotations\n",
+    "    if policies[policy, sub_policy][0] == 0:\n",
+    "        if np.random.uniform() < policies[policy, sub_policy][2]:\n",
+    "            degrees = policies[policy, sub_policy][4]\n",
+    "    elif policies[policy, sub_policy][1] == 0:\n",
+    "        if np.random.uniform() < policies[policy, sub_policy][3]:\n",
+    "            degrees = policies[policy, sub_policy][5]\n",
+    "\n",
+    "    # check for shears\n",
+    "    if policies[policy, sub_policy][0] == 1:\n",
+    "        if np.random.uniform() < policies[policy, sub_policy][2]:\n",
+    "            shear = policies[policy, sub_policy][4]\n",
+    "    elif policies[policy, sub_policy][1] == 1:\n",
+    "        if np.random.uniform() < policies[policy, sub_policy][3]:\n",
+    "            shear = policies[policy, sub_policy][5]\n",
+    "\n",
+    "    # check for scales\n",
+    "    if policies[policy, sub_policy][0] == 2:\n",
+    "        if np.random.uniform() < policies[policy, sub_policy][2]:\n",
+    "            scale = policies[policy, sub_policy][4]\n",
+    "    elif policies[policy, sub_policy][1] == 2:\n",
+    "        if np.random.uniform() < policies[policy, sub_policy][3]:\n",
+    "            scale = policies[policy, sub_policy][5]\n",
+    "\n",
+    "    return degrees, shear, scale"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {
+    "id": "vu_4I4qkbx73"
+   },
+   "outputs": [],
+   "source": [
+    "\"\"\"Sample policy, open and apply above transformations\"\"\"\n",
+    "def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, early_stop_flag, average_validation, iterations, IsLeNet):\n",
+    "\n",
+    "    # get number of policies and sub-policies\n",
+    "    num_policies = len(policies)\n",
+    "    num_sub_policies = len(policies[0])\n",
+    "\n",
+    "    #Initialize vector weights, counts and regret\n",
+    "    q_values = [0]*num_policies\n",
+    "    cnts = [0]*num_policies\n",
+    "    q_plus_cnt = [0]*num_policies\n",
+    "    total_count = 0\n",
+    "\n",
+    "    best_q_values = []\n",
+    "\n",
+    "    for policy in trange(iterations):\n",
+    "\n",
+    "        # get the action to try (either initially in order or using best q_plus_cnt value)\n",
+    "        if policy >= num_policies:\n",
+    "            this_policy = np.argmax(q_plus_cnt)\n",
+    "        else:\n",
+    "            this_policy = policy\n",
+    "\n",
+    "        # get info of transformation for this sub-policy\n",
+    "        degrees, shear, scale = sample_sub_policy(policies, this_policy, num_sub_policies)\n",
+    "\n",
+    "        # create transformations using above info\n",
+    "        transform = torchvision.transforms.Compose(\n",
+    "            [torchvision.transforms.RandomAffine(degrees=(degrees,degrees), shear=(shear,shear), scale=(scale,scale)),\n",
+    "            torchvision.transforms.ToTensor()])\n",
+    "\n",
+    "        # open data and apply these transformations\n",
+    "        if ds == \"MNIST\":\n",
+    "            train_dataset = datasets.MNIST(root='./MetaAugment/train', train=True, download=True, transform=transform)\n",
+    "            test_dataset = datasets.MNIST(root='./MetaAugment/test', train=False, download=True, transform=transform)\n",
+    "        elif ds == \"KMNIST\":\n",
+    "            train_dataset = datasets.KMNIST(root='./MetaAugment/train', train=True, download=True, transform=transform)\n",
+    "            test_dataset = datasets.KMNIST(root='./MetaAugment/test', train=False, download=True, transform=transform)\n",
+    "        elif ds == \"FashionMNIST\":\n",
+    "            train_dataset = datasets.FashionMNIST(root='./MetaAugment/train', train=True, download=True, transform=transform)\n",
+    "            test_dataset = datasets.FashionMNIST(root='./MetaAugment/test', train=False, download=True, transform=transform)\n",
+    "        elif ds == \"CIFAR10\":\n",
+    "            train_dataset = datasets.CIFAR10(root='./MetaAugment/train', train=True, download=True, transform=transform)\n",
+    "            test_dataset = datasets.CIFAR10(root='./MetaAugment/test', train=False, download=True, transform=transform)\n",
+    "        elif ds == \"CIFAR100\":\n",
+    "            train_dataset = datasets.CIFAR100(root='./MetaAugment/train', train=True, download=True, transform=transform)\n",
+    "            test_dataset = datasets.CIFAR100(root='./MetaAugment/test', train=False, download=True, transform=transform)\n",
+    "\n",
+    "        # check sizes of images\n",
+    "        img_height = len(train_dataset[0][0][0])\n",
+    "        img_width = len(train_dataset[0][0][0][0])\n",
+    "        img_channels = len(train_dataset[0][0])\n",
+    "\n",
+    "        # check output labels\n",
+    "        if ds == \"CIFAR10\" or ds == \"CIFAR100\":\n",
+    "            num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1)\n",
+    "        else:\n",
+    "            num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1).item()\n",
+    "\n",
+    "        # create toy dataset from above uploaded data\n",
+    "        train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, toy_size)\n",
+    "\n",
+    "        # create model\n",
+    "        device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
+    "        if IsLeNet == \"LeNet\":\n",
+    "            model = LeNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device)\n",
+    "        elif IsLeNet == \"EasyNet\":\n",
+    "            model = EasyNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device)\n",
+    "        else:\n",
+    "            model = SimpleNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device)\n",
+    "        sgd = optim.SGD(model.parameters(), lr=1e-1)\n",
+    "        cost = nn.CrossEntropyLoss()\n",
+    "\n",
+    "        # set variables for best validation accuracy and early stop count\n",
+    "        best_acc = 0\n",
+    "        early_stop_cnt = 0\n",
+    "        total_val = 0\n",
+    "\n",
+    "        # train model and check validation accuracy each epoch\n",
+    "        for _epoch in range(max_epochs):\n",
+    "\n",
+    "            # train model\n",
+    "            model.train()\n",
+    "            for idx, (train_x, train_label) in enumerate(train_loader):\n",
+    "                train_x, train_label = train_x.to(device), train_label.to(device) # new code\n",
+    "                label_np = np.zeros((train_label.shape[0], num_labels))\n",
+    "                sgd.zero_grad()\n",
+    "                predict_y = model(train_x.float())\n",
+    "                loss = cost(predict_y, train_label.long())\n",
+    "                loss.backward()\n",
+    "                sgd.step()\n",
+    "\n",
+    "            # check validation accuracy on validation set\n",
+    "            correct = 0\n",
+    "            _sum = 0\n",
+    "            model.eval()\n",
+    "            for idx, (test_x, test_label) in enumerate(test_loader):\n",
+    "                test_x, test_label = test_x.to(device), test_label.to(device) # new code\n",
+    "                predict_y = model(test_x.float()).detach()\n",
+    "                #predict_ys = np.argmax(predict_y, axis=-1)\n",
+    "                predict_ys = torch.argmax(predict_y, axis=-1) # changed np to torch\n",
+    "                #label_np = test_label.numpy()\n",
+    "                _ = predict_ys == test_label\n",
+    "                #correct += np.sum(_.numpy(), axis=-1)\n",
+    "                correct += np.sum(_.cpu().numpy(), axis=-1) # added .cpu()\n",
+    "                _sum += _.shape[0]\n",
+    "            \n",
+    "            acc = correct / _sum\n",
+    "\n",
+    "            if average_validation[0] <= _epoch <= average_validation[1]:\n",
+    "                total_val += acc\n",
+    "\n",
+    "            # update best validation accuracy if it was higher, otherwise increase early stop count\n",
+    "            if acc > best_acc :\n",
+    "                best_acc = acc\n",
+    "                early_stop_cnt = 0\n",
+    "            else:\n",
+    "                early_stop_cnt += 1\n",
+    "\n",
+    "            # exit if validation gets worse over 10 runs and using early stopping\n",
+    "            if early_stop_cnt >= early_stop_num and early_stop_flag:\n",
+    "                break\n",
+    "\n",
+    "            # exit if using fixed epoch length\n",
+    "            if _epoch >= average_validation[1] and not early_stop_flag:\n",
+    "                best_acc = total_val / (average_validation[1] - average_validation[0] + 1)\n",
+    "                break\n",
+    "\n",
+    "        # update q_values\n",
+    "        if policy < num_policies:\n",
+    "            q_values[this_policy] += best_acc\n",
+    "        else:\n",
+    "            q_values[this_policy] = (q_values[this_policy]*cnts[this_policy] + best_acc) / (cnts[this_policy] + 1)\n",
+    "\n",
+    "        best_q_value = max(q_values)\n",
+    "        best_q_values.append(best_q_value)\n",
+    "\n",
+    "        if (policy+1) % 10 == 0:\n",
+    "            print(\"Iteration: {},\\tQ-Values: {}, Best Policy: {}\".format(policy+1, list(np.around(np.array(q_values),2)), max(list(np.around(np.array(q_values),2)))))\n",
+    "\n",
+    "        # update counts\n",
+    "        cnts[this_policy] += 1\n",
+    "        total_count += 1\n",
+    "\n",
+    "        # update q_plus_cnt values every turn after the initial sweep through\n",
+    "        if policy >= num_policies - 1:\n",
+    "            for i in range(num_policies):\n",
+    "                q_plus_cnt[i] = q_values[i] + np.sqrt(2*np.log(total_count)/cnts[i])\n",
+    "\n",
+    "    return q_values, best_q_values"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {
+    "colab": {
+     "base_uri": "https://localhost:8080/",
+     "height": 342
     },
+    "id": "doHUtJ_tEiA6",
+    "outputId": "8195ba17-c95f-4b75-d8dc-19c5d76d5e43"
+   },
+   "outputs": [
     {
-      "cell_type": "code",
-      "source": [
-        "\"\"\"Randomly generate 10 policies\"\"\"\n",
-        "\"\"\"Each policy has 5 sub-policies\"\"\"\n",
-        "\"\"\"For each sub-policy, pick 2 transformations, 2 probabilities and 2 magnitudes\"\"\"\n",
-        "\n",
-        "def generate_policies(num_policies, num_sub_policies):\n",
-        "    \n",
-        "    policies = np.zeros([num_policies,num_sub_policies,6])\n",
-        "\n",
-        "    # Policies array will be 10x5x6\n",
-        "    for policy in range(num_policies):\n",
-        "        for sub_policy in range(num_sub_policies):\n",
-        "            # pick two sub_policy transformations (0=rotate, 1=shear, 2=scale)\n",
-        "            policies[policy, sub_policy, 0] = np.random.randint(0,3)\n",
-        "            policies[policy, sub_policy, 1] = np.random.randint(0,3)\n",
-        "            while policies[policy, sub_policy, 0] == policies[policy, sub_policy, 1]:\n",
-        "                policies[policy, sub_policy, 1] = np.random.randint(0,3)\n",
-        "\n",
-        "            # pick probabilities\n",
-        "            policies[policy, sub_policy, 2] = np.random.randint(0,11) / 10\n",
-        "            policies[policy, sub_policy, 3] = np.random.randint(0,11) / 10\n",
-        "\n",
-        "            # pick magnitudes\n",
-        "            for transformation in range(2):\n",
-        "                if policies[policy, sub_policy, transformation] <= 1:\n",
-        "                    policies[policy, sub_policy, transformation + 4] = np.random.randint(-4,5)*5\n",
-        "                elif policies[policy, sub_policy, transformation] == 2:\n",
-        "                    policies[policy, sub_policy, transformation + 4] = np.random.randint(5,15)/10\n",
-        "\n",
-        "    return policies"
-      ],
-      "metadata": {
-        "id": "Iql-c88jGGWy"
-      },
-      "execution_count": 6,
-      "outputs": []
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "100%|██████████| 10/10 [01:28<00:00,  8.84s/it]"
+     ]
     },
     {
-      "cell_type": "code",
-      "source": [
-        "\"\"\"Pick policy and sub-policy\"\"\"\n",
-        "\"\"\"Each row of data should have a different sub-policy but for now, this will do\"\"\"\n",
-        "\n",
-        "def sample_sub_policy(policies, policy, num_sub_policies):\n",
-        "    sub_policy = np.random.randint(0,num_sub_policies)\n",
-        "\n",
-        "    degrees = 0\n",
-        "    shear = 0\n",
-        "    scale = 1\n",
-        "\n",
-        "    # check for rotations\n",
-        "    if policies[policy, sub_policy][0] == 0:\n",
-        "        if np.random.uniform() < policies[policy, sub_policy][2]:\n",
-        "            degrees = policies[policy, sub_policy][4]\n",
-        "    elif policies[policy, sub_policy][1] == 0:\n",
-        "        if np.random.uniform() < policies[policy, sub_policy][3]:\n",
-        "            degrees = policies[policy, sub_policy][5]\n",
-        "\n",
-        "    # check for shears\n",
-        "    if policies[policy, sub_policy][0] == 1:\n",
-        "        if np.random.uniform() < policies[policy, sub_policy][2]:\n",
-        "            shear = policies[policy, sub_policy][4]\n",
-        "    elif policies[policy, sub_policy][1] == 1:\n",
-        "        if np.random.uniform() < policies[policy, sub_policy][3]:\n",
-        "            shear = policies[policy, sub_policy][5]\n",
-        "\n",
-        "    # check for scales\n",
-        "    if policies[policy, sub_policy][0] == 2:\n",
-        "        if np.random.uniform() < policies[policy, sub_policy][2]:\n",
-        "            scale = policies[policy, sub_policy][4]\n",
-        "    elif policies[policy, sub_policy][1] == 2:\n",
-        "        if np.random.uniform() < policies[policy, sub_policy][3]:\n",
-        "            scale = policies[policy, sub_policy][5]\n",
-        "\n",
-        "    return degrees, shear, scale"
-      ],
-      "metadata": {
-        "id": "QE2VWI8o731X"
-      },
-      "execution_count": 7,
-      "outputs": []
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Iteration: 10,\tQ-Values: [0.77, 0.74, 0.8, 0.72, 0.77], Best Policy: 0.8\n",
+      "CPU times: user 1min 21s, sys: 694 ms, total: 1min 22s\n",
+      "Wall time: 1min 28s\n"
+     ]
     },
     {
-      "cell_type": "code",
-      "execution_count": 8,
-      "metadata": {
-        "id": "vu_4I4qkbx73"
-      },
-      "outputs": [],
-      "source": [
-        "\"\"\"Sample policy, open and apply above transformations\"\"\"\n",
-        "def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, early_stop_flag, average_validation, iterations, IsLeNet):\n",
-        "\n",
-        "    # get number of policies and sub-policies\n",
-        "    num_policies = len(policies)\n",
-        "    num_sub_policies = len(policies[0])\n",
-        "\n",
-        "    #Initialize vector weights, counts and regret\n",
-        "    q_values = [0]*num_policies\n",
-        "    cnts = [0]*num_policies\n",
-        "    q_plus_cnt = [0]*num_policies\n",
-        "    total_count = 0\n",
-        "\n",
-        "    best_q_values = []\n",
-        "\n",
-        "    for policy in trange(iterations):\n",
-        "\n",
-        "        # get the action to try (either initially in order or using best q_plus_cnt value)\n",
-        "        if policy >= num_policies:\n",
-        "            this_policy = np.argmax(q_plus_cnt)\n",
-        "        else:\n",
-        "            this_policy = policy\n",
-        "\n",
-        "        # get info of transformation for this sub-policy\n",
-        "        degrees, shear, scale = sample_sub_policy(policies, this_policy, num_sub_policies)\n",
-        "\n",
-        "        # create transformations using above info\n",
-        "        transform = torchvision.transforms.Compose(\n",
-        "            [torchvision.transforms.RandomAffine(degrees=(degrees,degrees), shear=(shear,shear), scale=(scale,scale)),\n",
-        "            torchvision.transforms.ToTensor()])\n",
-        "\n",
-        "        # open data and apply these transformations\n",
-        "        if ds == \"MNIST\":\n",
-        "            train_dataset = datasets.MNIST(root='./MetaAugment/train', train=True, download=True, transform=transform)\n",
-        "            test_dataset = datasets.MNIST(root='./MetaAugment/test', train=False, download=True, transform=transform)\n",
-        "        elif ds == \"KMNIST\":\n",
-        "            train_dataset = datasets.KMNIST(root='./MetaAugment/train', train=True, download=True, transform=transform)\n",
-        "            test_dataset = datasets.KMNIST(root='./MetaAugment/test', train=False, download=True, transform=transform)\n",
-        "        elif ds == \"FashionMNIST\":\n",
-        "            train_dataset = datasets.FashionMNIST(root='./MetaAugment/train', train=True, download=True, transform=transform)\n",
-        "            test_dataset = datasets.FashionMNIST(root='./MetaAugment/test', train=False, download=True, transform=transform)\n",
-        "        elif ds == \"CIFAR10\":\n",
-        "            train_dataset = datasets.CIFAR10(root='./MetaAugment/train', train=True, download=True, transform=transform)\n",
-        "            test_dataset = datasets.CIFAR10(root='./MetaAugment/test', train=False, download=True, transform=transform)\n",
-        "        elif ds == \"CIFAR100\":\n",
-        "            train_dataset = datasets.CIFAR100(root='./MetaAugment/train', train=True, download=True, transform=transform)\n",
-        "            test_dataset = datasets.CIFAR100(root='./MetaAugment/test', train=False, download=True, transform=transform)\n",
-        "\n",
-        "        # check sizes of images\n",
-        "        img_height = len(train_dataset[0][0][0])\n",
-        "        img_width = len(train_dataset[0][0][0][0])\n",
-        "        img_channels = len(train_dataset[0][0])\n",
-        "\n",
-        "        # check output labels\n",
-        "        if ds == \"CIFAR10\" or ds == \"CIFAR100\":\n",
-        "            num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1)\n",
-        "        else:\n",
-        "            num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1).item()\n",
-        "\n",
-        "        # create toy dataset from above uploaded data\n",
-        "        train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, toy_size)\n",
-        "\n",
-        "        # create model\n",
-        "        device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
-        "        if IsLeNet == \"LeNet\":\n",
-        "            model = LeNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device)\n",
-        "        elif IsLeNet == \"EasyNet\":\n",
-        "            model = EasyNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device)\n",
-        "        else:\n",
-        "            model = SimpleNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device)\n",
-        "        sgd = optim.SGD(model.parameters(), lr=1e-1)\n",
-        "        cost = nn.CrossEntropyLoss()\n",
-        "\n",
-        "        # set variables for best validation accuracy and early stop count\n",
-        "        best_acc = 0\n",
-        "        early_stop_cnt = 0\n",
-        "        total_val = 0\n",
-        "\n",
-        "        # train model and check validation accuracy each epoch\n",
-        "        for _epoch in range(max_epochs):\n",
-        "\n",
-        "            # train model\n",
-        "            model.train()\n",
-        "            for idx, (train_x, train_label) in enumerate(train_loader):\n",
-        "                train_x, train_label = train_x.to(device), train_label.to(device) # new code\n",
-        "                label_np = np.zeros((train_label.shape[0], num_labels))\n",
-        "                sgd.zero_grad()\n",
-        "                predict_y = model(train_x.float())\n",
-        "                loss = cost(predict_y, train_label.long())\n",
-        "                loss.backward()\n",
-        "                sgd.step()\n",
-        "\n",
-        "            # check validation accuracy on validation set\n",
-        "            correct = 0\n",
-        "            _sum = 0\n",
-        "            model.eval()\n",
-        "            for idx, (test_x, test_label) in enumerate(test_loader):\n",
-        "                test_x, test_label = test_x.to(device), test_label.to(device) # new code\n",
-        "                predict_y = model(test_x.float()).detach()\n",
-        "                #predict_ys = np.argmax(predict_y, axis=-1)\n",
-        "                predict_ys = torch.argmax(predict_y, axis=-1) # changed np to torch\n",
-        "                #label_np = test_label.numpy()\n",
-        "                _ = predict_ys == test_label\n",
-        "                #correct += np.sum(_.numpy(), axis=-1)\n",
-        "                correct += np.sum(_.cpu().numpy(), axis=-1) # added .cpu()\n",
-        "                _sum += _.shape[0]\n",
-        "            \n",
-        "            acc = correct / _sum\n",
-        "\n",
-        "            if average_validation[0] <= _epoch <= average_validation[1]:\n",
-        "                total_val += acc\n",
-        "\n",
-        "            # update best validation accuracy if it was higher, otherwise increase early stop count\n",
-        "            if acc > best_acc :\n",
-        "                best_acc = acc\n",
-        "                early_stop_cnt = 0\n",
-        "            else:\n",
-        "                early_stop_cnt += 1\n",
-        "\n",
-        "            # exit if validation gets worse over 10 runs and using early stopping\n",
-        "            if early_stop_cnt >= early_stop_num and early_stop_flag:\n",
-        "                break\n",
-        "\n",
-        "            # exit if using fixed epoch length\n",
-        "            if _epoch >= average_validation[1] and not early_stop_flag:\n",
-        "                best_acc = total_val / (average_validation[1] - average_validation[0] + 1)\n",
-        "                break\n",
-        "\n",
-        "        # update q_values\n",
-        "        if policy < num_policies:\n",
-        "            q_values[this_policy] += best_acc\n",
-        "        else:\n",
-        "            q_values[this_policy] = (q_values[this_policy]*cnts[this_policy] + best_acc) / (cnts[this_policy] + 1)\n",
-        "\n",
-        "        best_q_value = max(q_values)\n",
-        "        best_q_values.append(best_q_value)\n",
-        "\n",
-        "        if (policy+1) % 10 == 0:\n",
-        "            print(\"Iteration: {},\\tQ-Values: {}, Best Policy: {}\".format(policy+1, list(np.around(np.array(q_values),2)), max(list(np.around(np.array(q_values),2)))))\n",
-        "\n",
-        "        # update counts\n",
-        "        cnts[this_policy] += 1\n",
-        "        total_count += 1\n",
-        "\n",
-        "        # update q_plus_cnt values every turn after the initial sweep through\n",
-        "        if policy >= num_policies - 1:\n",
-        "            for i in range(num_policies):\n",
-        "                q_plus_cnt[i] = q_values[i] + np.sqrt(2*np.log(total_count)/cnts[i])\n",
-        "\n",
-        "    return q_values, best_q_values"
-      ]
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "\n"
+     ]
     },
     {
-      "cell_type": "code",
-      "source": [
-        "%%time\n",
-        "\n",
-        "batch_size = 32               # size of batch the inner NN is trained with\n",
-        "learning_rate = 1e-1          # fix learning rate\n",
-        "ds = \"MNIST\"                  # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100)\n",
-        "toy_size = 0.02               # total propeortion of training and test set we use\n",
-        "max_epochs = 100              # max number of epochs that is run if early stopping is not hit\n",
-        "early_stop_num = 10           # max number of worse validation scores before early stopping is triggered\n",
-        "early_stop_flag = True        # implement early stopping or not\n",
-        "average_validation = [15,25]  # if not implementing early stopping, what epochs are we averaging over\n",
-        "num_policies = 5              # fix number of policies\n",
-        "num_sub_policies = 5          # fix number of sub-policies in a policy\n",
-        "iterations = 100              # total iterations, should be more than the number of policies\n",
-        "IsLeNet = \"SimpleNet\"         # using LeNet or EasyNet or SimpleNet\n",
-        "\n",
-        "# generate random policies at start\n",
-        "policies = generate_policies(num_policies, num_sub_policies)\n",
-        "\n",
-        "q_values, best_q_values = run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, early_stop_flag, average_validation, iterations, IsLeNet)\n",
-        "\n",
-        "plt.plot(best_q_values)\n",
-        "\n",
-        "best_q_values = np.array(best_q_values)\n",
-        "save('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), best_q_values)\n",
-        "#best_q_values = load('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), allow_pickle=True)"
-      ],
-      "metadata": {
-        "colab": {
-          "base_uri": "https://localhost:8080/",
-          "height": 342
-        },
-        "id": "doHUtJ_tEiA6",
-        "outputId": "8195ba17-c95f-4b75-d8dc-19c5d76d5e43"
-      },
-      "execution_count": 9,
-      "outputs": [
-        {
-          "output_type": "stream",
-          "name": "stderr",
-          "text": [
-            "100%|██████████| 10/10 [01:28<00:00,  8.84s/it]"
-          ]
-        },
-        {
-          "output_type": "stream",
-          "name": "stdout",
-          "text": [
-            "Iteration: 10,\tQ-Values: [0.77, 0.74, 0.8, 0.72, 0.77], Best Policy: 0.8\n",
-            "CPU times: user 1min 21s, sys: 694 ms, total: 1min 22s\n",
-            "Wall time: 1min 28s\n"
-          ]
-        },
-        {
-          "output_type": "stream",
-          "name": "stderr",
-          "text": [
-            "\n"
-          ]
-        },
-        {
-          "output_type": "display_data",
-          "data": {
-            "text/plain": [
-              "<Figure size 432x288 with 1 Axes>"
-            ],
-            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD7CAYAAABkO19ZAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAfGklEQVR4nO3da3BU95nn8e8jtS4gQAJJGFDLRlzGgM3FoCGOwcnYjis2ztpOYjJQm93KVCqeF+PsbDapLc/UrCvl2rzI1m4ymawzVa7Z3VRlZu0F4mScMTGuGXtm1LZjmzsGDGrABombWiAhEJKQ9OyLbtmNEKaBbp3u079PlYru0+d0P93AT6ef8z//Y+6OiIiEV0nQBYiISG4p6EVEQk5BLyIScgp6EZGQU9CLiIScgl5EJOQyCnoze8jMDphZ3MyeHuPxW83sDTPbYWa7zWxNavmDZrbNzPak/rw/229AREQ+nV1rHL2ZlQIHgQeBNuA9YL2770tb53lgh7v/tZktAja7+2wzuws45e7HzexOYIu7N+TqzYiIyJUiGayzEoi7+2EAM3sReAzYl7aOA1NSt6uB4wDuviNtnb3ABDOrcPf+q71YXV2dz549O+M3ICIisG3btoS714/1WCZB3wAcS7vfBnxm1DrfB14zs28DVcAXxnierwLbxwp5M3sSeBLg1ltvZevWrRmUJSIiI8zso6s9lq2DseuBn7t7FFgD/MLMPn5uM7sD+CHwx2Nt7O7Pu3uzuzfX14/5C0lERG5QJkHfDjSm3Y+mlqX7JrABwN3fBiqBOgAziwK/Av69ux+62YJFROT6ZBL07wHzzazJzMqBdcDLo9Y5CjwAYGYLSQZ9h5nVAK8AT7v7m9krW0REMnXNoHf3QeApYAuwH9jg7nvN7FkzezS12neBb5nZLuAF4BueHM7zFDAPeMbMdqZ+pufknYiIyJiuObxyvDU3N7sOxoqIXB8z2+buzWM9pjNjRURCTkEvIhJymYyjlwy4OyfP9bHrWBcHTp5naHg46JJklPoplSyL1nD7jMmUR7SPI8VDQX+Dui9eYk9bN7vauth5rItdx7o43fPJuWBmARYnV0g/FFUeKeGOWVNYGq1hWWMNSxtrmF07EdNfmoSUgj4D/YND7D/Rw65UoO9s6+Jwx4WPH59TV8WqeXUsjVaztLGGhTOnUFlWGmDFMpq703b2IrvaulJ/j938v/eO8fO3PgSgekIZS6LVyeCP1rCksZrpkyuDLVokSzTqZpThYedw4jw7j3WzOxUK+06c49JQ8nOqn1yRCoNkqC9pqKF6Yllg9cqNGxwapvX0+WTwtyXD/8CpHoaGk3/XDTUTWNpYnQz+aA2Lo9VMqtC+keSnTxt1U/RBf7K7L9l6SYX6nrZuevoHAagqL2VJNPnVflljMthnTKnUV/wQuzgwxN7j3al/E93sOtbF0TO9QLIdN3/6JJZ+/G8i2e8vK1W/X4KnoE8515fsq4/01He1dXHqXLKvXlZqLJgx5eM9uGWNNcypn0RpiUK92J25MJDW8kn+AjhzYQCAipF+fyr4l0ZruE39fglAUQZ9/+AQH5zouexg6aFRffWlaS0Y9dUlUyP9/p3HulLtvW72tHdz8dIQkOz3L22sYVm0+uNvhPWTKwKuWsKuKIL+Qv8gW/aeTB0s7Wb/8XMMDCWHONZNSvbVR9ov6qtLto3u9+881s2Bk+dItfsv6/f/27tvU69fsu7Tgj40/9oGBof5Txt2UVVeyuJoNX+0ejbLUntTM6vVV5fcipSWsHDmFBbOnMK6lbcC0DswyN7j55I7H6lfAJv3nOTY2V7+6+OLA65Yiklo9ugBDnWcZ3Ztlfrqkree+r/beftQJ+/8+QNEdBBXsqho5rqZq4OnkuceWTyTzgsDvHvkTNClSBEJVdCL5Ls/uH06E8pKeWXPiaBLkSKioBcZRxPKS3lg4XReff8kg0OaD0nGh4JeZJypfSPjTUEvMs7UvpHxpqAXGWdq38h4U9CLBEDtGxlPCnqRAKh9I+NJQS8SgAnlpdy/cDpb9qp9I7mnoBcJyJcWzyRxXu0byT0FvUhA1L6R8aKgFwmI2jcyXhT0IgFS+0bGg4JeJEBq38h4UNCLBEjtGxkPCnqRgD2i9o3kmIJeJGD3qX0jOaagFwlYevtmaDi/rvgm4ZBR0JvZQ2Z2wMziZvb0GI/famZvmNkOM9ttZmtSy2tTy8+b2f/MdvEiYTHSvnnnSGfQpUgIXTPozawUeA54GFgErDezRaNW+wtgg7vfBawDfpZa3gf8F+B7WatYJIQ+bt/sVvtGsi+TPfqVQNzdD7v7APAi8NiodRyYkrpdDRwHcPcL7h4jGfgichVq30guZRL0DcCxtPttqWXpvg983czagM3At7NSnUgRUftGciVbB2PXAz939yiwBviFmWX83Gb2pJltNbOtHR0dWSpJpLCofSO5kkkYtwONafejqWXpvglsAHD3t4FKoC7TItz9eXdvdvfm+vr6TDcTCRW1byRXMgn694D5ZtZkZuUkD7a+PGqdo8ADAGa2kGTQa9dc5DqpfSO5cM2gd/dB4ClgC7Cf5OiavWb2rJk9mlrtu8C3zGwX8ALwDXd3ADP7EPgR8A0zaxtjxI6IpKh9I7kQyWQld99M8iBr+rJn0m7vA1ZdZdvZN1GfSFFJb988+9idlJZY0CVJCOjMWJE8o/aNZJuCXiTPjLRvNmvuG8kSBb1Inhlp37z6vkbfSHYo6EXykNo3kk0KepE8pPaNZJOCXiQPTSgv5f4Fat9IdijoRfLUI0vUvpHsUNCL5Cm1byRbFPQieUrtG8kWBb1IHlP7RrJBQS+Sx9S+kWxQ0IvkMbVvJBsU9CJ5bo1OnpKbpKAXyXP3LahX+0ZuioJeJM9NLI+ofSM3RUEvUgDUvpGboaAXKQBq38jNUNCLFIBP2jen1L6R66agFykQyfZNP+8eORN0KVJgFPQiBeK+BfVUlpXwyp7jQZciBUZBL1IgJpZHeGDBLWrfyHVT0IsUELVv5EYo6EUKiNo3ciMU9CIFRO0buREKepECo/aNXC8FvUiBUftGrpeCXqTA6OQpuV4KepEC9MjiWWrfSMYU9CIFaKR9o7lvJBMKepECNNK++a2mLpYMKOhFCpTaN5KpjILezB4yswNmFjezp8d4/FYze8PMdpjZbjNbk/bYn6W2O2BmX8xm8SLFTO0bydQ1g97MSoHngIeBRcB6M1s0arW/ADa4+13AOuBnqW0Xpe7fATwE/Cz1fCJyk9S+kUxlske/Eoi7+2F3HwBeBB4btY4DU1K3q4GRAb6PAS+6e7+7HwHiqecTkSxQ+0YykUnQNwDH0u63pZal+z7wdTNrAzYD376ObTGzJ81sq5lt7ejoyLB0EVH7RjKRrYOx64Gfu3sUWAP8wswyfm53f97dm929ub6+PksliYSf2jfh8c8HTvOvB3Ozo5tJGLcDjWn3o6ll6b4JbABw97eBSqAuw21F5CZo7ptw+Mt/bOUn/9Sak+fOJOjfA+abWZOZlZM8uPryqHWOAg8AmNlCkkHfkVpvnZlVmFkTMB94N1vFiwjcv2C62jcFrrv3Ervbulg1ry4nz3/NoHf3QeApYAuwn+Tomr1m9qyZPZpa7bvAt8xsF/AC8A1P2ktyT38f8CrwJ+4+lIs3IlKs1L4pfG8fTjDscO/83AR9JJOV3H0zyYOs6cueSbu9D1h1lW1/APzgJmoUkWtYs3gmm/ec5N0jZ/js3Nqgy5Hr1NKaYFJFhGWNNTl5fp0ZKxICat8Utlg8wd1zplFWmptIVtCLhIDaN4Xr2JlePursZXWO+vOgoBcJDY2+KUwtrQkAVs/P3dByBb1ISKh9U5hi8Q5mVlcyt74qZ6+hoBcJCbVvCs/QsPNmvJPV8+ows5y9joJeJERG2jfvfaj2TSF4v72b7ouXWJ2jYZUjFPQiITLSvnllt9o3hSAWT/bnc3Wi1AgFvUiIqH1TWFpaO1g0cwp1kypy+joKepGQUfumMPQODLLto7M5Oxs2nYJeJGTUvikM7xw5w6Uhz3l/HhT0IqEzsTzCfberfZPvYq0JyiMl/P7saTl/LQW9SAg9skTtm3wXa02wcvY0Kstyf3VVBb1ICKl9k99On+vjwKmecWnbgIJeJJTUvslvI8Mqczm/TToFvUhIqX2Tv2KtCWqrylk0c8q4vJ6CXiSkNPdNfnJ3YvEE98yro6Qkd9MepFPQi4TUSPtm8x61b/LJwVPnOd3Tz73j1LYBBb1IqOnkqfzT0toBMG4HYkFBLxJqat/knzfjCebUVzGrZsK4vaaCXiTEqirUvsknA4PDvHPkzLiNthmhoBcJObVv8sf2o2fpHRhS0ItIdql9kz9irQlKS4y759aO6+sq6EVCTu2b/NEST7CssYYplWXj+roKepEioPZN8Lp7L7GnrWvc2zagoBcpCmrfBO+tQwmGnXGZf340Bb1IERhp32jum+C0xBNMqoiwtLFm3F9bQS9SJNYsnklHj9o3QYm1Jrh7Ti1lpeMfuwp6kSJx/4LpVETUvgnC0c5ejp7pDaRtAwp6kaJRVaELhwelJT7+0x6kU9CLFBG1b4IRa00wq7qSOXVVgbx+RkFvZg+Z2QEzi5vZ02M8/mMz25n6OWhmXWmP/dDM3k/9/GE2ixeR66P2zfgbGnbeOtTJ6vl1mI3PtMSjXTPozawUeA54GFgErDezRenruPt33H2Zuy8Dfgq8lNr2EWA5sAz4DPA9MxufmfZF5Apq34y/Pe3ddF+8xOr59YHVkMke/Uog7u6H3X0AeBF47FPWXw+8kLq9CPhXdx909wvAbuChmylYRG7OSPvm3SNq34yHWGpa4lXjPO1BukyCvgE4lna/LbXsCmZ2G9AEvJ5atAt4yMwmmlkdcB/QOMZ2T5rZVjPb2tHRcT31i8h1emDhdCZVRHhpe1vQpRSFltYEd8yaQu2kisBqyPbB2HXAJncfAnD314DNwFsk9/LfBoZGb+Tuz7t7s7s319cH9/VGpBhMLI/wyOKZvLLnBBf6B4MuJ9Qu9A+y/ejZwEbbjMgk6Nu5fC88mlo2lnV80rYBwN1/kOrfPwgYcPBGChWR7FnbHKV3YIhXdFA2p949coZLQ86984Ldgc0k6N8D5ptZk5mVkwzzl0evZGYLgKkk99pHlpWaWW3q9hJgCfBaNgoXkRu34rapzKmrYtNWtW9yqaU1QUWkhObZUwOt45pB7+6DwFPAFmA/sMHd95rZs2b2aNqq64AX3T39UH4Z0GJm+4Dnga+nnk9EAmRmfHVFlHc/PMOHiQtBlxNasXgHK5umUVlWGmgdkUxWcvfNJHvt6cueGXX/+2Ns10dy5I2I5JmvLo/yP147wKZtbXzvi7cHXU7onDrXx8FT5/nq8mjQpejMWJFiNaO6ks/9Xj2btrVpTH0OxFoTQHDTHqRT0IsUsbUrGjl5ro9YPBF0KaHzZjxBbVU5C2cEf46ogl6kiH1h0XRqJpaxceuxa68sGXN3YvEE98yro6QkmGkP0inoRYpYRaSUx5c18Nq+U3T1DgRdTmgcPHWe0z393BvAZQPHoqAXKXJPrIgyMDjMy7uOB11KaLS0Bjst8WgKepEid2dDNQtnTmGjxtRnTSyeYE59FbNqJgRdCqCgFxFg7Yooe9q7+eDkuaBLKXj9g0O8c/hM3rRtQEEvIsDjdzVQVmraq8+C7R91cfHSUKDTEo+moBcRplWV84WFt/DrHe0MDA4HXU5Bi8U7KC0x7p4zLehSPqagFxEgOdFZ54UBXv/gdNClFLRYa4K7GmuYXFkWdCkfU9CLCACfm1/P9MkVbNqmMfU3qqt3gN3t3Xkz2maEgl5EAIiUlvDl5Q28caCD0z19QZdTkN461Ik73KugF5F8tXZFI0PDzq93XO2SE/JpWloTTK6IsDRaE3Qpl1HQi8jH5k2fxPJba9iwtY3LZxyXTMTiHdw9t5ZIaX5Fa35VIyKBW9vcSPz0eXYe6wq6lILyUecFjp25mHdtG1DQi8goX1oyk8qyEjZu05j669EyMi1xHp0oNUJBLyKXmVxZxpo7Z/Kbnce5ODAUdDkFI9aaoKFmAk11VUGXcgUFvYhc4YnmKD39g2zZezLoUgrC0LDz1qEEq+fVYRb8tMSjKehF5Ap3N9USnTqBjRpTn5HdbV2c6xvMu/HzIxT0InKFkhLjiRVR3jrUSdvZ3qDLyXux1gRmsCoP+/OgoBeRq3hiRfKi1r/cpjH119IST3DHrClMqyoPupQxKehFZEzRqRO5Z24tm7YfY1gXD7+qC/2D7Dh6ltXz8me2ytEU9CJyVWtXNHLszEV+d6Qz6FLy1jtHOrk05Hk5fn6Egl5EruqLd8xgckWETZqn/qpirZ1UREpYcdvUoEu5KgW9iFzVhPJSvrR0FpvfP0FP36Wgy8lLsXgHK5umUVlWGnQpV6WgF5FP9bXmKH2Xhnll94mgS8k7p871cfDU+bw8Gzadgl5EPtWyxhrmTZ/Ehq0aUz9abGTagzzuz4OCXkSuwcxYuyLK9qNdxE+fD7qcvBKLJ6itKmfhjClBl/KpFPQick1fXt5AaYmxSROdfczdicUTrJpXR0lJ/k17kE5BLyLXNH1yJffdXs9L29sYHNLFwwEOnOqho6c/79s2kGHQm9lDZnbAzOJm9vQYj//YzHamfg6aWVfaY//NzPaa2X4z+yvLxxl/ROSanljRyOme/o+n4y12I/35fB4/P+KaQW9mpcBzwMPAImC9mS1KX8fdv+Puy9x9GfBT4KXUtvcAq4AlwJ3A7wOfz+o7EJFxcf+C6UyrKtdB2ZSW1gRz66uYWT0h6FKuKZM9+pVA3N0Pu/sA8CLw2Kesvx54IXXbgUqgHKgAyoBTN16uiASlPFLC48sa+Mf9pzhzYSDocgLVPzjEO0c6uXd+/k57kC6ToG8A0n+Ft6WWXcHMbgOagNcB3P1t4A3gROpni7vvH2O7J81sq5lt7ejouL53ICLjZm1zlEtDzt/vLO6JzrZ9dJa+S8N5P35+RLYPxq4DNrn7EICZzQMWAlGSvxzuN7N7R2/k7s+7e7O7N9fXF8ZvSJFitHDmFBY3VLOxyKdEiLUmiJQYd8+tDbqUjGQS9O1AY9r9aGrZWNbxSdsG4MvA79z9vLufB34LfPZGChWR/LC2Ocq+E+d4v7076FICE4snuOvWGiZVRIIuJSOZBP17wHwzazKzcpJh/vLolcxsATAVeDtt8VHg82YWMbMykgdir2jdiEjheHTpLMpLS4p2TP3ZCwPsae/O62mJR7tm0Lv7IPAUsIVkSG9w971m9qyZPZq26jrgRXdPn7h6E3AI2APsAna5+2+yVr2IjLuaieU8eMct/HpnO/2DxXfx8LcOdeKe/9MepMvoe4e7bwY2j1r2zKj73x9juyHgj2+iPhHJQ19rbuSV3Sf4p/2nWbN4ZtDljKtYvIPJlRGWRquDLiVjOjNWRK7b6nl1zKyuZGORjal3d1paE3x2Ti2R0sKJz8KpVETyRmmJ8ZXlDfzLwQ5OdvcFXc64+aizl7azFwvibNh0CnoRuSFPrGhk2OGlHcVzULYlPjItceEciAUFvYjcoKa6KlbOnsamrW1cPgYjvGKtHTTUTGB27cSgS7kuCnoRuWFPNEc5nLjA9qNngy4l5waHhnnrUCf3zq+j0OZmVNCLyA17ZPFMJpaXsuG98Ldvdrd309M3WFDDKkco6EXkhlVVRFizeCb/sPs4vQODQZeTU7HWBGZwz1wFvYgUmbUrolwYGOK3e04GXUpOxeIJ7pg1hWlV5UGXct0U9CJyU1Y2TWN27UQ2bgvvmPoL/YPsOHq2oKY9SKegF5GbYmY8sSLK7w6f4Whnb9Dl5MQ7Rzq5NOQFN35+hIJeRG7aV5ZHMYNNId2rb2lNUBEpYcVtU4Mu5YYo6EXkps2qmcDqeXX8cns7w8PhG1Mfa02wsmkalWWlQZdyQxT0IpIVX2tupL3rIm8d6gy6lKw62d1H6+nzBdu2AQW9iGTJg4tuYUplJHQHZWMj0x4U6IFYUNCLSJZUlpXy2LIGXn3/JN0XLwVdTtbEWjuom1TOghmTgy7lhinoRSRr1jZH6R8c5je7jgddSla4O7F4J6vm1VFSUljTHqRT0ItI1ixuqOb2WyazMSSXGfzgZA+J8/2snle4/XlQ0ItIFpkZa5uj7DrWxcFTPUGXc9Nircn+/L0FNi3xaAp6EcmqL9/VQKTEQnH1qZZ4gnnTJzGjujLoUm6Kgl5Esqp2UgX3L5jOr3a0c2loOOhybljfpSHePdJZ8G0bUNCLSA6sbW4kcX6Afz7QEXQpN2z7R2fpuzRc0OPnRyjoRSTr/uD2euomVRR0+6YlniBSYnxmTm3Qpdw0Bb2IZF1ZaQlfWd7A6x+cJnG+P+hybkisNcHyW6cyqSISdCk3TUEvIjmxdkWUwWHn1zvagy7lup29MMD7x7sL8mpSY1HQi0hOzL9lMksba9hYgBcPf/NQAncU9CIi1/K15igHTvWwp7076FKuS6w1weTKCEsaqoMuJSsU9CKSM/9m6SwqIiVs3Fo4Z8q6Oy2tCe6ZW0ukNBwRGY53ISJ5aUplGQ/dOYO/39lO36WhoMvJyIedvbR3XQzF+PkRCnoRyam1Kxo51zfIa/tOBV1KRmKtybH/qwt82oN0CnoRyal75tbSUDOhYMbUx+IJGmomMLt2YtClZE1GQW9mD5nZATOLm9nTYzz+YzPbmfo5aGZdqeX3pS3faWZ9ZvZ4tt+EiOSvkhLjqyuixOIJjnddDLqcTzU4NMxbhzq5d34dZoU7LfFo1wx6MysFngMeBhYB681sUfo67v4dd1/m7suAnwIvpZa/kbb8fqAXeC3L70FE8tzaFVHc4aXt+X1Qdnd7Nz19g6EZVjkikz36lUDc3Q+7+wDwIvDYp6y/HnhhjOVPAL91997rL1NEClnjtIncPWcaG7fl95j6WGsCM1g1t/iCvgFIb661pZZdwcxuA5qA18d4eB1j/wLAzJ40s61mtrWjo3AnQRKRq1u7opGPOnt598iZoEu5qlhrgjtnVTO1qjzoUrIq2wdj1wGb3P2ycVRmNhNYDGwZayN3f97dm929ub4+PEe6ReQTDy+ewaSKSN5efep8/yDbj54NXdsGMgv6dqAx7X40tWwsV9tr/xrwK3cPzxWDReS6TCyP8KUlM9m85wTn+weDLucK7xzuZHDYuTdE4+dHZBL07wHzzazJzMpJhvnLo1cyswXAVODtMZ7jan17ESkia5uj9A4MsXn3iaBLuUJLa4LKshJWzJ4adClZd82gd/dB4CmSbZf9wAZ332tmz5rZo2mrrgNe9FFHWsxsNslvBP+SraJFpDAtv3Uqc+qr2Lgt/8bUx+IJVjbVUhEpDbqUrMtoomV33wxsHrXsmVH3v3+VbT/kKgdvRaS4mBlrVzTyw1c/4EjiAk11VUGXBMCJ7ovET5/nD5sbr71yAdKZsSIyrr6yvIESg015tFcfa00A4ZmWeDQFvYiMq1umVPL536vnl9vaGRrOjzH1sXiCukkVLJgxOehSckJBLyLjbm1zIyfP9dHSGvx5M8PDzpvxBKvn1YZq2oN0CnoRGXcPLJxOzcSyvBhT/8HJHhLnB0I1W+VohX/VWxEpOBWRUh5f1sDf/u4jHvxRsAPyevqSY/rDNP/8aAp6EQnEN1c3cbZ3gEtDw0GXwvzpk5lRXRl0GTmjoBeRQDROm8hP1t0VdBlFQT16EZGQU9CLiIScgl5EJOQU9CIiIaegFxEJOQW9iEjIKehFREJOQS8iEnKWb1dkN7MO4KObeIo6IJGlcgqdPovL6fO4nD6PT4Ths7jN3cecsCfvgv5mmdlWd28Ouo58oM/icvo8LqfP4xNh/yzUuhERCTkFvYhIyIUx6J8PuoA8os/icvo8LqfP4xOh/ixC16MXEZHLhXGPXkRE0ijoRURCLjRBb2YPmdkBM4ub2dNB1xMkM2s0szfMbJ+Z7TWzPw26pqCZWamZ7TCzfwi6lqCZWY2ZbTKzD8xsv5l9NuiagmRm30n9P3nfzF4ws9BdaioUQW9mpcBzwMPAImC9mS0KtqpADQLfdfdFwN3AnxT55wHwp8D+oIvIEz8BXnX3BcBSivhzMbMG4D8Aze5+J1AKrAu2quwLRdADK4G4ux929wHgReCxgGsKjLufcPftqds9JP8jNwRbVXDMLAo8AvxN0LUEzcyqgc8B/wvA3QfcvSvYqgIXASaYWQSYCBwPuJ6sC0vQNwDH0u63UcTBls7MZgN3Ae8EW0mg/hL4z0DwV6EOXhPQAfyfVCvrb8ysKuiiguLu7cB/B44CJ4Bud38t2KqyLyxBL2Mws0nAL4H/6O7ngq4nCGb2JeC0u28LupY8EQGWA3/t7ncBF4CiPaZlZlNJfvtvAmYBVWb29WCryr6wBH070Jh2P5paVrTMrIxkyP+du78UdD0BWgU8amYfkmzp3W9mfxtsSYFqA9rcfeQb3iaSwV+svgAccfcOd78EvATcE3BNWReWoH8PmG9mTWZWTvJgyssB1xQYMzOSPdj97v6joOsJkrv/mbtH3X02yX8Xr7t76PbYMuXuJ4FjZnZ7atEDwL4ASwraUeBuM5uY+n/zACE8OB0JuoBscPdBM3sK2ELyqPn/dve9AZcVpFXAvwP2mNnO1LI/d/fNAdYk+ePbwN+ldooOA38UcD2Bcfd3zGwTsJ3kaLUdhHA6BE2BICIScmFp3YiIyFUo6EVEQk5BLyIScgp6EZGQU9CLiIScgl5EJOQU9CIiIff/AWm+N7rRaXZ6AAAAAElFTkSuQmCC\n"
-          },
-          "metadata": {
-            "needs_background": "light"
-          }
-        }
+     "data": {
+      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD7CAYAAABkO19ZAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAfGklEQVR4nO3da3BU95nn8e8jtS4gQAJJGFDLRlzGgM3FoCGOwcnYjis2ztpOYjJQm93KVCqeF+PsbDapLc/UrCvl2rzI1m4ymawzVa7Z3VRlZu0F4mScMTGuGXtm1LZjmzsGDGrABombWiAhEJKQ9OyLbtmNEKaBbp3u079PlYru0+d0P93AT6ef8z//Y+6OiIiEV0nQBYiISG4p6EVEQk5BLyIScgp6EZGQU9CLiIScgl5EJOQyCnoze8jMDphZ3MyeHuPxW83sDTPbYWa7zWxNavmDZrbNzPak/rw/229AREQ+nV1rHL2ZlQIHgQeBNuA9YL2770tb53lgh7v/tZktAja7+2wzuws45e7HzexOYIu7N+TqzYiIyJUiGayzEoi7+2EAM3sReAzYl7aOA1NSt6uB4wDuviNtnb3ABDOrcPf+q71YXV2dz549O+M3ICIisG3btoS714/1WCZB3wAcS7vfBnxm1DrfB14zs28DVcAXxnierwLbxwp5M3sSeBLg1ltvZevWrRmUJSIiI8zso6s9lq2DseuBn7t7FFgD/MLMPn5uM7sD+CHwx2Nt7O7Pu3uzuzfX14/5C0lERG5QJkHfDjSm3Y+mlqX7JrABwN3fBiqBOgAziwK/Av69ux+62YJFROT6ZBL07wHzzazJzMqBdcDLo9Y5CjwAYGYLSQZ9h5nVAK8AT7v7m9krW0REMnXNoHf3QeApYAuwH9jg7nvN7FkzezS12neBb5nZLuAF4BueHM7zFDAPeMbMdqZ+pufknYiIyJiuObxyvDU3N7sOxoqIXB8z2+buzWM9pjNjRURCTkEvIhJymYyjlwy4OyfP9bHrWBcHTp5naHg46JJklPoplSyL1nD7jMmUR7SPI8VDQX+Dui9eYk9bN7vauth5rItdx7o43fPJuWBmARYnV0g/FFUeKeGOWVNYGq1hWWMNSxtrmF07EdNfmoSUgj4D/YND7D/Rw65UoO9s6+Jwx4WPH59TV8WqeXUsjVaztLGGhTOnUFlWGmDFMpq703b2IrvaulJ/j938v/eO8fO3PgSgekIZS6LVyeCP1rCksZrpkyuDLVokSzTqZpThYedw4jw7j3WzOxUK+06c49JQ8nOqn1yRCoNkqC9pqKF6Yllg9cqNGxwapvX0+WTwtyXD/8CpHoaGk3/XDTUTWNpYnQz+aA2Lo9VMqtC+keSnTxt1U/RBf7K7L9l6SYX6nrZuevoHAagqL2VJNPnVflljMthnTKnUV/wQuzgwxN7j3al/E93sOtbF0TO9QLIdN3/6JJZ+/G8i2e8vK1W/X4KnoE8515fsq4/01He1dXHqXLKvXlZqLJgx5eM9uGWNNcypn0RpiUK92J25MJDW8kn+AjhzYQCAipF+fyr4l0ZruE39fglAUQZ9/+AQH5zouexg6aFRffWlaS0Y9dUlUyP9/p3HulLtvW72tHdz8dIQkOz3L22sYVm0+uNvhPWTKwKuWsKuKIL+Qv8gW/aeTB0s7Wb/8XMMDCWHONZNSvbVR9ov6qtLto3u9+881s2Bk+dItfsv6/f/27tvU69fsu7Tgj40/9oGBof5Txt2UVVeyuJoNX+0ejbLUntTM6vVV5fcipSWsHDmFBbOnMK6lbcC0DswyN7j55I7H6lfAJv3nOTY2V7+6+OLA65Yiklo9ugBDnWcZ3Ztlfrqkree+r/beftQJ+/8+QNEdBBXsqho5rqZq4OnkuceWTyTzgsDvHvkTNClSBEJVdCL5Ls/uH06E8pKeWXPiaBLkSKioBcZRxPKS3lg4XReff8kg0OaD0nGh4JeZJypfSPjTUEvMs7UvpHxpqAXGWdq38h4U9CLBEDtGxlPCnqRAKh9I+NJQS8SgAnlpdy/cDpb9qp9I7mnoBcJyJcWzyRxXu0byT0FvUhA1L6R8aKgFwmI2jcyXhT0IgFS+0bGg4JeJEBq38h4UNCLBEjtGxkPCnqRgD2i9o3kmIJeJGD3qX0jOaagFwlYevtmaDi/rvgm4ZBR0JvZQ2Z2wMziZvb0GI/famZvmNkOM9ttZmtSy2tTy8+b2f/MdvEiYTHSvnnnSGfQpUgIXTPozawUeA54GFgErDezRaNW+wtgg7vfBawDfpZa3gf8F+B7WatYJIQ+bt/sVvtGsi+TPfqVQNzdD7v7APAi8NiodRyYkrpdDRwHcPcL7h4jGfgichVq30guZRL0DcCxtPttqWXpvg983czagM3At7NSnUgRUftGciVbB2PXAz939yiwBviFmWX83Gb2pJltNbOtHR0dWSpJpLCofSO5kkkYtwONafejqWXpvglsAHD3t4FKoC7TItz9eXdvdvfm+vr6TDcTCRW1byRXMgn694D5ZtZkZuUkD7a+PGqdo8ADAGa2kGTQa9dc5DqpfSO5cM2gd/dB4ClgC7Cf5OiavWb2rJk9mlrtu8C3zGwX8ALwDXd3ADP7EPgR8A0zaxtjxI6IpKh9I7kQyWQld99M8iBr+rJn0m7vA1ZdZdvZN1GfSFFJb988+9idlJZY0CVJCOjMWJE8o/aNZJuCXiTPjLRvNmvuG8kSBb1Inhlp37z6vkbfSHYo6EXykNo3kk0KepE8pPaNZJOCXiQPTSgv5f4Fat9IdijoRfLUI0vUvpHsUNCL5Cm1byRbFPQieUrtG8kWBb1IHlP7RrJBQS+Sx9S+kWxQ0IvkMbVvJBsU9CJ5bo1OnpKbpKAXyXP3LahX+0ZuioJeJM9NLI+ofSM3RUEvUgDUvpGboaAXKQBq38jNUNCLFIBP2jen1L6R66agFykQyfZNP+8eORN0KVJgFPQiBeK+BfVUlpXwyp7jQZciBUZBL1IgJpZHeGDBLWrfyHVT0IsUELVv5EYo6EUKiNo3ciMU9CIFRO0buREKepECo/aNXC8FvUiBUftGrpeCXqTA6OQpuV4KepEC9MjiWWrfSMYU9CIFaKR9o7lvJBMKepECNNK++a2mLpYMKOhFCpTaN5KpjILezB4yswNmFjezp8d4/FYze8PMdpjZbjNbk/bYn6W2O2BmX8xm8SLFTO0bydQ1g97MSoHngIeBRcB6M1s0arW/ADa4+13AOuBnqW0Xpe7fATwE/Cz1fCJyk9S+kUxlske/Eoi7+2F3HwBeBB4btY4DU1K3q4GRAb6PAS+6e7+7HwHiqecTkSxQ+0YykUnQNwDH0u63pZal+z7wdTNrAzYD376ObTGzJ81sq5lt7ejoyLB0EVH7RjKRrYOx64Gfu3sUWAP8wswyfm53f97dm929ub6+PksliYSf2jfh8c8HTvOvB3Ozo5tJGLcDjWn3o6ll6b4JbABw97eBSqAuw21F5CZo7ptw+Mt/bOUn/9Sak+fOJOjfA+abWZOZlZM8uPryqHWOAg8AmNlCkkHfkVpvnZlVmFkTMB94N1vFiwjcv2C62jcFrrv3Ervbulg1ry4nz3/NoHf3QeApYAuwn+Tomr1m9qyZPZpa7bvAt8xsF/AC8A1P2ktyT38f8CrwJ+4+lIs3IlKs1L4pfG8fTjDscO/83AR9JJOV3H0zyYOs6cueSbu9D1h1lW1/APzgJmoUkWtYs3gmm/ec5N0jZ/js3Nqgy5Hr1NKaYFJFhGWNNTl5fp0ZKxICat8Utlg8wd1zplFWmptIVtCLhIDaN4Xr2JlePursZXWO+vOgoBcJDY2+KUwtrQkAVs/P3dByBb1ISKh9U5hi8Q5mVlcyt74qZ6+hoBcJCbVvCs/QsPNmvJPV8+ows5y9joJeJERG2jfvfaj2TSF4v72b7ouXWJ2jYZUjFPQiITLSvnllt9o3hSAWT/bnc3Wi1AgFvUiIqH1TWFpaO1g0cwp1kypy+joKepGQUfumMPQODLLto7M5Oxs2nYJeJGTUvikM7xw5w6Uhz3l/HhT0IqEzsTzCfberfZPvYq0JyiMl/P7saTl/LQW9SAg9skTtm3wXa02wcvY0Kstyf3VVBb1ICKl9k99On+vjwKmecWnbgIJeJJTUvslvI8Mqczm/TToFvUhIqX2Tv2KtCWqrylk0c8q4vJ6CXiSkNPdNfnJ3YvEE98yro6Qkd9MepFPQi4TUSPtm8x61b/LJwVPnOd3Tz73j1LYBBb1IqOnkqfzT0toBMG4HYkFBLxJqat/knzfjCebUVzGrZsK4vaaCXiTEqirUvsknA4PDvHPkzLiNthmhoBcJObVv8sf2o2fpHRhS0ItIdql9kz9irQlKS4y759aO6+sq6EVCTu2b/NEST7CssYYplWXj+roKepEioPZN8Lp7L7GnrWvc2zagoBcpCmrfBO+tQwmGnXGZf340Bb1IERhp32jum+C0xBNMqoiwtLFm3F9bQS9SJNYsnklHj9o3QYm1Jrh7Ti1lpeMfuwp6kSJx/4LpVETUvgnC0c5ejp7pDaRtAwp6kaJRVaELhwelJT7+0x6kU9CLFBG1b4IRa00wq7qSOXVVgbx+RkFvZg+Z2QEzi5vZ02M8/mMz25n6OWhmXWmP/dDM3k/9/GE2ixeR66P2zfgbGnbeOtTJ6vl1mI3PtMSjXTPozawUeA54GFgErDezRenruPt33H2Zuy8Dfgq8lNr2EWA5sAz4DPA9MxufmfZF5Apq34y/Pe3ddF+8xOr59YHVkMke/Uog7u6H3X0AeBF47FPWXw+8kLq9CPhXdx909wvAbuChmylYRG7OSPvm3SNq34yHWGpa4lXjPO1BukyCvgE4lna/LbXsCmZ2G9AEvJ5atAt4yMwmmlkdcB/QOMZ2T5rZVjPb2tHRcT31i8h1emDhdCZVRHhpe1vQpRSFltYEd8yaQu2kisBqyPbB2HXAJncfAnD314DNwFsk9/LfBoZGb+Tuz7t7s7s319cH9/VGpBhMLI/wyOKZvLLnBBf6B4MuJ9Qu9A+y/ejZwEbbjMgk6Nu5fC88mlo2lnV80rYBwN1/kOrfPwgYcPBGChWR7FnbHKV3YIhXdFA2p949coZLQ86984Ldgc0k6N8D5ptZk5mVkwzzl0evZGYLgKkk99pHlpWaWW3q9hJgCfBaNgoXkRu34rapzKmrYtNWtW9yqaU1QUWkhObZUwOt45pB7+6DwFPAFmA/sMHd95rZs2b2aNqq64AX3T39UH4Z0GJm+4Dnga+nnk9EAmRmfHVFlHc/PMOHiQtBlxNasXgHK5umUVlWGmgdkUxWcvfNJHvt6cueGXX/+2Ns10dy5I2I5JmvLo/yP147wKZtbXzvi7cHXU7onDrXx8FT5/nq8mjQpejMWJFiNaO6ks/9Xj2btrVpTH0OxFoTQHDTHqRT0IsUsbUrGjl5ro9YPBF0KaHzZjxBbVU5C2cEf46ogl6kiH1h0XRqJpaxceuxa68sGXN3YvEE98yro6QkmGkP0inoRYpYRaSUx5c18Nq+U3T1DgRdTmgcPHWe0z393BvAZQPHoqAXKXJPrIgyMDjMy7uOB11KaLS0Bjst8WgKepEid2dDNQtnTmGjxtRnTSyeYE59FbNqJgRdCqCgFxFg7Yooe9q7+eDkuaBLKXj9g0O8c/hM3rRtQEEvIsDjdzVQVmraq8+C7R91cfHSUKDTEo+moBcRplWV84WFt/DrHe0MDA4HXU5Bi8U7KC0x7p4zLehSPqagFxEgOdFZ54UBXv/gdNClFLRYa4K7GmuYXFkWdCkfU9CLCACfm1/P9MkVbNqmMfU3qqt3gN3t3Xkz2maEgl5EAIiUlvDl5Q28caCD0z19QZdTkN461Ik73KugF5F8tXZFI0PDzq93XO2SE/JpWloTTK6IsDRaE3Qpl1HQi8jH5k2fxPJba9iwtY3LZxyXTMTiHdw9t5ZIaX5Fa35VIyKBW9vcSPz0eXYe6wq6lILyUecFjp25mHdtG1DQi8goX1oyk8qyEjZu05j669EyMi1xHp0oNUJBLyKXmVxZxpo7Z/Kbnce5ODAUdDkFI9aaoKFmAk11VUGXcgUFvYhc4YnmKD39g2zZezLoUgrC0LDz1qEEq+fVYRb8tMSjKehF5Ap3N9USnTqBjRpTn5HdbV2c6xvMu/HzIxT0InKFkhLjiRVR3jrUSdvZ3qDLyXux1gRmsCoP+/OgoBeRq3hiRfKi1r/cpjH119IST3DHrClMqyoPupQxKehFZEzRqRO5Z24tm7YfY1gXD7+qC/2D7Dh6ltXz8me2ytEU9CJyVWtXNHLszEV+d6Qz6FLy1jtHOrk05Hk5fn6Egl5EruqLd8xgckWETZqn/qpirZ1UREpYcdvUoEu5KgW9iFzVhPJSvrR0FpvfP0FP36Wgy8lLsXgHK5umUVlWGnQpV6WgF5FP9bXmKH2Xhnll94mgS8k7p871cfDU+bw8Gzadgl5EPtWyxhrmTZ/Ehq0aUz9abGTagzzuz4OCXkSuwcxYuyLK9qNdxE+fD7qcvBKLJ6itKmfhjClBl/KpFPQick1fXt5AaYmxSROdfczdicUTrJpXR0lJ/k17kE5BLyLXNH1yJffdXs9L29sYHNLFwwEOnOqho6c/79s2kGHQm9lDZnbAzOJm9vQYj//YzHamfg6aWVfaY//NzPaa2X4z+yvLxxl/ROSanljRyOme/o+n4y12I/35fB4/P+KaQW9mpcBzwMPAImC9mS1KX8fdv+Puy9x9GfBT4KXUtvcAq4AlwJ3A7wOfz+o7EJFxcf+C6UyrKtdB2ZSW1gRz66uYWT0h6FKuKZM9+pVA3N0Pu/sA8CLw2Kesvx54IXXbgUqgHKgAyoBTN16uiASlPFLC48sa+Mf9pzhzYSDocgLVPzjEO0c6uXd+/k57kC6ToG8A0n+Ft6WWXcHMbgOagNcB3P1t4A3gROpni7vvH2O7J81sq5lt7ejouL53ICLjZm1zlEtDzt/vLO6JzrZ9dJa+S8N5P35+RLYPxq4DNrn7EICZzQMWAlGSvxzuN7N7R2/k7s+7e7O7N9fXF8ZvSJFitHDmFBY3VLOxyKdEiLUmiJQYd8+tDbqUjGQS9O1AY9r9aGrZWNbxSdsG4MvA79z9vLufB34LfPZGChWR/LC2Ocq+E+d4v7076FICE4snuOvWGiZVRIIuJSOZBP17wHwzazKzcpJh/vLolcxsATAVeDtt8VHg82YWMbMykgdir2jdiEjheHTpLMpLS4p2TP3ZCwPsae/O62mJR7tm0Lv7IPAUsIVkSG9w971m9qyZPZq26jrgRXdPn7h6E3AI2APsAna5+2+yVr2IjLuaieU8eMct/HpnO/2DxXfx8LcOdeKe/9MepMvoe4e7bwY2j1r2zKj73x9juyHgj2+iPhHJQ19rbuSV3Sf4p/2nWbN4ZtDljKtYvIPJlRGWRquDLiVjOjNWRK7b6nl1zKyuZGORjal3d1paE3x2Ti2R0sKJz8KpVETyRmmJ8ZXlDfzLwQ5OdvcFXc64+aizl7azFwvibNh0CnoRuSFPrGhk2OGlHcVzULYlPjItceEciAUFvYjcoKa6KlbOnsamrW1cPgYjvGKtHTTUTGB27cSgS7kuCnoRuWFPNEc5nLjA9qNngy4l5waHhnnrUCf3zq+j0OZmVNCLyA17ZPFMJpaXsuG98Ldvdrd309M3WFDDKkco6EXkhlVVRFizeCb/sPs4vQODQZeTU7HWBGZwz1wFvYgUmbUrolwYGOK3e04GXUpOxeIJ7pg1hWlV5UGXct0U9CJyU1Y2TWN27UQ2bgvvmPoL/YPsOHq2oKY9SKegF5GbYmY8sSLK7w6f4Whnb9Dl5MQ7Rzq5NOQFN35+hIJeRG7aV5ZHMYNNId2rb2lNUBEpYcVtU4Mu5YYo6EXkps2qmcDqeXX8cns7w8PhG1Mfa02wsmkalWWlQZdyQxT0IpIVX2tupL3rIm8d6gy6lKw62d1H6+nzBdu2AQW9iGTJg4tuYUplJHQHZWMj0x4U6IFYUNCLSJZUlpXy2LIGXn3/JN0XLwVdTtbEWjuom1TOghmTgy7lhinoRSRr1jZH6R8c5je7jgddSla4O7F4J6vm1VFSUljTHqRT0ItI1ixuqOb2WyazMSSXGfzgZA+J8/2snle4/XlQ0ItIFpkZa5uj7DrWxcFTPUGXc9Nircn+/L0FNi3xaAp6EcmqL9/VQKTEQnH1qZZ4gnnTJzGjujLoUm6Kgl5Esqp2UgX3L5jOr3a0c2loOOhybljfpSHePdJZ8G0bUNCLSA6sbW4kcX6Afz7QEXQpN2z7R2fpuzRc0OPnRyjoRSTr/uD2euomVRR0+6YlniBSYnxmTm3Qpdw0Bb2IZF1ZaQlfWd7A6x+cJnG+P+hybkisNcHyW6cyqSISdCk3TUEvIjmxdkWUwWHn1zvagy7lup29MMD7x7sL8mpSY1HQi0hOzL9lMksba9hYgBcPf/NQAncU9CIi1/K15igHTvWwp7076FKuS6w1weTKCEsaqoMuJSsU9CKSM/9m6SwqIiVs3Fo4Z8q6Oy2tCe6ZW0ukNBwRGY53ISJ5aUplGQ/dOYO/39lO36WhoMvJyIedvbR3XQzF+PkRCnoRyam1Kxo51zfIa/tOBV1KRmKtybH/qwt82oN0CnoRyal75tbSUDOhYMbUx+IJGmomMLt2YtClZE1GQW9mD5nZATOLm9nTYzz+YzPbmfo5aGZdqeX3pS3faWZ9ZvZ4tt+EiOSvkhLjqyuixOIJjnddDLqcTzU4NMxbhzq5d34dZoU7LfFo1wx6MysFngMeBhYB681sUfo67v4dd1/m7suAnwIvpZa/kbb8fqAXeC3L70FE8tzaFVHc4aXt+X1Qdnd7Nz19g6EZVjkikz36lUDc3Q+7+wDwIvDYp6y/HnhhjOVPAL91997rL1NEClnjtIncPWcaG7fl95j6WGsCM1g1t/iCvgFIb661pZZdwcxuA5qA18d4eB1j/wLAzJ40s61mtrWjo3AnQRKRq1u7opGPOnt598iZoEu5qlhrgjtnVTO1qjzoUrIq2wdj1wGb3P2ycVRmNhNYDGwZayN3f97dm929ub4+PEe6ReQTDy+ewaSKSN5efep8/yDbj54NXdsGMgv6dqAx7X40tWwsV9tr/xrwK3cPzxWDReS6TCyP8KUlM9m85wTn+weDLucK7xzuZHDYuTdE4+dHZBL07wHzzazJzMpJhvnLo1cyswXAVODtMZ7jan17ESkia5uj9A4MsXn3iaBLuUJLa4LKshJWzJ4adClZd82gd/dB4CmSbZf9wAZ332tmz5rZo2mrrgNe9FFHWsxsNslvBP+SraJFpDAtv3Uqc+qr2Lgt/8bUx+IJVjbVUhEpDbqUrMtoomV33wxsHrXsmVH3v3+VbT/kKgdvRaS4mBlrVzTyw1c/4EjiAk11VUGXBMCJ7ovET5/nD5sbr71yAdKZsSIyrr6yvIESg015tFcfa00A4ZmWeDQFvYiMq1umVPL536vnl9vaGRrOjzH1sXiCukkVLJgxOehSckJBLyLjbm1zIyfP9dHSGvx5M8PDzpvxBKvn1YZq2oN0CnoRGXcPLJxOzcSyvBhT/8HJHhLnB0I1W+VohX/VWxEpOBWRUh5f1sDf/u4jHvxRsAPyevqSY/rDNP/8aAp6EQnEN1c3cbZ3gEtDw0GXwvzpk5lRXRl0GTmjoBeRQDROm8hP1t0VdBlFQT16EZGQU9CLiIScgl5EJOQU9CIiIaegFxEJOQW9iEjIKehFREJOQS8iEnKWb1dkN7MO4KObeIo6IJGlcgqdPovL6fO4nD6PT4Ths7jN3cecsCfvgv5mmdlWd28Ouo58oM/icvo8LqfP4xNh/yzUuhERCTkFvYhIyIUx6J8PuoA8os/icvo8LqfP4xOh/ixC16MXEZHLhXGPXkRE0ijoRURCLjRBb2YPmdkBM4ub2dNB1xMkM2s0szfMbJ+Z7TWzPw26pqCZWamZ7TCzfwi6lqCZWY2ZbTKzD8xsv5l9NuiagmRm30n9P3nfzF4ws9BdaioUQW9mpcBzwMPAImC9mS0KtqpADQLfdfdFwN3AnxT55wHwp8D+oIvIEz8BXnX3BcBSivhzMbMG4D8Aze5+J1AKrAu2quwLRdADK4G4ux929wHgReCxgGsKjLufcPftqds9JP8jNwRbVXDMLAo8AvxN0LUEzcyqgc8B/wvA3QfcvSvYqgIXASaYWQSYCBwPuJ6sC0vQNwDH0u63UcTBls7MZgN3Ae8EW0mg/hL4z0DwV6EOXhPQAfyfVCvrb8ysKuiiguLu7cB/B44CJ4Bud38t2KqyLyxBL2Mws0nAL4H/6O7ngq4nCGb2JeC0u28LupY8EQGWA3/t7ncBF4CiPaZlZlNJfvtvAmYBVWb29WCryr6wBH070Jh2P5paVrTMrIxkyP+du78UdD0BWgU8amYfkmzp3W9mfxtsSYFqA9rcfeQb3iaSwV+svgAccfcOd78EvATcE3BNWReWoH8PmG9mTWZWTvJgyssB1xQYMzOSPdj97v6joOsJkrv/mbtH3X02yX8Xr7t76PbYMuXuJ4FjZnZ7atEDwL4ASwraUeBuM5uY+n/zACE8OB0JuoBscPdBM3sK2ELyqPn/dve9AZcVpFXAvwP2mNnO1LI/d/fNAdYk+ePbwN+ldooOA38UcD2Bcfd3zGwTsJ3kaLUdhHA6BE2BICIScmFp3YiIyFUo6EVEQk5BLyIScgp6EZGQU9CLiIScgl5EJOQU9CIiIff/AWm+N7rRaXZ6AAAAAElFTkSuQmCC\n",
+      "text/plain": [
+       "<Figure size 432x288 with 1 Axes>"
       ]
+     },
+     "metadata": {
+      "needs_background": "light"
+     },
+     "output_type": "display_data"
     }
-  ]
-}
\ No newline at end of file
+   ],
+   "source": [
+    "%%time\n",
+    "\n",
+    "batch_size = 32               # size of batch the inner NN is trained with\n",
+    "learning_rate = 1e-1          # fix learning rate\n",
+    "ds = \"MNIST\"                  # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100)\n",
+    "toy_size = 1               # total propeortion of training and test set we use\n",
+    "max_epochs = 100              # max number of epochs that is run if early stopping is not hit\n",
+    "early_stop_num = 10           # max number of worse validation scores before early stopping is triggered\n",
+    "early_stop_flag = True        # implement early stopping or not\n",
+    "average_validation = [15,25]  # if not implementing early stopping, what epochs are we averaging over\n",
+    "num_policies = 5              # fix number of policies\n",
+    "num_sub_policies = 5          # fix number of sub-policies in a policy\n",
+    "iterations = 100              # total iterations, should be more than the number of policies\n",
+    "IsLeNet = \"SimpleNet\"         # using LeNet or EasyNet or SimpleNet\n",
+    "\n",
+    "# generate random policies at start\n",
+    "policies = generate_policies(num_policies, num_sub_policies)\n",
+    "\n",
+    "q_values, best_q_values = run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, early_stop_flag, average_validation, iterations, IsLeNet)\n",
+    "\n",
+    "plt.plot(best_q_values)\n",
+    "\n",
+    "best_q_values = np.array(best_q_values)\n",
+    "save('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), best_q_values)\n",
+    "#best_q_values = load('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), allow_pickle=True)"
+   ]
+  }
+ ],
+ "metadata": {
+  "accelerator": "GPU",
+  "colab": {
+   "collapsed_sections": [],
+   "name": "UCB1.ipynb",
+   "provenance": []
+  },
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.7.7"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 1
+}