Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
M
MetaRL
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Wang, Mia
MetaRL
Commits
63a7084c
Commit
63a7084c
authored
3 years ago
by
Sun Jin Kim
Browse files
Options
Downloads
Patches
Plain Diff
John: Add EasyNet
parent
3a200057
No related branches found
No related tags found
No related merge requests found
Changes
2
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
MetaAugment/Baseline_JC.ipynb
+78
-192
78 additions, 192 deletions
MetaAugment/Baseline_JC.ipynb
MetaAugment/UCB1_JC.ipynb
+241
-283
241 additions, 283 deletions
MetaAugment/UCB1_JC.ipynb
with
319 additions
and
475 deletions
MetaAugment/Baseline_JC.ipynb
+
78
−
192
View file @
63a7084c
...
@@ -62,7 +62,33 @@
...
@@ -62,7 +62,33 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"source": [
"\"\"\"Define internal NN module that trains on the dataset\"\"\"\n",
"class EasyNet(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.fc1 = nn.Linear(784, 2048)\n",
" self.relu1 = nn.ReLU()\n",
" self.fc2 = nn.Linear(2048, 10)\n",
" self.relu2 = nn.ReLU()\n",
"\n",
" def forward(self, x):\n",
" y = x.view(x.shape[0], -1)\n",
" y = self.fc1(y)\n",
" y = self.relu1(y)\n",
" y = self.fc2(y)\n",
" y = self.relu2(y)\n",
" return y"
],
"metadata": {
"id": "ukf2-C94UWzs"
},
"execution_count": 3,
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"metadata": {
"id": "xujQtvVWBgMH"
"id": "xujQtvVWBgMH"
},
},
...
@@ -93,13 +119,13 @@
...
@@ -93,13 +119,13 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count":
4
,
"execution_count":
5
,
"metadata": {
"metadata": {
"id": "vu_4I4qkbx73"
"id": "vu_4I4qkbx73"
},
},
"outputs": [],
"outputs": [],
"source": [
"source": [
"def run_baseline(batch_size=32, toy_size=0.02, max_epochs=100, early_stop_num=10, early_stop_flag=True, average_validation=[15,25]):\n",
"def run_baseline(batch_size=32, toy_size=0.02, max_epochs=100, early_stop_num=10, early_stop_flag=True, average_validation=[15,25]
, IsLeNet=True
):\n",
"\n",
"\n",
" # create transformations using above info\n",
" # create transformations using above info\n",
" transform = torchvision.transforms.Compose([\n",
" transform = torchvision.transforms.Compose([\n",
...
@@ -113,7 +139,10 @@
...
@@ -113,7 +139,10 @@
" train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, toy_size)\n",
" train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, toy_size)\n",
"\n",
"\n",
" # create model\n",
" # create model\n",
" model = LeNet()\n",
" if IsLeNet:\n",
" model = LeNet()\n",
" else:\n",
" model = EasyNet()\n",
" sgd = optim.SGD(model.parameters(), lr=1e-1)\n",
" sgd = optim.SGD(model.parameters(), lr=1e-1)\n",
" cost = nn.CrossEntropyLoss()\n",
" cost = nn.CrossEntropyLoss()\n",
"\n",
"\n",
...
@@ -171,196 +200,20 @@
...
@@ -171,196 +200,20 @@
},
},
{
{
"cell_type": "code",
"cell_type": "code",
"execution_count": 5,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "KVhYheLfBP33",
"outputId": "8009d87f-7e39-40e3-c6ef-8f3a12f9433f"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n",
"Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MetaAugment/train/MNIST/raw/train-images-idx3-ubyte.gz\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"9913344it [00:04, 2462502.04it/s] \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting ./MetaAugment/train/MNIST/raw/train-images-idx3-ubyte.gz to ./MetaAugment/train/MNIST/raw\n",
"\n",
"Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n",
"Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MetaAugment/train/MNIST/raw/train-labels-idx1-ubyte.gz\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"29696it [00:00, 3785722.37it/s] \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting ./MetaAugment/train/MNIST/raw/train-labels-idx1-ubyte.gz to ./MetaAugment/train/MNIST/raw\n",
"\n",
"Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n",
"Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MetaAugment/train/MNIST/raw/t10k-images-idx3-ubyte.gz\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"1649664it [00:00, 3348476.95it/s] \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting ./MetaAugment/train/MNIST/raw/t10k-images-idx3-ubyte.gz to ./MetaAugment/train/MNIST/raw\n",
"\n",
"Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n",
"Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MetaAugment/train/MNIST/raw/t10k-labels-idx1-ubyte.gz\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"5120it [00:00, 2935726.11it/s] \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting ./MetaAugment/train/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MetaAugment/train/MNIST/raw\n",
"\n",
"Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n",
"Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MetaAugment/test/MNIST/raw/train-images-idx3-ubyte.gz\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"9913344it [00:04, 2338660.11it/s] \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting ./MetaAugment/test/MNIST/raw/train-images-idx3-ubyte.gz to ./MetaAugment/test/MNIST/raw\n",
"\n",
"Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n",
"Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MetaAugment/test/MNIST/raw/train-labels-idx1-ubyte.gz\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"29696it [00:00, 33554432.00it/s] "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting ./MetaAugment/test/MNIST/raw/train-labels-idx1-ubyte.gz to ./MetaAugment/test/MNIST/raw\n",
"\n",
"Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n",
"Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MetaAugment/test/MNIST/raw/t10k-images-idx3-ubyte.gz\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"1649664it [00:00, 2786152.46it/s] \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting ./MetaAugment/test/MNIST/raw/t10k-images-idx3-ubyte.gz to ./MetaAugment/test/MNIST/raw\n",
"\n",
"Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n",
"Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MetaAugment/test/MNIST/raw/t10k-labels-idx1-ubyte.gz\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"5120it [00:00, 4789214.20it/s] \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Extracting ./MetaAugment/test/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MetaAugment/test/MNIST/raw\n",
"\n",
"0\tBest accuracy: 18.00%\n",
"10\tBest accuracy: 75.50%\n",
"20\tBest accuracy: 78.00%\n",
"30\tBest accuracy: 95.00%\n",
"40\tBest accuracy: 95.50%\n",
"50\tBest accuracy: 94.00%\n",
"60\tBest accuracy: 85.00%\n",
"70\tBest accuracy: 85.50%\n",
"80\tBest accuracy: 62.50%\n",
"90\tBest accuracy: 76.00%\n",
"Average best accuracy: 79.86%\n",
"\n",
"0\tAverage accuracy: 93.50%\n",
"10\tAverage accuracy: 93.45%\n",
"20\tAverage accuracy: 46.95%\n",
"30\tAverage accuracy: 71.41%\n",
"40\tAverage accuracy: 73.68%\n",
"50\tAverage accuracy: 64.50%\n",
"60\tAverage accuracy: 72.50%\n",
"70\tAverage accuracy: 94.36%\n",
"80\tAverage accuracy: 84.77%\n",
"90\tAverage accuracy: 92.14%\n",
"Average average accuracy: 80.92%\n",
"\n"
]
}
],
"source": [
"source": [
"batch_size = 32 # size of batch the inner NN is trained with\n",
"batch_size = 32 # size of batch the inner NN is trained with\n",
"toy_size = 0.0
2
# total propeortion of training and test set we use\n",
"toy_size = 0.0
5
# total propeortion of training and test set we use\n",
"max_epochs = 100 # max number of epochs that is run if early stopping is not hit\n",
"max_epochs = 100 # max number of epochs that is run if early stopping is not hit\n",
"early_stop_num = 10 # max number of worse validation scores before early stopping is triggered\n",
"early_stop_num = 10 # max number of worse validation scores before early stopping is triggered\n",
"early_stop_flag = True # implement early stopping or not\n",
"early_stop_flag = True # implement early stopping or not\n",
"average_validation = [15,25] # if not implementing early stopping, what epochs are we averaging over\n",
"average_validation = [15,25] # if not implementing early stopping, what epochs are we averaging over\n",
"num_iterations = 100 # how many iterations are we averaging over\n",
"num_iterations = 100 # how many iterations are we averaging over\n",
"IsLeNet = True # using LeNet or EasyNet\n",
"\n",
"\n",
"# run using early stopping\n",
"# run using early stopping\n",
"best_accuracies = []\n",
"best_accuracies = []\n",
"for baselines in range(num_iterations):\n",
"for baselines in range(num_iterations):\n",
" best_acc = run_baseline(batch_size, toy_size, max_epochs, early_stop_num, early_stop_flag, average_validation)\n",
" best_acc = run_baseline(batch_size, toy_size, max_epochs, early_stop_num, early_stop_flag, average_validation
, IsLeNet
)\n",
" best_accuracies.append(best_acc)\n",
" best_accuracies.append(best_acc)\n",
" if baselines % 10 == 0:\n",
" if baselines % 10 == 0:\n",
" print(\"{}\\tBest accuracy: {:.2f}%\".format(baselines, best_acc*100))\n",
" print(\"{}\\tBest accuracy: {:.2f}%\".format(baselines, best_acc*100))\n",
...
@@ -370,19 +223,52 @@
...
@@ -370,19 +223,52 @@
"early_stop_flag = False\n",
"early_stop_flag = False\n",
"best_accuracies = []\n",
"best_accuracies = []\n",
"for baselines in range(num_iterations):\n",
"for baselines in range(num_iterations):\n",
" best_acc = run_baseline(batch_size, toy_size, max_epochs, early_stop_num, early_stop_flag, average_validation)\n",
" best_acc = run_baseline(batch_size, toy_size, max_epochs, early_stop_num, early_stop_flag, average_validation
, IsLeNet
)\n",
" best_accuracies.append(best_acc)\n",
" best_accuracies.append(best_acc)\n",
" if baselines % 10 == 0:\n",
" if baselines % 10 == 0:\n",
" print(\"{}\\tAverage accuracy: {:.2f}%\".format(baselines, best_acc*100))\n",
" print(\"{}\\tAverage accuracy: {:.2f}%\".format(baselines, best_acc*100))\n",
"print(\"Average average accuracy: {:.2f}%\\n\".format(np.mean(best_accuracies)*100))"
"print(\"Average average accuracy: {:.2f}%\\n\".format(np.mean(best_accuracies)*100))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "KVhYheLfBP33",
"outputId": "39c42079-a3cb-492e-8e26-68818eeac808"
},
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"0\tBest accuracy: 95.60%\n",
"10\tBest accuracy: 85.40%\n",
"20\tBest accuracy: 86.40%\n",
"30\tBest accuracy: 95.40%\n",
"40\tBest accuracy: 97.00%\n",
"50\tBest accuracy: 80.40%\n",
"60\tBest accuracy: 95.60%\n",
"70\tBest accuracy: 96.40%\n",
"80\tBest accuracy: 86.20%\n",
"90\tBest accuracy: 95.40%\n",
"Average best accuracy: 84.65%\n",
"\n",
"0\tAverage accuracy: 78.45%\n",
"10\tAverage accuracy: 58.02%\n",
"20\tAverage accuracy: 38.60%\n",
"30\tAverage accuracy: 65.15%\n",
"40\tAverage accuracy: 77.22%\n",
"50\tAverage accuracy: 79.09%\n",
"60\tAverage accuracy: 95.55%\n",
"70\tAverage accuracy: 86.33%\n",
"80\tAverage accuracy: 85.98%\n",
"90\tAverage accuracy: 78.20%\n",
"Average average accuracy: 83.31%\n",
"\n"
]
}
]
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
}
],
],
"metadata": {
"metadata": {
...
@@ -406,9 +292,9 @@
...
@@ -406,9 +292,9 @@
"name": "python",
"name": "python",
"nbconvert_exporter": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"pygments_lexer": "ipython3",
"version": "3.
9
.7"
"version": "3.
7
.7"
}
}
},
},
"nbformat": 4,
"nbformat": 4,
"nbformat_minor": 0
"nbformat_minor": 0
}
}
\ No newline at end of file
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
torch.optim
as
optim
import
torch.optim
as
optim
import
torch.utils.data
as
data_utils
import
torch.utils.data
as
data_utils
import
torchvision
import
torchvision
import
torchvision.datasets
as
datasets
import
torchvision.datasets
as
datasets
```
```
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
"""
Define internal NN module that trains on the dataset
"""
"""
Define internal NN module that trains on the dataset
"""
class
LeNet
(
nn
.
Module
):
class
LeNet
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__init__
()
super
().
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
1
,
6
,
5
)
self
.
conv1
=
nn
.
Conv2d
(
1
,
6
,
5
)
self
.
relu1
=
nn
.
ReLU
()
self
.
relu1
=
nn
.
ReLU
()
self
.
pool1
=
nn
.
MaxPool2d
(
2
)
self
.
pool1
=
nn
.
MaxPool2d
(
2
)
self
.
conv2
=
nn
.
Conv2d
(
6
,
16
,
5
)
self
.
conv2
=
nn
.
Conv2d
(
6
,
16
,
5
)
self
.
relu2
=
nn
.
ReLU
()
self
.
relu2
=
nn
.
ReLU
()
self
.
pool2
=
nn
.
MaxPool2d
(
2
)
self
.
pool2
=
nn
.
MaxPool2d
(
2
)
self
.
fc1
=
nn
.
Linear
(
256
,
120
)
self
.
fc1
=
nn
.
Linear
(
256
,
120
)
self
.
relu3
=
nn
.
ReLU
()
self
.
relu3
=
nn
.
ReLU
()
self
.
fc2
=
nn
.
Linear
(
120
,
84
)
self
.
fc2
=
nn
.
Linear
(
120
,
84
)
self
.
relu4
=
nn
.
ReLU
()
self
.
relu4
=
nn
.
ReLU
()
self
.
fc3
=
nn
.
Linear
(
84
,
10
)
self
.
fc3
=
nn
.
Linear
(
84
,
10
)
self
.
relu5
=
nn
.
ReLU
()
self
.
relu5
=
nn
.
ReLU
()
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
y
=
self
.
conv1
(
x
)
y
=
self
.
conv1
(
x
)
y
=
self
.
relu1
(
y
)
y
=
self
.
relu1
(
y
)
y
=
self
.
pool1
(
y
)
y
=
self
.
pool1
(
y
)
y
=
self
.
conv2
(
y
)
y
=
self
.
conv2
(
y
)
y
=
self
.
relu2
(
y
)
y
=
self
.
relu2
(
y
)
y
=
self
.
pool2
(
y
)
y
=
self
.
pool2
(
y
)
y
=
y
.
view
(
y
.
shape
[
0
],
-
1
)
y
=
y
.
view
(
y
.
shape
[
0
],
-
1
)
y
=
self
.
fc1
(
y
)
y
=
self
.
fc1
(
y
)
y
=
self
.
relu3
(
y
)
y
=
self
.
relu3
(
y
)
y
=
self
.
fc2
(
y
)
y
=
self
.
fc2
(
y
)
y
=
self
.
relu4
(
y
)
y
=
self
.
relu4
(
y
)
y
=
self
.
fc3
(
y
)
y
=
self
.
fc3
(
y
)
y
=
self
.
relu5
(
y
)
y
=
self
.
relu5
(
y
)
return
y
return
y
```
```
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
"""
Define internal NN module that trains on the dataset
"""
class
EasyNet
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
fc1
=
nn
.
Linear
(
784
,
2048
)
self
.
relu1
=
nn
.
ReLU
()
self
.
fc2
=
nn
.
Linear
(
2048
,
10
)
self
.
relu2
=
nn
.
ReLU
()
def
forward
(
self
,
x
):
y
=
x
.
view
(
x
.
shape
[
0
],
-
1
)
y
=
self
.
fc1
(
y
)
y
=
self
.
relu1
(
y
)
y
=
self
.
fc2
(
y
)
y
=
self
.
relu2
(
y
)
return
y
```
%% Cell type:code id: tags:
```
python
"""
Make toy dataset
"""
"""
Make toy dataset
"""
def
create_toy
(
train_dataset
,
test_dataset
,
batch_size
,
n_samples
):
def
create_toy
(
train_dataset
,
test_dataset
,
batch_size
,
n_samples
):
# shuffle and take first n_samples %age of training dataset
# shuffle and take first n_samples %age of training dataset
shuffle_order_train
=
np
.
random
.
RandomState
(
seed
=
100
).
permutation
(
len
(
train_dataset
))
shuffle_order_train
=
np
.
random
.
RandomState
(
seed
=
100
).
permutation
(
len
(
train_dataset
))
shuffled_train_dataset
=
torch
.
utils
.
data
.
Subset
(
train_dataset
,
shuffle_order_train
)
shuffled_train_dataset
=
torch
.
utils
.
data
.
Subset
(
train_dataset
,
shuffle_order_train
)
indices_train
=
torch
.
arange
(
int
(
n_samples
*
len
(
train_dataset
)))
indices_train
=
torch
.
arange
(
int
(
n_samples
*
len
(
train_dataset
)))
reduced_train_dataset
=
data_utils
.
Subset
(
shuffled_train_dataset
,
indices_train
)
reduced_train_dataset
=
data_utils
.
Subset
(
shuffled_train_dataset
,
indices_train
)
# shuffle and take first n_samples %age of test dataset
# shuffle and take first n_samples %age of test dataset
shuffle_order_test
=
np
.
random
.
RandomState
(
seed
=
1000
).
permutation
(
len
(
test_dataset
))
shuffle_order_test
=
np
.
random
.
RandomState
(
seed
=
1000
).
permutation
(
len
(
test_dataset
))
shuffled_test_dataset
=
torch
.
utils
.
data
.
Subset
(
test_dataset
,
shuffle_order_test
)
shuffled_test_dataset
=
torch
.
utils
.
data
.
Subset
(
test_dataset
,
shuffle_order_test
)
indices_test
=
torch
.
arange
(
int
(
n_samples
*
len
(
test_dataset
)))
indices_test
=
torch
.
arange
(
int
(
n_samples
*
len
(
test_dataset
)))
reduced_test_dataset
=
data_utils
.
Subset
(
shuffled_test_dataset
,
indices_test
)
reduced_test_dataset
=
data_utils
.
Subset
(
shuffled_test_dataset
,
indices_test
)
# push into DataLoader
# push into DataLoader
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
reduced_train_dataset
,
batch_size
=
batch_size
)
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
reduced_train_dataset
,
batch_size
=
batch_size
)
test_loader
=
torch
.
utils
.
data
.
DataLoader
(
reduced_test_dataset
,
batch_size
=
batch_size
)
test_loader
=
torch
.
utils
.
data
.
DataLoader
(
reduced_test_dataset
,
batch_size
=
batch_size
)
return
train_loader
,
test_loader
return
train_loader
,
test_loader
```
```
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
def
run_baseline
(
batch_size
=
32
,
toy_size
=
0.02
,
max_epochs
=
100
,
early_stop_num
=
10
,
early_stop_flag
=
True
,
average_validation
=
[
15
,
25
]):
def
run_baseline
(
batch_size
=
32
,
toy_size
=
0.02
,
max_epochs
=
100
,
early_stop_num
=
10
,
early_stop_flag
=
True
,
average_validation
=
[
15
,
25
]
,
IsLeNet
=
True
):
# create transformations using above info
# create transformations using above info
transform
=
torchvision
.
transforms
.
Compose
([
transform
=
torchvision
.
transforms
.
Compose
([
torchvision
.
transforms
.
ToTensor
()])
torchvision
.
transforms
.
ToTensor
()])
# open data and apply these transformations
# open data and apply these transformations
train_dataset
=
datasets
.
MNIST
(
root
=
'
./MetaAugment/train
'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
train_dataset
=
datasets
.
MNIST
(
root
=
'
./MetaAugment/train
'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
test_dataset
=
datasets
.
MNIST
(
root
=
'
./MetaAugment/test
'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
test_dataset
=
datasets
.
MNIST
(
root
=
'
./MetaAugment/test
'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
# create toy dataset from above uploaded data
# create toy dataset from above uploaded data
train_loader
,
test_loader
=
create_toy
(
train_dataset
,
test_dataset
,
batch_size
,
toy_size
)
train_loader
,
test_loader
=
create_toy
(
train_dataset
,
test_dataset
,
batch_size
,
toy_size
)
# create model
# create model
model
=
LeNet
()
if
IsLeNet
:
model
=
LeNet
()
else
:
model
=
EasyNet
()
sgd
=
optim
.
SGD
(
model
.
parameters
(),
lr
=
1e-1
)
sgd
=
optim
.
SGD
(
model
.
parameters
(),
lr
=
1e-1
)
cost
=
nn
.
CrossEntropyLoss
()
cost
=
nn
.
CrossEntropyLoss
()
# set variables for best validation accuracy and early stop count
# set variables for best validation accuracy and early stop count
best_acc
=
0
best_acc
=
0
early_stop_cnt
=
0
early_stop_cnt
=
0
total_val
=
0
total_val
=
0
# train model and check validation accuracy each epoch
# train model and check validation accuracy each epoch
for
_epoch
in
range
(
max_epochs
):
for
_epoch
in
range
(
max_epochs
):
# train model
# train model
model
.
train
()
model
.
train
()
for
idx
,
(
train_x
,
train_label
)
in
enumerate
(
train_loader
):
for
idx
,
(
train_x
,
train_label
)
in
enumerate
(
train_loader
):
label_np
=
np
.
zeros
((
train_label
.
shape
[
0
],
10
))
label_np
=
np
.
zeros
((
train_label
.
shape
[
0
],
10
))
sgd
.
zero_grad
()
sgd
.
zero_grad
()
predict_y
=
model
(
train_x
.
float
())
predict_y
=
model
(
train_x
.
float
())
loss
=
cost
(
predict_y
,
train_label
.
long
())
loss
=
cost
(
predict_y
,
train_label
.
long
())
loss
.
backward
()
loss
.
backward
()
sgd
.
step
()
sgd
.
step
()
# check validation accuracy on validation set
# check validation accuracy on validation set
correct
=
0
correct
=
0
_sum
=
0
_sum
=
0
model
.
eval
()
model
.
eval
()
for
idx
,
(
test_x
,
test_label
)
in
enumerate
(
test_loader
):
for
idx
,
(
test_x
,
test_label
)
in
enumerate
(
test_loader
):
predict_y
=
model
(
test_x
.
float
()).
detach
()
predict_y
=
model
(
test_x
.
float
()).
detach
()
predict_ys
=
np
.
argmax
(
predict_y
,
axis
=-
1
)
predict_ys
=
np
.
argmax
(
predict_y
,
axis
=-
1
)
label_np
=
test_label
.
numpy
()
label_np
=
test_label
.
numpy
()
_
=
predict_ys
==
test_label
_
=
predict_ys
==
test_label
correct
+=
np
.
sum
(
_
.
numpy
(),
axis
=-
1
)
correct
+=
np
.
sum
(
_
.
numpy
(),
axis
=-
1
)
_sum
+=
_
.
shape
[
0
]
_sum
+=
_
.
shape
[
0
]
acc
=
correct
/
_sum
acc
=
correct
/
_sum
# update the total validation
# update the total validation
if
average_validation
[
0
]
<=
_epoch
<=
average_validation
[
1
]:
if
average_validation
[
0
]
<=
_epoch
<=
average_validation
[
1
]:
total_val
+=
acc
total_val
+=
acc
# update best validation accuracy if it was higher, otherwise increase early stop count
# update best validation accuracy if it was higher, otherwise increase early stop count
if
acc
>
best_acc
:
if
acc
>
best_acc
:
best_acc
=
acc
best_acc
=
acc
early_stop_cnt
=
0
early_stop_cnt
=
0
else
:
else
:
early_stop_cnt
+=
1
early_stop_cnt
+=
1
# exit if validation gets worse over 10 runs and using early stopping
# exit if validation gets worse over 10 runs and using early stopping
if
early_stop_cnt
>=
early_stop_num
and
early_stop_flag
:
if
early_stop_cnt
>=
early_stop_num
and
early_stop_flag
:
return
best_acc
return
best_acc
# exit if using fixed epoch length
# exit if using fixed epoch length
if
_epoch
>=
average_validation
[
1
]
and
not
early_stop_flag
:
if
_epoch
>=
average_validation
[
1
]
and
not
early_stop_flag
:
return
total_val
/
(
average_validation
[
1
]
-
average_validation
[
0
]
+
1
)
return
total_val
/
(
average_validation
[
1
]
-
average_validation
[
0
]
+
1
)
```
```
%% Cell type:code id: tags:
%% Cell type:code id: tags:
```
python
```
python
batch_size
=
32
# size of batch the inner NN is trained with
batch_size
=
32
# size of batch the inner NN is trained with
toy_size
=
0.0
2
# total propeortion of training and test set we use
toy_size
=
0.0
5
# total propeortion of training and test set we use
max_epochs
=
100
# max number of epochs that is run if early stopping is not hit
max_epochs
=
100
# max number of epochs that is run if early stopping is not hit
early_stop_num
=
10
# max number of worse validation scores before early stopping is triggered
early_stop_num
=
10
# max number of worse validation scores before early stopping is triggered
early_stop_flag
=
True
# implement early stopping or not
early_stop_flag
=
True
# implement early stopping or not
average_validation
=
[
15
,
25
]
# if not implementing early stopping, what epochs are we averaging over
average_validation
=
[
15
,
25
]
# if not implementing early stopping, what epochs are we averaging over
num_iterations
=
100
# how many iterations are we averaging over
num_iterations
=
100
# how many iterations are we averaging over
IsLeNet
=
True
# using LeNet or EasyNet
# run using early stopping
# run using early stopping
best_accuracies
=
[]
best_accuracies
=
[]
for
baselines
in
range
(
num_iterations
):
for
baselines
in
range
(
num_iterations
):
best_acc
=
run_baseline
(
batch_size
,
toy_size
,
max_epochs
,
early_stop_num
,
early_stop_flag
,
average_validation
)
best_acc
=
run_baseline
(
batch_size
,
toy_size
,
max_epochs
,
early_stop_num
,
early_stop_flag
,
average_validation
,
IsLeNet
)
best_accuracies
.
append
(
best_acc
)
best_accuracies
.
append
(
best_acc
)
if
baselines
%
10
==
0
:
if
baselines
%
10
==
0
:
print
(
"
{}
\t
Best accuracy: {:.2f}%
"
.
format
(
baselines
,
best_acc
*
100
))
print
(
"
{}
\t
Best accuracy: {:.2f}%
"
.
format
(
baselines
,
best_acc
*
100
))
print
(
"
Average best accuracy: {:.2f}%
\n
"
.
format
(
np
.
mean
(
best_accuracies
)
*
100
))
print
(
"
Average best accuracy: {:.2f}%
\n
"
.
format
(
np
.
mean
(
best_accuracies
)
*
100
))
# run using average validation losses
# run using average validation losses
early_stop_flag
=
False
early_stop_flag
=
False
best_accuracies
=
[]
best_accuracies
=
[]
for
baselines
in
range
(
num_iterations
):
for
baselines
in
range
(
num_iterations
):
best_acc
=
run_baseline
(
batch_size
,
toy_size
,
max_epochs
,
early_stop_num
,
early_stop_flag
,
average_validation
)
best_acc
=
run_baseline
(
batch_size
,
toy_size
,
max_epochs
,
early_stop_num
,
early_stop_flag
,
average_validation
,
IsLeNet
)
best_accuracies
.
append
(
best_acc
)
best_accuracies
.
append
(
best_acc
)
if
baselines
%
10
==
0
:
if
baselines
%
10
==
0
:
print
(
"
{}
\t
Average accuracy: {:.2f}%
"
.
format
(
baselines
,
best_acc
*
100
))
print
(
"
{}
\t
Average accuracy: {:.2f}%
"
.
format
(
baselines
,
best_acc
*
100
))
print
(
"
Average average accuracy: {:.2f}%
\n
"
.
format
(
np
.
mean
(
best_accuracies
)
*
100
))
print
(
"
Average average accuracy: {:.2f}%
\n
"
.
format
(
np
.
mean
(
best_accuracies
)
*
100
))
```
```
%% Output
%% Output
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
0 Best accuracy: 95.60%
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MetaAugment/train/MNIST/raw/train-images-idx3-ubyte.gz
10 Best accuracy: 85.40%
20 Best accuracy: 86.40%
9913344it [00:04, 2462502.04it/s]
30 Best accuracy: 95.40%
40 Best accuracy: 97.00%
Extracting ./MetaAugment/train/MNIST/raw/train-images-idx3-ubyte.gz to ./MetaAugment/train/MNIST/raw
50 Best accuracy: 80.40%
60 Best accuracy: 95.60%
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
70 Best accuracy: 96.40%
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MetaAugment/train/MNIST/raw/train-labels-idx1-ubyte.gz
80 Best accuracy: 86.20%
90 Best accuracy: 95.40%
29696it [00:00, 3785722.37it/s]
Average best accuracy: 84.65%
Extracting ./MetaAugment/train/MNIST/raw/train-labels-idx1-ubyte.gz to ./MetaAugment/train/MNIST/raw
0 Average accuracy: 78.45%
10 Average accuracy: 58.02%
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
20 Average accuracy: 38.60%
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MetaAugment/train/MNIST/raw/t10k-images-idx3-ubyte.gz
30 Average accuracy: 65.15%
40 Average accuracy: 77.22%
1649664it [00:00, 3348476.95it/s]
50 Average accuracy: 79.09%
60 Average accuracy: 95.55%
Extracting ./MetaAugment/train/MNIST/raw/t10k-images-idx3-ubyte.gz to ./MetaAugment/train/MNIST/raw
70 Average accuracy: 86.33%
80 Average accuracy: 85.98%
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
90 Average accuracy: 78.20%
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MetaAugment/train/MNIST/raw/t10k-labels-idx1-ubyte.gz
Average average accuracy: 83.31%
5120it [00:00, 2935726.11it/s]
Extracting ./MetaAugment/train/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MetaAugment/train/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MetaAugment/test/MNIST/raw/train-images-idx3-ubyte.gz
9913344it [00:04, 2338660.11it/s]
Extracting ./MetaAugment/test/MNIST/raw/train-images-idx3-ubyte.gz to ./MetaAugment/test/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MetaAugment/test/MNIST/raw/train-labels-idx1-ubyte.gz
29696it [00:00, 33554432.00it/s]
Extracting ./MetaAugment/test/MNIST/raw/train-labels-idx1-ubyte.gz to ./MetaAugment/test/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MetaAugment/test/MNIST/raw/t10k-images-idx3-ubyte.gz
1649664it [00:00, 2786152.46it/s]
Extracting ./MetaAugment/test/MNIST/raw/t10k-images-idx3-ubyte.gz to ./MetaAugment/test/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MetaAugment/test/MNIST/raw/t10k-labels-idx1-ubyte.gz
5120it [00:00, 4789214.20it/s]
Extracting ./MetaAugment/test/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MetaAugment/test/MNIST/raw
0 Best accuracy: 18.00%
10 Best accuracy: 75.50%
20 Best accuracy: 78.00%
30 Best accuracy: 95.00%
40 Best accuracy: 95.50%
50 Best accuracy: 94.00%
60 Best accuracy: 85.00%
70 Best accuracy: 85.50%
80 Best accuracy: 62.50%
90 Best accuracy: 76.00%
Average best accuracy: 79.86%
0 Average accuracy: 93.50%
10 Average accuracy: 93.45%
20 Average accuracy: 46.95%
30 Average accuracy: 71.41%
40 Average accuracy: 73.68%
50 Average accuracy: 64.50%
60 Average accuracy: 72.50%
70 Average accuracy: 94.36%
80 Average accuracy: 84.77%
90 Average accuracy: 92.14%
Average average accuracy: 80.92%
%% Cell type:code id: tags:
```
python
```
...
...
This diff is collapsed.
Click to expand it.
MetaAugment/UCB1_JC.ipynb
+
241
−
283
View file @
63a7084c
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment