Skip to content
Snippets Groups Projects
Commit 63a7084c authored by Sun Jin Kim's avatar Sun Jin Kim
Browse files

John: Add EasyNet

parent 3a200057
No related branches found
No related tags found
No related merge requests found
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
import torch.utils.data as data_utils import torch.utils.data as data_utils
import torchvision import torchvision
import torchvision.datasets as datasets import torchvision.datasets as datasets
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
"""Define internal NN module that trains on the dataset""" """Define internal NN module that trains on the dataset"""
class LeNet(nn.Module): class LeNet(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.conv1 = nn.Conv2d(1, 6, 5) self.conv1 = nn.Conv2d(1, 6, 5)
self.relu1 = nn.ReLU() self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(2) self.pool1 = nn.MaxPool2d(2)
self.conv2 = nn.Conv2d(6, 16, 5) self.conv2 = nn.Conv2d(6, 16, 5)
self.relu2 = nn.ReLU() self.relu2 = nn.ReLU()
self.pool2 = nn.MaxPool2d(2) self.pool2 = nn.MaxPool2d(2)
self.fc1 = nn.Linear(256, 120) self.fc1 = nn.Linear(256, 120)
self.relu3 = nn.ReLU() self.relu3 = nn.ReLU()
self.fc2 = nn.Linear(120, 84) self.fc2 = nn.Linear(120, 84)
self.relu4 = nn.ReLU() self.relu4 = nn.ReLU()
self.fc3 = nn.Linear(84, 10) self.fc3 = nn.Linear(84, 10)
self.relu5 = nn.ReLU() self.relu5 = nn.ReLU()
def forward(self, x): def forward(self, x):
y = self.conv1(x) y = self.conv1(x)
y = self.relu1(y) y = self.relu1(y)
y = self.pool1(y) y = self.pool1(y)
y = self.conv2(y) y = self.conv2(y)
y = self.relu2(y) y = self.relu2(y)
y = self.pool2(y) y = self.pool2(y)
y = y.view(y.shape[0], -1) y = y.view(y.shape[0], -1)
y = self.fc1(y) y = self.fc1(y)
y = self.relu3(y) y = self.relu3(y)
y = self.fc2(y) y = self.fc2(y)
y = self.relu4(y) y = self.relu4(y)
y = self.fc3(y) y = self.fc3(y)
y = self.relu5(y) y = self.relu5(y)
return y return y
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
"""Define internal NN module that trains on the dataset"""
class EasyNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 2048)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(2048, 10)
self.relu2 = nn.ReLU()
def forward(self, x):
y = x.view(x.shape[0], -1)
y = self.fc1(y)
y = self.relu1(y)
y = self.fc2(y)
y = self.relu2(y)
return y
```
%% Cell type:code id: tags:
``` python
"""Make toy dataset""" """Make toy dataset"""
def create_toy(train_dataset, test_dataset, batch_size, n_samples): def create_toy(train_dataset, test_dataset, batch_size, n_samples):
# shuffle and take first n_samples %age of training dataset # shuffle and take first n_samples %age of training dataset
shuffle_order_train = np.random.RandomState(seed=100).permutation(len(train_dataset)) shuffle_order_train = np.random.RandomState(seed=100).permutation(len(train_dataset))
shuffled_train_dataset = torch.utils.data.Subset(train_dataset, shuffle_order_train) shuffled_train_dataset = torch.utils.data.Subset(train_dataset, shuffle_order_train)
indices_train = torch.arange(int(n_samples*len(train_dataset))) indices_train = torch.arange(int(n_samples*len(train_dataset)))
reduced_train_dataset = data_utils.Subset(shuffled_train_dataset, indices_train) reduced_train_dataset = data_utils.Subset(shuffled_train_dataset, indices_train)
# shuffle and take first n_samples %age of test dataset # shuffle and take first n_samples %age of test dataset
shuffle_order_test = np.random.RandomState(seed=1000).permutation(len(test_dataset)) shuffle_order_test = np.random.RandomState(seed=1000).permutation(len(test_dataset))
shuffled_test_dataset = torch.utils.data.Subset(test_dataset, shuffle_order_test) shuffled_test_dataset = torch.utils.data.Subset(test_dataset, shuffle_order_test)
indices_test = torch.arange(int(n_samples*len(test_dataset))) indices_test = torch.arange(int(n_samples*len(test_dataset)))
reduced_test_dataset = data_utils.Subset(shuffled_test_dataset, indices_test) reduced_test_dataset = data_utils.Subset(shuffled_test_dataset, indices_test)
# push into DataLoader # push into DataLoader
train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=batch_size) train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(reduced_test_dataset, batch_size=batch_size) test_loader = torch.utils.data.DataLoader(reduced_test_dataset, batch_size=batch_size)
return train_loader, test_loader return train_loader, test_loader
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
def run_baseline(batch_size=32, toy_size=0.02, max_epochs=100, early_stop_num=10, early_stop_flag=True, average_validation=[15,25]): def run_baseline(batch_size=32, toy_size=0.02, max_epochs=100, early_stop_num=10, early_stop_flag=True, average_validation=[15,25], IsLeNet=True):
# create transformations using above info # create transformations using above info
transform = torchvision.transforms.Compose([ transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()]) torchvision.transforms.ToTensor()])
# open data and apply these transformations # open data and apply these transformations
train_dataset = datasets.MNIST(root='./MetaAugment/train', train=True, download=True, transform=transform) train_dataset = datasets.MNIST(root='./MetaAugment/train', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./MetaAugment/test', train=False, download=True, transform=transform) test_dataset = datasets.MNIST(root='./MetaAugment/test', train=False, download=True, transform=transform)
# create toy dataset from above uploaded data # create toy dataset from above uploaded data
train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, toy_size) train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, toy_size)
# create model # create model
model = LeNet() if IsLeNet:
model = LeNet()
else:
model = EasyNet()
sgd = optim.SGD(model.parameters(), lr=1e-1) sgd = optim.SGD(model.parameters(), lr=1e-1)
cost = nn.CrossEntropyLoss() cost = nn.CrossEntropyLoss()
# set variables for best validation accuracy and early stop count # set variables for best validation accuracy and early stop count
best_acc = 0 best_acc = 0
early_stop_cnt = 0 early_stop_cnt = 0
total_val = 0 total_val = 0
# train model and check validation accuracy each epoch # train model and check validation accuracy each epoch
for _epoch in range(max_epochs): for _epoch in range(max_epochs):
# train model # train model
model.train() model.train()
for idx, (train_x, train_label) in enumerate(train_loader): for idx, (train_x, train_label) in enumerate(train_loader):
label_np = np.zeros((train_label.shape[0], 10)) label_np = np.zeros((train_label.shape[0], 10))
sgd.zero_grad() sgd.zero_grad()
predict_y = model(train_x.float()) predict_y = model(train_x.float())
loss = cost(predict_y, train_label.long()) loss = cost(predict_y, train_label.long())
loss.backward() loss.backward()
sgd.step() sgd.step()
# check validation accuracy on validation set # check validation accuracy on validation set
correct = 0 correct = 0
_sum = 0 _sum = 0
model.eval() model.eval()
for idx, (test_x, test_label) in enumerate(test_loader): for idx, (test_x, test_label) in enumerate(test_loader):
predict_y = model(test_x.float()).detach() predict_y = model(test_x.float()).detach()
predict_ys = np.argmax(predict_y, axis=-1) predict_ys = np.argmax(predict_y, axis=-1)
label_np = test_label.numpy() label_np = test_label.numpy()
_ = predict_ys == test_label _ = predict_ys == test_label
correct += np.sum(_.numpy(), axis=-1) correct += np.sum(_.numpy(), axis=-1)
_sum += _.shape[0] _sum += _.shape[0]
acc = correct / _sum acc = correct / _sum
# update the total validation # update the total validation
if average_validation[0] <= _epoch <= average_validation[1]: if average_validation[0] <= _epoch <= average_validation[1]:
total_val += acc total_val += acc
# update best validation accuracy if it was higher, otherwise increase early stop count # update best validation accuracy if it was higher, otherwise increase early stop count
if acc > best_acc: if acc > best_acc:
best_acc = acc best_acc = acc
early_stop_cnt = 0 early_stop_cnt = 0
else: else:
early_stop_cnt += 1 early_stop_cnt += 1
# exit if validation gets worse over 10 runs and using early stopping # exit if validation gets worse over 10 runs and using early stopping
if early_stop_cnt >= early_stop_num and early_stop_flag: if early_stop_cnt >= early_stop_num and early_stop_flag:
return best_acc return best_acc
# exit if using fixed epoch length # exit if using fixed epoch length
if _epoch >= average_validation[1] and not early_stop_flag: if _epoch >= average_validation[1] and not early_stop_flag:
return total_val / (average_validation[1] - average_validation[0] + 1) return total_val / (average_validation[1] - average_validation[0] + 1)
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
batch_size = 32 # size of batch the inner NN is trained with batch_size = 32 # size of batch the inner NN is trained with
toy_size = 0.02 # total propeortion of training and test set we use toy_size = 0.05 # total propeortion of training and test set we use
max_epochs = 100 # max number of epochs that is run if early stopping is not hit max_epochs = 100 # max number of epochs that is run if early stopping is not hit
early_stop_num = 10 # max number of worse validation scores before early stopping is triggered early_stop_num = 10 # max number of worse validation scores before early stopping is triggered
early_stop_flag = True # implement early stopping or not early_stop_flag = True # implement early stopping or not
average_validation = [15,25] # if not implementing early stopping, what epochs are we averaging over average_validation = [15,25] # if not implementing early stopping, what epochs are we averaging over
num_iterations = 100 # how many iterations are we averaging over num_iterations = 100 # how many iterations are we averaging over
IsLeNet = True # using LeNet or EasyNet
# run using early stopping # run using early stopping
best_accuracies = [] best_accuracies = []
for baselines in range(num_iterations): for baselines in range(num_iterations):
best_acc = run_baseline(batch_size, toy_size, max_epochs, early_stop_num, early_stop_flag, average_validation) best_acc = run_baseline(batch_size, toy_size, max_epochs, early_stop_num, early_stop_flag, average_validation, IsLeNet)
best_accuracies.append(best_acc) best_accuracies.append(best_acc)
if baselines % 10 == 0: if baselines % 10 == 0:
print("{}\tBest accuracy: {:.2f}%".format(baselines, best_acc*100)) print("{}\tBest accuracy: {:.2f}%".format(baselines, best_acc*100))
print("Average best accuracy: {:.2f}%\n".format(np.mean(best_accuracies)*100)) print("Average best accuracy: {:.2f}%\n".format(np.mean(best_accuracies)*100))
# run using average validation losses # run using average validation losses
early_stop_flag = False early_stop_flag = False
best_accuracies = [] best_accuracies = []
for baselines in range(num_iterations): for baselines in range(num_iterations):
best_acc = run_baseline(batch_size, toy_size, max_epochs, early_stop_num, early_stop_flag, average_validation) best_acc = run_baseline(batch_size, toy_size, max_epochs, early_stop_num, early_stop_flag, average_validation, IsLeNet)
best_accuracies.append(best_acc) best_accuracies.append(best_acc)
if baselines % 10 == 0: if baselines % 10 == 0:
print("{}\tAverage accuracy: {:.2f}%".format(baselines, best_acc*100)) print("{}\tAverage accuracy: {:.2f}%".format(baselines, best_acc*100))
print("Average average accuracy: {:.2f}%\n".format(np.mean(best_accuracies)*100)) print("Average average accuracy: {:.2f}%\n".format(np.mean(best_accuracies)*100))
``` ```
%% Output %% Output
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz 0 Best accuracy: 95.60%
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MetaAugment/train/MNIST/raw/train-images-idx3-ubyte.gz 10 Best accuracy: 85.40%
20 Best accuracy: 86.40%
9913344it [00:04, 2462502.04it/s] 30 Best accuracy: 95.40%
40 Best accuracy: 97.00%
Extracting ./MetaAugment/train/MNIST/raw/train-images-idx3-ubyte.gz to ./MetaAugment/train/MNIST/raw 50 Best accuracy: 80.40%
60 Best accuracy: 95.60%
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz 70 Best accuracy: 96.40%
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MetaAugment/train/MNIST/raw/train-labels-idx1-ubyte.gz 80 Best accuracy: 86.20%
90 Best accuracy: 95.40%
29696it [00:00, 3785722.37it/s] Average best accuracy: 84.65%
Extracting ./MetaAugment/train/MNIST/raw/train-labels-idx1-ubyte.gz to ./MetaAugment/train/MNIST/raw 0 Average accuracy: 78.45%
10 Average accuracy: 58.02%
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz 20 Average accuracy: 38.60%
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MetaAugment/train/MNIST/raw/t10k-images-idx3-ubyte.gz 30 Average accuracy: 65.15%
40 Average accuracy: 77.22%
1649664it [00:00, 3348476.95it/s] 50 Average accuracy: 79.09%
60 Average accuracy: 95.55%
Extracting ./MetaAugment/train/MNIST/raw/t10k-images-idx3-ubyte.gz to ./MetaAugment/train/MNIST/raw 70 Average accuracy: 86.33%
80 Average accuracy: 85.98%
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz 90 Average accuracy: 78.20%
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MetaAugment/train/MNIST/raw/t10k-labels-idx1-ubyte.gz Average average accuracy: 83.31%
5120it [00:00, 2935726.11it/s]
Extracting ./MetaAugment/train/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MetaAugment/train/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MetaAugment/test/MNIST/raw/train-images-idx3-ubyte.gz
9913344it [00:04, 2338660.11it/s]
Extracting ./MetaAugment/test/MNIST/raw/train-images-idx3-ubyte.gz to ./MetaAugment/test/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MetaAugment/test/MNIST/raw/train-labels-idx1-ubyte.gz
29696it [00:00, 33554432.00it/s]
Extracting ./MetaAugment/test/MNIST/raw/train-labels-idx1-ubyte.gz to ./MetaAugment/test/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MetaAugment/test/MNIST/raw/t10k-images-idx3-ubyte.gz
1649664it [00:00, 2786152.46it/s]
Extracting ./MetaAugment/test/MNIST/raw/t10k-images-idx3-ubyte.gz to ./MetaAugment/test/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MetaAugment/test/MNIST/raw/t10k-labels-idx1-ubyte.gz
5120it [00:00, 4789214.20it/s]
Extracting ./MetaAugment/test/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MetaAugment/test/MNIST/raw
0 Best accuracy: 18.00%
10 Best accuracy: 75.50%
20 Best accuracy: 78.00%
30 Best accuracy: 95.00%
40 Best accuracy: 95.50%
50 Best accuracy: 94.00%
60 Best accuracy: 85.00%
70 Best accuracy: 85.50%
80 Best accuracy: 62.50%
90 Best accuracy: 76.00%
Average best accuracy: 79.86%
0 Average accuracy: 93.50%
10 Average accuracy: 93.45%
20 Average accuracy: 46.95%
30 Average accuracy: 71.41%
40 Average accuracy: 73.68%
50 Average accuracy: 64.50%
60 Average accuracy: 72.50%
70 Average accuracy: 94.36%
80 Average accuracy: 84.77%
90 Average accuracy: 92.14%
Average average accuracy: 80.92%
%% Cell type:code id: tags:
``` python
```
......
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment