In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data_utils
import torchvision
import torchvision.datasets as datasets
from tqdm import trange

In [None]:
"""Define internal NN module that trains on the dataset"""
class LeNet(nn.Module):
    def __init__(self, img_height, img_width, num_labels, img_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(img_channels, 6, 5)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(int((((img_height-4)/2-4)/2)*(((img_width-4)/2-4)/2)*16), 120)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(120, 84)
        self.relu4 = nn.ReLU()
        self.fc3 = nn.Linear(84, num_labels)
        self.relu5 = nn.ReLU()

    def forward(self, x):
        y = self.conv1(x)
        y = self.relu1(y)
        y = self.pool1(y)
        y = self.conv2(y)
        y = self.relu2(y)
        y = self.pool2(y)
        y = y.view(y.shape[0], -1)
        y = self.fc1(y)
        y = self.relu3(y)
        y = self.fc2(y)
        y = self.relu4(y)
        y = self.fc3(y)
        y = self.relu5(y)
        return y

In [None]:
"""Define internal NN module that trains on the dataset"""
class EasyNet(nn.Module):
    def __init__(self, img_height, img_width, num_labels, img_channels):
        super().__init__()
        self.fc1 = nn.Linear(img_height*img_width*img_channels, 2048)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(2048, num_labels)
        self.relu2 = nn.ReLU()

    def forward(self, x):
        y = x.view(x.shape[0], -1)
        y = self.fc1(y)
        y = self.relu1(y)
        y = self.fc2(y)
        y = self.relu2(y)
        return y

In [None]:
"""Define internal NN module that trains on the dataset"""
class SimpleNet(nn.Module):
    def __init__(self, img_height, img_width, num_labels, img_channels):
        super().__init__()
        self.fc1 = nn.Linear(img_height*img_width*img_channels, num_labels)
        self.relu1 = nn.ReLU()

    def forward(self, x):
        y = x.view(x.shape[0], -1)
        y = self.fc1(y)
        y = self.relu1(y)
        return y

In [None]:
"""Make toy dataset"""

def create_toy(train_dataset, test_dataset, batch_size, n_samples):
    
    # shuffle and take first n_samples %age of training dataset
    shuffle_order_train = np.random.RandomState(seed=100).permutation(len(train_dataset))
    shuffled_train_dataset = torch.utils.data.Subset(train_dataset, shuffle_order_train)
    indices_train = torch.arange(int(n_samples*len(train_dataset)))
    reduced_train_dataset = data_utils.Subset(shuffled_train_dataset, indices_train)

    # shuffle and take first n_samples %age of test dataset
    shuffle_order_test = np.random.RandomState(seed=1000).permutation(len(test_dataset))
    shuffled_test_dataset = torch.utils.data.Subset(test_dataset, shuffle_order_test)
    indices_test = torch.arange(int(n_samples*len(test_dataset)))
    reduced_test_dataset = data_utils.Subset(shuffled_test_dataset, indices_test)

    # push into DataLoader
    train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=batch_size)
    test_loader = torch.utils.data.DataLoader(reduced_test_dataset, batch_size=batch_size)

    return train_loader, test_loader

In [None]:
def run_baseline(batch_size=32, learning_rate=1e-1, ds="MNIST", toy_size=0.02, max_epochs=100, early_stop_num=10, early_stop_flag=True, average_validation=[15,25], IsLeNet="LeNet"):

    # create transformations using above info
    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor()])

    # open data and apply these transformations
    if ds == "MNIST":
        train_dataset = datasets.MNIST(root='./MetaAugment/train', train=True, download=True, transform=transform)
        test_dataset = datasets.MNIST(root='./MetaAugment/test', train=False, download=True, transform=transform)
    elif ds == "KMNIST":
        train_dataset = datasets.KMNIST(root='./MetaAugment/train', train=True, download=True, transform=transform)
        test_dataset = datasets.KMNIST(root='./MetaAugment/test', train=False, download=True, transform=transform)
    elif ds == "FashionMNIST":
        train_dataset = datasets.FashionMNIST(root='./MetaAugment/train', train=True, download=True, transform=transform)
        test_dataset = datasets.FashionMNIST(root='./MetaAugment/test', train=False, download=True, transform=transform)
    elif ds == "CIFAR10":
        train_dataset = datasets.CIFAR10(root='./MetaAugment/train', train=True, download=True, transform=transform)
        test_dataset = datasets.CIFAR10(root='./MetaAugment/test', train=False, download=True, transform=transform)
    elif ds == "CIFAR100":
        train_dataset = datasets.CIFAR100(root='./MetaAugment/train', train=True, download=True, transform=transform)
        test_dataset = datasets.CIFAR100(root='./MetaAugment/test', train=False, download=True, transform=transform)

    # check sizes of images
    img_height = len(train_dataset[0][0][0])
    img_width = len(train_dataset[0][0][0][0])
    img_channels = len(train_dataset[0][0])

    # check output labels
    if ds == "CIFAR10" or ds == "CIFAR100":
        num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1)
    else:
        num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1).item()

    # create toy dataset from above uploaded data
    train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, toy_size)

    # create model
    if IsLeNet == "LeNet":
        model = LeNet(img_height, img_width, num_labels, img_channels)
    elif IsLeNet == "EasyNet":
        model = EasyNet(img_height, img_width, num_labels, img_channels)
    else:
        model = SimpleNet(img_height, img_width, num_labels, img_channels)
    sgd = optim.SGD(model.parameters(), lr=learning_rate)
    cost = nn.CrossEntropyLoss()

    # set variables for best validation accuracy and early stop count
    best_acc = 0
    early_stop_cnt = 0
    total_val = 0

    # train model and check validation accuracy each epoch
    for _epoch in range(max_epochs):

        # train model
        model.train()
        for idx, (train_x, train_label) in enumerate(train_loader):
            label_np = np.zeros((train_label.shape[0], num_labels))
            sgd.zero_grad()
            predict_y = model(train_x.float())
            loss = cost(predict_y, train_label.long())
            loss.backward()
            sgd.step()

        # check validation accuracy on validation set
        correct = 0
        _sum = 0
        model.eval()
        for idx, (test_x, test_label) in enumerate(test_loader):
            predict_y = model(test_x.float()).detach()
            predict_ys = np.argmax(predict_y, axis=-1)
            label_np = test_label.numpy()
            _ = predict_ys == test_label
            correct += np.sum(_.numpy(), axis=-1)
            _sum += _.shape[0]

        acc = correct / _sum

        # update the total validation
        if average_validation[0] <= _epoch <= average_validation[1]:
            total_val += acc

        # update best validation accuracy if it was higher, otherwise increase early stop count
        if acc > best_acc:
            best_acc = acc
            early_stop_cnt = 0
        else:
            early_stop_cnt += 1

        # exit if validation gets worse over 10 runs and using early stopping
        if early_stop_cnt >= early_stop_num and early_stop_flag:
            return best_acc

        # exit if using fixed epoch length
        if _epoch >= average_validation[1] and not early_stop_flag:
            return total_val / (average_validation[1] - average_validation[0] + 1)

In [None]:
batch_size = 32               # size of batch the inner NN is trained with
learning_rate = 1e-1          # fix learning rate
ds = "CIFAR100"               # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10,...)
toy_size = 0.02               # total propeortion of training and test set we use
max_epochs = 100              # max number of epochs that is run if early stopping is not hit
early_stop_num = 10           # max number of worse validation scores before early stopping is triggered
early_stop_flag = True        # implement early stopping or not
average_validation = [15,25]  # if not implementing early stopping, what epochs are we averaging over
num_iterations = 5            # how many iterations are we averaging over
IsLeNet = "LeNet"             # using LeNet or EasyNet or SimpleNet

# run using early stopping
best_accuracies = []
for baselines in trange(num_iterations):
    best_acc = run_baseline(batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, early_stop_flag, average_validation, IsLeNet)
    best_accuracies.append(best_acc)
    if baselines % 10 == 0:
        print("{}\tBest accuracy: {:.2f}%".format(baselines, best_acc*100))
print("Average best accuracy: {:.2f}%\n".format(np.mean(best_accuracies)*100))

# run using average validation losses
early_stop_flag = False
best_accuracies = []
for baselines in trange(num_iterations):
    best_acc = run_baseline(batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, early_stop_flag, average_validation, IsLeNet)
    best_accuracies.append(best_acc)
    if baselines % 10 == 0:
        print("{}\tAverage accuracy: {:.2f}%".format(baselines, best_acc*100))
print("Average average accuracy: {:.2f}%\n".format(np.mean(best_accuracies)*100))

  0%|          | 0/5 [00:00<?, ?it/s]

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./MetaAugment/train/cifar-100-python.tar.gz


  0%|          | 0/169001437 [00:00<?, ?it/s]

Extracting ./MetaAugment/train/cifar-100-python.tar.gz to ./MetaAugment/train
Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./MetaAugment/test/cifar-100-python.tar.gz


  0%|          | 0/169001437 [00:00<?, ?it/s]

Extracting ./MetaAugment/test/cifar-100-python.tar.gz to ./MetaAugment/test


 20%|██        | 1/5 [00:19<01:16, 19.24s/it]

0	Best accuracy: 1.00%
Files already downloaded and verified
Files already downloaded and verified


 40%|████      | 2/5 [00:35<00:51, 17.21s/it]

Files already downloaded and verified
Files already downloaded and verified


 60%|██████    | 3/5 [00:43<00:26, 13.33s/it]

Files already downloaded and verified
Files already downloaded and verified


 80%|████████  | 4/5 [01:12<00:19, 19.50s/it]

Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 5/5 [01:18<00:00, 15.67s/it]


Average best accuracy: 4.00%



  0%|          | 0/5 [00:00<?, ?it/s]

Files already downloaded and verified
Files already downloaded and verified


 20%|██        | 1/5 [00:11<00:44, 11.15s/it]

0	Average accuracy: 1.86%
Files already downloaded and verified
Files already downloaded and verified


 40%|████      | 2/5 [00:22<00:33, 11.04s/it]

Files already downloaded and verified
Files already downloaded and verified


 60%|██████    | 3/5 [00:33<00:21, 10.98s/it]

Files already downloaded and verified
Files already downloaded and verified


 80%|████████  | 4/5 [00:44<00:11, 11.05s/it]

Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 5/5 [00:55<00:00, 11.06s/it]

Average average accuracy: 1.97%




