diff --git a/MetaAugment/UCB1_JC.ipynb b/MetaAugment/UCB1_JC.ipynb index 5710a128c11db4392adb0d1372caa890c01e344e..2ebf29b8cbb7ed6c7cf1bbaffd9b55881910bfb4 100644 --- a/MetaAugment/UCB1_JC.ipynb +++ b/MetaAugment/UCB1_JC.ipynb @@ -1,485 +1,495 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "UCB1.ipynb", - "provenance": [], - "collapsed_sections": [] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - }, - "accelerator": "GPU" + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "U_ZJ2LqDiu_v" + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import torch\n", + "torch.manual_seed(0)\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import torch.optim as optim\n", + "import torch.utils.data as data_utils\n", + "import torchvision\n", + "import torchvision.datasets as datasets\n", + "\n", + "from matplotlib import pyplot as plt\n", + "from numpy import save, load\n", + "from tqdm import trange" + ] }, - "cells": [ - { - "cell_type": "code", - "source": [ - "import numpy as np\n", - "import torch\n", - "torch.manual_seed(0)\n", - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", - "import torch.optim as optim\n", - "import torch.utils.data as data_utils\n", - "import torchvision\n", - "import torchvision.datasets as datasets\n", - "\n", - "from matplotlib import pyplot as plt\n", - "from numpy import save, load\n", - "from tqdm import trange" - ], - "metadata": { - "id": "U_ZJ2LqDiu_v" - }, - "execution_count": 1, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "\"\"\"Define internal NN module that trains on the dataset\"\"\"\n", - "class LeNet(nn.Module):\n", - " def __init__(self, img_height, img_width, num_labels, img_channels):\n", - " super().__init__()\n", - " self.conv1 = nn.Conv2d(img_channels, 6, 5)\n", - " self.relu1 = nn.ReLU()\n", - " self.pool1 = nn.MaxPool2d(2)\n", - " self.conv2 = nn.Conv2d(6, 16, 5)\n", - " self.relu2 = nn.ReLU()\n", - " self.pool2 = nn.MaxPool2d(2)\n", - " self.fc1 = nn.Linear(int((((img_height-4)/2-4)/2)*(((img_width-4)/2-4)/2)*16), 120)\n", - " self.relu3 = nn.ReLU()\n", - " self.fc2 = nn.Linear(120, 84)\n", - " self.relu4 = nn.ReLU()\n", - " self.fc3 = nn.Linear(84, num_labels)\n", - " self.relu5 = nn.ReLU()\n", - "\n", - " def forward(self, x):\n", - " y = self.conv1(x)\n", - " y = self.relu1(y)\n", - " y = self.pool1(y)\n", - " y = self.conv2(y)\n", - " y = self.relu2(y)\n", - " y = self.pool2(y)\n", - " y = y.view(y.shape[0], -1)\n", - " y = self.fc1(y)\n", - " y = self.relu3(y)\n", - " y = self.fc2(y)\n", - " y = self.relu4(y)\n", - " y = self.fc3(y)\n", - " y = self.relu5(y)\n", - " return y" - ], - "metadata": { - "id": "4ksS_duLFADW" - }, - "execution_count": 2, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "\"\"\"Define internal NN module that trains on the dataset\"\"\"\n", - "class EasyNet(nn.Module):\n", - " def __init__(self, img_height, img_width, num_labels, img_channels):\n", - " super().__init__()\n", - " self.fc1 = nn.Linear(img_height*img_width*img_channels, 2048)\n", - " self.relu1 = nn.ReLU()\n", - " self.fc2 = nn.Linear(2048, num_labels)\n", - " self.relu2 = nn.ReLU()\n", - "\n", - " def forward(self, x):\n", - " y = x.view(x.shape[0], -1)\n", - " y = self.fc1(y)\n", - " y = self.relu1(y)\n", - " y = self.fc2(y)\n", - " y = self.relu2(y)\n", - " return y" - ], - "metadata": { - "id": "LckxnUXGfxjW" - }, - "execution_count": 3, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "\"\"\"Define internal NN module that trains on the dataset\"\"\"\n", - "class SimpleNet(nn.Module):\n", - " def __init__(self, img_height, img_width, num_labels, img_channels):\n", - " super().__init__()\n", - " self.fc1 = nn.Linear(img_height*img_width*img_channels, num_labels)\n", - " self.relu1 = nn.ReLU()\n", - "\n", - " def forward(self, x):\n", - " y = x.view(x.shape[0], -1)\n", - " y = self.fc1(y)\n", - " y = self.relu1(y)\n", - " return y" - ], - "metadata": { - "id": "enaD2xbw5hew" - }, - "execution_count": 4, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "\"\"\"Make toy dataset\"\"\"\n", - "\n", - "def create_toy(train_dataset, test_dataset, batch_size, n_samples):\n", - " \n", - " # shuffle and take first n_samples %age of training dataset\n", - " shuffle_order_train = np.random.RandomState(seed=100).permutation(len(train_dataset))\n", - " shuffled_train_dataset = torch.utils.data.Subset(train_dataset, shuffle_order_train)\n", - " indices_train = torch.arange(int(n_samples*len(train_dataset)))\n", - " reduced_train_dataset = data_utils.Subset(shuffled_train_dataset, indices_train)\n", - "\n", - " # shuffle and take first n_samples %age of test dataset\n", - " shuffle_order_test = np.random.RandomState(seed=1000).permutation(len(test_dataset))\n", - " shuffled_test_dataset = torch.utils.data.Subset(test_dataset, shuffle_order_test)\n", - " indices_test = torch.arange(int(n_samples*len(test_dataset)))\n", - " reduced_test_dataset = data_utils.Subset(shuffled_test_dataset, indices_test)\n", - "\n", - " # push into DataLoader\n", - " train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=batch_size)\n", - " test_loader = torch.utils.data.DataLoader(reduced_test_dataset, batch_size=batch_size)\n", - "\n", - " return train_loader, test_loader" - ], - "metadata": { - "id": "xujQtvVWBgMH" - }, - "execution_count": 5, - "outputs": [] + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "4ksS_duLFADW" + }, + "outputs": [], + "source": [ + "\"\"\"Define internal NN module that trains on the dataset\"\"\"\n", + "class LeNet(nn.Module):\n", + " def __init__(self, img_height, img_width, num_labels, img_channels):\n", + " super().__init__()\n", + " self.conv1 = nn.Conv2d(img_channels, 6, 5)\n", + " self.relu1 = nn.ReLU()\n", + " self.pool1 = nn.MaxPool2d(2)\n", + " self.conv2 = nn.Conv2d(6, 16, 5)\n", + " self.relu2 = nn.ReLU()\n", + " self.pool2 = nn.MaxPool2d(2)\n", + " self.fc1 = nn.Linear(int((((img_height-4)/2-4)/2)*(((img_width-4)/2-4)/2)*16), 120)\n", + " self.relu3 = nn.ReLU()\n", + " self.fc2 = nn.Linear(120, 84)\n", + " self.relu4 = nn.ReLU()\n", + " self.fc3 = nn.Linear(84, num_labels)\n", + " self.relu5 = nn.ReLU()\n", + "\n", + " def forward(self, x):\n", + " y = self.conv1(x)\n", + " y = self.relu1(y)\n", + " y = self.pool1(y)\n", + " y = self.conv2(y)\n", + " y = self.relu2(y)\n", + " y = self.pool2(y)\n", + " y = y.view(y.shape[0], -1)\n", + " y = self.fc1(y)\n", + " y = self.relu3(y)\n", + " y = self.fc2(y)\n", + " y = self.relu4(y)\n", + " y = self.fc3(y)\n", + " y = self.relu5(y)\n", + " return y" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "LckxnUXGfxjW" + }, + "outputs": [], + "source": [ + "\"\"\"Define internal NN module that trains on the dataset\"\"\"\n", + "class EasyNet(nn.Module):\n", + " def __init__(self, img_height, img_width, num_labels, img_channels):\n", + " super().__init__()\n", + " self.fc1 = nn.Linear(img_height*img_width*img_channels, 2048)\n", + " self.relu1 = nn.ReLU()\n", + " self.fc2 = nn.Linear(2048, num_labels)\n", + " self.relu2 = nn.ReLU()\n", + "\n", + " def forward(self, x):\n", + " y = x.view(x.shape[0], -1)\n", + " y = self.fc1(y)\n", + " y = self.relu1(y)\n", + " y = self.fc2(y)\n", + " y = self.relu2(y)\n", + " return y" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "id": "enaD2xbw5hew" + }, + "outputs": [], + "source": [ + "\"\"\"Define internal NN module that trains on the dataset\"\"\"\n", + "class SimpleNet(nn.Module):\n", + " def __init__(self, img_height, img_width, num_labels, img_channels):\n", + " super().__init__()\n", + " self.fc1 = nn.Linear(img_height*img_width*img_channels, num_labels)\n", + " self.relu1 = nn.ReLU()\n", + "\n", + " def forward(self, x):\n", + " y = x.view(x.shape[0], -1)\n", + " y = self.fc1(y)\n", + " y = self.relu1(y)\n", + " return y" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "id": "xujQtvVWBgMH" + }, + "outputs": [], + "source": [ + "\"\"\"Make toy dataset\"\"\"\n", + "\n", + "def create_toy(train_dataset, test_dataset, batch_size, n_samples):\n", + " \n", + " # shuffle and take first n_samples %age of training dataset\n", + " shuffle_order_train = np.random.RandomState(seed=100).permutation(len(train_dataset))\n", + " shuffled_train_dataset = torch.utils.data.Subset(train_dataset, shuffle_order_train)\n", + " indices_train = torch.arange(int(n_samples*len(train_dataset)))\n", + " reduced_train_dataset = data_utils.Subset(shuffled_train_dataset, indices_train)\n", + "\n", + " # shuffle and take first n_samples %age of test dataset\n", + " shuffle_order_test = np.random.RandomState(seed=1000).permutation(len(test_dataset))\n", + " shuffled_test_dataset = torch.utils.data.Subset(test_dataset, shuffle_order_test)\n", + " indices_test = torch.arange(int(n_samples*len(test_dataset)))\n", + " reduced_test_dataset = data_utils.Subset(shuffled_test_dataset, indices_test)\n", + "\n", + " # push into DataLoader\n", + " train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=batch_size)\n", + " test_loader = torch.utils.data.DataLoader(reduced_test_dataset, batch_size=batch_size)\n", + "\n", + " return train_loader, test_loader" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "id": "Iql-c88jGGWy" + }, + "outputs": [], + "source": [ + "\"\"\"Randomly generate 10 policies\"\"\"\n", + "\"\"\"Each policy has 5 sub-policies\"\"\"\n", + "\"\"\"For each sub-policy, pick 2 transformations, 2 probabilities and 2 magnitudes\"\"\"\n", + "\n", + "def generate_policies(num_policies, num_sub_policies):\n", + " \n", + " policies = np.zeros([num_policies,num_sub_policies,6])\n", + "\n", + " # Policies array will be 10x5x6\n", + " for policy in range(num_policies):\n", + " for sub_policy in range(num_sub_policies):\n", + " # pick two sub_policy transformations (0=rotate, 1=shear, 2=scale)\n", + " policies[policy, sub_policy, 0] = np.random.randint(0,3)\n", + " policies[policy, sub_policy, 1] = np.random.randint(0,3)\n", + " while policies[policy, sub_policy, 0] == policies[policy, sub_policy, 1]:\n", + " policies[policy, sub_policy, 1] = np.random.randint(0,3)\n", + "\n", + " # pick probabilities\n", + " policies[policy, sub_policy, 2] = np.random.randint(0,11) / 10\n", + " policies[policy, sub_policy, 3] = np.random.randint(0,11) / 10\n", + "\n", + " # pick magnitudes\n", + " for transformation in range(2):\n", + " if policies[policy, sub_policy, transformation] <= 1:\n", + " policies[policy, sub_policy, transformation + 4] = np.random.randint(-4,5)*5\n", + " elif policies[policy, sub_policy, transformation] == 2:\n", + " policies[policy, sub_policy, transformation + 4] = np.random.randint(5,15)/10\n", + "\n", + " return policies" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "id": "QE2VWI8o731X" + }, + "outputs": [], + "source": [ + "\"\"\"Pick policy and sub-policy\"\"\"\n", + "\"\"\"Each row of data should have a different sub-policy but for now, this will do\"\"\"\n", + "\n", + "def sample_sub_policy(policies, policy, num_sub_policies):\n", + " sub_policy = np.random.randint(0,num_sub_policies)\n", + "\n", + " degrees = 0\n", + " shear = 0\n", + " scale = 1\n", + "\n", + " # check for rotations\n", + " if policies[policy, sub_policy][0] == 0:\n", + " if np.random.uniform() < policies[policy, sub_policy][2]:\n", + " degrees = policies[policy, sub_policy][4]\n", + " elif policies[policy, sub_policy][1] == 0:\n", + " if np.random.uniform() < policies[policy, sub_policy][3]:\n", + " degrees = policies[policy, sub_policy][5]\n", + "\n", + " # check for shears\n", + " if policies[policy, sub_policy][0] == 1:\n", + " if np.random.uniform() < policies[policy, sub_policy][2]:\n", + " shear = policies[policy, sub_policy][4]\n", + " elif policies[policy, sub_policy][1] == 1:\n", + " if np.random.uniform() < policies[policy, sub_policy][3]:\n", + " shear = policies[policy, sub_policy][5]\n", + "\n", + " # check for scales\n", + " if policies[policy, sub_policy][0] == 2:\n", + " if np.random.uniform() < policies[policy, sub_policy][2]:\n", + " scale = policies[policy, sub_policy][4]\n", + " elif policies[policy, sub_policy][1] == 2:\n", + " if np.random.uniform() < policies[policy, sub_policy][3]:\n", + " scale = policies[policy, sub_policy][5]\n", + "\n", + " return degrees, shear, scale" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "id": "vu_4I4qkbx73" + }, + "outputs": [], + "source": [ + "\"\"\"Sample policy, open and apply above transformations\"\"\"\n", + "def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, early_stop_flag, average_validation, iterations, IsLeNet):\n", + "\n", + " # get number of policies and sub-policies\n", + " num_policies = len(policies)\n", + " num_sub_policies = len(policies[0])\n", + "\n", + " #Initialize vector weights, counts and regret\n", + " q_values = [0]*num_policies\n", + " cnts = [0]*num_policies\n", + " q_plus_cnt = [0]*num_policies\n", + " total_count = 0\n", + "\n", + " best_q_values = []\n", + "\n", + " for policy in trange(iterations):\n", + "\n", + " # get the action to try (either initially in order or using best q_plus_cnt value)\n", + " if policy >= num_policies:\n", + " this_policy = np.argmax(q_plus_cnt)\n", + " else:\n", + " this_policy = policy\n", + "\n", + " # get info of transformation for this sub-policy\n", + " degrees, shear, scale = sample_sub_policy(policies, this_policy, num_sub_policies)\n", + "\n", + " # create transformations using above info\n", + " transform = torchvision.transforms.Compose(\n", + " [torchvision.transforms.RandomAffine(degrees=(degrees,degrees), shear=(shear,shear), scale=(scale,scale)),\n", + " torchvision.transforms.ToTensor()])\n", + "\n", + " # open data and apply these transformations\n", + " if ds == \"MNIST\":\n", + " train_dataset = datasets.MNIST(root='./MetaAugment/train', train=True, download=True, transform=transform)\n", + " test_dataset = datasets.MNIST(root='./MetaAugment/test', train=False, download=True, transform=transform)\n", + " elif ds == \"KMNIST\":\n", + " train_dataset = datasets.KMNIST(root='./MetaAugment/train', train=True, download=True, transform=transform)\n", + " test_dataset = datasets.KMNIST(root='./MetaAugment/test', train=False, download=True, transform=transform)\n", + " elif ds == \"FashionMNIST\":\n", + " train_dataset = datasets.FashionMNIST(root='./MetaAugment/train', train=True, download=True, transform=transform)\n", + " test_dataset = datasets.FashionMNIST(root='./MetaAugment/test', train=False, download=True, transform=transform)\n", + " elif ds == \"CIFAR10\":\n", + " train_dataset = datasets.CIFAR10(root='./MetaAugment/train', train=True, download=True, transform=transform)\n", + " test_dataset = datasets.CIFAR10(root='./MetaAugment/test', train=False, download=True, transform=transform)\n", + " elif ds == \"CIFAR100\":\n", + " train_dataset = datasets.CIFAR100(root='./MetaAugment/train', train=True, download=True, transform=transform)\n", + " test_dataset = datasets.CIFAR100(root='./MetaAugment/test', train=False, download=True, transform=transform)\n", + "\n", + " # check sizes of images\n", + " img_height = len(train_dataset[0][0][0])\n", + " img_width = len(train_dataset[0][0][0][0])\n", + " img_channels = len(train_dataset[0][0])\n", + "\n", + " # check output labels\n", + " if ds == \"CIFAR10\" or ds == \"CIFAR100\":\n", + " num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1)\n", + " else:\n", + " num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1).item()\n", + "\n", + " # create toy dataset from above uploaded data\n", + " train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, toy_size)\n", + "\n", + " # create model\n", + " device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", + " if IsLeNet == \"LeNet\":\n", + " model = LeNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device)\n", + " elif IsLeNet == \"EasyNet\":\n", + " model = EasyNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device)\n", + " else:\n", + " model = SimpleNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device)\n", + " sgd = optim.SGD(model.parameters(), lr=1e-1)\n", + " cost = nn.CrossEntropyLoss()\n", + "\n", + " # set variables for best validation accuracy and early stop count\n", + " best_acc = 0\n", + " early_stop_cnt = 0\n", + " total_val = 0\n", + "\n", + " # train model and check validation accuracy each epoch\n", + " for _epoch in range(max_epochs):\n", + "\n", + " # train model\n", + " model.train()\n", + " for idx, (train_x, train_label) in enumerate(train_loader):\n", + " train_x, train_label = train_x.to(device), train_label.to(device) # new code\n", + " label_np = np.zeros((train_label.shape[0], num_labels))\n", + " sgd.zero_grad()\n", + " predict_y = model(train_x.float())\n", + " loss = cost(predict_y, train_label.long())\n", + " loss.backward()\n", + " sgd.step()\n", + "\n", + " # check validation accuracy on validation set\n", + " correct = 0\n", + " _sum = 0\n", + " model.eval()\n", + " for idx, (test_x, test_label) in enumerate(test_loader):\n", + " test_x, test_label = test_x.to(device), test_label.to(device) # new code\n", + " predict_y = model(test_x.float()).detach()\n", + " #predict_ys = np.argmax(predict_y, axis=-1)\n", + " predict_ys = torch.argmax(predict_y, axis=-1) # changed np to torch\n", + " #label_np = test_label.numpy()\n", + " _ = predict_ys == test_label\n", + " #correct += np.sum(_.numpy(), axis=-1)\n", + " correct += np.sum(_.cpu().numpy(), axis=-1) # added .cpu()\n", + " _sum += _.shape[0]\n", + " \n", + " acc = correct / _sum\n", + "\n", + " if average_validation[0] <= _epoch <= average_validation[1]:\n", + " total_val += acc\n", + "\n", + " # update best validation accuracy if it was higher, otherwise increase early stop count\n", + " if acc > best_acc :\n", + " best_acc = acc\n", + " early_stop_cnt = 0\n", + " else:\n", + " early_stop_cnt += 1\n", + "\n", + " # exit if validation gets worse over 10 runs and using early stopping\n", + " if early_stop_cnt >= early_stop_num and early_stop_flag:\n", + " break\n", + "\n", + " # exit if using fixed epoch length\n", + " if _epoch >= average_validation[1] and not early_stop_flag:\n", + " best_acc = total_val / (average_validation[1] - average_validation[0] + 1)\n", + " break\n", + "\n", + " # update q_values\n", + " if policy < num_policies:\n", + " q_values[this_policy] += best_acc\n", + " else:\n", + " q_values[this_policy] = (q_values[this_policy]*cnts[this_policy] + best_acc) / (cnts[this_policy] + 1)\n", + "\n", + " best_q_value = max(q_values)\n", + " best_q_values.append(best_q_value)\n", + "\n", + " if (policy+1) % 10 == 0:\n", + " print(\"Iteration: {},\\tQ-Values: {}, Best Policy: {}\".format(policy+1, list(np.around(np.array(q_values),2)), max(list(np.around(np.array(q_values),2)))))\n", + "\n", + " # update counts\n", + " cnts[this_policy] += 1\n", + " total_count += 1\n", + "\n", + " # update q_plus_cnt values every turn after the initial sweep through\n", + " if policy >= num_policies - 1:\n", + " for i in range(num_policies):\n", + " q_plus_cnt[i] = q_values[i] + np.sqrt(2*np.log(total_count)/cnts[i])\n", + "\n", + " return q_values, best_q_values" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 342 }, + "id": "doHUtJ_tEiA6", + "outputId": "8195ba17-c95f-4b75-d8dc-19c5d76d5e43" + }, + "outputs": [ { - "cell_type": "code", - "source": [ - "\"\"\"Randomly generate 10 policies\"\"\"\n", - "\"\"\"Each policy has 5 sub-policies\"\"\"\n", - "\"\"\"For each sub-policy, pick 2 transformations, 2 probabilities and 2 magnitudes\"\"\"\n", - "\n", - "def generate_policies(num_policies, num_sub_policies):\n", - " \n", - " policies = np.zeros([num_policies,num_sub_policies,6])\n", - "\n", - " # Policies array will be 10x5x6\n", - " for policy in range(num_policies):\n", - " for sub_policy in range(num_sub_policies):\n", - " # pick two sub_policy transformations (0=rotate, 1=shear, 2=scale)\n", - " policies[policy, sub_policy, 0] = np.random.randint(0,3)\n", - " policies[policy, sub_policy, 1] = np.random.randint(0,3)\n", - " while policies[policy, sub_policy, 0] == policies[policy, sub_policy, 1]:\n", - " policies[policy, sub_policy, 1] = np.random.randint(0,3)\n", - "\n", - " # pick probabilities\n", - " policies[policy, sub_policy, 2] = np.random.randint(0,11) / 10\n", - " policies[policy, sub_policy, 3] = np.random.randint(0,11) / 10\n", - "\n", - " # pick magnitudes\n", - " for transformation in range(2):\n", - " if policies[policy, sub_policy, transformation] <= 1:\n", - " policies[policy, sub_policy, transformation + 4] = np.random.randint(-4,5)*5\n", - " elif policies[policy, sub_policy, transformation] == 2:\n", - " policies[policy, sub_policy, transformation + 4] = np.random.randint(5,15)/10\n", - "\n", - " return policies" - ], - "metadata": { - "id": "Iql-c88jGGWy" - }, - "execution_count": 6, - "outputs": [] + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 10/10 [01:28<00:00, 8.84s/it]" + ] }, { - "cell_type": "code", - "source": [ - "\"\"\"Pick policy and sub-policy\"\"\"\n", - "\"\"\"Each row of data should have a different sub-policy but for now, this will do\"\"\"\n", - "\n", - "def sample_sub_policy(policies, policy, num_sub_policies):\n", - " sub_policy = np.random.randint(0,num_sub_policies)\n", - "\n", - " degrees = 0\n", - " shear = 0\n", - " scale = 1\n", - "\n", - " # check for rotations\n", - " if policies[policy, sub_policy][0] == 0:\n", - " if np.random.uniform() < policies[policy, sub_policy][2]:\n", - " degrees = policies[policy, sub_policy][4]\n", - " elif policies[policy, sub_policy][1] == 0:\n", - " if np.random.uniform() < policies[policy, sub_policy][3]:\n", - " degrees = policies[policy, sub_policy][5]\n", - "\n", - " # check for shears\n", - " if policies[policy, sub_policy][0] == 1:\n", - " if np.random.uniform() < policies[policy, sub_policy][2]:\n", - " shear = policies[policy, sub_policy][4]\n", - " elif policies[policy, sub_policy][1] == 1:\n", - " if np.random.uniform() < policies[policy, sub_policy][3]:\n", - " shear = policies[policy, sub_policy][5]\n", - "\n", - " # check for scales\n", - " if policies[policy, sub_policy][0] == 2:\n", - " if np.random.uniform() < policies[policy, sub_policy][2]:\n", - " scale = policies[policy, sub_policy][4]\n", - " elif policies[policy, sub_policy][1] == 2:\n", - " if np.random.uniform() < policies[policy, sub_policy][3]:\n", - " scale = policies[policy, sub_policy][5]\n", - "\n", - " return degrees, shear, scale" - ], - "metadata": { - "id": "QE2VWI8o731X" - }, - "execution_count": 7, - "outputs": [] + "name": "stdout", + "output_type": "stream", + "text": [ + "Iteration: 10,\tQ-Values: [0.77, 0.74, 0.8, 0.72, 0.77], Best Policy: 0.8\n", + "CPU times: user 1min 21s, sys: 694 ms, total: 1min 22s\n", + "Wall time: 1min 28s\n" + ] }, { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "id": "vu_4I4qkbx73" - }, - "outputs": [], - "source": [ - "\"\"\"Sample policy, open and apply above transformations\"\"\"\n", - "def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, early_stop_flag, average_validation, iterations, IsLeNet):\n", - "\n", - " # get number of policies and sub-policies\n", - " num_policies = len(policies)\n", - " num_sub_policies = len(policies[0])\n", - "\n", - " #Initialize vector weights, counts and regret\n", - " q_values = [0]*num_policies\n", - " cnts = [0]*num_policies\n", - " q_plus_cnt = [0]*num_policies\n", - " total_count = 0\n", - "\n", - " best_q_values = []\n", - "\n", - " for policy in trange(iterations):\n", - "\n", - " # get the action to try (either initially in order or using best q_plus_cnt value)\n", - " if policy >= num_policies:\n", - " this_policy = np.argmax(q_plus_cnt)\n", - " else:\n", - " this_policy = policy\n", - "\n", - " # get info of transformation for this sub-policy\n", - " degrees, shear, scale = sample_sub_policy(policies, this_policy, num_sub_policies)\n", - "\n", - " # create transformations using above info\n", - " transform = torchvision.transforms.Compose(\n", - " [torchvision.transforms.RandomAffine(degrees=(degrees,degrees), shear=(shear,shear), scale=(scale,scale)),\n", - " torchvision.transforms.ToTensor()])\n", - "\n", - " # open data and apply these transformations\n", - " if ds == \"MNIST\":\n", - " train_dataset = datasets.MNIST(root='./MetaAugment/train', train=True, download=True, transform=transform)\n", - " test_dataset = datasets.MNIST(root='./MetaAugment/test', train=False, download=True, transform=transform)\n", - " elif ds == \"KMNIST\":\n", - " train_dataset = datasets.KMNIST(root='./MetaAugment/train', train=True, download=True, transform=transform)\n", - " test_dataset = datasets.KMNIST(root='./MetaAugment/test', train=False, download=True, transform=transform)\n", - " elif ds == \"FashionMNIST\":\n", - " train_dataset = datasets.FashionMNIST(root='./MetaAugment/train', train=True, download=True, transform=transform)\n", - " test_dataset = datasets.FashionMNIST(root='./MetaAugment/test', train=False, download=True, transform=transform)\n", - " elif ds == \"CIFAR10\":\n", - " train_dataset = datasets.CIFAR10(root='./MetaAugment/train', train=True, download=True, transform=transform)\n", - " test_dataset = datasets.CIFAR10(root='./MetaAugment/test', train=False, download=True, transform=transform)\n", - " elif ds == \"CIFAR100\":\n", - " train_dataset = datasets.CIFAR100(root='./MetaAugment/train', train=True, download=True, transform=transform)\n", - " test_dataset = datasets.CIFAR100(root='./MetaAugment/test', train=False, download=True, transform=transform)\n", - "\n", - " # check sizes of images\n", - " img_height = len(train_dataset[0][0][0])\n", - " img_width = len(train_dataset[0][0][0][0])\n", - " img_channels = len(train_dataset[0][0])\n", - "\n", - " # check output labels\n", - " if ds == \"CIFAR10\" or ds == \"CIFAR100\":\n", - " num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1)\n", - " else:\n", - " num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1).item()\n", - "\n", - " # create toy dataset from above uploaded data\n", - " train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, toy_size)\n", - "\n", - " # create model\n", - " device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", - " if IsLeNet == \"LeNet\":\n", - " model = LeNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device)\n", - " elif IsLeNet == \"EasyNet\":\n", - " model = EasyNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device)\n", - " else:\n", - " model = SimpleNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device)\n", - " sgd = optim.SGD(model.parameters(), lr=1e-1)\n", - " cost = nn.CrossEntropyLoss()\n", - "\n", - " # set variables for best validation accuracy and early stop count\n", - " best_acc = 0\n", - " early_stop_cnt = 0\n", - " total_val = 0\n", - "\n", - " # train model and check validation accuracy each epoch\n", - " for _epoch in range(max_epochs):\n", - "\n", - " # train model\n", - " model.train()\n", - " for idx, (train_x, train_label) in enumerate(train_loader):\n", - " train_x, train_label = train_x.to(device), train_label.to(device) # new code\n", - " label_np = np.zeros((train_label.shape[0], num_labels))\n", - " sgd.zero_grad()\n", - " predict_y = model(train_x.float())\n", - " loss = cost(predict_y, train_label.long())\n", - " loss.backward()\n", - " sgd.step()\n", - "\n", - " # check validation accuracy on validation set\n", - " correct = 0\n", - " _sum = 0\n", - " model.eval()\n", - " for idx, (test_x, test_label) in enumerate(test_loader):\n", - " test_x, test_label = test_x.to(device), test_label.to(device) # new code\n", - " predict_y = model(test_x.float()).detach()\n", - " #predict_ys = np.argmax(predict_y, axis=-1)\n", - " predict_ys = torch.argmax(predict_y, axis=-1) # changed np to torch\n", - " #label_np = test_label.numpy()\n", - " _ = predict_ys == test_label\n", - " #correct += np.sum(_.numpy(), axis=-1)\n", - " correct += np.sum(_.cpu().numpy(), axis=-1) # added .cpu()\n", - " _sum += _.shape[0]\n", - " \n", - " acc = correct / _sum\n", - "\n", - " if average_validation[0] <= _epoch <= average_validation[1]:\n", - " total_val += acc\n", - "\n", - " # update best validation accuracy if it was higher, otherwise increase early stop count\n", - " if acc > best_acc :\n", - " best_acc = acc\n", - " early_stop_cnt = 0\n", - " else:\n", - " early_stop_cnt += 1\n", - "\n", - " # exit if validation gets worse over 10 runs and using early stopping\n", - " if early_stop_cnt >= early_stop_num and early_stop_flag:\n", - " break\n", - "\n", - " # exit if using fixed epoch length\n", - " if _epoch >= average_validation[1] and not early_stop_flag:\n", - " best_acc = total_val / (average_validation[1] - average_validation[0] + 1)\n", - " break\n", - "\n", - " # update q_values\n", - " if policy < num_policies:\n", - " q_values[this_policy] += best_acc\n", - " else:\n", - " q_values[this_policy] = (q_values[this_policy]*cnts[this_policy] + best_acc) / (cnts[this_policy] + 1)\n", - "\n", - " best_q_value = max(q_values)\n", - " best_q_values.append(best_q_value)\n", - "\n", - " if (policy+1) % 10 == 0:\n", - " print(\"Iteration: {},\\tQ-Values: {}, Best Policy: {}\".format(policy+1, list(np.around(np.array(q_values),2)), max(list(np.around(np.array(q_values),2)))))\n", - "\n", - " # update counts\n", - " cnts[this_policy] += 1\n", - " total_count += 1\n", - "\n", - " # update q_plus_cnt values every turn after the initial sweep through\n", - " if policy >= num_policies - 1:\n", - " for i in range(num_policies):\n", - " q_plus_cnt[i] = q_values[i] + np.sqrt(2*np.log(total_count)/cnts[i])\n", - "\n", - " return q_values, best_q_values" - ] + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] }, { - "cell_type": "code", - "source": [ - "%%time\n", - "\n", - "batch_size = 32 # size of batch the inner NN is trained with\n", - "learning_rate = 1e-1 # fix learning rate\n", - "ds = \"MNIST\" # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100)\n", - "toy_size = 0.02 # total propeortion of training and test set we use\n", - "max_epochs = 100 # max number of epochs that is run if early stopping is not hit\n", - "early_stop_num = 10 # max number of worse validation scores before early stopping is triggered\n", - "early_stop_flag = True # implement early stopping or not\n", - "average_validation = [15,25] # if not implementing early stopping, what epochs are we averaging over\n", - "num_policies = 5 # fix number of policies\n", - "num_sub_policies = 5 # fix number of sub-policies in a policy\n", - "iterations = 100 # total iterations, should be more than the number of policies\n", - "IsLeNet = \"SimpleNet\" # using LeNet or EasyNet or SimpleNet\n", - "\n", - "# generate random policies at start\n", - "policies = generate_policies(num_policies, num_sub_policies)\n", - "\n", - "q_values, best_q_values = run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, early_stop_flag, average_validation, iterations, IsLeNet)\n", - "\n", - "plt.plot(best_q_values)\n", - "\n", - "best_q_values = np.array(best_q_values)\n", - "save('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), best_q_values)\n", - "#best_q_values = load('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), allow_pickle=True)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 342 - }, - "id": "doHUtJ_tEiA6", - "outputId": "8195ba17-c95f-4b75-d8dc-19c5d76d5e43" - }, - "execution_count": 9, - "outputs": [ - { - "output_type": "stream", - "name": "stderr", - "text": [ - "100%|██████████| 10/10 [01:28<00:00, 8.84s/it]" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Iteration: 10,\tQ-Values: [0.77, 0.74, 0.8, 0.72, 0.77], Best Policy: 0.8\n", - "CPU times: user 1min 21s, sys: 694 ms, total: 1min 22s\n", - "Wall time: 1min 28s\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "\n" - ] - }, - { - "output_type": "display_data", - "data": { - "text/plain": [ - "<Figure size 432x288 with 1 Axes>" - ], - "image/png": "\n" - }, - "metadata": { - "needs_background": "light" - } - } + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD7CAYAAABkO19ZAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAfGklEQVR4nO3da3BU95nn8e8jtS4gQAJJGFDLRlzGgM3FoCGOwcnYjis2ztpOYjJQm93KVCqeF+PsbDapLc/UrCvl2rzI1m4ymawzVa7Z3VRlZu0F4mScMTGuGXtm1LZjmzsGDGrABombWiAhEJKQ9OyLbtmNEKaBbp3u079PlYru0+d0P93AT6ef8z//Y+6OiIiEV0nQBYiISG4p6EVEQk5BLyIScgp6EZGQU9CLiIScgl5EJOQyCnoze8jMDphZ3MyeHuPxW83sDTPbYWa7zWxNavmDZrbNzPak/rw/229AREQ+nV1rHL2ZlQIHgQeBNuA9YL2770tb53lgh7v/tZktAja7+2wzuws45e7HzexOYIu7N+TqzYiIyJUiGayzEoi7+2EAM3sReAzYl7aOA1NSt6uB4wDuviNtnb3ABDOrcPf+q71YXV2dz549O+M3ICIisG3btoS714/1WCZB3wAcS7vfBnxm1DrfB14zs28DVcAXxnierwLbxwp5M3sSeBLg1ltvZevWrRmUJSIiI8zso6s9lq2DseuBn7t7FFgD/MLMPn5uM7sD+CHwx2Nt7O7Pu3uzuzfX14/5C0lERG5QJkHfDjSm3Y+mlqX7JrABwN3fBiqBOgAziwK/Av69ux+62YJFROT6ZBL07wHzzazJzMqBdcDLo9Y5CjwAYGYLSQZ9h5nVAK8AT7v7m9krW0REMnXNoHf3QeApYAuwH9jg7nvN7FkzezS12neBb5nZLuAF4BueHM7zFDAPeMbMdqZ+pufknYiIyJiuObxyvDU3N7sOxoqIXB8z2+buzWM9pjNjRURCTkEvIhJymYyjlwy4OyfP9bHrWBcHTp5naHg46JJklPoplSyL1nD7jMmUR7SPI8VDQX+Dui9eYk9bN7vauth5rItdx7o43fPJuWBmARYnV0g/FFUeKeGOWVNYGq1hWWMNSxtrmF07EdNfmoSUgj4D/YND7D/Rw65UoO9s6+Jwx4WPH59TV8WqeXUsjVaztLGGhTOnUFlWGmDFMpq703b2IrvaulJ/j938v/eO8fO3PgSgekIZS6LVyeCP1rCksZrpkyuDLVokSzTqZpThYedw4jw7j3WzOxUK+06c49JQ8nOqn1yRCoNkqC9pqKF6Yllg9cqNGxwapvX0+WTwtyXD/8CpHoaGk3/XDTUTWNpYnQz+aA2Lo9VMqtC+keSnTxt1U/RBf7K7L9l6SYX6nrZuevoHAagqL2VJNPnVflljMthnTKnUV/wQuzgwxN7j3al/E93sOtbF0TO9QLIdN3/6JJZ+/G8i2e8vK1W/X4KnoE8515fsq4/01He1dXHqXLKvXlZqLJgx5eM9uGWNNcypn0RpiUK92J25MJDW8kn+AjhzYQCAipF+fyr4l0ZruE39fglAUQZ9/+AQH5zouexg6aFRffWlaS0Y9dUlUyP9/p3HulLtvW72tHdz8dIQkOz3L22sYVm0+uNvhPWTKwKuWsKuKIL+Qv8gW/aeTB0s7Wb/8XMMDCWHONZNSvbVR9ov6qtLto3u9+881s2Bk+dItfsv6/f/27tvU69fsu7Tgj40/9oGBof5Txt2UVVeyuJoNX+0ejbLUntTM6vVV5fcipSWsHDmFBbOnMK6lbcC0DswyN7j55I7H6lfAJv3nOTY2V7+6+OLA65Yiklo9ugBDnWcZ3Ztlfrqkree+r/beftQJ+/8+QNEdBBXsqho5rqZq4OnkuceWTyTzgsDvHvkTNClSBEJVdCL5Ls/uH06E8pKeWXPiaBLkSKioBcZRxPKS3lg4XReff8kg0OaD0nGh4JeZJypfSPjTUEvMs7UvpHxpqAXGWdq38h4U9CLBEDtGxlPCnqRAKh9I+NJQS8SgAnlpdy/cDpb9qp9I7mnoBcJyJcWzyRxXu0byT0FvUhA1L6R8aKgFwmI2jcyXhT0IgFS+0bGg4JeJEBq38h4UNCLBEjtGxkPCnqRgD2i9o3kmIJeJGD3qX0jOaagFwlYevtmaDi/rvgm4ZBR0JvZQ2Z2wMziZvb0GI/famZvmNkOM9ttZmtSy2tTy8+b2f/MdvEiYTHSvnnnSGfQpUgIXTPozawUeA54GFgErDezRaNW+wtgg7vfBawDfpZa3gf8F+B7WatYJIQ+bt/sVvtGsi+TPfqVQNzdD7v7APAi8NiodRyYkrpdDRwHcPcL7h4jGfgichVq30guZRL0DcCxtPttqWXpvg983czagM3At7NSnUgRUftGciVbB2PXAz939yiwBviFmWX83Gb2pJltNbOtHR0dWSpJpLCofSO5kkkYtwONafejqWXpvglsAHD3t4FKoC7TItz9eXdvdvfm+vr6TDcTCRW1byRXMgn694D5ZtZkZuUkD7a+PGqdo8ADAGa2kGTQa9dc5DqpfSO5cM2gd/dB4ClgC7Cf5OiavWb2rJk9mlrtu8C3zGwX8ALwDXd3ADP7EPgR8A0zaxtjxI6IpKh9I7kQyWQld99M8iBr+rJn0m7vA1ZdZdvZN1GfSFFJb988+9idlJZY0CVJCOjMWJE8o/aNZJuCXiTPjLRvNmvuG8kSBb1Inhlp37z6vkbfSHYo6EXykNo3kk0KepE8pPaNZJOCXiQPTSgv5f4Fat9IdijoRfLUI0vUvpHsUNCL5Cm1byRbFPQieUrtG8kWBb1IHlP7RrJBQS+Sx9S+kWxQ0IvkMbVvJBsU9CJ5bo1OnpKbpKAXyXP3LahX+0ZuioJeJM9NLI+ofSM3RUEvUgDUvpGboaAXKQBq38jNUNCLFIBP2jen1L6R66agFykQyfZNP+8eORN0KVJgFPQiBeK+BfVUlpXwyp7jQZciBUZBL1IgJpZHeGDBLWrfyHVT0IsUELVv5EYo6EUKiNo3ciMU9CIFRO0buREKepECo/aNXC8FvUiBUftGrpeCXqTA6OQpuV4KepEC9MjiWWrfSMYU9CIFaKR9o7lvJBMKepECNNK++a2mLpYMKOhFCpTaN5KpjILezB4yswNmFjezp8d4/FYze8PMdpjZbjNbk/bYn6W2O2BmX8xm8SLFTO0bydQ1g97MSoHngIeBRcB6M1s0arW/ADa4+13AOuBnqW0Xpe7fATwE/Cz1fCJyk9S+kUxlske/Eoi7+2F3HwBeBB4btY4DU1K3q4GRAb6PAS+6e7+7HwHiqecTkSxQ+0YykUnQNwDH0u63pZal+z7wdTNrAzYD376ObTGzJ81sq5lt7ejoyLB0EVH7RjKRrYOx64Gfu3sUWAP8wswyfm53f97dm929ub6+PksliYSf2jfh8c8HTvOvB3Ozo5tJGLcDjWn3o6ll6b4JbABw97eBSqAuw21F5CZo7ptw+Mt/bOUn/9Sak+fOJOjfA+abWZOZlZM8uPryqHWOAg8AmNlCkkHfkVpvnZlVmFkTMB94N1vFiwjcv2C62jcFrrv3Ervbulg1ry4nz3/NoHf3QeApYAuwn+Tomr1m9qyZPZpa7bvAt8xsF/AC8A1P2ktyT38f8CrwJ+4+lIs3IlKs1L4pfG8fTjDscO/83AR9JJOV3H0zyYOs6cueSbu9D1h1lW1/APzgJmoUkWtYs3gmm/ec5N0jZ/js3Nqgy5Hr1NKaYFJFhGWNNTl5fp0ZKxICat8Utlg8wd1zplFWmptIVtCLhIDaN4Xr2JlePursZXWO+vOgoBcJDY2+KUwtrQkAVs/P3dByBb1ISKh9U5hi8Q5mVlcyt74qZ6+hoBcJCbVvCs/QsPNmvJPV8+ows5y9joJeJERG2jfvfaj2TSF4v72b7ouXWJ2jYZUjFPQiITLSvnllt9o3hSAWT/bnc3Wi1AgFvUiIqH1TWFpaO1g0cwp1kypy+joKepGQUfumMPQODLLto7M5Oxs2nYJeJGTUvikM7xw5w6Uhz3l/HhT0IqEzsTzCfberfZPvYq0JyiMl/P7saTl/LQW9SAg9skTtm3wXa02wcvY0Kstyf3VVBb1ICKl9k99On+vjwKmecWnbgIJeJJTUvslvI8Mqczm/TToFvUhIqX2Tv2KtCWqrylk0c8q4vJ6CXiSkNPdNfnJ3YvEE98yro6Qkd9MepFPQi4TUSPtm8x61b/LJwVPnOd3Tz73j1LYBBb1IqOnkqfzT0toBMG4HYkFBLxJqat/knzfjCebUVzGrZsK4vaaCXiTEqirUvsknA4PDvHPkzLiNthmhoBcJObVv8sf2o2fpHRhS0ItIdql9kz9irQlKS4y759aO6+sq6EVCTu2b/NEST7CssYYplWXj+roKepEioPZN8Lp7L7GnrWvc2zagoBcpCmrfBO+tQwmGnXGZf340Bb1IERhp32jum+C0xBNMqoiwtLFm3F9bQS9SJNYsnklHj9o3QYm1Jrh7Ti1lpeMfuwp6kSJx/4LpVETUvgnC0c5ejp7pDaRtAwp6kaJRVaELhwelJT7+0x6kU9CLFBG1b4IRa00wq7qSOXVVgbx+RkFvZg+Z2QEzi5vZ02M8/mMz25n6OWhmXWmP/dDM3k/9/GE2ixeR66P2zfgbGnbeOtTJ6vl1mI3PtMSjXTPozawUeA54GFgErDezRenruPt33H2Zuy8Dfgq8lNr2EWA5sAz4DPA9MxufmfZF5Apq34y/Pe3ddF+8xOr59YHVkMke/Uog7u6H3X0AeBF47FPWXw+8kLq9CPhXdx909wvAbuChmylYRG7OSPvm3SNq34yHWGpa4lXjPO1BukyCvgE4lna/LbXsCmZ2G9AEvJ5atAt4yMwmmlkdcB/QOMZ2T5rZVjPb2tHRcT31i8h1emDhdCZVRHhpe1vQpRSFltYEd8yaQu2kisBqyPbB2HXAJncfAnD314DNwFsk9/LfBoZGb+Tuz7t7s7s319cH9/VGpBhMLI/wyOKZvLLnBBf6B4MuJ9Qu9A+y/ejZwEbbjMgk6Nu5fC88mlo2lnV80rYBwN1/kOrfPwgYcPBGChWR7FnbHKV3YIhXdFA2p949coZLQ86984Ldgc0k6N8D5ptZk5mVkwzzl0evZGYLgKkk99pHlpWaWW3q9hJgCfBaNgoXkRu34rapzKmrYtNWtW9yqaU1QUWkhObZUwOt45pB7+6DwFPAFmA/sMHd95rZs2b2aNqq64AX3T39UH4Z0GJm+4Dnga+nnk9EAmRmfHVFlHc/PMOHiQtBlxNasXgHK5umUVlWGmgdkUxWcvfNJHvt6cueGXX/+2Ns10dy5I2I5JmvLo/yP147wKZtbXzvi7cHXU7onDrXx8FT5/nq8mjQpejMWJFiNaO6ks/9Xj2btrVpTH0OxFoTQHDTHqRT0IsUsbUrGjl5ro9YPBF0KaHzZjxBbVU5C2cEf46ogl6kiH1h0XRqJpaxceuxa68sGXN3YvEE98yro6QkmGkP0inoRYpYRaSUx5c18Nq+U3T1DgRdTmgcPHWe0z393BvAZQPHoqAXKXJPrIgyMDjMy7uOB11KaLS0Bjst8WgKepEid2dDNQtnTmGjxtRnTSyeYE59FbNqJgRdCqCgFxFg7Yooe9q7+eDkuaBLKXj9g0O8c/hM3rRtQEEvIsDjdzVQVmraq8+C7R91cfHSUKDTEo+moBcRplWV84WFt/DrHe0MDA4HXU5Bi8U7KC0x7p4zLehSPqagFxEgOdFZ54UBXv/gdNClFLRYa4K7GmuYXFkWdCkfU9CLCACfm1/P9MkVbNqmMfU3qqt3gN3t3Xkz2maEgl5EAIiUlvDl5Q28caCD0z19QZdTkN461Ik73KugF5F8tXZFI0PDzq93XO2SE/JpWloTTK6IsDRaE3Qpl1HQi8jH5k2fxPJba9iwtY3LZxyXTMTiHdw9t5ZIaX5Fa35VIyKBW9vcSPz0eXYe6wq6lILyUecFjp25mHdtG1DQi8goX1oyk8qyEjZu05j669EyMi1xHp0oNUJBLyKXmVxZxpo7Z/Kbnce5ODAUdDkFI9aaoKFmAk11VUGXcgUFvYhc4YnmKD39g2zZezLoUgrC0LDz1qEEq+fVYRb8tMSjKehF5Ap3N9USnTqBjRpTn5HdbV2c6xvMu/HzIxT0InKFkhLjiRVR3jrUSdvZ3qDLyXux1gRmsCoP+/OgoBeRq3hiRfKi1r/cpjH119IST3DHrClMqyoPupQxKehFZEzRqRO5Z24tm7YfY1gXD7+qC/2D7Dh6ltXz8me2ytEU9CJyVWtXNHLszEV+d6Qz6FLy1jtHOrk05Hk5fn6Egl5EruqLd8xgckWETZqn/qpirZ1UREpYcdvUoEu5KgW9iFzVhPJSvrR0FpvfP0FP36Wgy8lLsXgHK5umUVlWGnQpV6WgF5FP9bXmKH2Xhnll94mgS8k7p871cfDU+bw8Gzadgl5EPtWyxhrmTZ/Ehq0aUz9abGTagzzuz4OCXkSuwcxYuyLK9qNdxE+fD7qcvBKLJ6itKmfhjClBl/KpFPQick1fXt5AaYmxSROdfczdicUTrJpXR0lJ/k17kE5BLyLXNH1yJffdXs9L29sYHNLFwwEOnOqho6c/79s2kGHQm9lDZnbAzOJm9vQYj//YzHamfg6aWVfaY//NzPaa2X4z+yvLxxl/ROSanljRyOme/o+n4y12I/35fB4/P+KaQW9mpcBzwMPAImC9mS1KX8fdv+Puy9x9GfBT4KXUtvcAq4AlwJ3A7wOfz+o7EJFxcf+C6UyrKtdB2ZSW1gRz66uYWT0h6FKuKZM9+pVA3N0Pu/sA8CLw2Kesvx54IXXbgUqgHKgAyoBTN16uiASlPFLC48sa+Mf9pzhzYSDocgLVPzjEO0c6uXd+/k57kC6ToG8A0n+Ft6WWXcHMbgOagNcB3P1t4A3gROpni7vvH2O7J81sq5lt7ejouL53ICLjZm1zlEtDzt/vLO6JzrZ9dJa+S8N5P35+RLYPxq4DNrn7EICZzQMWAlGSvxzuN7N7R2/k7s+7e7O7N9fXF8ZvSJFitHDmFBY3VLOxyKdEiLUmiJQYd8+tDbqUjGQS9O1AY9r9aGrZWNbxSdsG4MvA79z9vLufB34LfPZGChWR/LC2Ocq+E+d4v7076FICE4snuOvWGiZVRIIuJSOZBP17wHwzazKzcpJh/vLolcxsATAVeDtt8VHg82YWMbMykgdir2jdiEjheHTpLMpLS4p2TP3ZCwPsae/O62mJR7tm0Lv7IPAUsIVkSG9w971m9qyZPZq26jrgRXdPn7h6E3AI2APsAna5+2+yVr2IjLuaieU8eMct/HpnO/2DxXfx8LcOdeKe/9MepMvoe4e7bwY2j1r2zKj73x9juyHgj2+iPhHJQ19rbuSV3Sf4p/2nWbN4ZtDljKtYvIPJlRGWRquDLiVjOjNWRK7b6nl1zKyuZGORjal3d1paE3x2Ti2R0sKJz8KpVETyRmmJ8ZXlDfzLwQ5OdvcFXc64+aizl7azFwvibNh0CnoRuSFPrGhk2OGlHcVzULYlPjItceEciAUFvYjcoKa6KlbOnsamrW1cPgYjvGKtHTTUTGB27cSgS7kuCnoRuWFPNEc5nLjA9qNngy4l5waHhnnrUCf3zq+j0OZmVNCLyA17ZPFMJpaXsuG98Ldvdrd309M3WFDDKkco6EXkhlVVRFizeCb/sPs4vQODQZeTU7HWBGZwz1wFvYgUmbUrolwYGOK3e04GXUpOxeIJ7pg1hWlV5UGXct0U9CJyU1Y2TWN27UQ2bgvvmPoL/YPsOHq2oKY9SKegF5GbYmY8sSLK7w6f4Whnb9Dl5MQ7Rzq5NOQFN35+hIJeRG7aV5ZHMYNNId2rb2lNUBEpYcVtU4Mu5YYo6EXkps2qmcDqeXX8cns7w8PhG1Mfa02wsmkalWWlQZdyQxT0IpIVX2tupL3rIm8d6gy6lKw62d1H6+nzBdu2AQW9iGTJg4tuYUplJHQHZWMj0x4U6IFYUNCLSJZUlpXy2LIGXn3/JN0XLwVdTtbEWjuom1TOghmTgy7lhinoRSRr1jZH6R8c5je7jgddSla4O7F4J6vm1VFSUljTHqRT0ItI1ixuqOb2WyazMSSXGfzgZA+J8/2snle4/XlQ0ItIFpkZa5uj7DrWxcFTPUGXc9Nircn+/L0FNi3xaAp6EcmqL9/VQKTEQnH1qZZ4gnnTJzGjujLoUm6Kgl5Esqp2UgX3L5jOr3a0c2loOOhybljfpSHePdJZ8G0bUNCLSA6sbW4kcX6Afz7QEXQpN2z7R2fpuzRc0OPnRyjoRSTr/uD2euomVRR0+6YlniBSYnxmTm3Qpdw0Bb2IZF1ZaQlfWd7A6x+cJnG+P+hybkisNcHyW6cyqSISdCk3TUEvIjmxdkWUwWHn1zvagy7lup29MMD7x7sL8mpSY1HQi0hOzL9lMksba9hYgBcPf/NQAncU9CIi1/K15igHTvWwp7076FKuS6w1weTKCEsaqoMuJSsU9CKSM/9m6SwqIiVs3Fo4Z8q6Oy2tCe6ZW0ukNBwRGY53ISJ5aUplGQ/dOYO/39lO36WhoMvJyIedvbR3XQzF+PkRCnoRyam1Kxo51zfIa/tOBV1KRmKtybH/qwt82oN0CnoRyal75tbSUDOhYMbUx+IJGmomMLt2YtClZE1GQW9mD5nZATOLm9nTYzz+YzPbmfo5aGZdqeX3pS3faWZ9ZvZ4tt+EiOSvkhLjqyuixOIJjnddDLqcTzU4NMxbhzq5d34dZoU7LfFo1wx6MysFngMeBhYB681sUfo67v4dd1/m7suAnwIvpZa/kbb8fqAXeC3L70FE8tzaFVHc4aXt+X1Qdnd7Nz19g6EZVjkikz36lUDc3Q+7+wDwIvDYp6y/HnhhjOVPAL91997rL1NEClnjtIncPWcaG7fl95j6WGsCM1g1t/iCvgFIb661pZZdwcxuA5qA18d4eB1j/wLAzJ40s61mtrWjo3AnQRKRq1u7opGPOnt598iZoEu5qlhrgjtnVTO1qjzoUrIq2wdj1wGb3P2ycVRmNhNYDGwZayN3f97dm929ub4+PEe6ReQTDy+ewaSKSN5efep8/yDbj54NXdsGMgv6dqAx7X40tWwsV9tr/xrwK3cPzxWDReS6TCyP8KUlM9m85wTn+weDLucK7xzuZHDYuTdE4+dHZBL07wHzzazJzMpJhvnLo1cyswXAVODtMZ7jan17ESkia5uj9A4MsXn3iaBLuUJLa4LKshJWzJ4adClZd82gd/dB4CmSbZf9wAZ332tmz5rZo2mrrgNe9FFHWsxsNslvBP+SraJFpDAtv3Uqc+qr2Lgt/8bUx+IJVjbVUhEpDbqUrMtoomV33wxsHrXsmVH3v3+VbT/kKgdvRaS4mBlrVzTyw1c/4EjiAk11VUGXBMCJ7ovET5/nD5sbr71yAdKZsSIyrr6yvIESg015tFcfa00A4ZmWeDQFvYiMq1umVPL536vnl9vaGRrOjzH1sXiCukkVLJgxOehSckJBLyLjbm1zIyfP9dHSGvx5M8PDzpvxBKvn1YZq2oN0CnoRGXcPLJxOzcSyvBhT/8HJHhLnB0I1W+VohX/VWxEpOBWRUh5f1sDf/u4jHvxRsAPyevqSY/rDNP/8aAp6EQnEN1c3cbZ3gEtDw0GXwvzpk5lRXRl0GTmjoBeRQDROm8hP1t0VdBlFQT16EZGQU9CLiIScgl5EJOQU9CIiIaegFxEJOQW9iEjIKehFREJOQS8iEnKWb1dkN7MO4KObeIo6IJGlcgqdPovL6fO4nD6PT4Ths7jN3cecsCfvgv5mmdlWd28Ouo58oM/icvo8LqfP4xNh/yzUuhERCTkFvYhIyIUx6J8PuoA8os/icvo8LqfP4xOh/ixC16MXEZHLhXGPXkRE0ijoRURCLjRBb2YPmdkBM4ub2dNB1xMkM2s0szfMbJ+Z7TWzPw26pqCZWamZ7TCzfwi6lqCZWY2ZbTKzD8xsv5l9NuiagmRm30n9P3nfzF4ws9BdaioUQW9mpcBzwMPAImC9mS0KtqpADQLfdfdFwN3AnxT55wHwp8D+oIvIEz8BXnX3BcBSivhzMbMG4D8Aze5+J1AKrAu2quwLRdADK4G4ux929wHgReCxgGsKjLufcPftqds9JP8jNwRbVXDMLAo8AvxN0LUEzcyqgc8B/wvA3QfcvSvYqgIXASaYWQSYCBwPuJ6sC0vQNwDH0u63UcTBls7MZgN3Ae8EW0mg/hL4z0DwV6EOXhPQAfyfVCvrb8ysKuiiguLu7cB/B44CJ4Bud38t2KqyLyxBL2Mws0nAL4H/6O7ngq4nCGb2JeC0u28LupY8EQGWA3/t7ncBF4CiPaZlZlNJfvtvAmYBVWb29WCryr6wBH070Jh2P5paVrTMrIxkyP+du78UdD0BWgU8amYfkmzp3W9mfxtsSYFqA9rcfeQb3iaSwV+svgAccfcOd78EvATcE3BNWReWoH8PmG9mTWZWTvJgyssB1xQYMzOSPdj97v6joOsJkrv/mbtH3X02yX8Xr7t76PbYMuXuJ4FjZnZ7atEDwL4ASwraUeBuM5uY+n/zACE8OB0JuoBscPdBM3sK2ELyqPn/dve9AZcVpFXAvwP2mNnO1LI/d/fNAdYk+ePbwN+ldooOA38UcD2Bcfd3zGwTsJ3kaLUdhHA6BE2BICIScmFp3YiIyFUo6EVEQk5BLyIScgp6EZGQU9CLiIScgl5EJOQU9CIiIff/AWm+N7rRaXZ6AAAAAElFTkSuQmCC\n", + "text/plain": [ + "<Figure size 432x288 with 1 Axes>" ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" } - ] -} \ No newline at end of file + ], + "source": [ + "%%time\n", + "\n", + "batch_size = 32 # size of batch the inner NN is trained with\n", + "learning_rate = 1e-1 # fix learning rate\n", + "ds = \"MNIST\" # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100)\n", + "toy_size = 1 # total propeortion of training and test set we use\n", + "max_epochs = 100 # max number of epochs that is run if early stopping is not hit\n", + "early_stop_num = 10 # max number of worse validation scores before early stopping is triggered\n", + "early_stop_flag = True # implement early stopping or not\n", + "average_validation = [15,25] # if not implementing early stopping, what epochs are we averaging over\n", + "num_policies = 5 # fix number of policies\n", + "num_sub_policies = 5 # fix number of sub-policies in a policy\n", + "iterations = 100 # total iterations, should be more than the number of policies\n", + "IsLeNet = \"SimpleNet\" # using LeNet or EasyNet or SimpleNet\n", + "\n", + "# generate random policies at start\n", + "policies = generate_policies(num_policies, num_sub_policies)\n", + "\n", + "q_values, best_q_values = run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, early_stop_flag, average_validation, iterations, IsLeNet)\n", + "\n", + "plt.plot(best_q_values)\n", + "\n", + "best_q_values = np.array(best_q_values)\n", + "save('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), best_q_values)\n", + "#best_q_values = load('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), allow_pickle=True)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "UCB1.ipynb", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.7" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +}