Skip to content
Snippets Groups Projects
Commit 3c53b3fb authored by Mia Wang's avatar Mia Wang
Browse files

adjusted ucb1_jc_py.py imports after the merge

parent 61517705
No related branches found
No related tags found
No related merge requests found
......@@ -20,8 +20,8 @@ from matplotlib import pyplot as plt
from numpy import save, load
from tqdm import trange
from MetaAugment.child_networks import *
from MetaAugment.main import create_toy, train_child_network
from .child_networks import *
from .main import create_toy, train_child_network
# In[6]:
......@@ -102,7 +102,7 @@ def sample_sub_policy(policies, policy, num_sub_policies):
"""Sample policy, open and apply above transformations"""
def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, early_stop_flag, average_validation, iterations, IsLeNet, ds_name=None):
def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, early_stop_flag, average_validation, iterations, IsLeNet):
# get number of policies and sub-policies
num_policies = len(policies)
......@@ -130,40 +130,32 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl
# create transformations using above info
transform = torchvision.transforms.Compose(
[torchvision.transforms.RandomAffine(degrees=(degrees,degrees), shear=(shear,shear), scale=(scale,scale)),
torchvision.transforms.CenterCrop(28), # <--- need to remove after finishing testing
torchvision.transforms.ToTensor()])
# open data and apply these transformations
if ds == "MNIST":
train_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/train', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/test', train=False, download=True, transform=transform)
train_dataset = datasets.MNIST(root='./MetaAugment/train', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./MetaAugment/test', train=False, download=True, transform=transform)
elif ds == "KMNIST":
train_dataset = datasets.KMNIST(root='./MetaAugment/datasets/kmnist/train', train=True, download=True, transform=transform)
test_dataset = datasets.KMNIST(root='./MetaAugment/datasets/kmnist/test', train=False, download=True, transform=transform)
train_dataset = datasets.KMNIST(root='./MetaAugment/train', train=True, download=True, transform=transform)
test_dataset = datasets.KMNIST(root='./MetaAugment/test', train=False, download=True, transform=transform)
elif ds == "FashionMNIST":
train_dataset = datasets.FashionMNIST(root='./MetaAugment/datasets/fashionmnist/train', train=True, download=True, transform=transform)
test_dataset = datasets.FashionMNIST(root='./MetaAugment/datasets/fashionmnist/test', train=False, download=True, transform=transform)
train_dataset = datasets.FashionMNIST(root='./MetaAugment/train', train=True, download=True, transform=transform)
test_dataset = datasets.FashionMNIST(root='./MetaAugment/test', train=False, download=True, transform=transform)
elif ds == "CIFAR10":
train_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/cifar10/train', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/cifar10/test', train=False, download=True, transform=transform)
train_dataset = datasets.CIFAR10(root='./MetaAugment/train', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./MetaAugment/test', train=False, download=True, transform=transform)
elif ds == "CIFAR100":
train_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/cifar100/train', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/cifar100/test', train=False, download=True, transform=transform)
elif ds == 'Other':
dataset = datasets.ImageFolder('./MetaAugment/datasets/upload_dataset/'+ ds_name, transform=transform)
len_train = int(0.8*len(dataset))
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len_train, len(dataset)-len_train])
train_dataset = datasets.CIFAR100(root='./MetaAugment/train', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR100(root='./MetaAugment/test', train=False, download=True, transform=transform)
# check sizes of images
img_height = len(train_dataset[0][0][0])
img_width = len(train_dataset[0][0][0][0])
img_channels = len(train_dataset[0][0])
# check output labels
if ds == 'Other':
num_labels = len(dataset.class_to_idx)
elif ds == "CIFAR10" or ds == "CIFAR100":
if ds == "CIFAR10" or ds == "CIFAR100":
num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1)
else:
num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1).item()
......@@ -172,22 +164,70 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl
train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, toy_size)
# create model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if IsLeNet == "LeNet":
model = LeNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device)
elif IsLeNet == "EasyNet":
model = EasyNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device)
elif IsLeNet == 'SimpleNet':
model = SimpleNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device)
else:
model = pickle.load(open(f'datasets/childnetwork', "rb"))
model = SimpleNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device)
sgd = optim.SGD(model.parameters(), lr=1e-1)
cost = nn.CrossEntropyLoss()
best_acc = train_child_network(model, train_loader, test_loader, sgd,
cost, max_epochs, early_stop_num, early_stop_flag,
average_validation, logging=False, print_every_epoch=False)
# set variables for best validation accuracy and early stop count
best_acc = 0
early_stop_cnt = 0
total_val = 0
# train model and check validation accuracy each epoch
for _epoch in range(max_epochs):
# train model
model.train()
for idx, (train_x, train_label) in enumerate(train_loader):
train_x, train_label = train_x.to(device), train_label.to(device) # new code
label_np = np.zeros((train_label.shape[0], num_labels))
sgd.zero_grad()
predict_y = model(train_x.float())
loss = cost(predict_y, train_label.long())
loss.backward()
sgd.step()
# check validation accuracy on validation set
correct = 0
_sum = 0
model.eval()
for idx, (test_x, test_label) in enumerate(test_loader):
test_x, test_label = test_x.to(device), test_label.to(device) # new code
predict_y = model(test_x.float()).detach()
#predict_ys = np.argmax(predict_y, axis=-1)
predict_ys = torch.argmax(predict_y, axis=-1) # changed np to torch
#label_np = test_label.numpy()
_ = predict_ys == test_label
#correct += np.sum(_.numpy(), axis=-1)
correct += np.sum(_.cpu().numpy(), axis=-1) # added .cpu()
_sum += _.shape[0]
acc = correct / _sum
if average_validation[0] <= _epoch <= average_validation[1]:
total_val += acc
# update best validation accuracy if it was higher, otherwise increase early stop count
if acc > best_acc :
best_acc = acc
early_stop_cnt = 0
else:
early_stop_cnt += 1
# exit if validation gets worse over 10 runs and using early stopping
if early_stop_cnt >= early_stop_num and early_stop_flag:
break
# exit if using fixed epoch length
if _epoch >= average_validation[1] and not early_stop_flag:
best_acc = total_val / (average_validation[1] - average_validation[0] + 1)
break
# update q_values
if policy < num_policies:
......@@ -198,7 +238,7 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl
best_q_value = max(q_values)
best_q_values.append(best_q_value)
if (policy+1) % 5 == 0:
if (policy+1) % 10 == 0:
print("Iteration: {},\tQ-Values: {}, Best Policy: {}".format(policy+1, list(np.around(np.array(q_values),2)), max(list(np.around(np.array(q_values),2)))))
# update counts
......@@ -210,7 +250,6 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl
for i in range(num_policies):
q_plus_cnt[i] = q_values[i] + np.sqrt(2*np.log(total_count)/cnts[i])
# yield q_values, best_q_values
return q_values, best_q_values
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment