Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
M
MetaRL
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Wang, Mia
MetaRL
Commits
19d569df
Commit
19d569df
authored
2 years ago
by
John Carter
Browse files
Options
Downloads
Patches
Plain Diff
ucb1 set for evaluation
parent
06ae5603
No related branches found
Branches containing commit
No related tags found
Tags containing commit
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
MetaAugment/UCB1_JC.ipynb
+483
-473
483 additions, 473 deletions
MetaAugment/UCB1_JC.ipynb
with
483 additions
and
473 deletions
MetaAugment/UCB1_JC.ipynb
+
483
−
473
View file @
19d569df
{
{
"nbformat": 4,
"cells": [
"nbformat_minor": 0,
{
"metadata": {
"cell_type": "code",
"colab": {
"execution_count": 1,
"name": "UCB1.ipynb",
"metadata": {
"provenance": [],
"id": "U_ZJ2LqDiu_v"
"collapsed_sections": []
},
},
"outputs": [],
"kernelspec": {
"source": [
"name": "python3",
"import numpy as np\n",
"display_name": "Python 3"
"import torch\n",
},
"torch.manual_seed(0)\n",
"language_info": {
"import torch.nn as nn\n",
"name": "python"
"import torch.nn.functional as F\n",
},
"import torch.optim as optim\n",
"accelerator": "GPU"
"import torch.utils.data as data_utils\n",
"import torchvision\n",
"import torchvision.datasets as datasets\n",
"\n",
"from matplotlib import pyplot as plt\n",
"from numpy import save, load\n",
"from tqdm import trange"
]
},
},
"cells": [
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count": 2,
"source": [
"metadata": {
"import numpy as np\n",
"id": "4ksS_duLFADW"
"import torch\n",
},
"torch.manual_seed(0)\n",
"outputs": [],
"import torch.nn as nn\n",
"source": [
"import torch.nn.functional as F\n",
"\"\"\"Define internal NN module that trains on the dataset\"\"\"\n",
"import torch.optim as optim\n",
"class LeNet(nn.Module):\n",
"import torch.utils.data as data_utils\n",
" def __init__(self, img_height, img_width, num_labels, img_channels):\n",
"import torchvision\n",
" super().__init__()\n",
"import torchvision.datasets as datasets\n",
" self.conv1 = nn.Conv2d(img_channels, 6, 5)\n",
"\n",
" self.relu1 = nn.ReLU()\n",
"from matplotlib import pyplot as plt\n",
" self.pool1 = nn.MaxPool2d(2)\n",
"from numpy import save, load\n",
" self.conv2 = nn.Conv2d(6, 16, 5)\n",
"from tqdm import trange"
" self.relu2 = nn.ReLU()\n",
],
" self.pool2 = nn.MaxPool2d(2)\n",
"metadata": {
" self.fc1 = nn.Linear(int((((img_height-4)/2-4)/2)*(((img_width-4)/2-4)/2)*16), 120)\n",
"id": "U_ZJ2LqDiu_v"
" self.relu3 = nn.ReLU()\n",
},
" self.fc2 = nn.Linear(120, 84)\n",
"execution_count": 1,
" self.relu4 = nn.ReLU()\n",
"outputs": []
" self.fc3 = nn.Linear(84, num_labels)\n",
},
" self.relu5 = nn.ReLU()\n",
{
"\n",
"cell_type": "code",
" def forward(self, x):\n",
"source": [
" y = self.conv1(x)\n",
"\"\"\"Define internal NN module that trains on the dataset\"\"\"\n",
" y = self.relu1(y)\n",
"class LeNet(nn.Module):\n",
" y = self.pool1(y)\n",
" def __init__(self, img_height, img_width, num_labels, img_channels):\n",
" y = self.conv2(y)\n",
" super().__init__()\n",
" y = self.relu2(y)\n",
" self.conv1 = nn.Conv2d(img_channels, 6, 5)\n",
" y = self.pool2(y)\n",
" self.relu1 = nn.ReLU()\n",
" y = y.view(y.shape[0], -1)\n",
" self.pool1 = nn.MaxPool2d(2)\n",
" y = self.fc1(y)\n",
" self.conv2 = nn.Conv2d(6, 16, 5)\n",
" y = self.relu3(y)\n",
" self.relu2 = nn.ReLU()\n",
" y = self.fc2(y)\n",
" self.pool2 = nn.MaxPool2d(2)\n",
" y = self.relu4(y)\n",
" self.fc1 = nn.Linear(int((((img_height-4)/2-4)/2)*(((img_width-4)/2-4)/2)*16), 120)\n",
" y = self.fc3(y)\n",
" self.relu3 = nn.ReLU()\n",
" y = self.relu5(y)\n",
" self.fc2 = nn.Linear(120, 84)\n",
" return y"
" self.relu4 = nn.ReLU()\n",
]
" self.fc3 = nn.Linear(84, num_labels)\n",
},
" self.relu5 = nn.ReLU()\n",
{
"\n",
"cell_type": "code",
" def forward(self, x):\n",
"execution_count": 3,
" y = self.conv1(x)\n",
"metadata": {
" y = self.relu1(y)\n",
"id": "LckxnUXGfxjW"
" y = self.pool1(y)\n",
},
" y = self.conv2(y)\n",
"outputs": [],
" y = self.relu2(y)\n",
"source": [
" y = self.pool2(y)\n",
"\"\"\"Define internal NN module that trains on the dataset\"\"\"\n",
" y = y.view(y.shape[0], -1)\n",
"class EasyNet(nn.Module):\n",
" y = self.fc1(y)\n",
" def __init__(self, img_height, img_width, num_labels, img_channels):\n",
" y = self.relu3(y)\n",
" super().__init__()\n",
" y = self.fc2(y)\n",
" self.fc1 = nn.Linear(img_height*img_width*img_channels, 2048)\n",
" y = self.relu4(y)\n",
" self.relu1 = nn.ReLU()\n",
" y = self.fc3(y)\n",
" self.fc2 = nn.Linear(2048, num_labels)\n",
" y = self.relu5(y)\n",
" self.relu2 = nn.ReLU()\n",
" return y"
"\n",
],
" def forward(self, x):\n",
"metadata": {
" y = x.view(x.shape[0], -1)\n",
"id": "4ksS_duLFADW"
" y = self.fc1(y)\n",
},
" y = self.relu1(y)\n",
"execution_count": 2,
" y = self.fc2(y)\n",
"outputs": []
" y = self.relu2(y)\n",
},
" return y"
{
]
"cell_type": "code",
},
"source": [
{
"\"\"\"Define internal NN module that trains on the dataset\"\"\"\n",
"cell_type": "code",
"class EasyNet(nn.Module):\n",
"execution_count": 4,
" def __init__(self, img_height, img_width, num_labels, img_channels):\n",
"metadata": {
" super().__init__()\n",
"id": "enaD2xbw5hew"
" self.fc1 = nn.Linear(img_height*img_width*img_channels, 2048)\n",
},
" self.relu1 = nn.ReLU()\n",
"outputs": [],
" self.fc2 = nn.Linear(2048, num_labels)\n",
"source": [
" self.relu2 = nn.ReLU()\n",
"\"\"\"Define internal NN module that trains on the dataset\"\"\"\n",
"\n",
"class SimpleNet(nn.Module):\n",
" def forward(self, x):\n",
" def __init__(self, img_height, img_width, num_labels, img_channels):\n",
" y = x.view(x.shape[0], -1)\n",
" super().__init__()\n",
" y = self.fc1(y)\n",
" self.fc1 = nn.Linear(img_height*img_width*img_channels, num_labels)\n",
" y = self.relu1(y)\n",
" self.relu1 = nn.ReLU()\n",
" y = self.fc2(y)\n",
"\n",
" y = self.relu2(y)\n",
" def forward(self, x):\n",
" return y"
" y = x.view(x.shape[0], -1)\n",
],
" y = self.fc1(y)\n",
"metadata": {
" y = self.relu1(y)\n",
"id": "LckxnUXGfxjW"
" return y"
},
]
"execution_count": 3,
},
"outputs": []
{
},
"cell_type": "code",
{
"execution_count": 5,
"cell_type": "code",
"metadata": {
"source": [
"id": "xujQtvVWBgMH"
"\"\"\"Define internal NN module that trains on the dataset\"\"\"\n",
},
"class SimpleNet(nn.Module):\n",
"outputs": [],
" def __init__(self, img_height, img_width, num_labels, img_channels):\n",
"source": [
" super().__init__()\n",
"\"\"\"Make toy dataset\"\"\"\n",
" self.fc1 = nn.Linear(img_height*img_width*img_channels, num_labels)\n",
"\n",
" self.relu1 = nn.ReLU()\n",
"def create_toy(train_dataset, test_dataset, batch_size, n_samples):\n",
"\n",
" \n",
" def forward(self, x):\n",
" # shuffle and take first n_samples %age of training dataset\n",
" y = x.view(x.shape[0], -1)\n",
" shuffle_order_train = np.random.RandomState(seed=100).permutation(len(train_dataset))\n",
" y = self.fc1(y)\n",
" shuffled_train_dataset = torch.utils.data.Subset(train_dataset, shuffle_order_train)\n",
" y = self.relu1(y)\n",
" indices_train = torch.arange(int(n_samples*len(train_dataset)))\n",
" return y"
" reduced_train_dataset = data_utils.Subset(shuffled_train_dataset, indices_train)\n",
],
"\n",
"metadata": {
" # shuffle and take first n_samples %age of test dataset\n",
"id": "enaD2xbw5hew"
" shuffle_order_test = np.random.RandomState(seed=1000).permutation(len(test_dataset))\n",
},
" shuffled_test_dataset = torch.utils.data.Subset(test_dataset, shuffle_order_test)\n",
"execution_count": 4,
" indices_test = torch.arange(int(n_samples*len(test_dataset)))\n",
"outputs": []
" reduced_test_dataset = data_utils.Subset(shuffled_test_dataset, indices_test)\n",
},
"\n",
{
" # push into DataLoader\n",
"cell_type": "code",
" train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=batch_size)\n",
"source": [
" test_loader = torch.utils.data.DataLoader(reduced_test_dataset, batch_size=batch_size)\n",
"\"\"\"Make toy dataset\"\"\"\n",
"\n",
"\n",
" return train_loader, test_loader"
"def create_toy(train_dataset, test_dataset, batch_size, n_samples):\n",
]
" \n",
},
" # shuffle and take first n_samples %age of training dataset\n",
{
" shuffle_order_train = np.random.RandomState(seed=100).permutation(len(train_dataset))\n",
"cell_type": "code",
" shuffled_train_dataset = torch.utils.data.Subset(train_dataset, shuffle_order_train)\n",
"execution_count": 6,
" indices_train = torch.arange(int(n_samples*len(train_dataset)))\n",
"metadata": {
" reduced_train_dataset = data_utils.Subset(shuffled_train_dataset, indices_train)\n",
"id": "Iql-c88jGGWy"
"\n",
},
" # shuffle and take first n_samples %age of test dataset\n",
"outputs": [],
" shuffle_order_test = np.random.RandomState(seed=1000).permutation(len(test_dataset))\n",
"source": [
" shuffled_test_dataset = torch.utils.data.Subset(test_dataset, shuffle_order_test)\n",
"\"\"\"Randomly generate 10 policies\"\"\"\n",
" indices_test = torch.arange(int(n_samples*len(test_dataset)))\n",
"\"\"\"Each policy has 5 sub-policies\"\"\"\n",
" reduced_test_dataset = data_utils.Subset(shuffled_test_dataset, indices_test)\n",
"\"\"\"For each sub-policy, pick 2 transformations, 2 probabilities and 2 magnitudes\"\"\"\n",
"\n",
"\n",
" # push into DataLoader\n",
"def generate_policies(num_policies, num_sub_policies):\n",
" train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=batch_size)\n",
" \n",
" test_loader = torch.utils.data.DataLoader(reduced_test_dataset, batch_size=batch_size)\n",
" policies = np.zeros([num_policies,num_sub_policies,6])\n",
"\n",
"\n",
" return train_loader, test_loader"
" # Policies array will be 10x5x6\n",
],
" for policy in range(num_policies):\n",
"metadata": {
" for sub_policy in range(num_sub_policies):\n",
"id": "xujQtvVWBgMH"
" # pick two sub_policy transformations (0=rotate, 1=shear, 2=scale)\n",
},
" policies[policy, sub_policy, 0] = np.random.randint(0,3)\n",
"execution_count": 5,
" policies[policy, sub_policy, 1] = np.random.randint(0,3)\n",
"outputs": []
" while policies[policy, sub_policy, 0] == policies[policy, sub_policy, 1]:\n",
" policies[policy, sub_policy, 1] = np.random.randint(0,3)\n",
"\n",
" # pick probabilities\n",
" policies[policy, sub_policy, 2] = np.random.randint(0,11) / 10\n",
" policies[policy, sub_policy, 3] = np.random.randint(0,11) / 10\n",
"\n",
" # pick magnitudes\n",
" for transformation in range(2):\n",
" if policies[policy, sub_policy, transformation] <= 1:\n",
" policies[policy, sub_policy, transformation + 4] = np.random.randint(-4,5)*5\n",
" elif policies[policy, sub_policy, transformation] == 2:\n",
" policies[policy, sub_policy, transformation + 4] = np.random.randint(5,15)/10\n",
"\n",
" return policies"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "QE2VWI8o731X"
},
"outputs": [],
"source": [
"\"\"\"Pick policy and sub-policy\"\"\"\n",
"\"\"\"Each row of data should have a different sub-policy but for now, this will do\"\"\"\n",
"\n",
"def sample_sub_policy(policies, policy, num_sub_policies):\n",
" sub_policy = np.random.randint(0,num_sub_policies)\n",
"\n",
" degrees = 0\n",
" shear = 0\n",
" scale = 1\n",
"\n",
" # check for rotations\n",
" if policies[policy, sub_policy][0] == 0:\n",
" if np.random.uniform() < policies[policy, sub_policy][2]:\n",
" degrees = policies[policy, sub_policy][4]\n",
" elif policies[policy, sub_policy][1] == 0:\n",
" if np.random.uniform() < policies[policy, sub_policy][3]:\n",
" degrees = policies[policy, sub_policy][5]\n",
"\n",
" # check for shears\n",
" if policies[policy, sub_policy][0] == 1:\n",
" if np.random.uniform() < policies[policy, sub_policy][2]:\n",
" shear = policies[policy, sub_policy][4]\n",
" elif policies[policy, sub_policy][1] == 1:\n",
" if np.random.uniform() < policies[policy, sub_policy][3]:\n",
" shear = policies[policy, sub_policy][5]\n",
"\n",
" # check for scales\n",
" if policies[policy, sub_policy][0] == 2:\n",
" if np.random.uniform() < policies[policy, sub_policy][2]:\n",
" scale = policies[policy, sub_policy][4]\n",
" elif policies[policy, sub_policy][1] == 2:\n",
" if np.random.uniform() < policies[policy, sub_policy][3]:\n",
" scale = policies[policy, sub_policy][5]\n",
"\n",
" return degrees, shear, scale"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "vu_4I4qkbx73"
},
"outputs": [],
"source": [
"\"\"\"Sample policy, open and apply above transformations\"\"\"\n",
"def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, early_stop_flag, average_validation, iterations, IsLeNet):\n",
"\n",
" # get number of policies and sub-policies\n",
" num_policies = len(policies)\n",
" num_sub_policies = len(policies[0])\n",
"\n",
" #Initialize vector weights, counts and regret\n",
" q_values = [0]*num_policies\n",
" cnts = [0]*num_policies\n",
" q_plus_cnt = [0]*num_policies\n",
" total_count = 0\n",
"\n",
" best_q_values = []\n",
"\n",
" for policy in trange(iterations):\n",
"\n",
" # get the action to try (either initially in order or using best q_plus_cnt value)\n",
" if policy >= num_policies:\n",
" this_policy = np.argmax(q_plus_cnt)\n",
" else:\n",
" this_policy = policy\n",
"\n",
" # get info of transformation for this sub-policy\n",
" degrees, shear, scale = sample_sub_policy(policies, this_policy, num_sub_policies)\n",
"\n",
" # create transformations using above info\n",
" transform = torchvision.transforms.Compose(\n",
" [torchvision.transforms.RandomAffine(degrees=(degrees,degrees), shear=(shear,shear), scale=(scale,scale)),\n",
" torchvision.transforms.ToTensor()])\n",
"\n",
" # open data and apply these transformations\n",
" if ds == \"MNIST\":\n",
" train_dataset = datasets.MNIST(root='./MetaAugment/train', train=True, download=True, transform=transform)\n",
" test_dataset = datasets.MNIST(root='./MetaAugment/test', train=False, download=True, transform=transform)\n",
" elif ds == \"KMNIST\":\n",
" train_dataset = datasets.KMNIST(root='./MetaAugment/train', train=True, download=True, transform=transform)\n",
" test_dataset = datasets.KMNIST(root='./MetaAugment/test', train=False, download=True, transform=transform)\n",
" elif ds == \"FashionMNIST\":\n",
" train_dataset = datasets.FashionMNIST(root='./MetaAugment/train', train=True, download=True, transform=transform)\n",
" test_dataset = datasets.FashionMNIST(root='./MetaAugment/test', train=False, download=True, transform=transform)\n",
" elif ds == \"CIFAR10\":\n",
" train_dataset = datasets.CIFAR10(root='./MetaAugment/train', train=True, download=True, transform=transform)\n",
" test_dataset = datasets.CIFAR10(root='./MetaAugment/test', train=False, download=True, transform=transform)\n",
" elif ds == \"CIFAR100\":\n",
" train_dataset = datasets.CIFAR100(root='./MetaAugment/train', train=True, download=True, transform=transform)\n",
" test_dataset = datasets.CIFAR100(root='./MetaAugment/test', train=False, download=True, transform=transform)\n",
"\n",
" # check sizes of images\n",
" img_height = len(train_dataset[0][0][0])\n",
" img_width = len(train_dataset[0][0][0][0])\n",
" img_channels = len(train_dataset[0][0])\n",
"\n",
" # check output labels\n",
" if ds == \"CIFAR10\" or ds == \"CIFAR100\":\n",
" num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1)\n",
" else:\n",
" num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1).item()\n",
"\n",
" # create toy dataset from above uploaded data\n",
" train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, toy_size)\n",
"\n",
" # create model\n",
" device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
" if IsLeNet == \"LeNet\":\n",
" model = LeNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device)\n",
" elif IsLeNet == \"EasyNet\":\n",
" model = EasyNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device)\n",
" else:\n",
" model = SimpleNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device)\n",
" sgd = optim.SGD(model.parameters(), lr=1e-1)\n",
" cost = nn.CrossEntropyLoss()\n",
"\n",
" # set variables for best validation accuracy and early stop count\n",
" best_acc = 0\n",
" early_stop_cnt = 0\n",
" total_val = 0\n",
"\n",
" # train model and check validation accuracy each epoch\n",
" for _epoch in range(max_epochs):\n",
"\n",
" # train model\n",
" model.train()\n",
" for idx, (train_x, train_label) in enumerate(train_loader):\n",
" train_x, train_label = train_x.to(device), train_label.to(device) # new code\n",
" label_np = np.zeros((train_label.shape[0], num_labels))\n",
" sgd.zero_grad()\n",
" predict_y = model(train_x.float())\n",
" loss = cost(predict_y, train_label.long())\n",
" loss.backward()\n",
" sgd.step()\n",
"\n",
" # check validation accuracy on validation set\n",
" correct = 0\n",
" _sum = 0\n",
" model.eval()\n",
" for idx, (test_x, test_label) in enumerate(test_loader):\n",
" test_x, test_label = test_x.to(device), test_label.to(device) # new code\n",
" predict_y = model(test_x.float()).detach()\n",
" #predict_ys = np.argmax(predict_y, axis=-1)\n",
" predict_ys = torch.argmax(predict_y, axis=-1) # changed np to torch\n",
" #label_np = test_label.numpy()\n",
" _ = predict_ys == test_label\n",
" #correct += np.sum(_.numpy(), axis=-1)\n",
" correct += np.sum(_.cpu().numpy(), axis=-1) # added .cpu()\n",
" _sum += _.shape[0]\n",
" \n",
" acc = correct / _sum\n",
"\n",
" if average_validation[0] <= _epoch <= average_validation[1]:\n",
" total_val += acc\n",
"\n",
" # update best validation accuracy if it was higher, otherwise increase early stop count\n",
" if acc > best_acc :\n",
" best_acc = acc\n",
" early_stop_cnt = 0\n",
" else:\n",
" early_stop_cnt += 1\n",
"\n",
" # exit if validation gets worse over 10 runs and using early stopping\n",
" if early_stop_cnt >= early_stop_num and early_stop_flag:\n",
" break\n",
"\n",
" # exit if using fixed epoch length\n",
" if _epoch >= average_validation[1] and not early_stop_flag:\n",
" best_acc = total_val / (average_validation[1] - average_validation[0] + 1)\n",
" break\n",
"\n",
" # update q_values\n",
" if policy < num_policies:\n",
" q_values[this_policy] += best_acc\n",
" else:\n",
" q_values[this_policy] = (q_values[this_policy]*cnts[this_policy] + best_acc) / (cnts[this_policy] + 1)\n",
"\n",
" best_q_value = max(q_values)\n",
" best_q_values.append(best_q_value)\n",
"\n",
" if (policy+1) % 10 == 0:\n",
" print(\"Iteration: {},\\tQ-Values: {}, Best Policy: {}\".format(policy+1, list(np.around(np.array(q_values),2)), max(list(np.around(np.array(q_values),2)))))\n",
"\n",
" # update counts\n",
" cnts[this_policy] += 1\n",
" total_count += 1\n",
"\n",
" # update q_plus_cnt values every turn after the initial sweep through\n",
" if policy >= num_policies - 1:\n",
" for i in range(num_policies):\n",
" q_plus_cnt[i] = q_values[i] + np.sqrt(2*np.log(total_count)/cnts[i])\n",
"\n",
" return q_values, best_q_values"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 342
},
},
"id": "doHUtJ_tEiA6",
"outputId": "8195ba17-c95f-4b75-d8dc-19c5d76d5e43"
},
"outputs": [
{
{
"cell_type": "code",
"name": "stderr",
"source": [
"output_type": "stream",
"\"\"\"Randomly generate 10 policies\"\"\"\n",
"text": [
"\"\"\"Each policy has 5 sub-policies\"\"\"\n",
"100%|██████████| 10/10 [01:28<00:00, 8.84s/it]"
"\"\"\"For each sub-policy, pick 2 transformations, 2 probabilities and 2 magnitudes\"\"\"\n",
]
"\n",
"def generate_policies(num_policies, num_sub_policies):\n",
" \n",
" policies = np.zeros([num_policies,num_sub_policies,6])\n",
"\n",
" # Policies array will be 10x5x6\n",
" for policy in range(num_policies):\n",
" for sub_policy in range(num_sub_policies):\n",
" # pick two sub_policy transformations (0=rotate, 1=shear, 2=scale)\n",
" policies[policy, sub_policy, 0] = np.random.randint(0,3)\n",
" policies[policy, sub_policy, 1] = np.random.randint(0,3)\n",
" while policies[policy, sub_policy, 0] == policies[policy, sub_policy, 1]:\n",
" policies[policy, sub_policy, 1] = np.random.randint(0,3)\n",
"\n",
" # pick probabilities\n",
" policies[policy, sub_policy, 2] = np.random.randint(0,11) / 10\n",
" policies[policy, sub_policy, 3] = np.random.randint(0,11) / 10\n",
"\n",
" # pick magnitudes\n",
" for transformation in range(2):\n",
" if policies[policy, sub_policy, transformation] <= 1:\n",
" policies[policy, sub_policy, transformation + 4] = np.random.randint(-4,5)*5\n",
" elif policies[policy, sub_policy, transformation] == 2:\n",
" policies[policy, sub_policy, transformation + 4] = np.random.randint(5,15)/10\n",
"\n",
" return policies"
],
"metadata": {
"id": "Iql-c88jGGWy"
},
"execution_count": 6,
"outputs": []
},
},
{
{
"cell_type": "code",
"name": "stdout",
"source": [
"output_type": "stream",
"\"\"\"Pick policy and sub-policy\"\"\"\n",
"text": [
"\"\"\"Each row of data should have a different sub-policy but for now, this will do\"\"\"\n",
"Iteration: 10,\tQ-Values: [0.77, 0.74, 0.8, 0.72, 0.77], Best Policy: 0.8\n",
"\n",
"CPU times: user 1min 21s, sys: 694 ms, total: 1min 22s\n",
"def sample_sub_policy(policies, policy, num_sub_policies):\n",
"Wall time: 1min 28s\n"
" sub_policy = np.random.randint(0,num_sub_policies)\n",
]
"\n",
" degrees = 0\n",
" shear = 0\n",
" scale = 1\n",
"\n",
" # check for rotations\n",
" if policies[policy, sub_policy][0] == 0:\n",
" if np.random.uniform() < policies[policy, sub_policy][2]:\n",
" degrees = policies[policy, sub_policy][4]\n",
" elif policies[policy, sub_policy][1] == 0:\n",
" if np.random.uniform() < policies[policy, sub_policy][3]:\n",
" degrees = policies[policy, sub_policy][5]\n",
"\n",
" # check for shears\n",
" if policies[policy, sub_policy][0] == 1:\n",
" if np.random.uniform() < policies[policy, sub_policy][2]:\n",
" shear = policies[policy, sub_policy][4]\n",
" elif policies[policy, sub_policy][1] == 1:\n",
" if np.random.uniform() < policies[policy, sub_policy][3]:\n",
" shear = policies[policy, sub_policy][5]\n",
"\n",
" # check for scales\n",
" if policies[policy, sub_policy][0] == 2:\n",
" if np.random.uniform() < policies[policy, sub_policy][2]:\n",
" scale = policies[policy, sub_policy][4]\n",
" elif policies[policy, sub_policy][1] == 2:\n",
" if np.random.uniform() < policies[policy, sub_policy][3]:\n",
" scale = policies[policy, sub_policy][5]\n",
"\n",
" return degrees, shear, scale"
],
"metadata": {
"id": "QE2VWI8o731X"
},
"execution_count": 7,
"outputs": []
},
},
{
{
"cell_type": "code",
"name": "stderr",
"execution_count": 8,
"output_type": "stream",
"metadata": {
"text": [
"id": "vu_4I4qkbx73"
"\n"
},
]
"outputs": [],
"source": [
"\"\"\"Sample policy, open and apply above transformations\"\"\"\n",
"def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, early_stop_flag, average_validation, iterations, IsLeNet):\n",
"\n",
" # get number of policies and sub-policies\n",
" num_policies = len(policies)\n",
" num_sub_policies = len(policies[0])\n",
"\n",
" #Initialize vector weights, counts and regret\n",
" q_values = [0]*num_policies\n",
" cnts = [0]*num_policies\n",
" q_plus_cnt = [0]*num_policies\n",
" total_count = 0\n",
"\n",
" best_q_values = []\n",
"\n",
" for policy in trange(iterations):\n",
"\n",
" # get the action to try (either initially in order or using best q_plus_cnt value)\n",
" if policy >= num_policies:\n",
" this_policy = np.argmax(q_plus_cnt)\n",
" else:\n",
" this_policy = policy\n",
"\n",
" # get info of transformation for this sub-policy\n",
" degrees, shear, scale = sample_sub_policy(policies, this_policy, num_sub_policies)\n",
"\n",
" # create transformations using above info\n",
" transform = torchvision.transforms.Compose(\n",
" [torchvision.transforms.RandomAffine(degrees=(degrees,degrees), shear=(shear,shear), scale=(scale,scale)),\n",
" torchvision.transforms.ToTensor()])\n",
"\n",
" # open data and apply these transformations\n",
" if ds == \"MNIST\":\n",
" train_dataset = datasets.MNIST(root='./MetaAugment/train', train=True, download=True, transform=transform)\n",
" test_dataset = datasets.MNIST(root='./MetaAugment/test', train=False, download=True, transform=transform)\n",
" elif ds == \"KMNIST\":\n",
" train_dataset = datasets.KMNIST(root='./MetaAugment/train', train=True, download=True, transform=transform)\n",
" test_dataset = datasets.KMNIST(root='./MetaAugment/test', train=False, download=True, transform=transform)\n",
" elif ds == \"FashionMNIST\":\n",
" train_dataset = datasets.FashionMNIST(root='./MetaAugment/train', train=True, download=True, transform=transform)\n",
" test_dataset = datasets.FashionMNIST(root='./MetaAugment/test', train=False, download=True, transform=transform)\n",
" elif ds == \"CIFAR10\":\n",
" train_dataset = datasets.CIFAR10(root='./MetaAugment/train', train=True, download=True, transform=transform)\n",
" test_dataset = datasets.CIFAR10(root='./MetaAugment/test', train=False, download=True, transform=transform)\n",
" elif ds == \"CIFAR100\":\n",
" train_dataset = datasets.CIFAR100(root='./MetaAugment/train', train=True, download=True, transform=transform)\n",
" test_dataset = datasets.CIFAR100(root='./MetaAugment/test', train=False, download=True, transform=transform)\n",
"\n",
" # check sizes of images\n",
" img_height = len(train_dataset[0][0][0])\n",
" img_width = len(train_dataset[0][0][0][0])\n",
" img_channels = len(train_dataset[0][0])\n",
"\n",
" # check output labels\n",
" if ds == \"CIFAR10\" or ds == \"CIFAR100\":\n",
" num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1)\n",
" else:\n",
" num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1).item()\n",
"\n",
" # create toy dataset from above uploaded data\n",
" train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, toy_size)\n",
"\n",
" # create model\n",
" device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
" if IsLeNet == \"LeNet\":\n",
" model = LeNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device)\n",
" elif IsLeNet == \"EasyNet\":\n",
" model = EasyNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device)\n",
" else:\n",
" model = SimpleNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device)\n",
" sgd = optim.SGD(model.parameters(), lr=1e-1)\n",
" cost = nn.CrossEntropyLoss()\n",
"\n",
" # set variables for best validation accuracy and early stop count\n",
" best_acc = 0\n",
" early_stop_cnt = 0\n",
" total_val = 0\n",
"\n",
" # train model and check validation accuracy each epoch\n",
" for _epoch in range(max_epochs):\n",
"\n",
" # train model\n",
" model.train()\n",
" for idx, (train_x, train_label) in enumerate(train_loader):\n",
" train_x, train_label = train_x.to(device), train_label.to(device) # new code\n",
" label_np = np.zeros((train_label.shape[0], num_labels))\n",
" sgd.zero_grad()\n",
" predict_y = model(train_x.float())\n",
" loss = cost(predict_y, train_label.long())\n",
" loss.backward()\n",
" sgd.step()\n",
"\n",
" # check validation accuracy on validation set\n",
" correct = 0\n",
" _sum = 0\n",
" model.eval()\n",
" for idx, (test_x, test_label) in enumerate(test_loader):\n",
" test_x, test_label = test_x.to(device), test_label.to(device) # new code\n",
" predict_y = model(test_x.float()).detach()\n",
" #predict_ys = np.argmax(predict_y, axis=-1)\n",
" predict_ys = torch.argmax(predict_y, axis=-1) # changed np to torch\n",
" #label_np = test_label.numpy()\n",
" _ = predict_ys == test_label\n",
" #correct += np.sum(_.numpy(), axis=-1)\n",
" correct += np.sum(_.cpu().numpy(), axis=-1) # added .cpu()\n",
" _sum += _.shape[0]\n",
" \n",
" acc = correct / _sum\n",
"\n",
" if average_validation[0] <= _epoch <= average_validation[1]:\n",
" total_val += acc\n",
"\n",
" # update best validation accuracy if it was higher, otherwise increase early stop count\n",
" if acc > best_acc :\n",
" best_acc = acc\n",
" early_stop_cnt = 0\n",
" else:\n",
" early_stop_cnt += 1\n",
"\n",
" # exit if validation gets worse over 10 runs and using early stopping\n",
" if early_stop_cnt >= early_stop_num and early_stop_flag:\n",
" break\n",
"\n",
" # exit if using fixed epoch length\n",
" if _epoch >= average_validation[1] and not early_stop_flag:\n",
" best_acc = total_val / (average_validation[1] - average_validation[0] + 1)\n",
" break\n",
"\n",
" # update q_values\n",
" if policy < num_policies:\n",
" q_values[this_policy] += best_acc\n",
" else:\n",
" q_values[this_policy] = (q_values[this_policy]*cnts[this_policy] + best_acc) / (cnts[this_policy] + 1)\n",
"\n",
" best_q_value = max(q_values)\n",
" best_q_values.append(best_q_value)\n",
"\n",
" if (policy+1) % 10 == 0:\n",
" print(\"Iteration: {},\\tQ-Values: {}, Best Policy: {}\".format(policy+1, list(np.around(np.array(q_values),2)), max(list(np.around(np.array(q_values),2)))))\n",
"\n",
" # update counts\n",
" cnts[this_policy] += 1\n",
" total_count += 1\n",
"\n",
" # update q_plus_cnt values every turn after the initial sweep through\n",
" if policy >= num_policies - 1:\n",
" for i in range(num_policies):\n",
" q_plus_cnt[i] = q_values[i] + np.sqrt(2*np.log(total_count)/cnts[i])\n",
"\n",
" return q_values, best_q_values"
]
},
},
{
{
"cell_type": "code",
"data": {
"source": [
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD7CAYAAABkO19ZAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAfGklEQVR4nO3da3BU95nn8e8jtS4gQAJJGFDLRlzGgM3FoCGOwcnYjis2ztpOYjJQm93KVCqeF+PsbDapLc/UrCvl2rzI1m4ymawzVa7Z3VRlZu0F4mScMTGuGXtm1LZjmzsGDGrABombWiAhEJKQ9OyLbtmNEKaBbp3u079PlYru0+d0P93AT6ef8z//Y+6OiIiEV0nQBYiISG4p6EVEQk5BLyIScgp6EZGQU9CLiIScgl5EJOQyCnoze8jMDphZ3MyeHuPxW83sDTPbYWa7zWxNavmDZrbNzPak/rw/229AREQ+nV1rHL2ZlQIHgQeBNuA9YL2770tb53lgh7v/tZktAja7+2wzuws45e7HzexOYIu7N+TqzYiIyJUiGayzEoi7+2EAM3sReAzYl7aOA1NSt6uB4wDuviNtnb3ABDOrcPf+q71YXV2dz549O+M3ICIisG3btoS714/1WCZB3wAcS7vfBnxm1DrfB14zs28DVcAXxnierwLbxwp5M3sSeBLg1ltvZevWrRmUJSIiI8zso6s9lq2DseuBn7t7FFgD/MLMPn5uM7sD+CHwx2Nt7O7Pu3uzuzfX14/5C0lERG5QJkHfDjSm3Y+mlqX7JrABwN3fBiqBOgAziwK/Av69ux+62YJFROT6ZBL07wHzzazJzMqBdcDLo9Y5CjwAYGYLSQZ9h5nVAK8AT7v7m9krW0REMnXNoHf3QeApYAuwH9jg7nvN7FkzezS12neBb5nZLuAF4BueHM7zFDAPeMbMdqZ+pufknYiIyJiuObxyvDU3N7sOxoqIXB8z2+buzWM9pjNjRURCTkEvIhJymYyjlwy4OyfP9bHrWBcHTp5naHg46JJklPoplSyL1nD7jMmUR7SPI8VDQX+Dui9eYk9bN7vauth5rItdx7o43fPJuWBmARYnV0g/FFUeKeGOWVNYGq1hWWMNSxtrmF07EdNfmoSUgj4D/YND7D/Rw65UoO9s6+Jwx4WPH59TV8WqeXUsjVaztLGGhTOnUFlWGmDFMpq703b2IrvaulJ/j938v/eO8fO3PgSgekIZS6LVyeCP1rCksZrpkyuDLVokSzTqZpThYedw4jw7j3WzOxUK+06c49JQ8nOqn1yRCoNkqC9pqKF6Yllg9cqNGxwapvX0+WTwtyXD/8CpHoaGk3/XDTUTWNpYnQz+aA2Lo9VMqtC+keSnTxt1U/RBf7K7L9l6SYX6nrZuevoHAagqL2VJNPnVflljMthnTKnUV/wQuzgwxN7j3al/E93sOtbF0TO9QLIdN3/6JJZ+/G8i2e8vK1W/X4KnoE8515fsq4/01He1dXHqXLKvXlZqLJgx5eM9uGWNNcypn0RpiUK92J25MJDW8kn+AjhzYQCAipF+fyr4l0ZruE39fglAUQZ9/+AQH5zouexg6aFRffWlaS0Y9dUlUyP9/p3HulLtvW72tHdz8dIQkOz3L22sYVm0+uNvhPWTKwKuWsKuKIL+Qv8gW/aeTB0s7Wb/8XMMDCWHONZNSvbVR9ov6qtLto3u9+881s2Bk+dItfsv6/f/27tvU69fsu7Tgj40/9oGBof5Txt2UVVeyuJoNX+0ejbLUntTM6vVV5fcipSWsHDmFBbOnMK6lbcC0DswyN7j55I7H6lfAJv3nOTY2V7+6+OLA65Yiklo9ugBDnWcZ3Ztlfrqkree+r/beftQJ+/8+QNEdBBXsqho5rqZq4OnkuceWTyTzgsDvHvkTNClSBEJVdCL5Ls/uH06E8pKeWXPiaBLkSKioBcZRxPKS3lg4XReff8kg0OaD0nGh4JeZJypfSPjTUEvMs7UvpHxpqAXGWdq38h4U9CLBEDtGxlPCnqRAKh9I+NJQS8SgAnlpdy/cDpb9qp9I7mnoBcJyJcWzyRxXu0byT0FvUhA1L6R8aKgFwmI2jcyXhT0IgFS+0bGg4JeJEBq38h4UNCLBEjtGxkPCnqRgD2i9o3kmIJeJGD3qX0jOaagFwlYevtmaDi/rvgm4ZBR0JvZQ2Z2wMziZvb0GI/famZvmNkOM9ttZmtSy2tTy8+b2f/MdvEiYTHSvnnnSGfQpUgIXTPozawUeA54GFgErDezRaNW+wtgg7vfBawDfpZa3gf8F+B7WatYJIQ+bt/sVvtGsi+TPfqVQNzdD7v7APAi8NiodRyYkrpdDRwHcPcL7h4jGfgichVq30guZRL0DcCxtPttqWXpvg983czagM3At7NSnUgRUftGciVbB2PXAz939yiwBviFmWX83Gb2pJltNbOtHR0dWSpJpLCofSO5kkkYtwONafejqWXpvglsAHD3t4FKoC7TItz9eXdvdvfm+vr6TDcTCRW1byRXMgn694D5ZtZkZuUkD7a+PGqdo8ADAGa2kGTQa9dc5DqpfSO5cM2gd/dB4ClgC7Cf5OiavWb2rJk9mlrtu8C3zGwX8ALwDXd3ADP7EPgR8A0zaxtjxI6IpKh9I7kQyWQld99M8iBr+rJn0m7vA1ZdZdvZN1GfSFFJb988+9idlJZY0CVJCOjMWJE8o/aNZJuCXiTPjLRvNmvuG8kSBb1Inhlp37z6vkbfSHYo6EXykNo3kk0KepE8pPaNZJOCXiQPTSgv5f4Fat9IdijoRfLUI0vUvpHsUNCL5Cm1byRbFPQieUrtG8kWBb1IHlP7RrJBQS+Sx9S+kWxQ0IvkMbVvJBsU9CJ5bo1OnpKbpKAXyXP3LahX+0ZuioJeJM9NLI+ofSM3RUEvUgDUvpGboaAXKQBq38jNUNCLFIBP2jen1L6R66agFykQyfZNP+8eORN0KVJgFPQiBeK+BfVUlpXwyp7jQZciBUZBL1IgJpZHeGDBLWrfyHVT0IsUELVv5EYo6EUKiNo3ciMU9CIFRO0buREKepECo/aNXC8FvUiBUftGrpeCXqTA6OQpuV4KepEC9MjiWWrfSMYU9CIFaKR9o7lvJBMKepECNNK++a2mLpYMKOhFCpTaN5KpjILezB4yswNmFjezp8d4/FYze8PMdpjZbjNbk/bYn6W2O2BmX8xm8SLFTO0bydQ1g97MSoHngIeBRcB6M1s0arW/ADa4+13AOuBnqW0Xpe7fATwE/Cz1fCJyk9S+kUxlske/Eoi7+2F3HwBeBB4btY4DU1K3q4GRAb6PAS+6e7+7HwHiqecTkSxQ+0YykUnQNwDH0u63pZal+z7wdTNrAzYD376ObTGzJ81sq5lt7ejoyLB0EVH7RjKRrYOx64Gfu3sUWAP8wswyfm53f97dm929ub6+PksliYSf2jfh8c8HTvOvB3Ozo5tJGLcDjWn3o6ll6b4JbABw97eBSqAuw21F5CZo7ptw+Mt/bOUn/9Sak+fOJOjfA+abWZOZlZM8uPryqHWOAg8AmNlCkkHfkVpvnZlVmFkTMB94N1vFiwjcv2C62jcFrrv3Ervbulg1ry4nz3/NoHf3QeApYAuwn+Tomr1m9qyZPZpa7bvAt8xsF/AC8A1P2ktyT38f8CrwJ+4+lIs3IlKs1L4pfG8fTjDscO/83AR9JJOV3H0zyYOs6cueSbu9D1h1lW1/APzgJmoUkWtYs3gmm/ec5N0jZ/js3Nqgy5Hr1NKaYFJFhGWNNTl5fp0ZKxICat8Utlg8wd1zplFWmptIVtCLhIDaN4Xr2JlePursZXWO+vOgoBcJDY2+KUwtrQkAVs/P3dByBb1ISKh9U5hi8Q5mVlcyt74qZ6+hoBcJCbVvCs/QsPNmvJPV8+ows5y9joJeJERG2jfvfaj2TSF4v72b7ouXWJ2jYZUjFPQiITLSvnllt9o3hSAWT/bnc3Wi1AgFvUiIqH1TWFpaO1g0cwp1kypy+joKepGQUfumMPQODLLto7M5Oxs2nYJeJGTUvikM7xw5w6Uhz3l/HhT0IqEzsTzCfberfZPvYq0JyiMl/P7saTl/LQW9SAg9skTtm3wXa02wcvY0Kstyf3VVBb1ICKl9k99On+vjwKmecWnbgIJeJJTUvslvI8Mqczm/TToFvUhIqX2Tv2KtCWqrylk0c8q4vJ6CXiSkNPdNfnJ3YvEE98yro6Qkd9MepFPQi4TUSPtm8x61b/LJwVPnOd3Tz73j1LYBBb1IqOnkqfzT0toBMG4HYkFBLxJqat/knzfjCebUVzGrZsK4vaaCXiTEqirUvsknA4PDvHPkzLiNthmhoBcJObVv8sf2o2fpHRhS0ItIdql9kz9irQlKS4y759aO6+sq6EVCTu2b/NEST7CssYYplWXj+roKepEioPZN8Lp7L7GnrWvc2zagoBcpCmrfBO+tQwmGnXGZf340Bb1IERhp32jum+C0xBNMqoiwtLFm3F9bQS9SJNYsnklHj9o3QYm1Jrh7Ti1lpeMfuwp6kSJx/4LpVETUvgnC0c5ejp7pDaRtAwp6kaJRVaELhwelJT7+0x6kU9CLFBG1b4IRa00wq7qSOXVVgbx+RkFvZg+Z2QEzi5vZ02M8/mMz25n6OWhmXWmP/dDM3k/9/GE2ixeR66P2zfgbGnbeOtTJ6vl1mI3PtMSjXTPozawUeA54GFgErDezRenruPt33H2Zuy8Dfgq8lNr2EWA5sAz4DPA9MxufmfZF5Apq34y/Pe3ddF+8xOr59YHVkMke/Uog7u6H3X0AeBF47FPWXw+8kLq9CPhXdx909wvAbuChmylYRG7OSPvm3SNq34yHWGpa4lXjPO1BukyCvgE4lna/LbXsCmZ2G9AEvJ5atAt4yMwmmlkdcB/QOMZ2T5rZVjPb2tHRcT31i8h1emDhdCZVRHhpe1vQpRSFltYEd8yaQu2kisBqyPbB2HXAJncfAnD314DNwFsk9/LfBoZGb+Tuz7t7s7s319cH9/VGpBhMLI/wyOKZvLLnBBf6B4MuJ9Qu9A+y/ejZwEbbjMgk6Nu5fC88mlo2lnV80rYBwN1/kOrfPwgYcPBGChWR7FnbHKV3YIhXdFA2p949coZLQ86984Ldgc0k6N8D5ptZk5mVkwzzl0evZGYLgKkk99pHlpWaWW3q9hJgCfBaNgoXkRu34rapzKmrYtNWtW9yqaU1QUWkhObZUwOt45pB7+6DwFPAFmA/sMHd95rZs2b2aNqq64AX3T39UH4Z0GJm+4Dnga+nnk9EAmRmfHVFlHc/PMOHiQtBlxNasXgHK5umUVlWGmgdkUxWcvfNJHvt6cueGXX/+2Ns10dy5I2I5JmvLo/yP147wKZtbXzvi7cHXU7onDrXx8FT5/nq8mjQpejMWJFiNaO6ks/9Xj2btrVpTH0OxFoTQHDTHqRT0IsUsbUrGjl5ro9YPBF0KaHzZjxBbVU5C2cEf46ogl6kiH1h0XRqJpaxceuxa68sGXN3YvEE98yro6QkmGkP0inoRYpYRaSUx5c18Nq+U3T1DgRdTmgcPHWe0z393BvAZQPHoqAXKXJPrIgyMDjMy7uOB11KaLS0Bjst8WgKepEid2dDNQtnTmGjxtRnTSyeYE59FbNqJgRdCqCgFxFg7Yooe9q7+eDkuaBLKXj9g0O8c/hM3rRtQEEvIsDjdzVQVmraq8+C7R91cfHSUKDTEo+moBcRplWV84WFt/DrHe0MDA4HXU5Bi8U7KC0x7p4zLehSPqagFxEgOdFZ54UBXv/gdNClFLRYa4K7GmuYXFkWdCkfU9CLCACfm1/P9MkVbNqmMfU3qqt3gN3t3Xkz2maEgl5EAIiUlvDl5Q28caCD0z19QZdTkN461Ik73KugF5F8tXZFI0PDzq93XO2SE/JpWloTTK6IsDRaE3Qpl1HQi8jH5k2fxPJba9iwtY3LZxyXTMTiHdw9t5ZIaX5Fa35VIyKBW9vcSPz0eXYe6wq6lILyUecFjp25mHdtG1DQi8goX1oyk8qyEjZu05j669EyMi1xHp0oNUJBLyKXmVxZxpo7Z/Kbnce5ODAUdDkFI9aaoKFmAk11VUGXcgUFvYhc4YnmKD39g2zZezLoUgrC0LDz1qEEq+fVYRb8tMSjKehF5Ap3N9USnTqBjRpTn5HdbV2c6xvMu/HzIxT0InKFkhLjiRVR3jrUSdvZ3qDLyXux1gRmsCoP+/OgoBeRq3hiRfKi1r/cpjH119IST3DHrClMqyoPupQxKehFZEzRqRO5Z24tm7YfY1gXD7+qC/2D7Dh6ltXz8me2ytEU9CJyVWtXNHLszEV+d6Qz6FLy1jtHOrk05Hk5fn6Egl5EruqLd8xgckWETZqn/qpirZ1UREpYcdvUoEu5KgW9iFzVhPJSvrR0FpvfP0FP36Wgy8lLsXgHK5umUVlWGnQpV6WgF5FP9bXmKH2Xhnll94mgS8k7p871cfDU+bw8Gzadgl5EPtWyxhrmTZ/Ehq0aUz9abGTagzzuz4OCXkSuwcxYuyLK9qNdxE+fD7qcvBKLJ6itKmfhjClBl/KpFPQick1fXt5AaYmxSROdfczdicUTrJpXR0lJ/k17kE5BLyLXNH1yJffdXs9L29sYHNLFwwEOnOqho6c/79s2kGHQm9lDZnbAzOJm9vQYj//YzHamfg6aWVfaY//NzPaa2X4z+yvLxxl/ROSanljRyOme/o+n4y12I/35fB4/P+KaQW9mpcBzwMPAImC9mS1KX8fdv+Puy9x9GfBT4KXUtvcAq4AlwJ3A7wOfz+o7EJFxcf+C6UyrKtdB2ZSW1gRz66uYWT0h6FKuKZM9+pVA3N0Pu/sA8CLw2Kesvx54IXXbgUqgHKgAyoBTN16uiASlPFLC48sa+Mf9pzhzYSDocgLVPzjEO0c6uXd+/k57kC6ToG8A0n+Ft6WWXcHMbgOagNcB3P1t4A3gROpni7vvH2O7J81sq5lt7ejouL53ICLjZm1zlEtDzt/vLO6JzrZ9dJa+S8N5P35+RLYPxq4DNrn7EICZzQMWAlGSvxzuN7N7R2/k7s+7e7O7N9fXF8ZvSJFitHDmFBY3VLOxyKdEiLUmiJQYd8+tDbqUjGQS9O1AY9r9aGrZWNbxSdsG4MvA79z9vLufB34LfPZGChWR/LC2Ocq+E+d4v7076FICE4snuOvWGiZVRIIuJSOZBP17wHwzazKzcpJh/vLolcxsATAVeDtt8VHg82YWMbMykgdir2jdiEjheHTpLMpLS4p2TP3ZCwPsae/O62mJR7tm0Lv7IPAUsIVkSG9w971m9qyZPZq26jrgRXdPn7h6E3AI2APsAna5+2+yVr2IjLuaieU8eMct/HpnO/2DxXfx8LcOdeKe/9MepMvoe4e7bwY2j1r2zKj73x9juyHgj2+iPhHJQ19rbuSV3Sf4p/2nWbN4ZtDljKtYvIPJlRGWRquDLiVjOjNWRK7b6nl1zKyuZGORjal3d1paE3x2Ti2R0sKJz8KpVETyRmmJ8ZXlDfzLwQ5OdvcFXc64+aizl7azFwvibNh0CnoRuSFPrGhk2OGlHcVzULYlPjItceEciAUFvYjcoKa6KlbOnsamrW1cPgYjvGKtHTTUTGB27cSgS7kuCnoRuWFPNEc5nLjA9qNngy4l5waHhnnrUCf3zq+j0OZmVNCLyA17ZPFMJpaXsuG98Ldvdrd309M3WFDDKkco6EXkhlVVRFizeCb/sPs4vQODQZeTU7HWBGZwz1wFvYgUmbUrolwYGOK3e04GXUpOxeIJ7pg1hWlV5UGXct0U9CJyU1Y2TWN27UQ2bgvvmPoL/YPsOHq2oKY9SKegF5GbYmY8sSLK7w6f4Whnb9Dl5MQ7Rzq5NOQFN35+hIJeRG7aV5ZHMYNNId2rb2lNUBEpYcVtU4Mu5YYo6EXkps2qmcDqeXX8cns7w8PhG1Mfa02wsmkalWWlQZdyQxT0IpIVX2tupL3rIm8d6gy6lKw62d1H6+nzBdu2AQW9iGTJg4tuYUplJHQHZWMj0x4U6IFYUNCLSJZUlpXy2LIGXn3/JN0XLwVdTtbEWjuom1TOghmTgy7lhinoRSRr1jZH6R8c5je7jgddSla4O7F4J6vm1VFSUljTHqRT0ItI1ixuqOb2WyazMSSXGfzgZA+J8/2snle4/XlQ0ItIFpkZa5uj7DrWxcFTPUGXc9Nircn+/L0FNi3xaAp6EcmqL9/VQKTEQnH1qZZ4gnnTJzGjujLoUm6Kgl5Esqp2UgX3L5jOr3a0c2loOOhybljfpSHePdJZ8G0bUNCLSA6sbW4kcX6Afz7QEXQpN2z7R2fpuzRc0OPnRyjoRSTr/uD2euomVRR0+6YlniBSYnxmTm3Qpdw0Bb2IZF1ZaQlfWd7A6x+cJnG+P+hybkisNcHyW6cyqSISdCk3TUEvIjmxdkWUwWHn1zvagy7lup29MMD7x7sL8mpSY1HQi0hOzL9lMksba9hYgBcPf/NQAncU9CIi1/K15igHTvWwp7076FKuS6w1weTKCEsaqoMuJSsU9CKSM/9m6SwqIiVs3Fo4Z8q6Oy2tCe6ZW0ukNBwRGY53ISJ5aUplGQ/dOYO/39lO36WhoMvJyIedvbR3XQzF+PkRCnoRyam1Kxo51zfIa/tOBV1KRmKtybH/qwt82oN0CnoRyal75tbSUDOhYMbUx+IJGmomMLt2YtClZE1GQW9mD5nZATOLm9nTYzz+YzPbmfo5aGZdqeX3pS3faWZ9ZvZ4tt+EiOSvkhLjqyuixOIJjnddDLqcTzU4NMxbhzq5d34dZoU7LfFo1wx6MysFngMeBhYB681sUfo67v4dd1/m7suAnwIvpZa/kbb8fqAXeC3L70FE8tzaFVHc4aXt+X1Qdnd7Nz19g6EZVjkikz36lUDc3Q+7+wDwIvDYp6y/HnhhjOVPAL91997rL1NEClnjtIncPWcaG7fl95j6WGsCM1g1t/iCvgFIb661pZZdwcxuA5qA18d4eB1j/wLAzJ40s61mtrWjo3AnQRKRq1u7opGPOnt598iZoEu5qlhrgjtnVTO1qjzoUrIq2wdj1wGb3P2ycVRmNhNYDGwZayN3f97dm929ub4+PEe6ReQTDy+ewaSKSN5efep8/yDbj54NXdsGMgv6dqAx7X40tWwsV9tr/xrwK3cPzxWDReS6TCyP8KUlM9m85wTn+weDLucK7xzuZHDYuTdE4+dHZBL07wHzzazJzMpJhvnLo1cyswXAVODtMZ7jan17ESkia5uj9A4MsXn3iaBLuUJLa4LKshJWzJ4adClZd82gd/dB4CmSbZf9wAZ332tmz5rZo2mrrgNe9FFHWsxsNslvBP+SraJFpDAtv3Uqc+qr2Lgt/8bUx+IJVjbVUhEpDbqUrMtoomV33wxsHrXsmVH3v3+VbT/kKgdvRaS4mBlrVzTyw1c/4EjiAk11VUGXBMCJ7ovET5/nD5sbr71yAdKZsSIyrr6yvIESg015tFcfa00A4ZmWeDQFvYiMq1umVPL536vnl9vaGRrOjzH1sXiCukkVLJgxOehSckJBLyLjbm1zIyfP9dHSGvx5M8PDzpvxBKvn1YZq2oN0CnoRGXcPLJxOzcSyvBhT/8HJHhLnB0I1W+VohX/VWxEpOBWRUh5f1sDf/u4jHvxRsAPyevqSY/rDNP/8aAp6EQnEN1c3cbZ3gEtDw0GXwvzpk5lRXRl0GTmjoBeRQDROm8hP1t0VdBlFQT16EZGQU9CLiIScgl5EJOQU9CIiIaegFxEJOQW9iEjIKehFREJOQS8iEnKWb1dkN7MO4KObeIo6IJGlcgqdPovL6fO4nD6PT4Ths7jN3cecsCfvgv5mmdlWd28Ouo58oM/icvo8LqfP4xNh/yzUuhERCTkFvYhIyIUx6J8PuoA8os/icvo8LqfP4xOh/ixC16MXEZHLhXGPXkRE0ijoRURCLjRBb2YPmdkBM4ub2dNB1xMkM2s0szfMbJ+Z7TWzPw26pqCZWamZ7TCzfwi6lqCZWY2ZbTKzD8xsv5l9NuiagmRm30n9P3nfzF4ws9BdaioUQW9mpcBzwMPAImC9mS0KtqpADQLfdfdFwN3AnxT55wHwp8D+oIvIEz8BXnX3BcBSivhzMbMG4D8Aze5+J1AKrAu2quwLRdADK4G4ux929wHgReCxgGsKjLufcPftqds9JP8jNwRbVXDMLAo8AvxN0LUEzcyqgc8B/wvA3QfcvSvYqgIXASaYWQSYCBwPuJ6sC0vQNwDH0u63UcTBls7MZgN3Ae8EW0mg/hL4z0DwV6EOXhPQAfyfVCvrb8ysKuiiguLu7cB/B44CJ4Bud38t2KqyLyxBL2Mws0nAL4H/6O7ngq4nCGb2JeC0u28LupY8EQGWA3/t7ncBF4CiPaZlZlNJfvtvAmYBVWb29WCryr6wBH070Jh2P5paVrTMrIxkyP+du78UdD0BWgU8amYfkmzp3W9mfxtsSYFqA9rcfeQb3iaSwV+svgAccfcOd78EvATcE3BNWReWoH8PmG9mTWZWTvJgyssB1xQYMzOSPdj97v6joOsJkrv/mbtH3X02yX8Xr7t76PbYMuXuJ4FjZnZ7atEDwL4ASwraUeBuM5uY+n/zACE8OB0JuoBscPdBM3sK2ELyqPn/dve9AZcVpFXAvwP2mNnO1LI/d/fNAdYk+ePbwN+ldooOA38UcD2Bcfd3zGwTsJ3kaLUdhHA6BE2BICIScmFp3YiIyFUo6EVEQk5BLyIScgp6EZGQU9CLiIScgl5EJOQU9CIiIff/AWm+N7rRaXZ6AAAAAElFTkSuQmCC\n",
"%%time\n",
"text/plain": [
"\n",
"<Figure size 432x288 with 1 Axes>"
"batch_size = 32 # size of batch the inner NN is trained with\n",
"learning_rate = 1e-1 # fix learning rate\n",
"ds = \"MNIST\" # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100)\n",
"toy_size = 0.02 # total propeortion of training and test set we use\n",
"max_epochs = 100 # max number of epochs that is run if early stopping is not hit\n",
"early_stop_num = 10 # max number of worse validation scores before early stopping is triggered\n",
"early_stop_flag = True # implement early stopping or not\n",
"average_validation = [15,25] # if not implementing early stopping, what epochs are we averaging over\n",
"num_policies = 5 # fix number of policies\n",
"num_sub_policies = 5 # fix number of sub-policies in a policy\n",
"iterations = 100 # total iterations, should be more than the number of policies\n",
"IsLeNet = \"SimpleNet\" # using LeNet or EasyNet or SimpleNet\n",
"\n",
"# generate random policies at start\n",
"policies = generate_policies(num_policies, num_sub_policies)\n",
"\n",
"q_values, best_q_values = run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, early_stop_flag, average_validation, iterations, IsLeNet)\n",
"\n",
"plt.plot(best_q_values)\n",
"\n",
"best_q_values = np.array(best_q_values)\n",
"save('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), best_q_values)\n",
"#best_q_values = load('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), allow_pickle=True)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 342
},
"id": "doHUtJ_tEiA6",
"outputId": "8195ba17-c95f-4b75-d8dc-19c5d76d5e43"
},
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"100%|██████████| 10/10 [01:28<00:00, 8.84s/it]"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Iteration: 10,\tQ-Values: [0.77, 0.74, 0.8, 0.72, 0.77], Best Policy: 0.8\n",
"CPU times: user 1min 21s, sys: 694 ms, total: 1min 22s\n",
"Wall time: 1min 28s\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD7CAYAAABkO19ZAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAfGklEQVR4nO3da3BU95nn8e8jtS4gQAJJGFDLRlzGgM3FoCGOwcnYjis2ztpOYjJQm93KVCqeF+PsbDapLc/UrCvl2rzI1m4ymawzVa7Z3VRlZu0F4mScMTGuGXtm1LZjmzsGDGrABombWiAhEJKQ9OyLbtmNEKaBbp3u079PlYru0+d0P93AT6ef8z//Y+6OiIiEV0nQBYiISG4p6EVEQk5BLyIScgp6EZGQU9CLiIScgl5EJOQyCnoze8jMDphZ3MyeHuPxW83sDTPbYWa7zWxNavmDZrbNzPak/rw/229AREQ+nV1rHL2ZlQIHgQeBNuA9YL2770tb53lgh7v/tZktAja7+2wzuws45e7HzexOYIu7N+TqzYiIyJUiGayzEoi7+2EAM3sReAzYl7aOA1NSt6uB4wDuviNtnb3ABDOrcPf+q71YXV2dz549O+M3ICIisG3btoS714/1WCZB3wAcS7vfBnxm1DrfB14zs28DVcAXxnierwLbxwp5M3sSeBLg1ltvZevWrRmUJSIiI8zso6s9lq2DseuBn7t7FFgD/MLMPn5uM7sD+CHwx2Nt7O7Pu3uzuzfX14/5C0lERG5QJkHfDjSm3Y+mlqX7JrABwN3fBiqBOgAziwK/Av69ux+62YJFROT6ZBL07wHzzazJzMqBdcDLo9Y5CjwAYGYLSQZ9h5nVAK8AT7v7m9krW0REMnXNoHf3QeApYAuwH9jg7nvN7FkzezS12neBb5nZLuAF4BueHM7zFDAPeMbMdqZ+pufknYiIyJiuObxyvDU3N7sOxoqIXB8z2+buzWM9pjNjRURCTkEvIhJymYyjlwy4OyfP9bHrWBcHTp5naHg46JJklPoplSyL1nD7jMmUR7SPI8VDQX+Dui9eYk9bN7vauth5rItdx7o43fPJuWBmARYnV0g/FFUeKeGOWVNYGq1hWWMNSxtrmF07EdNfmoSUgj4D/YND7D/Rw65UoO9s6+Jwx4WPH59TV8WqeXUsjVaztLGGhTOnUFlWGmDFMpq703b2IrvaulJ/j938v/eO8fO3PgSgekIZS6LVyeCP1rCksZrpkyuDLVokSzTqZpThYedw4jw7j3WzOxUK+06c49JQ8nOqn1yRCoNkqC9pqKF6Yllg9cqNGxwapvX0+WTwtyXD/8CpHoaGk3/XDTUTWNpYnQz+aA2Lo9VMqtC+keSnTxt1U/RBf7K7L9l6SYX6nrZuevoHAagqL2VJNPnVflljMthnTKnUV/wQuzgwxN7j3al/E93sOtbF0TO9QLIdN3/6JJZ+/G8i2e8vK1W/X4KnoE8515fsq4/01He1dXHqXLKvXlZqLJgx5eM9uGWNNcypn0RpiUK92J25MJDW8kn+AjhzYQCAipF+fyr4l0ZruE39fglAUQZ9/+AQH5zouexg6aFRffWlaS0Y9dUlUyP9/p3HulLtvW72tHdz8dIQkOz3L22sYVm0+uNvhPWTKwKuWsKuKIL+Qv8gW/aeTB0s7Wb/8XMMDCWHONZNSvbVR9ov6qtLto3u9+881s2Bk+dItfsv6/f/27tvU69fsu7Tgj40/9oGBof5Txt2UVVeyuJoNX+0ejbLUntTM6vVV5fcipSWsHDmFBbOnMK6lbcC0DswyN7j55I7H6lfAJv3nOTY2V7+6+OLA65Yiklo9ugBDnWcZ3Ztlfrqkree+r/beftQJ+/8+QNEdBBXsqho5rqZq4OnkuceWTyTzgsDvHvkTNClSBEJVdCL5Ls/uH06E8pKeWXPiaBLkSKioBcZRxPKS3lg4XReff8kg0OaD0nGh4JeZJypfSPjTUEvMs7UvpHxpqAXGWdq38h4U9CLBEDtGxlPCnqRAKh9I+NJQS8SgAnlpdy/cDpb9qp9I7mnoBcJyJcWzyRxXu0byT0FvUhA1L6R8aKgFwmI2jcyXhT0IgFS+0bGg4JeJEBq38h4UNCLBEjtGxkPCnqRgD2i9o3kmIJeJGD3qX0jOaagFwlYevtmaDi/rvgm4ZBR0JvZQ2Z2wMziZvb0GI/famZvmNkOM9ttZmtSy2tTy8+b2f/MdvEiYTHSvnnnSGfQpUgIXTPozawUeA54GFgErDezRaNW+wtgg7vfBawDfpZa3gf8F+B7WatYJIQ+bt/sVvtGsi+TPfqVQNzdD7v7APAi8NiodRyYkrpdDRwHcPcL7h4jGfgichVq30guZRL0DcCxtPttqWXpvg983czagM3At7NSnUgRUftGciVbB2PXAz939yiwBviFmWX83Gb2pJltNbOtHR0dWSpJpLCofSO5kkkYtwONafejqWXpvglsAHD3t4FKoC7TItz9eXdvdvfm+vr6TDcTCRW1byRXMgn694D5ZtZkZuUkD7a+PGqdo8ADAGa2kGTQa9dc5DqpfSO5cM2gd/dB4ClgC7Cf5OiavWb2rJk9mlrtu8C3zGwX8ALwDXd3ADP7EPgR8A0zaxtjxI6IpKh9I7kQyWQld99M8iBr+rJn0m7vA1ZdZdvZN1GfSFFJb988+9idlJZY0CVJCOjMWJE8o/aNZJuCXiTPjLRvNmvuG8kSBb1Inhlp37z6vkbfSHYo6EXykNo3kk0KepE8pPaNZJOCXiQPTSgv5f4Fat9IdijoRfLUI0vUvpHsUNCL5Cm1byRbFPQieUrtG8kWBb1IHlP7RrJBQS+Sx9S+kWxQ0IvkMbVvJBsU9CJ5bo1OnpKbpKAXyXP3LahX+0ZuioJeJM9NLI+ofSM3RUEvUgDUvpGboaAXKQBq38jNUNCLFIBP2jen1L6R66agFykQyfZNP+8eORN0KVJgFPQiBeK+BfVUlpXwyp7jQZciBUZBL1IgJpZHeGDBLWrfyHVT0IsUELVv5EYo6EUKiNo3ciMU9CIFRO0buREKepECo/aNXC8FvUiBUftGrpeCXqTA6OQpuV4KepEC9MjiWWrfSMYU9CIFaKR9o7lvJBMKepECNNK++a2mLpYMKOhFCpTaN5KpjILezB4yswNmFjezp8d4/FYze8PMdpjZbjNbk/bYn6W2O2BmX8xm8SLFTO0bydQ1g97MSoHngIeBRcB6M1s0arW/ADa4+13AOuBnqW0Xpe7fATwE/Cz1fCJyk9S+kUxlske/Eoi7+2F3HwBeBB4btY4DU1K3q4GRAb6PAS+6e7+7HwHiqecTkSxQ+0YykUnQNwDH0u63pZal+z7wdTNrAzYD376ObTGzJ81sq5lt7ejoyLB0EVH7RjKRrYOx64Gfu3sUWAP8wswyfm53f97dm929ub6+PksliYSf2jfh8c8HTvOvB3Ozo5tJGLcDjWn3o6ll6b4JbABw97eBSqAuw21F5CZo7ptw+Mt/bOUn/9Sak+fOJOjfA+abWZOZlZM8uPryqHWOAg8AmNlCkkHfkVpvnZlVmFkTMB94N1vFiwjcv2C62jcFrrv3Ervbulg1ry4nz3/NoHf3QeApYAuwn+Tomr1m9qyZPZpa7bvAt8xsF/AC8A1P2ktyT38f8CrwJ+4+lIs3IlKs1L4pfG8fTjDscO/83AR9JJOV3H0zyYOs6cueSbu9D1h1lW1/APzgJmoUkWtYs3gmm/ec5N0jZ/js3Nqgy5Hr1NKaYFJFhGWNNTl5fp0ZKxICat8Utlg8wd1zplFWmptIVtCLhIDaN4Xr2JlePursZXWO+vOgoBcJDY2+KUwtrQkAVs/P3dByBb1ISKh9U5hi8Q5mVlcyt74qZ6+hoBcJCbVvCs/QsPNmvJPV8+ows5y9joJeJERG2jfvfaj2TSF4v72b7ouXWJ2jYZUjFPQiITLSvnllt9o3hSAWT/bnc3Wi1AgFvUiIqH1TWFpaO1g0cwp1kypy+joKepGQUfumMPQODLLto7M5Oxs2nYJeJGTUvikM7xw5w6Uhz3l/HhT0IqEzsTzCfberfZPvYq0JyiMl/P7saTl/LQW9SAg9skTtm3wXa02wcvY0Kstyf3VVBb1ICKl9k99On+vjwKmecWnbgIJeJJTUvslvI8Mqczm/TToFvUhIqX2Tv2KtCWqrylk0c8q4vJ6CXiSkNPdNfnJ3YvEE98yro6Qkd9MepFPQi4TUSPtm8x61b/LJwVPnOd3Tz73j1LYBBb1IqOnkqfzT0toBMG4HYkFBLxJqat/knzfjCebUVzGrZsK4vaaCXiTEqirUvsknA4PDvHPkzLiNthmhoBcJObVv8sf2o2fpHRhS0ItIdql9kz9irQlKS4y759aO6+sq6EVCTu2b/NEST7CssYYplWXj+roKepEioPZN8Lp7L7GnrWvc2zagoBcpCmrfBO+tQwmGnXGZf340Bb1IERhp32jum+C0xBNMqoiwtLFm3F9bQS9SJNYsnklHj9o3QYm1Jrh7Ti1lpeMfuwp6kSJx/4LpVETUvgnC0c5ejp7pDaRtAwp6kaJRVaELhwelJT7+0x6kU9CLFBG1b4IRa00wq7qSOXVVgbx+RkFvZg+Z2QEzi5vZ02M8/mMz25n6OWhmXWmP/dDM3k/9/GE2ixeR66P2zfgbGnbeOtTJ6vl1mI3PtMSjXTPozawUeA54GFgErDezRenruPt33H2Zuy8Dfgq8lNr2EWA5sAz4DPA9MxufmfZF5Apq34y/Pe3ddF+8xOr59YHVkMke/Uog7u6H3X0AeBF47FPWXw+8kLq9CPhXdx909wvAbuChmylYRG7OSPvm3SNq34yHWGpa4lXjPO1BukyCvgE4lna/LbXsCmZ2G9AEvJ5atAt4yMwmmlkdcB/QOMZ2T5rZVjPb2tHRcT31i8h1emDhdCZVRHhpe1vQpRSFltYEd8yaQu2kisBqyPbB2HXAJncfAnD314DNwFsk9/LfBoZGb+Tuz7t7s7s319cH9/VGpBhMLI/wyOKZvLLnBBf6B4MuJ9Qu9A+y/ejZwEbbjMgk6Nu5fC88mlo2lnV80rYBwN1/kOrfPwgYcPBGChWR7FnbHKV3YIhXdFA2p949coZLQ86984Ldgc0k6N8D5ptZk5mVkwzzl0evZGYLgKkk99pHlpWaWW3q9hJgCfBaNgoXkRu34rapzKmrYtNWtW9yqaU1QUWkhObZUwOt45pB7+6DwFPAFmA/sMHd95rZs2b2aNqq64AX3T39UH4Z0GJm+4Dnga+nnk9EAmRmfHVFlHc/PMOHiQtBlxNasXgHK5umUVlWGmgdkUxWcvfNJHvt6cueGXX/+2Ns10dy5I2I5JmvLo/yP147wKZtbXzvi7cHXU7onDrXx8FT5/nq8mjQpejMWJFiNaO6ks/9Xj2btrVpTH0OxFoTQHDTHqRT0IsUsbUrGjl5ro9YPBF0KaHzZjxBbVU5C2cEf46ogl6kiH1h0XRqJpaxceuxa68sGXN3YvEE98yro6QkmGkP0inoRYpYRaSUx5c18Nq+U3T1DgRdTmgcPHWe0z393BvAZQPHoqAXKXJPrIgyMDjMy7uOB11KaLS0Bjst8WgKepEid2dDNQtnTmGjxtRnTSyeYE59FbNqJgRdCqCgFxFg7Yooe9q7+eDkuaBLKXj9g0O8c/hM3rRtQEEvIsDjdzVQVmraq8+C7R91cfHSUKDTEo+moBcRplWV84WFt/DrHe0MDA4HXU5Bi8U7KC0x7p4zLehSPqagFxEgOdFZ54UBXv/gdNClFLRYa4K7GmuYXFkWdCkfU9CLCACfm1/P9MkVbNqmMfU3qqt3gN3t3Xkz2maEgl5EAIiUlvDl5Q28caCD0z19QZdTkN461Ik73KugF5F8tXZFI0PDzq93XO2SE/JpWloTTK6IsDRaE3Qpl1HQi8jH5k2fxPJba9iwtY3LZxyXTMTiHdw9t5ZIaX5Fa35VIyKBW9vcSPz0eXYe6wq6lILyUecFjp25mHdtG1DQi8goX1oyk8qyEjZu05j669EyMi1xHp0oNUJBLyKXmVxZxpo7Z/Kbnce5ODAUdDkFI9aaoKFmAk11VUGXcgUFvYhc4YnmKD39g2zZezLoUgrC0LDz1qEEq+fVYRb8tMSjKehF5Ap3N9USnTqBjRpTn5HdbV2c6xvMu/HzIxT0InKFkhLjiRVR3jrUSdvZ3qDLyXux1gRmsCoP+/OgoBeRq3hiRfKi1r/cpjH119IST3DHrClMqyoPupQxKehFZEzRqRO5Z24tm7YfY1gXD7+qC/2D7Dh6ltXz8me2ytEU9CJyVWtXNHLszEV+d6Qz6FLy1jtHOrk05Hk5fn6Egl5EruqLd8xgckWETZqn/qpirZ1UREpYcdvUoEu5KgW9iFzVhPJSvrR0FpvfP0FP36Wgy8lLsXgHK5umUVlWGnQpV6WgF5FP9bXmKH2Xhnll94mgS8k7p871cfDU+bw8Gzadgl5EPtWyxhrmTZ/Ehq0aUz9abGTagzzuz4OCXkSuwcxYuyLK9qNdxE+fD7qcvBKLJ6itKmfhjClBl/KpFPQick1fXt5AaYmxSROdfczdicUTrJpXR0lJ/k17kE5BLyLXNH1yJffdXs9L29sYHNLFwwEOnOqho6c/79s2kGHQm9lDZnbAzOJm9vQYj//YzHamfg6aWVfaY//NzPaa2X4z+yvLxxl/ROSanljRyOme/o+n4y12I/35fB4/P+KaQW9mpcBzwMPAImC9mS1KX8fdv+Puy9x9GfBT4KXUtvcAq4AlwJ3A7wOfz+o7EJFxcf+C6UyrKtdB2ZSW1gRz66uYWT0h6FKuKZM9+pVA3N0Pu/sA8CLw2Kesvx54IXXbgUqgHKgAyoBTN16uiASlPFLC48sa+Mf9pzhzYSDocgLVPzjEO0c6uXd+/k57kC6ToG8A0n+Ft6WWXcHMbgOagNcB3P1t4A3gROpni7vvH2O7J81sq5lt7ejouL53ICLjZm1zlEtDzt/vLO6JzrZ9dJa+S8N5P35+RLYPxq4DNrn7EICZzQMWAlGSvxzuN7N7R2/k7s+7e7O7N9fXF8ZvSJFitHDmFBY3VLOxyKdEiLUmiJQYd8+tDbqUjGQS9O1AY9r9aGrZWNbxSdsG4MvA79z9vLufB34LfPZGChWR/LC2Ocq+E+d4v7076FICE4snuOvWGiZVRIIuJSOZBP17wHwzazKzcpJh/vLolcxsATAVeDtt8VHg82YWMbMykgdir2jdiEjheHTpLMpLS4p2TP3ZCwPsae/O62mJR7tm0Lv7IPAUsIVkSG9w971m9qyZPZq26jrgRXdPn7h6E3AI2APsAna5+2+yVr2IjLuaieU8eMct/HpnO/2DxXfx8LcOdeKe/9MepMvoe4e7bwY2j1r2zKj73x9juyHgj2+iPhHJQ19rbuSV3Sf4p/2nWbN4ZtDljKtYvIPJlRGWRquDLiVjOjNWRK7b6nl1zKyuZGORjal3d1paE3x2Ti2R0sKJz8KpVETyRmmJ8ZXlDfzLwQ5OdvcFXc64+aizl7azFwvibNh0CnoRuSFPrGhk2OGlHcVzULYlPjItceEciAUFvYjcoKa6KlbOnsamrW1cPgYjvGKtHTTUTGB27cSgS7kuCnoRuWFPNEc5nLjA9qNngy4l5waHhnnrUCf3zq+j0OZmVNCLyA17ZPFMJpaXsuG98Ldvdrd309M3WFDDKkco6EXkhlVVRFizeCb/sPs4vQODQZeTU7HWBGZwz1wFvYgUmbUrolwYGOK3e04GXUpOxeIJ7pg1hWlV5UGXct0U9CJyU1Y2TWN27UQ2bgvvmPoL/YPsOHq2oKY9SKegF5GbYmY8sSLK7w6f4Whnb9Dl5MQ7Rzq5NOQFN35+hIJeRG7aV5ZHMYNNId2rb2lNUBEpYcVtU4Mu5YYo6EXkps2qmcDqeXX8cns7w8PhG1Mfa02wsmkalWWlQZdyQxT0IpIVX2tupL3rIm8d6gy6lKw62d1H6+nzBdu2AQW9iGTJg4tuYUplJHQHZWMj0x4U6IFYUNCLSJZUlpXy2LIGXn3/JN0XLwVdTtbEWjuom1TOghmTgy7lhinoRSRr1jZH6R8c5je7jgddSla4O7F4J6vm1VFSUljTHqRT0ItI1ixuqOb2WyazMSSXGfzgZA+J8/2snle4/XlQ0ItIFpkZa5uj7DrWxcFTPUGXc9Nircn+/L0FNi3xaAp6EcmqL9/VQKTEQnH1qZZ4gnnTJzGjujLoUm6Kgl5Esqp2UgX3L5jOr3a0c2loOOhybljfpSHePdJZ8G0bUNCLSA6sbW4kcX6Afz7QEXQpN2z7R2fpuzRc0OPnRyjoRSTr/uD2euomVRR0+6YlniBSYnxmTm3Qpdw0Bb2IZF1ZaQlfWd7A6x+cJnG+P+hybkisNcHyW6cyqSISdCk3TUEvIjmxdkWUwWHn1zvagy7lup29MMD7x7sL8mpSY1HQi0hOzL9lMksba9hYgBcPf/NQAncU9CIi1/K15igHTvWwp7076FKuS6w1weTKCEsaqoMuJSsU9CKSM/9m6SwqIiVs3Fo4Z8q6Oy2tCe6ZW0ukNBwRGY53ISJ5aUplGQ/dOYO/39lO36WhoMvJyIedvbR3XQzF+PkRCnoRyam1Kxo51zfIa/tOBV1KRmKtybH/qwt82oN0CnoRyal75tbSUDOhYMbUx+IJGmomMLt2YtClZE1GQW9mD5nZATOLm9nTYzz+YzPbmfo5aGZdqeX3pS3faWZ9ZvZ4tt+EiOSvkhLjqyuixOIJjnddDLqcTzU4NMxbhzq5d34dZoU7LfFo1wx6MysFngMeBhYB681sUfo67v4dd1/m7suAnwIvpZa/kbb8fqAXeC3L70FE8tzaFVHc4aXt+X1Qdnd7Nz19g6EZVjkikz36lUDc3Q+7+wDwIvDYp6y/HnhhjOVPAL91997rL1NEClnjtIncPWcaG7fl95j6WGsCM1g1t/iCvgFIb661pZZdwcxuA5qA18d4eB1j/wLAzJ40s61mtrWjo3AnQRKRq1u7opGPOnt598iZoEu5qlhrgjtnVTO1qjzoUrIq2wdj1wGb3P2ycVRmNhNYDGwZayN3f97dm929ub4+PEe6ReQTDy+ewaSKSN5efep8/yDbj54NXdsGMgv6dqAx7X40tWwsV9tr/xrwK3cPzxWDReS6TCyP8KUlM9m85wTn+weDLucK7xzuZHDYuTdE4+dHZBL07wHzzazJzMpJhvnLo1cyswXAVODtMZ7jan17ESkia5uj9A4MsXn3iaBLuUJLa4LKshJWzJ4adClZd82gd/dB4CmSbZf9wAZ332tmz5rZo2mrrgNe9FFHWsxsNslvBP+SraJFpDAtv3Uqc+qr2Lgt/8bUx+IJVjbVUhEpDbqUrMtoomV33wxsHrXsmVH3v3+VbT/kKgdvRaS4mBlrVzTyw1c/4EjiAk11VUGXBMCJ7ovET5/nD5sbr71yAdKZsSIyrr6yvIESg015tFcfa00A4ZmWeDQFvYiMq1umVPL536vnl9vaGRrOjzH1sXiCukkVLJgxOehSckJBLyLjbm1zIyfP9dHSGvx5M8PDzpvxBKvn1YZq2oN0CnoRGXcPLJxOzcSyvBhT/8HJHhLnB0I1W+VohX/VWxEpOBWRUh5f1sDf/u4jHvxRsAPyevqSY/rDNP/8aAp6EQnEN1c3cbZ3gEtDw0GXwvzpk5lRXRl0GTmjoBeRQDROm8hP1t0VdBlFQT16EZGQU9CLiIScgl5EJOQU9CIiIaegFxEJOQW9iEjIKehFREJOQS8iEnKWb1dkN7MO4KObeIo6IJGlcgqdPovL6fO4nD6PT4Ths7jN3cecsCfvgv5mmdlWd28Ouo58oM/icvo8LqfP4xNh/yzUuhERCTkFvYhIyIUx6J8PuoA8os/icvo8LqfP4xOh/ixC16MXEZHLhXGPXkRE0ijoRURCLjRBb2YPmdkBM4ub2dNB1xMkM2s0szfMbJ+Z7TWzPw26pqCZWamZ7TCzfwi6lqCZWY2ZbTKzD8xsv5l9NuiagmRm30n9P3nfzF4ws9BdaioUQW9mpcBzwMPAImC9mS0KtqpADQLfdfdFwN3AnxT55wHwp8D+oIvIEz8BXnX3BcBSivhzMbMG4D8Aze5+J1AKrAu2quwLRdADK4G4ux929wHgReCxgGsKjLufcPftqds9JP8jNwRbVXDMLAo8AvxN0LUEzcyqgc8B/wvA3QfcvSvYqgIXASaYWQSYCBwPuJ6sC0vQNwDH0u63UcTBls7MZgN3Ae8EW0mg/hL4z0DwV6EOXhPQAfyfVCvrb8ysKuiiguLu7cB/B44CJ4Bud38t2KqyLyxBL2Mws0nAL4H/6O7ngq4nCGb2JeC0u28LupY8EQGWA3/t7ncBF4CiPaZlZlNJfvtvAmYBVWb29WCryr6wBH070Jh2P5paVrTMrIxkyP+du78UdD0BWgU8amYfkmzp3W9mfxtsSYFqA9rcfeQb3iaSwV+svgAccfcOd78EvATcE3BNWReWoH8PmG9mTWZWTvJgyssB1xQYMzOSPdj97v6joOsJkrv/mbtH3X02yX8Xr7t76PbYMuXuJ4FjZnZ7atEDwL4ASwraUeBuM5uY+n/zACE8OB0JuoBscPdBM3sK2ELyqPn/dve9AZcVpFXAvwP2mNnO1LI/d/fNAdYk+ePbwN+ldooOA38UcD2Bcfd3zGwTsJ3kaLUdhHA6BE2BICIScmFp3YiIyFUo6EVEQk5BLyIScgp6EZGQU9CLiIScgl5EJOQU9CIiIff/AWm+N7rRaXZ6AAAAAElFTkSuQmCC\n"
},
"metadata": {
"needs_background": "light"
}
}
]
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
}
]
],
}
"source": [
\ No newline at end of file
"%%time\n",
"\n",
"batch_size = 32 # size of batch the inner NN is trained with\n",
"learning_rate = 1e-1 # fix learning rate\n",
"ds = \"MNIST\" # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100)\n",
"toy_size = 1 # total propeortion of training and test set we use\n",
"max_epochs = 100 # max number of epochs that is run if early stopping is not hit\n",
"early_stop_num = 10 # max number of worse validation scores before early stopping is triggered\n",
"early_stop_flag = True # implement early stopping or not\n",
"average_validation = [15,25] # if not implementing early stopping, what epochs are we averaging over\n",
"num_policies = 5 # fix number of policies\n",
"num_sub_policies = 5 # fix number of sub-policies in a policy\n",
"iterations = 100 # total iterations, should be more than the number of policies\n",
"IsLeNet = \"SimpleNet\" # using LeNet or EasyNet or SimpleNet\n",
"\n",
"# generate random policies at start\n",
"policies = generate_policies(num_policies, num_sub_policies)\n",
"\n",
"q_values, best_q_values = run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, early_stop_flag, average_validation, iterations, IsLeNet)\n",
"\n",
"plt.plot(best_q_values)\n",
"\n",
"best_q_values = np.array(best_q_values)\n",
"save('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), best_q_values)\n",
"#best_q_values = load('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), allow_pickle=True)"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "UCB1.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
```
python
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
torch.optim
as
optim
import
torch.optim
as
optim
import
torch.utils.data
as
data_utils
import
torch.utils.data
as
data_utils
import
torchvision
import
torchvision
import
torchvision.datasets
as
datasets
import
torchvision.datasets
as
datasets
from
matplotlib
import
pyplot
as
plt
from
matplotlib
import
pyplot
as
plt
from
numpy
import
save
,
load
from
numpy
import
save
,
load
from
tqdm
import
trange
from
tqdm
import
trange
```
```
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
```
python
"""
Define internal NN module that trains on the dataset
"""
"""
Define internal NN module that trains on the dataset
"""
class
LeNet
(
nn
.
Module
):
class
LeNet
(
nn
.
Module
):
def
__init__
(
self
,
img_height
,
img_width
,
num_labels
,
img_channels
):
def
__init__
(
self
,
img_height
,
img_width
,
num_labels
,
img_channels
):
super
().
__init__
()
super
().
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
img_channels
,
6
,
5
)
self
.
conv1
=
nn
.
Conv2d
(
img_channels
,
6
,
5
)
self
.
relu1
=
nn
.
ReLU
()
self
.
relu1
=
nn
.
ReLU
()
self
.
pool1
=
nn
.
MaxPool2d
(
2
)
self
.
pool1
=
nn
.
MaxPool2d
(
2
)
self
.
conv2
=
nn
.
Conv2d
(
6
,
16
,
5
)
self
.
conv2
=
nn
.
Conv2d
(
6
,
16
,
5
)
self
.
relu2
=
nn
.
ReLU
()
self
.
relu2
=
nn
.
ReLU
()
self
.
pool2
=
nn
.
MaxPool2d
(
2
)
self
.
pool2
=
nn
.
MaxPool2d
(
2
)
self
.
fc1
=
nn
.
Linear
(
int
((((
img_height
-
4
)
/
2
-
4
)
/
2
)
*
(((
img_width
-
4
)
/
2
-
4
)
/
2
)
*
16
),
120
)
self
.
fc1
=
nn
.
Linear
(
int
((((
img_height
-
4
)
/
2
-
4
)
/
2
)
*
(((
img_width
-
4
)
/
2
-
4
)
/
2
)
*
16
),
120
)
self
.
relu3
=
nn
.
ReLU
()
self
.
relu3
=
nn
.
ReLU
()
self
.
fc2
=
nn
.
Linear
(
120
,
84
)
self
.
fc2
=
nn
.
Linear
(
120
,
84
)
self
.
relu4
=
nn
.
ReLU
()
self
.
relu4
=
nn
.
ReLU
()
self
.
fc3
=
nn
.
Linear
(
84
,
num_labels
)
self
.
fc3
=
nn
.
Linear
(
84
,
num_labels
)
self
.
relu5
=
nn
.
ReLU
()
self
.
relu5
=
nn
.
ReLU
()
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
y
=
self
.
conv1
(
x
)
y
=
self
.
conv1
(
x
)
y
=
self
.
relu1
(
y
)
y
=
self
.
relu1
(
y
)
y
=
self
.
pool1
(
y
)
y
=
self
.
pool1
(
y
)
y
=
self
.
conv2
(
y
)
y
=
self
.
conv2
(
y
)
y
=
self
.
relu2
(
y
)
y
=
self
.
relu2
(
y
)
y
=
self
.
pool2
(
y
)
y
=
self
.
pool2
(
y
)
y
=
y
.
view
(
y
.
shape
[
0
],
-
1
)
y
=
y
.
view
(
y
.
shape
[
0
],
-
1
)
y
=
self
.
fc1
(
y
)
y
=
self
.
fc1
(
y
)
y
=
self
.
relu3
(
y
)
y
=
self
.
relu3
(
y
)
y
=
self
.
fc2
(
y
)
y
=
self
.
fc2
(
y
)
y
=
self
.
relu4
(
y
)
y
=
self
.
relu4
(
y
)
y
=
self
.
fc3
(
y
)
y
=
self
.
fc3
(
y
)
y
=
self
.
relu5
(
y
)
y
=
self
.
relu5
(
y
)
return
y
return
y
```
```
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
```
python
"""
Define internal NN module that trains on the dataset
"""
"""
Define internal NN module that trains on the dataset
"""
class
EasyNet
(
nn
.
Module
):
class
EasyNet
(
nn
.
Module
):
def
__init__
(
self
,
img_height
,
img_width
,
num_labels
,
img_channels
):
def
__init__
(
self
,
img_height
,
img_width
,
num_labels
,
img_channels
):
super
().
__init__
()
super
().
__init__
()
self
.
fc1
=
nn
.
Linear
(
img_height
*
img_width
*
img_channels
,
2048
)
self
.
fc1
=
nn
.
Linear
(
img_height
*
img_width
*
img_channels
,
2048
)
self
.
relu1
=
nn
.
ReLU
()
self
.
relu1
=
nn
.
ReLU
()
self
.
fc2
=
nn
.
Linear
(
2048
,
num_labels
)
self
.
fc2
=
nn
.
Linear
(
2048
,
num_labels
)
self
.
relu2
=
nn
.
ReLU
()
self
.
relu2
=
nn
.
ReLU
()
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
y
=
x
.
view
(
x
.
shape
[
0
],
-
1
)
y
=
x
.
view
(
x
.
shape
[
0
],
-
1
)
y
=
self
.
fc1
(
y
)
y
=
self
.
fc1
(
y
)
y
=
self
.
relu1
(
y
)
y
=
self
.
relu1
(
y
)
y
=
self
.
fc2
(
y
)
y
=
self
.
fc2
(
y
)
y
=
self
.
relu2
(
y
)
y
=
self
.
relu2
(
y
)
return
y
return
y
```
```
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
```
python
"""
Define internal NN module that trains on the dataset
"""
"""
Define internal NN module that trains on the dataset
"""
class
SimpleNet
(
nn
.
Module
):
class
SimpleNet
(
nn
.
Module
):
def
__init__
(
self
,
img_height
,
img_width
,
num_labels
,
img_channels
):
def
__init__
(
self
,
img_height
,
img_width
,
num_labels
,
img_channels
):
super
().
__init__
()
super
().
__init__
()
self
.
fc1
=
nn
.
Linear
(
img_height
*
img_width
*
img_channels
,
num_labels
)
self
.
fc1
=
nn
.
Linear
(
img_height
*
img_width
*
img_channels
,
num_labels
)
self
.
relu1
=
nn
.
ReLU
()
self
.
relu1
=
nn
.
ReLU
()
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
y
=
x
.
view
(
x
.
shape
[
0
],
-
1
)
y
=
x
.
view
(
x
.
shape
[
0
],
-
1
)
y
=
self
.
fc1
(
y
)
y
=
self
.
fc1
(
y
)
y
=
self
.
relu1
(
y
)
y
=
self
.
relu1
(
y
)
return
y
return
y
```
```
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
```
python
"""
Make toy dataset
"""
"""
Make toy dataset
"""
def
create_toy
(
train_dataset
,
test_dataset
,
batch_size
,
n_samples
):
def
create_toy
(
train_dataset
,
test_dataset
,
batch_size
,
n_samples
):
# shuffle and take first n_samples %age of training dataset
# shuffle and take first n_samples %age of training dataset
shuffle_order_train
=
np
.
random
.
RandomState
(
seed
=
100
).
permutation
(
len
(
train_dataset
))
shuffle_order_train
=
np
.
random
.
RandomState
(
seed
=
100
).
permutation
(
len
(
train_dataset
))
shuffled_train_dataset
=
torch
.
utils
.
data
.
Subset
(
train_dataset
,
shuffle_order_train
)
shuffled_train_dataset
=
torch
.
utils
.
data
.
Subset
(
train_dataset
,
shuffle_order_train
)
indices_train
=
torch
.
arange
(
int
(
n_samples
*
len
(
train_dataset
)))
indices_train
=
torch
.
arange
(
int
(
n_samples
*
len
(
train_dataset
)))
reduced_train_dataset
=
data_utils
.
Subset
(
shuffled_train_dataset
,
indices_train
)
reduced_train_dataset
=
data_utils
.
Subset
(
shuffled_train_dataset
,
indices_train
)
# shuffle and take first n_samples %age of test dataset
# shuffle and take first n_samples %age of test dataset
shuffle_order_test
=
np
.
random
.
RandomState
(
seed
=
1000
).
permutation
(
len
(
test_dataset
))
shuffle_order_test
=
np
.
random
.
RandomState
(
seed
=
1000
).
permutation
(
len
(
test_dataset
))
shuffled_test_dataset
=
torch
.
utils
.
data
.
Subset
(
test_dataset
,
shuffle_order_test
)
shuffled_test_dataset
=
torch
.
utils
.
data
.
Subset
(
test_dataset
,
shuffle_order_test
)
indices_test
=
torch
.
arange
(
int
(
n_samples
*
len
(
test_dataset
)))
indices_test
=
torch
.
arange
(
int
(
n_samples
*
len
(
test_dataset
)))
reduced_test_dataset
=
data_utils
.
Subset
(
shuffled_test_dataset
,
indices_test
)
reduced_test_dataset
=
data_utils
.
Subset
(
shuffled_test_dataset
,
indices_test
)
# push into DataLoader
# push into DataLoader
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
reduced_train_dataset
,
batch_size
=
batch_size
)
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
reduced_train_dataset
,
batch_size
=
batch_size
)
test_loader
=
torch
.
utils
.
data
.
DataLoader
(
reduced_test_dataset
,
batch_size
=
batch_size
)
test_loader
=
torch
.
utils
.
data
.
DataLoader
(
reduced_test_dataset
,
batch_size
=
batch_size
)
return
train_loader
,
test_loader
return
train_loader
,
test_loader
```
```
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
```
python
"""
Randomly generate 10 policies
"""
"""
Randomly generate 10 policies
"""
"""
Each policy has 5 sub-policies
"""
"""
Each policy has 5 sub-policies
"""
"""
For each sub-policy, pick 2 transformations, 2 probabilities and 2 magnitudes
"""
"""
For each sub-policy, pick 2 transformations, 2 probabilities and 2 magnitudes
"""
def
generate_policies
(
num_policies
,
num_sub_policies
):
def
generate_policies
(
num_policies
,
num_sub_policies
):
policies
=
np
.
zeros
([
num_policies
,
num_sub_policies
,
6
])
policies
=
np
.
zeros
([
num_policies
,
num_sub_policies
,
6
])
# Policies array will be 10x5x6
# Policies array will be 10x5x6
for
policy
in
range
(
num_policies
):
for
policy
in
range
(
num_policies
):
for
sub_policy
in
range
(
num_sub_policies
):
for
sub_policy
in
range
(
num_sub_policies
):
# pick two sub_policy transformations (0=rotate, 1=shear, 2=scale)
# pick two sub_policy transformations (0=rotate, 1=shear, 2=scale)
policies
[
policy
,
sub_policy
,
0
]
=
np
.
random
.
randint
(
0
,
3
)
policies
[
policy
,
sub_policy
,
0
]
=
np
.
random
.
randint
(
0
,
3
)
policies
[
policy
,
sub_policy
,
1
]
=
np
.
random
.
randint
(
0
,
3
)
policies
[
policy
,
sub_policy
,
1
]
=
np
.
random
.
randint
(
0
,
3
)
while
policies
[
policy
,
sub_policy
,
0
]
==
policies
[
policy
,
sub_policy
,
1
]:
while
policies
[
policy
,
sub_policy
,
0
]
==
policies
[
policy
,
sub_policy
,
1
]:
policies
[
policy
,
sub_policy
,
1
]
=
np
.
random
.
randint
(
0
,
3
)
policies
[
policy
,
sub_policy
,
1
]
=
np
.
random
.
randint
(
0
,
3
)
# pick probabilities
# pick probabilities
policies
[
policy
,
sub_policy
,
2
]
=
np
.
random
.
randint
(
0
,
11
)
/
10
policies
[
policy
,
sub_policy
,
2
]
=
np
.
random
.
randint
(
0
,
11
)
/
10
policies
[
policy
,
sub_policy
,
3
]
=
np
.
random
.
randint
(
0
,
11
)
/
10
policies
[
policy
,
sub_policy
,
3
]
=
np
.
random
.
randint
(
0
,
11
)
/
10
# pick magnitudes
# pick magnitudes
for
transformation
in
range
(
2
):
for
transformation
in
range
(
2
):
if
policies
[
policy
,
sub_policy
,
transformation
]
<=
1
:
if
policies
[
policy
,
sub_policy
,
transformation
]
<=
1
:
policies
[
policy
,
sub_policy
,
transformation
+
4
]
=
np
.
random
.
randint
(
-
4
,
5
)
*
5
policies
[
policy
,
sub_policy
,
transformation
+
4
]
=
np
.
random
.
randint
(
-
4
,
5
)
*
5
elif
policies
[
policy
,
sub_policy
,
transformation
]
==
2
:
elif
policies
[
policy
,
sub_policy
,
transformation
]
==
2
:
policies
[
policy
,
sub_policy
,
transformation
+
4
]
=
np
.
random
.
randint
(
5
,
15
)
/
10
policies
[
policy
,
sub_policy
,
transformation
+
4
]
=
np
.
random
.
randint
(
5
,
15
)
/
10
return
policies
return
policies
```
```
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
```
python
"""
Pick policy and sub-policy
"""
"""
Pick policy and sub-policy
"""
"""
Each row of data should have a different sub-policy but for now, this will do
"""
"""
Each row of data should have a different sub-policy but for now, this will do
"""
def
sample_sub_policy
(
policies
,
policy
,
num_sub_policies
):
def
sample_sub_policy
(
policies
,
policy
,
num_sub_policies
):
sub_policy
=
np
.
random
.
randint
(
0
,
num_sub_policies
)
sub_policy
=
np
.
random
.
randint
(
0
,
num_sub_policies
)
degrees
=
0
degrees
=
0
shear
=
0
shear
=
0
scale
=
1
scale
=
1
# check for rotations
# check for rotations
if
policies
[
policy
,
sub_policy
][
0
]
==
0
:
if
policies
[
policy
,
sub_policy
][
0
]
==
0
:
if
np
.
random
.
uniform
()
<
policies
[
policy
,
sub_policy
][
2
]:
if
np
.
random
.
uniform
()
<
policies
[
policy
,
sub_policy
][
2
]:
degrees
=
policies
[
policy
,
sub_policy
][
4
]
degrees
=
policies
[
policy
,
sub_policy
][
4
]
elif
policies
[
policy
,
sub_policy
][
1
]
==
0
:
elif
policies
[
policy
,
sub_policy
][
1
]
==
0
:
if
np
.
random
.
uniform
()
<
policies
[
policy
,
sub_policy
][
3
]:
if
np
.
random
.
uniform
()
<
policies
[
policy
,
sub_policy
][
3
]:
degrees
=
policies
[
policy
,
sub_policy
][
5
]
degrees
=
policies
[
policy
,
sub_policy
][
5
]
# check for shears
# check for shears
if
policies
[
policy
,
sub_policy
][
0
]
==
1
:
if
policies
[
policy
,
sub_policy
][
0
]
==
1
:
if
np
.
random
.
uniform
()
<
policies
[
policy
,
sub_policy
][
2
]:
if
np
.
random
.
uniform
()
<
policies
[
policy
,
sub_policy
][
2
]:
shear
=
policies
[
policy
,
sub_policy
][
4
]
shear
=
policies
[
policy
,
sub_policy
][
4
]
elif
policies
[
policy
,
sub_policy
][
1
]
==
1
:
elif
policies
[
policy
,
sub_policy
][
1
]
==
1
:
if
np
.
random
.
uniform
()
<
policies
[
policy
,
sub_policy
][
3
]:
if
np
.
random
.
uniform
()
<
policies
[
policy
,
sub_policy
][
3
]:
shear
=
policies
[
policy
,
sub_policy
][
5
]
shear
=
policies
[
policy
,
sub_policy
][
5
]
# check for scales
# check for scales
if
policies
[
policy
,
sub_policy
][
0
]
==
2
:
if
policies
[
policy
,
sub_policy
][
0
]
==
2
:
if
np
.
random
.
uniform
()
<
policies
[
policy
,
sub_policy
][
2
]:
if
np
.
random
.
uniform
()
<
policies
[
policy
,
sub_policy
][
2
]:
scale
=
policies
[
policy
,
sub_policy
][
4
]
scale
=
policies
[
policy
,
sub_policy
][
4
]
elif
policies
[
policy
,
sub_policy
][
1
]
==
2
:
elif
policies
[
policy
,
sub_policy
][
1
]
==
2
:
if
np
.
random
.
uniform
()
<
policies
[
policy
,
sub_policy
][
3
]:
if
np
.
random
.
uniform
()
<
policies
[
policy
,
sub_policy
][
3
]:
scale
=
policies
[
policy
,
sub_policy
][
5
]
scale
=
policies
[
policy
,
sub_policy
][
5
]
return
degrees
,
shear
,
scale
return
degrees
,
shear
,
scale
```
```
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
```
python
"""
Sample policy, open and apply above transformations
"""
"""
Sample policy, open and apply above transformations
"""
def
run_UCB1
(
policies
,
batch_size
,
learning_rate
,
ds
,
toy_size
,
max_epochs
,
early_stop_num
,
early_stop_flag
,
average_validation
,
iterations
,
IsLeNet
):
def
run_UCB1
(
policies
,
batch_size
,
learning_rate
,
ds
,
toy_size
,
max_epochs
,
early_stop_num
,
early_stop_flag
,
average_validation
,
iterations
,
IsLeNet
):
# get number of policies and sub-policies
# get number of policies and sub-policies
num_policies
=
len
(
policies
)
num_policies
=
len
(
policies
)
num_sub_policies
=
len
(
policies
[
0
])
num_sub_policies
=
len
(
policies
[
0
])
#Initialize vector weights, counts and regret
#Initialize vector weights, counts and regret
q_values
=
[
0
]
*
num_policies
q_values
=
[
0
]
*
num_policies
cnts
=
[
0
]
*
num_policies
cnts
=
[
0
]
*
num_policies
q_plus_cnt
=
[
0
]
*
num_policies
q_plus_cnt
=
[
0
]
*
num_policies
total_count
=
0
total_count
=
0
best_q_values
=
[]
best_q_values
=
[]
for
policy
in
trange
(
iterations
):
for
policy
in
trange
(
iterations
):
# get the action to try (either initially in order or using best q_plus_cnt value)
# get the action to try (either initially in order or using best q_plus_cnt value)
if
policy
>=
num_policies
:
if
policy
>=
num_policies
:
this_policy
=
np
.
argmax
(
q_plus_cnt
)
this_policy
=
np
.
argmax
(
q_plus_cnt
)
else
:
else
:
this_policy
=
policy
this_policy
=
policy
# get info of transformation for this sub-policy
# get info of transformation for this sub-policy
degrees
,
shear
,
scale
=
sample_sub_policy
(
policies
,
this_policy
,
num_sub_policies
)
degrees
,
shear
,
scale
=
sample_sub_policy
(
policies
,
this_policy
,
num_sub_policies
)
# create transformations using above info
# create transformations using above info
transform
=
torchvision
.
transforms
.
Compose
(
transform
=
torchvision
.
transforms
.
Compose
(
[
torchvision
.
transforms
.
RandomAffine
(
degrees
=
(
degrees
,
degrees
),
shear
=
(
shear
,
shear
),
scale
=
(
scale
,
scale
)),
[
torchvision
.
transforms
.
RandomAffine
(
degrees
=
(
degrees
,
degrees
),
shear
=
(
shear
,
shear
),
scale
=
(
scale
,
scale
)),
torchvision
.
transforms
.
ToTensor
()])
torchvision
.
transforms
.
ToTensor
()])
# open data and apply these transformations
# open data and apply these transformations
if
ds
==
"
MNIST
"
:
if
ds
==
"
MNIST
"
:
train_dataset
=
datasets
.
MNIST
(
root
=
'
./MetaAugment/train
'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
train_dataset
=
datasets
.
MNIST
(
root
=
'
./MetaAugment/train
'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
test_dataset
=
datasets
.
MNIST
(
root
=
'
./MetaAugment/test
'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
test_dataset
=
datasets
.
MNIST
(
root
=
'
./MetaAugment/test
'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
elif
ds
==
"
KMNIST
"
:
elif
ds
==
"
KMNIST
"
:
train_dataset
=
datasets
.
KMNIST
(
root
=
'
./MetaAugment/train
'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
train_dataset
=
datasets
.
KMNIST
(
root
=
'
./MetaAugment/train
'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
test_dataset
=
datasets
.
KMNIST
(
root
=
'
./MetaAugment/test
'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
test_dataset
=
datasets
.
KMNIST
(
root
=
'
./MetaAugment/test
'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
elif
ds
==
"
FashionMNIST
"
:
elif
ds
==
"
FashionMNIST
"
:
train_dataset
=
datasets
.
FashionMNIST
(
root
=
'
./MetaAugment/train
'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
train_dataset
=
datasets
.
FashionMNIST
(
root
=
'
./MetaAugment/train
'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
test_dataset
=
datasets
.
FashionMNIST
(
root
=
'
./MetaAugment/test
'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
test_dataset
=
datasets
.
FashionMNIST
(
root
=
'
./MetaAugment/test
'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
elif
ds
==
"
CIFAR10
"
:
elif
ds
==
"
CIFAR10
"
:
train_dataset
=
datasets
.
CIFAR10
(
root
=
'
./MetaAugment/train
'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
train_dataset
=
datasets
.
CIFAR10
(
root
=
'
./MetaAugment/train
'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
test_dataset
=
datasets
.
CIFAR10
(
root
=
'
./MetaAugment/test
'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
test_dataset
=
datasets
.
CIFAR10
(
root
=
'
./MetaAugment/test
'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
elif
ds
==
"
CIFAR100
"
:
elif
ds
==
"
CIFAR100
"
:
train_dataset
=
datasets
.
CIFAR100
(
root
=
'
./MetaAugment/train
'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
train_dataset
=
datasets
.
CIFAR100
(
root
=
'
./MetaAugment/train
'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
test_dataset
=
datasets
.
CIFAR100
(
root
=
'
./MetaAugment/test
'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
test_dataset
=
datasets
.
CIFAR100
(
root
=
'
./MetaAugment/test
'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
# check sizes of images
# check sizes of images
img_height
=
len
(
train_dataset
[
0
][
0
][
0
])
img_height
=
len
(
train_dataset
[
0
][
0
][
0
])
img_width
=
len
(
train_dataset
[
0
][
0
][
0
][
0
])
img_width
=
len
(
train_dataset
[
0
][
0
][
0
][
0
])
img_channels
=
len
(
train_dataset
[
0
][
0
])
img_channels
=
len
(
train_dataset
[
0
][
0
])
# check output labels
# check output labels
if
ds
==
"
CIFAR10
"
or
ds
==
"
CIFAR100
"
:
if
ds
==
"
CIFAR10
"
or
ds
==
"
CIFAR100
"
:
num_labels
=
(
max
(
train_dataset
.
targets
)
-
min
(
train_dataset
.
targets
)
+
1
)
num_labels
=
(
max
(
train_dataset
.
targets
)
-
min
(
train_dataset
.
targets
)
+
1
)
else
:
else
:
num_labels
=
(
max
(
train_dataset
.
targets
)
-
min
(
train_dataset
.
targets
)
+
1
).
item
()
num_labels
=
(
max
(
train_dataset
.
targets
)
-
min
(
train_dataset
.
targets
)
+
1
).
item
()
# create toy dataset from above uploaded data
# create toy dataset from above uploaded data
train_loader
,
test_loader
=
create_toy
(
train_dataset
,
test_dataset
,
batch_size
,
toy_size
)
train_loader
,
test_loader
=
create_toy
(
train_dataset
,
test_dataset
,
batch_size
,
toy_size
)
# create model
# create model
device
=
'
cuda
'
if
torch
.
cuda
.
is_available
()
else
'
cpu
'
device
=
'
cuda
'
if
torch
.
cuda
.
is_available
()
else
'
cpu
'
if
IsLeNet
==
"
LeNet
"
:
if
IsLeNet
==
"
LeNet
"
:
model
=
LeNet
(
img_height
,
img_width
,
num_labels
,
img_channels
).
to
(
device
)
# added .to(device)
model
=
LeNet
(
img_height
,
img_width
,
num_labels
,
img_channels
).
to
(
device
)
# added .to(device)
elif
IsLeNet
==
"
EasyNet
"
:
elif
IsLeNet
==
"
EasyNet
"
:
model
=
EasyNet
(
img_height
,
img_width
,
num_labels
,
img_channels
).
to
(
device
)
# added .to(device)
model
=
EasyNet
(
img_height
,
img_width
,
num_labels
,
img_channels
).
to
(
device
)
# added .to(device)
else
:
else
:
model
=
SimpleNet
(
img_height
,
img_width
,
num_labels
,
img_channels
).
to
(
device
)
# added .to(device)
model
=
SimpleNet
(
img_height
,
img_width
,
num_labels
,
img_channels
).
to
(
device
)
# added .to(device)
sgd
=
optim
.
SGD
(
model
.
parameters
(),
lr
=
1e-1
)
sgd
=
optim
.
SGD
(
model
.
parameters
(),
lr
=
1e-1
)
cost
=
nn
.
CrossEntropyLoss
()
cost
=
nn
.
CrossEntropyLoss
()
# set variables for best validation accuracy and early stop count
# set variables for best validation accuracy and early stop count
best_acc
=
0
best_acc
=
0
early_stop_cnt
=
0
early_stop_cnt
=
0
total_val
=
0
total_val
=
0
# train model and check validation accuracy each epoch
# train model and check validation accuracy each epoch
for
_epoch
in
range
(
max_epochs
):
for
_epoch
in
range
(
max_epochs
):
# train model
# train model
model
.
train
()
model
.
train
()
for
idx
,
(
train_x
,
train_label
)
in
enumerate
(
train_loader
):
for
idx
,
(
train_x
,
train_label
)
in
enumerate
(
train_loader
):
train_x
,
train_label
=
train_x
.
to
(
device
),
train_label
.
to
(
device
)
# new code
train_x
,
train_label
=
train_x
.
to
(
device
),
train_label
.
to
(
device
)
# new code
label_np
=
np
.
zeros
((
train_label
.
shape
[
0
],
num_labels
))
label_np
=
np
.
zeros
((
train_label
.
shape
[
0
],
num_labels
))
sgd
.
zero_grad
()
sgd
.
zero_grad
()
predict_y
=
model
(
train_x
.
float
())
predict_y
=
model
(
train_x
.
float
())
loss
=
cost
(
predict_y
,
train_label
.
long
())
loss
=
cost
(
predict_y
,
train_label
.
long
())
loss
.
backward
()
loss
.
backward
()
sgd
.
step
()
sgd
.
step
()
# check validation accuracy on validation set
# check validation accuracy on validation set
correct
=
0
correct
=
0
_sum
=
0
_sum
=
0
model
.
eval
()
model
.
eval
()
for
idx
,
(
test_x
,
test_label
)
in
enumerate
(
test_loader
):
for
idx
,
(
test_x
,
test_label
)
in
enumerate
(
test_loader
):
test_x
,
test_label
=
test_x
.
to
(
device
),
test_label
.
to
(
device
)
# new code
test_x
,
test_label
=
test_x
.
to
(
device
),
test_label
.
to
(
device
)
# new code
predict_y
=
model
(
test_x
.
float
()).
detach
()
predict_y
=
model
(
test_x
.
float
()).
detach
()
#predict_ys = np.argmax(predict_y, axis=-1)
#predict_ys = np.argmax(predict_y, axis=-1)
predict_ys
=
torch
.
argmax
(
predict_y
,
axis
=-
1
)
# changed np to torch
predict_ys
=
torch
.
argmax
(
predict_y
,
axis
=-
1
)
# changed np to torch
#label_np = test_label.numpy()
#label_np = test_label.numpy()
_
=
predict_ys
==
test_label
_
=
predict_ys
==
test_label
#correct += np.sum(_.numpy(), axis=-1)
#correct += np.sum(_.numpy(), axis=-1)
correct
+=
np
.
sum
(
_
.
cpu
().
numpy
(),
axis
=-
1
)
# added .cpu()
correct
+=
np
.
sum
(
_
.
cpu
().
numpy
(),
axis
=-
1
)
# added .cpu()
_sum
+=
_
.
shape
[
0
]
_sum
+=
_
.
shape
[
0
]
acc
=
correct
/
_sum
acc
=
correct
/
_sum
if
average_validation
[
0
]
<=
_epoch
<=
average_validation
[
1
]:
if
average_validation
[
0
]
<=
_epoch
<=
average_validation
[
1
]:
total_val
+=
acc
total_val
+=
acc
# update best validation accuracy if it was higher, otherwise increase early stop count
# update best validation accuracy if it was higher, otherwise increase early stop count
if
acc
>
best_acc
:
if
acc
>
best_acc
:
best_acc
=
acc
best_acc
=
acc
early_stop_cnt
=
0
early_stop_cnt
=
0
else
:
else
:
early_stop_cnt
+=
1
early_stop_cnt
+=
1
# exit if validation gets worse over 10 runs and using early stopping
# exit if validation gets worse over 10 runs and using early stopping
if
early_stop_cnt
>=
early_stop_num
and
early_stop_flag
:
if
early_stop_cnt
>=
early_stop_num
and
early_stop_flag
:
break
break
# exit if using fixed epoch length
# exit if using fixed epoch length
if
_epoch
>=
average_validation
[
1
]
and
not
early_stop_flag
:
if
_epoch
>=
average_validation
[
1
]
and
not
early_stop_flag
:
best_acc
=
total_val
/
(
average_validation
[
1
]
-
average_validation
[
0
]
+
1
)
best_acc
=
total_val
/
(
average_validation
[
1
]
-
average_validation
[
0
]
+
1
)
break
break
# update q_values
# update q_values
if
policy
<
num_policies
:
if
policy
<
num_policies
:
q_values
[
this_policy
]
+=
best_acc
q_values
[
this_policy
]
+=
best_acc
else
:
else
:
q_values
[
this_policy
]
=
(
q_values
[
this_policy
]
*
cnts
[
this_policy
]
+
best_acc
)
/
(
cnts
[
this_policy
]
+
1
)
q_values
[
this_policy
]
=
(
q_values
[
this_policy
]
*
cnts
[
this_policy
]
+
best_acc
)
/
(
cnts
[
this_policy
]
+
1
)
best_q_value
=
max
(
q_values
)
best_q_value
=
max
(
q_values
)
best_q_values
.
append
(
best_q_value
)
best_q_values
.
append
(
best_q_value
)
if
(
policy
+
1
)
%
10
==
0
:
if
(
policy
+
1
)
%
10
==
0
:
print
(
"
Iteration: {},
\t
Q-Values: {}, Best Policy: {}
"
.
format
(
policy
+
1
,
list
(
np
.
around
(
np
.
array
(
q_values
),
2
)),
max
(
list
(
np
.
around
(
np
.
array
(
q_values
),
2
)))))
print
(
"
Iteration: {},
\t
Q-Values: {}, Best Policy: {}
"
.
format
(
policy
+
1
,
list
(
np
.
around
(
np
.
array
(
q_values
),
2
)),
max
(
list
(
np
.
around
(
np
.
array
(
q_values
),
2
)))))
# update counts
# update counts
cnts
[
this_policy
]
+=
1
cnts
[
this_policy
]
+=
1
total_count
+=
1
total_count
+=
1
# update q_plus_cnt values every turn after the initial sweep through
# update q_plus_cnt values every turn after the initial sweep through
if
policy
>=
num_policies
-
1
:
if
policy
>=
num_policies
-
1
:
for
i
in
range
(
num_policies
):
for
i
in
range
(
num_policies
):
q_plus_cnt
[
i
]
=
q_values
[
i
]
+
np
.
sqrt
(
2
*
np
.
log
(
total_count
)
/
cnts
[
i
])
q_plus_cnt
[
i
]
=
q_values
[
i
]
+
np
.
sqrt
(
2
*
np
.
log
(
total_count
)
/
cnts
[
i
])
return
q_values
,
best_q_values
return
q_values
,
best_q_values
```
```
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
```
python
%%
time
%%
time
batch_size
=
32
# size of batch the inner NN is trained with
batch_size
=
32
# size of batch the inner NN is trained with
learning_rate
=
1e-1
# fix learning rate
learning_rate
=
1e-1
# fix learning rate
ds
=
"
MNIST
"
# pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100)
ds
=
"
MNIST
"
# pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100)
toy_size =
0.02
# total propeortion of training and test set we use
toy_size
=
1
# total propeortion of training and test set we use
max_epochs
=
100
# max number of epochs that is run if early stopping is not hit
max_epochs
=
100
# max number of epochs that is run if early stopping is not hit
early_stop_num
=
10
# max number of worse validation scores before early stopping is triggered
early_stop_num
=
10
# max number of worse validation scores before early stopping is triggered
early_stop_flag
=
True
# implement early stopping or not
early_stop_flag
=
True
# implement early stopping or not
average_validation
=
[
15
,
25
]
# if not implementing early stopping, what epochs are we averaging over
average_validation
=
[
15
,
25
]
# if not implementing early stopping, what epochs are we averaging over
num_policies
=
5
# fix number of policies
num_policies
=
5
# fix number of policies
num_sub_policies
=
5
# fix number of sub-policies in a policy
num_sub_policies
=
5
# fix number of sub-policies in a policy
iterations
=
100
# total iterations, should be more than the number of policies
iterations
=
100
# total iterations, should be more than the number of policies
IsLeNet
=
"
SimpleNet
"
# using LeNet or EasyNet or SimpleNet
IsLeNet
=
"
SimpleNet
"
# using LeNet or EasyNet or SimpleNet
# generate random policies at start
# generate random policies at start
policies
=
generate_policies
(
num_policies
,
num_sub_policies
)
policies
=
generate_policies
(
num_policies
,
num_sub_policies
)
q_values
,
best_q_values
=
run_UCB1
(
policies
,
batch_size
,
learning_rate
,
ds
,
toy_size
,
max_epochs
,
early_stop_num
,
early_stop_flag
,
average_validation
,
iterations
,
IsLeNet
)
q_values
,
best_q_values
=
run_UCB1
(
policies
,
batch_size
,
learning_rate
,
ds
,
toy_size
,
max_epochs
,
early_stop_num
,
early_stop_flag
,
average_validation
,
iterations
,
IsLeNet
)
plt
.
plot
(
best_q_values
)
plt
.
plot
(
best_q_values
)
best_q_values
=
np
.
array
(
best_q_values
)
best_q_values
=
np
.
array
(
best_q_values
)
save
(
'
best_q_values_{}_{}percent_{}.npy
'
.
format
(
IsLeNet
,
int
(
toy_size
*
100
),
ds
),
best_q_values
)
save
(
'
best_q_values_{}_{}percent_{}.npy
'
.
format
(
IsLeNet
,
int
(
toy_size
*
100
),
ds
),
best_q_values
)
#best_q_values = load('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), allow_pickle=True)
#best_q_values = load('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), allow_pickle=True)
```
```
%% Output
%% Output
100%|██████████| 10/10 [01:28<00:00, 8.84s/it]
100%|██████████| 10/10 [01:28<00:00, 8.84s/it]
Iteration: 10, Q-Values: [0.77, 0.74, 0.8, 0.72, 0.77], Best Policy: 0.8
Iteration: 10, Q-Values: [0.77, 0.74, 0.8, 0.72, 0.77], Best Policy: 0.8
CPU times: user 1min 21s, sys: 694 ms, total: 1min 22s
CPU times: user 1min 21s, sys: 694 ms, total: 1min 22s
Wall time: 1min 28s
Wall time: 1min 28s
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment