diff --git a/MetaAugment/UCB1_JC.ipynb b/MetaAugment/UCB1_JC.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..a6795294a3e7f9ee0b333f01d1e22d2486baa326
--- /dev/null
+++ b/MetaAugment/UCB1_JC.ipynb
@@ -0,0 +1,341 @@
+{
+  "nbformat": 4,
+  "nbformat_minor": 0,
+  "metadata": {
+    "colab": {
+      "name": "UCB1.ipynb",
+      "provenance": [],
+      "collapsed_sections": []
+    },
+    "kernelspec": {
+      "name": "python3",
+      "display_name": "Python 3"
+    },
+    "language_info": {
+      "name": "python"
+    }
+  },
+  "cells": [
+    {
+      "cell_type": "code",
+      "source": [
+        "import numpy as np\n",
+        "import torch\n",
+        "torch.manual_seed(0)\n",
+        "import torch.nn as nn\n",
+        "import torch.nn.functional as F\n",
+        "import torch.optim as optim\n",
+        "import torch.utils.data as data_utils\n",
+        "import torchvision\n",
+        "import torchvision.datasets as datasets"
+      ],
+      "metadata": {
+        "id": "U_ZJ2LqDiu_v"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "class LeNet(nn.Module):\n",
+        "    def __init__(self):\n",
+        "        super().__init__()\n",
+        "        self.conv1 = nn.Conv2d(1, 6, 5)\n",
+        "        self.relu1 = nn.ReLU()\n",
+        "        self.pool1 = nn.MaxPool2d(2)\n",
+        "        self.conv2 = nn.Conv2d(6, 16, 5)\n",
+        "        self.relu2 = nn.ReLU()\n",
+        "        self.pool2 = nn.MaxPool2d(2)\n",
+        "        self.fc1 = nn.Linear(256, 120)\n",
+        "        self.relu3 = nn.ReLU()\n",
+        "        self.fc2 = nn.Linear(120, 84)\n",
+        "        self.relu4 = nn.ReLU()\n",
+        "        self.fc3 = nn.Linear(84, 10)\n",
+        "        self.relu5 = nn.ReLU()\n",
+        "\n",
+        "    def forward(self, x):\n",
+        "        y = self.conv1(x)\n",
+        "        y = self.relu1(y)\n",
+        "        y = self.pool1(y)\n",
+        "        y = self.conv2(y)\n",
+        "        y = self.relu2(y)\n",
+        "        y = self.pool2(y)\n",
+        "        y = y.view(y.shape[0], -1)\n",
+        "        y = self.fc1(y)\n",
+        "        y = self.relu3(y)\n",
+        "        y = self.fc2(y)\n",
+        "        y = self.relu4(y)\n",
+        "        y = self.fc3(y)\n",
+        "        y = self.relu5(y)\n",
+        "        return y"
+      ],
+      "metadata": {
+        "id": "4ksS_duLFADW"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "\"\"\"Make toy dataset\"\"\"\n",
+        "\n",
+        "def create_toy(train_dataset, test_dataset, batch_size, n_samples):\n",
+        "    # shuffle and take first n_samples %age of training dataset\n",
+        "    shuffled_train_dataset = torch.utils.data.Subset(train_dataset, torch.randperm(len(train_dataset)).tolist())\n",
+        "    indices_train = torch.arange(int(n_samples*len(train_dataset)))\n",
+        "    reduced_train_dataset = data_utils.Subset(shuffled_train_dataset, indices_train)\n",
+        "\n",
+        "    # shuffle and take first n_samples %age of test dataset\n",
+        "    shuffled_test_dataset = torch.utils.data.Subset(test_dataset, torch.randperm(len(test_dataset)).tolist())\n",
+        "    indices_test = torch.arange(int(n_samples*len(test_dataset)))\n",
+        "    reduced_test_dataset = data_utils.Subset(shuffled_test_dataset, indices_test)\n",
+        "\n",
+        "    train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=batch_size)\n",
+        "    test_loader = torch.utils.data.DataLoader(reduced_test_dataset, batch_size=batch_size)\n",
+        "\n",
+        "    return train_loader, test_loader"
+      ],
+      "metadata": {
+        "id": "xujQtvVWBgMH"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "\"\"\"Randomly generate 10 policies\"\"\"\n",
+        "\"\"\"Each policy has 5 sub-policies\"\"\"\n",
+        "\"\"\"For each sub-policy, pick 2 transformations, 2 probabilities and 2 magnitudes\"\"\"\n",
+        "\n",
+        "def generate_policies(num_policies, num_sub_policies):\n",
+        "    \n",
+        "    policies = np.zeros([num_policies,num_sub_policies,6])\n",
+        "\n",
+        "    # Policies array will be 10x5x6\n",
+        "    for policy in range(num_policies):\n",
+        "        for sub_policy in range(num_sub_policies):\n",
+        "            # pick two sub_policy transformations (0=rotate, 1=shear, 2=scale)\n",
+        "            policies[policy, sub_policy, 0] = np.random.randint(0,3)\n",
+        "            policies[policy, sub_policy, 1] = np.random.randint(0,3)\n",
+        "            while policies[policy, sub_policy, 0] == policies[policy, sub_policy, 1]:\n",
+        "                policies[policy, sub_policy, 1] = np.random.randint(0,3)\n",
+        "\n",
+        "            # pick probabilities\n",
+        "            policies[policy, sub_policy, 2] = np.random.randint(0,11) / 10\n",
+        "            policies[policy, sub_policy, 3] = np.random.randint(0,11) / 10\n",
+        "\n",
+        "            # pick magnitudes\n",
+        "            for transformation in range(2):\n",
+        "                if policies[policy, sub_policy, transformation] <= 1:\n",
+        "                    policies[policy, sub_policy, transformation + 4] = np.random.randint(-4,5)*5\n",
+        "                elif policies[policy, sub_policy, transformation] == 2:\n",
+        "                    policies[policy, sub_policy, transformation + 4] = np.random.randint(5,15)/10\n",
+        "\n",
+        "    return policies"
+      ],
+      "metadata": {
+        "id": "Iql-c88jGGWy"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "\"\"\"Pick policy and sub-policy\"\"\"\n",
+        "\"\"\"Each row of data should have a different sub-policy but for now, this will do\"\"\"\n",
+        "\n",
+        "def sample_sub_policy(policies, policy, num_sub_policies):\n",
+        "    sub_policy = np.random.randint(0,num_sub_policies)\n",
+        "\n",
+        "    degrees = 0\n",
+        "    shear = 0\n",
+        "    scale = 1\n",
+        "\n",
+        "    if policies[policy, sub_policy][0] == 0:\n",
+        "        if np.random.uniform() < policies[policy, sub_policy][2]:\n",
+        "            degrees = policies[policy, sub_policy][4]\n",
+        "    elif policies[policy, sub_policy][1] == 0:\n",
+        "        if np.random.uniform() < policies[policy, sub_policy][3]:\n",
+        "            degrees = policies[policy, sub_policy][5]\n",
+        "\n",
+        "    if policies[policy, sub_policy][0] == 1:\n",
+        "        if np.random.uniform() < policies[policy, sub_policy][2]:\n",
+        "            shear = policies[policy, sub_policy][4]\n",
+        "    elif policies[policy, sub_policy][1] == 1:\n",
+        "        if np.random.uniform() < policies[policy, sub_policy][3]:\n",
+        "            shear = policies[policy, sub_policy][5]\n",
+        "\n",
+        "    if policies[policy, sub_policy][0] == 2:\n",
+        "        if np.random.uniform() < policies[policy, sub_policy][2]:\n",
+        "            scale = policies[policy, sub_policy][4]\n",
+        "    elif policies[policy, sub_policy][1] == 2:\n",
+        "        if np.random.uniform() < policies[policy, sub_policy][3]:\n",
+        "            scale = policies[policy, sub_policy][5]\n",
+        "\n",
+        "    return degrees, shear, scale"
+      ],
+      "metadata": {
+        "id": "QE2VWI8o731X"
+      },
+      "execution_count": null,
+      "outputs": []
+    },
+    {
+      "cell_type": "code",
+      "execution_count": null,
+      "metadata": {
+        "id": "vu_4I4qkbx73"
+      },
+      "outputs": [],
+      "source": [
+        "\"\"\"Sample policy, open and apply above transformations\"\"\"\n",
+        "def run_UCB1(q_values, cnts, total_count, q_plus_cnt, policies, num_policies, num_sub_policies, initial_iteration, batch_size, toy_size, iterations):\n",
+        "\n",
+        "    #Pull each bandit arm just once\n",
+        "    if initial_iteration:\n",
+        "        iterations = num_policies\n",
+        "\n",
+        "    for policy in range(iterations):\n",
+        "        # sample policy and get transformations\n",
+        "        if not initial_iteration:\n",
+        "            this_policy = np.argmax(q_plus_cnt)\n",
+        "        else:\n",
+        "            this_policy = policy\n",
+        "\n",
+        "        degrees, shear, scale = sample_sub_policy(policies, this_policy, num_sub_policies)\n",
+        "\n",
+        "        # create transformations\n",
+        "        transform = torchvision.transforms.Compose(\n",
+        "            [torchvision.transforms.RandomAffine(degrees=(degrees,degrees), shear=(shear,shear), scale=(scale,scale)),\n",
+        "            torchvision.transforms.ToTensor()])\n",
+        "\n",
+        "        # open data and apply these transformations\n",
+        "        train_dataset = datasets.MNIST(root='./MetaAugment/train', train=True, download=True, transform=transform)\n",
+        "        test_dataset = datasets.MNIST(root='./MetaAugment/test', train=False, download=True, transform=transform)\n",
+        "\n",
+        "\n",
+        "        \"\"\"Make toy dataset\"\"\"\n",
+        "        train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, toy_size)\n",
+        "\n",
+        "\n",
+        "        \"\"\" Run model\"\"\"\n",
+        "        model = LeNet()\n",
+        "        sgd = optim.SGD(model.parameters(), lr=1e-1)\n",
+        "        cost = nn.CrossEntropyLoss()\n",
+        "\n",
+        "        best_acc = 0\n",
+        "        early_stop_cnt = 0\n",
+        "\n",
+        "        # choose how many past best validation accuracy we go\n",
+        "        early_stop_num = 10\n",
+        "\n",
+        "        # choose max number of epochs\n",
+        "        epoch = 100\n",
+        "\n",
+        "        for _epoch in range(epoch):\n",
+        "            model.train()\n",
+        "            for idx, (train_x, train_label) in enumerate(train_loader):\n",
+        "                label_np = np.zeros((train_label.shape[0], 10))\n",
+        "                sgd.zero_grad()\n",
+        "                predict_y = model(train_x.float())\n",
+        "                loss = cost(predict_y, train_label.long())\n",
+        "                loss.backward()\n",
+        "                sgd.step()\n",
+        "\n",
+        "            correct = 0\n",
+        "            _sum = 0\n",
+        "            model.eval()\n",
+        "            for idx, (test_x, test_label) in enumerate(test_loader):\n",
+        "                predict_y = model(test_x.float()).detach()\n",
+        "                predict_ys = np.argmax(predict_y, axis=-1)\n",
+        "                label_np = test_label.numpy()\n",
+        "                _ = predict_ys == test_label\n",
+        "                correct += np.sum(_.numpy(), axis=-1)\n",
+        "                _sum += _.shape[0]\n",
+        "            \n",
+        "            acc = correct / _sum\n",
+        "            if acc > best_acc :\n",
+        "                best_acc = acc\n",
+        "                early_stop_cnt = 0\n",
+        "            else:\n",
+        "                early_stop_cnt += 1\n",
+        "\n",
+        "            # exit if validation gets worse for \n",
+        "            if early_stop_cnt >= early_stop_num:\n",
+        "                break\n",
+        "\n",
+        "        # update q_values\n",
+        "        if initial_iteration:\n",
+        "            q_values[this_policy] += best_acc\n",
+        "        else:\n",
+        "            q_values[this_policy] = (q_values[this_policy]*cnts[this_policy] + best_acc) / (cnts[this_policy] + 1)\n",
+        "\n",
+        "        # update counts\n",
+        "        cnts[this_policy] += 1\n",
+        "        total_count += 1\n",
+        "\n",
+        "        # update q_plus_cnt values\n",
+        "        if not initial_iteration:\n",
+        "            for i in range(num_policies):\n",
+        "                q_plus_cnt[i] = q_values[i] + np.sqrt(2*np.log(total_count)/cnts[i])\n",
+        "\n",
+        "        #print(q_values)\n",
+        "\n",
+        "    if initial_iteration:\n",
+        "        for i in range(num_policies):\n",
+        "            q_plus_cnt[i] = q_values[i] + np.sqrt(2*np.log(total_count)/cnts[i])\n",
+        "\n",
+        "    return q_values, cnts, total_count, q_plus_cnt"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "source": [
+        "%%time\n",
+        "\n",
+        "batch_size = 32\n",
+        "toy_size = 0.02\n",
+        "total_iterations = 50\n",
+        "\n",
+        "num_policies = 10\n",
+        "num_sub_policies = 5\n",
+        "policies = generate_policies(num_policies, num_sub_policies)\n",
+        "\n",
+        "#Initialize vector weights, counts and regret\n",
+        "q_values = [0]*num_policies\n",
+        "cnts = [0]*num_policies\n",
+        "q_plus_cnt = [0]*num_policies\n",
+        "total_count = 0\n",
+        "\n",
+        "q_values, cnts, total_count, q_plus_cnt = run_UCB1(q_values, cnts, total_count, q_plus_cnt, policies, num_policies, num_sub_policies, True, batch_size, toy_size, 0)\n",
+        "print(q_values)\n",
+        "q_values, cnts, total_count, q_plus_cnt = run_UCB1(q_values, cnts, total_count, q_plus_cnt, policies, num_policies, num_sub_policies, False, batch_size, toy_size , total_iterations)\n",
+        "print(q_values)"
+      ],
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "id": "doHUtJ_tEiA6",
+        "outputId": "0f25a17b-aab2-4d59-ecea-2e36c7bd5592"
+      },
+      "execution_count": null,
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "[0.81, 0.94, 0.835, 0.94, 0.775, 0.78, 0.96, 0.935, 0.97, 0.76]\n",
+            "[0.722, 0.8578571428571429, 0.7966666666666665, 0.8950000000000001, 0.7766666666666667, 0.8558333333333333, 0.8383333333333334, 0.688, 0.8041666666666666, 0.8766666666666668]\n",
+            "CPU times: user 14min 46s, sys: 10.9 s, total: 14min 57s\n",
+            "Wall time: 14min 58s\n"
+          ]
+        }
+      ]
+    }
+  ]
+}
\ No newline at end of file