diff --git a/MetaAugment/UCB1_JC.ipynb b/MetaAugment/UCB1_JC.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..8cdc7a9b77d317acf2df35a8818a1cd09ae79ff9 --- /dev/null +++ b/MetaAugment/UCB1_JC.ipynb @@ -0,0 +1,341 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "UCB1.ipynb", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "code", + "source": [ + "import numpy as np\n", + "import torch\n", + "torch.manual_seed(0)\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import torch.optim as optim\n", + "import torch.utils.data as data_utils\n", + "import torchvision\n", + "import torchvision.datasets as datasets" + ], + "metadata": { + "id": "U_ZJ2LqDiu_v" + }, + "execution_count": 1, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "\"\"\"Define internal NN module that trains on the dataset\"\"\"\n", + "class LeNet(nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.conv1 = nn.Conv2d(1, 6, 5)\n", + " self.relu1 = nn.ReLU()\n", + " self.pool1 = nn.MaxPool2d(2)\n", + " self.conv2 = nn.Conv2d(6, 16, 5)\n", + " self.relu2 = nn.ReLU()\n", + " self.pool2 = nn.MaxPool2d(2)\n", + " self.fc1 = nn.Linear(256, 120)\n", + " self.relu3 = nn.ReLU()\n", + " self.fc2 = nn.Linear(120, 84)\n", + " self.relu4 = nn.ReLU()\n", + " self.fc3 = nn.Linear(84, 10)\n", + " self.relu5 = nn.ReLU()\n", + "\n", + " def forward(self, x):\n", + " y = self.conv1(x)\n", + " y = self.relu1(y)\n", + " y = self.pool1(y)\n", + " y = self.conv2(y)\n", + " y = self.relu2(y)\n", + " y = self.pool2(y)\n", + " y = y.view(y.shape[0], -1)\n", + " y = self.fc1(y)\n", + " y = self.relu3(y)\n", + " y = self.fc2(y)\n", + " y = self.relu4(y)\n", + " y = self.fc3(y)\n", + " y = self.relu5(y)\n", + " return y" + ], + "metadata": { + "id": "4ksS_duLFADW" + }, + "execution_count": 2, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "\"\"\"Make toy dataset\"\"\"\n", + "\n", + "def create_toy(train_dataset, test_dataset, batch_size, n_samples):\n", + " # shuffle and take first n_samples %age of training dataset\n", + " shuffled_train_dataset = torch.utils.data.Subset(train_dataset, torch.randperm(len(train_dataset)).tolist())\n", + " indices_train = torch.arange(int(n_samples*len(train_dataset)))\n", + " reduced_train_dataset = data_utils.Subset(shuffled_train_dataset, indices_train)\n", + "\n", + " # shuffle and take first n_samples %age of test dataset\n", + " shuffled_test_dataset = torch.utils.data.Subset(test_dataset, torch.randperm(len(test_dataset)).tolist())\n", + " indices_test = torch.arange(int(n_samples*len(test_dataset)))\n", + " reduced_test_dataset = data_utils.Subset(shuffled_test_dataset, indices_test)\n", + "\n", + " # push into DataLoader\n", + " train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=batch_size)\n", + " test_loader = torch.utils.data.DataLoader(reduced_test_dataset, batch_size=batch_size)\n", + "\n", + " return train_loader, test_loader" + ], + "metadata": { + "id": "xujQtvVWBgMH" + }, + "execution_count": 3, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "\"\"\"Randomly generate 10 policies\"\"\"\n", + "\"\"\"Each policy has 5 sub-policies\"\"\"\n", + "\"\"\"For each sub-policy, pick 2 transformations, 2 probabilities and 2 magnitudes\"\"\"\n", + "\n", + "def generate_policies(num_policies, num_sub_policies):\n", + " \n", + " policies = np.zeros([num_policies,num_sub_policies,6])\n", + "\n", + " # Policies array will be 10x5x6\n", + " for policy in range(num_policies):\n", + " for sub_policy in range(num_sub_policies):\n", + " # pick two sub_policy transformations (0=rotate, 1=shear, 2=scale)\n", + " policies[policy, sub_policy, 0] = np.random.randint(0,3)\n", + " policies[policy, sub_policy, 1] = np.random.randint(0,3)\n", + " while policies[policy, sub_policy, 0] == policies[policy, sub_policy, 1]:\n", + " policies[policy, sub_policy, 1] = np.random.randint(0,3)\n", + "\n", + " # pick probabilities\n", + " policies[policy, sub_policy, 2] = np.random.randint(0,11) / 10\n", + " policies[policy, sub_policy, 3] = np.random.randint(0,11) / 10\n", + "\n", + " # pick magnitudes\n", + " for transformation in range(2):\n", + " if policies[policy, sub_policy, transformation] <= 1:\n", + " policies[policy, sub_policy, transformation + 4] = np.random.randint(-4,5)*5\n", + " elif policies[policy, sub_policy, transformation] == 2:\n", + " policies[policy, sub_policy, transformation + 4] = np.random.randint(5,15)/10\n", + "\n", + " return policies" + ], + "metadata": { + "id": "Iql-c88jGGWy" + }, + "execution_count": 4, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "\"\"\"Pick policy and sub-policy\"\"\"\n", + "\"\"\"Each row of data should have a different sub-policy but for now, this will do\"\"\"\n", + "\n", + "def sample_sub_policy(policies, policy, num_sub_policies):\n", + " sub_policy = np.random.randint(0,num_sub_policies)\n", + "\n", + " degrees = 0\n", + " shear = 0\n", + " scale = 1\n", + "\n", + " # check for rotations\n", + " if policies[policy, sub_policy][0] == 0:\n", + " if np.random.uniform() < policies[policy, sub_policy][2]:\n", + " degrees = policies[policy, sub_policy][4]\n", + " elif policies[policy, sub_policy][1] == 0:\n", + " if np.random.uniform() < policies[policy, sub_policy][3]:\n", + " degrees = policies[policy, sub_policy][5]\n", + "\n", + " # check for shears\n", + " if policies[policy, sub_policy][0] == 1:\n", + " if np.random.uniform() < policies[policy, sub_policy][2]:\n", + " shear = policies[policy, sub_policy][4]\n", + " elif policies[policy, sub_policy][1] == 1:\n", + " if np.random.uniform() < policies[policy, sub_policy][3]:\n", + " shear = policies[policy, sub_policy][5]\n", + "\n", + " # check for scales\n", + " if policies[policy, sub_policy][0] == 2:\n", + " if np.random.uniform() < policies[policy, sub_policy][2]:\n", + " scale = policies[policy, sub_policy][4]\n", + " elif policies[policy, sub_policy][1] == 2:\n", + " if np.random.uniform() < policies[policy, sub_policy][3]:\n", + " scale = policies[policy, sub_policy][5]\n", + "\n", + " return degrees, shear, scale" + ], + "metadata": { + "id": "QE2VWI8o731X" + }, + "execution_count": 5, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "id": "vu_4I4qkbx73" + }, + "outputs": [], + "source": [ + "\"\"\"Sample policy, open and apply above transformations\"\"\"\n", + "def run_UCB1(policies, batch_size, toy_size, max_epochs, early_stop_num, iterations):\n", + "\n", + " # get number of policies and sub-policies\n", + " num_policies = len(policies)\n", + " num_sub_policies = len(policies[0])\n", + "\n", + " #Initialize vector weights, counts and regret\n", + " q_values = [0]*num_policies\n", + " cnts = [0]*num_policies\n", + " q_plus_cnt = [0]*num_policies\n", + " total_count = 0\n", + "\n", + " for policy in range(iterations):\n", + "\n", + " # get the action to try (either initially in order or using best q_plus_cnt value)\n", + " if policy >= num_policies:\n", + " this_policy = np.argmax(q_plus_cnt)\n", + " else:\n", + " this_policy = policy\n", + "\n", + " # get info of transformation for this sub-policy\n", + " degrees, shear, scale = sample_sub_policy(policies, this_policy, num_sub_policies)\n", + "\n", + " # create transformations using above info\n", + " transform = torchvision.transforms.Compose(\n", + " [torchvision.transforms.RandomAffine(degrees=(degrees,degrees), shear=(shear,shear), scale=(scale,scale)),\n", + " torchvision.transforms.ToTensor()])\n", + "\n", + " # open data and apply these transformations\n", + " train_dataset = datasets.MNIST(root='./MetaAugment/train', train=True, download=True, transform=transform)\n", + " test_dataset = datasets.MNIST(root='./MetaAugment/test', train=False, download=True, transform=transform)\n", + "\n", + " # create toy dataset from above uploaded data\n", + " train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, toy_size)\n", + "\n", + " # create model\n", + " model = LeNet()\n", + " sgd = optim.SGD(model.parameters(), lr=1e-1)\n", + " cost = nn.CrossEntropyLoss()\n", + "\n", + " # set variables for best validation accuracy and early stop count\n", + " best_acc = 0\n", + " early_stop_cnt = 0\n", + "\n", + " # train model and check validation accuracy each epoch\n", + " for _epoch in range(max_epochs):\n", + "\n", + " # train model\n", + " model.train()\n", + " for idx, (train_x, train_label) in enumerate(train_loader):\n", + " label_np = np.zeros((train_label.shape[0], 10))\n", + " sgd.zero_grad()\n", + " predict_y = model(train_x.float())\n", + " loss = cost(predict_y, train_label.long())\n", + " loss.backward()\n", + " sgd.step()\n", + "\n", + " # check validation accuracy on validation set\n", + " correct = 0\n", + " _sum = 0\n", + " model.eval()\n", + " for idx, (test_x, test_label) in enumerate(test_loader):\n", + " predict_y = model(test_x.float()).detach()\n", + " predict_ys = np.argmax(predict_y, axis=-1)\n", + " label_np = test_label.numpy()\n", + " _ = predict_ys == test_label\n", + " correct += np.sum(_.numpy(), axis=-1)\n", + " _sum += _.shape[0]\n", + " \n", + " # update best validation accuracy if it was higher, otherwise increase early stop count\n", + " acc = correct / _sum\n", + " if acc > best_acc :\n", + " best_acc = acc\n", + " early_stop_cnt = 0\n", + " else:\n", + " early_stop_cnt += 1\n", + "\n", + " # exit if validation gets worse over 10 runs\n", + " if early_stop_cnt >= early_stop_num:\n", + " break\n", + "\n", + " # update q_values\n", + " if policy < num_policies:\n", + " q_values[this_policy] += best_acc\n", + " else:\n", + " q_values[this_policy] = (q_values[this_policy]*cnts[this_policy] + best_acc) / (cnts[this_policy] + 1)\n", + "\n", + " print(q_values)\n", + "\n", + " # update counts\n", + " cnts[this_policy] += 1\n", + " total_count += 1\n", + "\n", + " # update q_plus_cnt values every turn after the initial sweep through\n", + " if policy >= num_policies - 1:\n", + " for i in range(num_policies):\n", + " q_plus_cnt[i] = q_values[i] + np.sqrt(2*np.log(total_count)/cnts[i])\n", + "\n", + " return q_values" + ] + }, + { + "cell_type": "code", + "source": [ + "%%time\n", + "\n", + "batch_size = 32 # size of batch inner NN is trained with\n", + "toy_size = 0.02 # total propeortion of training and test set we use\n", + "max_epochs = 100 # max number of epochs that is run if early stopping is not hit\n", + "early_stop_num = 10 # max number of worse validation scores before early stopping\n", + "iterations = 20 # total iterations, should be more than the number of policies\n", + "\n", + "# generate policies and sub-policies\n", + "num_policies = 10\n", + "num_sub_policies = 5\n", + "policies = generate_policies(num_policies, num_sub_policies)\n", + "\n", + "q_values = run_UCB1(policies, batch_size, toy_size, max_epochs, early_stop_num, iterations)\n", + "#print(q_values)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "doHUtJ_tEiA6", + "outputId": "6735e812-f7be-4f8b-cec2-52a069f7731b" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "10\n", + "5\n" + ] + } + ] + } + ] +} \ No newline at end of file