Skip to content
Snippets Groups Projects
Commit d646941f authored by Sun Jin Kim's avatar Sun Jin Kim
Browse files

refactor UCB1_JC_py.py

parent d6f15d90
No related branches found
No related tags found
No related merge requests found
......@@ -20,106 +20,8 @@ from matplotlib import pyplot as plt
from numpy import save, load
from tqdm import trange
# In[2]:
"""Define internal NN module that trains on the dataset"""
class LeNet(nn.Module):
def __init__(self, img_height, img_width, num_labels, img_channels):
super().__init__()
self.conv1 = nn.Conv2d(img_channels, 6, 5)
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.relu2 = nn.ReLU()
self.pool2 = nn.MaxPool2d(2)
self.fc1 = nn.Linear(int((((img_height-4)/2-4)/2)*(((img_width-4)/2-4)/2)*16), 120)
self.relu3 = nn.ReLU()
self.fc2 = nn.Linear(120, 84)
self.relu4 = nn.ReLU()
self.fc3 = nn.Linear(84, num_labels)
self.relu5 = nn.ReLU()
def forward(self, x):
y = self.conv1(x)
y = self.relu1(y)
y = self.pool1(y)
y = self.conv2(y)
y = self.relu2(y)
y = self.pool2(y)
y = y.view(y.shape[0], -1)
y = self.fc1(y)
y = self.relu3(y)
y = self.fc2(y)
y = self.relu4(y)
y = self.fc3(y)
y = self.relu5(y)
return y
# In[3]:
"""Define internal NN module that trains on the dataset"""
class EasyNet(nn.Module):
def __init__(self, img_height, img_width, num_labels, img_channels):
super().__init__()
self.fc1 = nn.Linear(img_height*img_width*img_channels, 2048)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(2048, num_labels)
self.relu2 = nn.ReLU()
def forward(self, x):
y = x.view(x.shape[0], -1)
y = self.fc1(y)
y = self.relu1(y)
y = self.fc2(y)
y = self.relu2(y)
return y
# In[4]:
"""Define internal NN module that trains on the dataset"""
class SimpleNet(nn.Module):
def __init__(self, img_height, img_width, num_labels, img_channels):
super().__init__()
self.fc1 = nn.Linear(img_height*img_width*img_channels, num_labels)
self.relu1 = nn.ReLU()
def forward(self, x):
y = x.view(x.shape[0], -1)
y = self.fc1(y)
y = self.relu1(y)
return y
# In[5]:
"""Make toy dataset"""
def create_toy(train_dataset, test_dataset, batch_size, n_samples):
# shuffle and take first n_samples %age of training dataset
shuffle_order_train = np.random.RandomState(seed=100).permutation(len(train_dataset))
shuffled_train_dataset = torch.utils.data.Subset(train_dataset, shuffle_order_train)
indices_train = torch.arange(int(n_samples*len(train_dataset)))
reduced_train_dataset = data_utils.Subset(shuffled_train_dataset, indices_train)
# shuffle and take first n_samples %age of test dataset
shuffle_order_test = np.random.RandomState(seed=1000).permutation(len(test_dataset))
shuffled_test_dataset = torch.utils.data.Subset(test_dataset, shuffle_order_test)
indices_test = torch.arange(int(n_samples*len(test_dataset)))
reduced_test_dataset = data_utils.Subset(shuffled_test_dataset, indices_test)
# push into DataLoader
train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(reduced_test_dataset, batch_size=batch_size)
return train_loader, test_loader
from MetaAugment.child_networks import *
from MetaAugment.main import create_toy
# In[6]:
......
......@@ -19,7 +19,7 @@ def create_toy(train_dataset, test_dataset, batch_size, n_samples, seed=100):
reduced_train_dataset = torch.utils.data.Subset(shuffled_train_dataset, indices_train)
# shuffle and take first n_samples %age of test dataset
shuffle_order_test = np.random.RandomState(seed=seed).permutation(len(test_dataset))
shuffle_order_test = np.random.RandomState(seed=10*seed).permutation(len(test_dataset))
shuffled_test_dataset = torch.utils.data.Subset(test_dataset, shuffle_order_test)
big = 1 # how much bigger is the test set
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment