Skip to content
Snippets Groups Projects
Commit 6b61cb19 authored by John Carter's avatar John Carter
Browse files

baseline code added - jc

parent 5488b922
No related branches found
No related tags found
No related merge requests found
%% Cell type:code id: tags:
``` python
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data_utils
import torchvision
import torchvision.datasets as datasets
```
%% Cell type:code id: tags:
``` python
"""Define internal NN module that trains on the dataset"""
class LeNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.relu2 = nn.ReLU()
self.pool2 = nn.MaxPool2d(2)
self.fc1 = nn.Linear(256, 120)
self.relu3 = nn.ReLU()
self.fc2 = nn.Linear(120, 84)
self.relu4 = nn.ReLU()
self.fc3 = nn.Linear(84, 10)
self.relu5 = nn.ReLU()
def forward(self, x):
y = self.conv1(x)
y = self.relu1(y)
y = self.pool1(y)
y = self.conv2(y)
y = self.relu2(y)
y = self.pool2(y)
y = y.view(y.shape[0], -1)
y = self.fc1(y)
y = self.relu3(y)
y = self.fc2(y)
y = self.relu4(y)
y = self.fc3(y)
y = self.relu5(y)
return y
```
%% Cell type:code id: tags:
``` python
"""Make toy dataset"""
def create_toy(train_dataset, test_dataset, batch_size, n_samples):
# shuffle and take first n_samples %age of training dataset
shuffle_order_train = np.random.RandomState(seed=100).permutation(len(train_dataset))
shuffled_train_dataset = torch.utils.data.Subset(train_dataset, shuffle_order_train)
indices_train = torch.arange(int(n_samples*len(train_dataset)))
reduced_train_dataset = data_utils.Subset(shuffled_train_dataset, indices_train)
# shuffle and take first n_samples %age of test dataset
shuffle_order_test = np.random.RandomState(seed=1000).permutation(len(test_dataset))
shuffled_test_dataset = torch.utils.data.Subset(test_dataset, shuffle_order_test)
indices_test = torch.arange(int(n_samples*len(test_dataset)))
reduced_test_dataset = data_utils.Subset(shuffled_test_dataset, indices_test)
# push into DataLoader
train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(reduced_test_dataset, batch_size=batch_size)
return train_loader, test_loader
```
%% Cell type:code id: tags:
``` python
def run_baseline(batch_size=32, toy_size=0.02, max_epochs=100, early_stop_num=10, early_stop_flag=True, average_validation=[15,25]):
# create transformations using above info
transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()])
# open data and apply these transformations
train_dataset = datasets.MNIST(root='./MetaAugment/train', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./MetaAugment/test', train=False, download=True, transform=transform)
# create toy dataset from above uploaded data
train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, toy_size)
# create model
model = LeNet()
sgd = optim.SGD(model.parameters(), lr=1e-1)
cost = nn.CrossEntropyLoss()
# set variables for best validation accuracy and early stop count
best_acc = 0
early_stop_cnt = 0
total_val = 0
# train model and check validation accuracy each epoch
for _epoch in range(max_epochs):
# train model
model.train()
for idx, (train_x, train_label) in enumerate(train_loader):
label_np = np.zeros((train_label.shape[0], 10))
sgd.zero_grad()
predict_y = model(train_x.float())
loss = cost(predict_y, train_label.long())
loss.backward()
sgd.step()
# check validation accuracy on validation set
correct = 0
_sum = 0
model.eval()
for idx, (test_x, test_label) in enumerate(test_loader):
predict_y = model(test_x.float()).detach()
predict_ys = np.argmax(predict_y, axis=-1)
label_np = test_label.numpy()
_ = predict_ys == test_label
correct += np.sum(_.numpy(), axis=-1)
_sum += _.shape[0]
acc = correct / _sum
# update the total validation
if average_validation[0] <= _epoch <= average_validation[1]:
total_val += acc
# update best validation accuracy if it was higher, otherwise increase early stop count
if acc > best_acc:
best_acc = acc
early_stop_cnt = 0
else:
early_stop_cnt += 1
# exit if validation gets worse over 10 runs and using early stopping
if early_stop_cnt >= early_stop_num and early_stop_flag:
return best_acc
# exit if using fixed epoch length
if _epoch >= average_validation[1] and not early_stop_flag:
return total_val / (average_validation[1] - average_validation[0] + 1)
```
%% Cell type:code id: tags:
``` python
batch_size = 32 # size of batch the inner NN is trained with
toy_size = 0.02 # total propeortion of training and test set we use
max_epochs = 100 # max number of epochs that is run if early stopping is not hit
early_stop_num = 10 # max number of worse validation scores before early stopping is triggered
early_stop_flag = True # implement early stopping or not
average_validation = [15,25] # if not implementing early stopping, what epochs are we averaging over
num_iterations = 100 # how many iterations are we averaging over
# run using early stopping
best_accuracies = []
for baselines in range(num_iterations):
best_acc = run_baseline(batch_size, toy_size, max_epochs, early_stop_num, early_stop_flag, average_validation)
best_accuracies.append(best_acc)
if baselines % 10 == 0:
print("{}\tBest accuracy: {:.2f}%".format(baselines, best_acc*100))
print("Average best accuracy: {:.2f}%\n".format(np.mean(best_accuracies)*100))
# run using average validation losses
early_stop_flag = False
best_accuracies = []
for baselines in range(num_iterations):
best_acc = run_baseline(batch_size, toy_size, max_epochs, early_stop_num, early_stop_flag, average_validation)
best_accuracies.append(best_acc)
if baselines % 10 == 0:
print("{}\tAverage accuracy: {:.2f}%".format(baselines, best_acc*100))
print("Average average accuracy: {:.2f}%\n".format(np.mean(best_accuracies)*100))
```
%% Output
0 Best accuracy: 49.00%
10 Best accuracy: 86.50%
20 Best accuracy: 95.00%
30 Best accuracy: 54.00%
40 Best accuracy: 94.00%
50 Best accuracy: 93.50%
60 Best accuracy: 66.50%
70 Best accuracy: 94.50%
80 Best accuracy: 74.50%
90 Best accuracy: 74.00%
Average best accuracy: 79.58%
0 Average accuracy: 68.95%
10 Average accuracy: 69.95%
20 Average accuracy: 85.00%
30 Average accuracy: 93.32%
40 Average accuracy: 68.00%
50 Average accuracy: 85.36%
60 Average accuracy: 92.36%
70 Average accuracy: 56.95%
80 Average accuracy: 93.59%
90 Average accuracy: 64.91%
Average average accuracy: 78.90%
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment