From d646941f33f933a623b4b842e24bf191b23b6190 Mon Sep 17 00:00:00 2001 From: Sun Jin Kim <sk2521@ic.ac.uk> Date: Sat, 16 Apr 2022 15:17:30 +0100 Subject: [PATCH] refactor UCB1_JC_py.py --- MetaAugment/UCB1_JC_py.py | 102 +------------------------------------- MetaAugment/main.py | 2 +- 2 files changed, 3 insertions(+), 101 deletions(-) diff --git a/MetaAugment/UCB1_JC_py.py b/MetaAugment/UCB1_JC_py.py index 252a2551..9c16382a 100644 --- a/MetaAugment/UCB1_JC_py.py +++ b/MetaAugment/UCB1_JC_py.py @@ -20,106 +20,8 @@ from matplotlib import pyplot as plt from numpy import save, load from tqdm import trange - -# In[2]: - - -"""Define internal NN module that trains on the dataset""" -class LeNet(nn.Module): - def __init__(self, img_height, img_width, num_labels, img_channels): - super().__init__() - self.conv1 = nn.Conv2d(img_channels, 6, 5) - self.relu1 = nn.ReLU() - self.pool1 = nn.MaxPool2d(2) - self.conv2 = nn.Conv2d(6, 16, 5) - self.relu2 = nn.ReLU() - self.pool2 = nn.MaxPool2d(2) - self.fc1 = nn.Linear(int((((img_height-4)/2-4)/2)*(((img_width-4)/2-4)/2)*16), 120) - self.relu3 = nn.ReLU() - self.fc2 = nn.Linear(120, 84) - self.relu4 = nn.ReLU() - self.fc3 = nn.Linear(84, num_labels) - self.relu5 = nn.ReLU() - - def forward(self, x): - y = self.conv1(x) - y = self.relu1(y) - y = self.pool1(y) - y = self.conv2(y) - y = self.relu2(y) - y = self.pool2(y) - y = y.view(y.shape[0], -1) - y = self.fc1(y) - y = self.relu3(y) - y = self.fc2(y) - y = self.relu4(y) - y = self.fc3(y) - y = self.relu5(y) - return y - - -# In[3]: - - -"""Define internal NN module that trains on the dataset""" -class EasyNet(nn.Module): - def __init__(self, img_height, img_width, num_labels, img_channels): - super().__init__() - self.fc1 = nn.Linear(img_height*img_width*img_channels, 2048) - self.relu1 = nn.ReLU() - self.fc2 = nn.Linear(2048, num_labels) - self.relu2 = nn.ReLU() - - def forward(self, x): - y = x.view(x.shape[0], -1) - y = self.fc1(y) - y = self.relu1(y) - y = self.fc2(y) - y = self.relu2(y) - return y - - -# In[4]: - - -"""Define internal NN module that trains on the dataset""" -class SimpleNet(nn.Module): - def __init__(self, img_height, img_width, num_labels, img_channels): - super().__init__() - self.fc1 = nn.Linear(img_height*img_width*img_channels, num_labels) - self.relu1 = nn.ReLU() - - def forward(self, x): - y = x.view(x.shape[0], -1) - y = self.fc1(y) - y = self.relu1(y) - return y - - -# In[5]: - - -"""Make toy dataset""" - -def create_toy(train_dataset, test_dataset, batch_size, n_samples): - - # shuffle and take first n_samples %age of training dataset - shuffle_order_train = np.random.RandomState(seed=100).permutation(len(train_dataset)) - shuffled_train_dataset = torch.utils.data.Subset(train_dataset, shuffle_order_train) - indices_train = torch.arange(int(n_samples*len(train_dataset))) - reduced_train_dataset = data_utils.Subset(shuffled_train_dataset, indices_train) - - # shuffle and take first n_samples %age of test dataset - shuffle_order_test = np.random.RandomState(seed=1000).permutation(len(test_dataset)) - shuffled_test_dataset = torch.utils.data.Subset(test_dataset, shuffle_order_test) - indices_test = torch.arange(int(n_samples*len(test_dataset))) - reduced_test_dataset = data_utils.Subset(shuffled_test_dataset, indices_test) - - # push into DataLoader - train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=batch_size) - test_loader = torch.utils.data.DataLoader(reduced_test_dataset, batch_size=batch_size) - - return train_loader, test_loader +from MetaAugment.child_networks import * +from MetaAugment.main import create_toy # In[6]: diff --git a/MetaAugment/main.py b/MetaAugment/main.py index 0c5cdeae..9a43c222 100644 --- a/MetaAugment/main.py +++ b/MetaAugment/main.py @@ -19,7 +19,7 @@ def create_toy(train_dataset, test_dataset, batch_size, n_samples, seed=100): reduced_train_dataset = torch.utils.data.Subset(shuffled_train_dataset, indices_train) # shuffle and take first n_samples %age of test dataset - shuffle_order_test = np.random.RandomState(seed=seed).permutation(len(test_dataset)) + shuffle_order_test = np.random.RandomState(seed=10*seed).permutation(len(test_dataset)) shuffled_test_dataset = torch.utils.data.Subset(test_dataset, shuffle_order_test) big = 1 # how much bigger is the test set -- GitLab