Skip to content
Snippets Groups Projects
Commit b5622f7e authored by Ramsay King, Maxim's avatar Ramsay King, Maxim
Browse files

Replace UCB1_JC_py.py

parent 3b1b5371
No related branches found
No related tags found
No related merge requests found
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import numpy as np import numpy as np
from sklearn.covariance import log_likelihood
import torch import torch
torch.manual_seed(0) torch.manual_seed(0)
import torch.nn as nn import torch.nn as nn
...@@ -13,6 +14,7 @@ import torch.optim as optim ...@@ -13,6 +14,7 @@ import torch.optim as optim
import torch.utils.data as data_utils import torch.utils.data as data_utils
import torchvision import torchvision
import torchvision.datasets as datasets import torchvision.datasets as datasets
import pickle
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from numpy import save, load from numpy import save, load
...@@ -198,7 +200,7 @@ def sample_sub_policy(policies, policy, num_sub_policies): ...@@ -198,7 +200,7 @@ def sample_sub_policy(policies, policy, num_sub_policies):
"""Sample policy, open and apply above transformations""" """Sample policy, open and apply above transformations"""
def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet): def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet, ds_name=None):
# get number of policies and sub-policies # get number of policies and sub-policies
num_policies = len(policies) num_policies = len(policies)
...@@ -226,6 +228,7 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl ...@@ -226,6 +228,7 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl
# create transformations using above info # create transformations using above info
transform = torchvision.transforms.Compose( transform = torchvision.transforms.Compose(
[torchvision.transforms.RandomAffine(degrees=(degrees,degrees), shear=(shear,shear), scale=(scale,scale)), [torchvision.transforms.RandomAffine(degrees=(degrees,degrees), shear=(shear,shear), scale=(scale,scale)),
torchvision.transforms.CenterCrop(28),
torchvision.transforms.ToTensor()]) torchvision.transforms.ToTensor()])
# open data and apply these transformations # open data and apply these transformations
...@@ -244,14 +247,24 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl ...@@ -244,14 +247,24 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl
elif ds == "CIFAR100": elif ds == "CIFAR100":
train_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/fashionmnist/train', train=True, download=True, transform=transform) train_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/fashionmnist/train', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/fashionmnist/test', train=False, download=True, transform=transform) test_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/fashionmnist/test', train=False, download=True, transform=transform)
elif ds == 'Other':
dataset = datasets.ImageFolder('./MetaAugment/datasets/'+ ds_name, transform=transform)
len_train = int(0.8*len(dataset))
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len_train, len(dataset)-len_train])
print('train_dataset', len(train_dataset), 'test_dataset', len(test_dataset))
# check sizes of images # check sizes of images
img_height = len(train_dataset[0][0][0]) img_height = len(train_dataset[0][0][0])
img_width = len(train_dataset[0][0][0][0]) img_width = len(train_dataset[0][0][0][0])
img_channels = len(train_dataset[0][0]) img_channels = len(train_dataset[0][0])
# check output labels # check output labels
if ds == "CIFAR10" or ds == "CIFAR100": if ds == 'Other':
num_labels = len(dataset.class_to_idx)
elif ds == "CIFAR10" or ds == "CIFAR100":
num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1) num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1)
else: else:
num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1).item() num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1).item()
...@@ -264,8 +277,11 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl ...@@ -264,8 +277,11 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl
model = LeNet(img_height, img_width, num_labels, img_channels) model = LeNet(img_height, img_width, num_labels, img_channels)
elif IsLeNet == "EasyNet": elif IsLeNet == "EasyNet":
model = EasyNet(img_height, img_width, num_labels, img_channels) model = EasyNet(img_height, img_width, num_labels, img_channels)
else: elif IsLeNet == 'SimpleNet':
model = SimpleNet(img_height, img_width, num_labels, img_channels) model = SimpleNet(img_height, img_width, num_labels, img_channels)
else:
model = pickle.load(open(f'datasets/childnetwork', "rb"))
sgd = optim.SGD(model.parameters(), lr=1e-1) sgd = optim.SGD(model.parameters(), lr=1e-1)
cost = nn.CrossEntropyLoss() cost = nn.CrossEntropyLoss()
...@@ -319,7 +335,7 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl ...@@ -319,7 +335,7 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl
best_q_value = max(q_values) best_q_value = max(q_values)
best_q_values.append(best_q_value) best_q_values.append(best_q_value)
if (policy+1) % 10 == 0: if (policy+1) % 1 == 0:
print("Iteration: {},\tQ-Values: {}, Best Policy: {}".format(policy+1, list(np.around(np.array(q_values),2)), max(list(np.around(np.array(q_values),2))))) print("Iteration: {},\tQ-Values: {}, Best Policy: {}".format(policy+1, list(np.around(np.array(q_values),2)), max(list(np.around(np.array(q_values),2)))))
# update counts # update counts
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment