diff --git a/MetaAugment/UCB1_JC_py.py b/MetaAugment/UCB1_JC_py.py index 1986368aff7f5d42e966f61e0cf17424d0f2fb7e..66bacb32e8896f486422bd4661c4c1f5699589ce 100644 --- a/MetaAugment/UCB1_JC_py.py +++ b/MetaAugment/UCB1_JC_py.py @@ -5,6 +5,7 @@ import numpy as np +from sklearn.covariance import log_likelihood import torch torch.manual_seed(0) import torch.nn as nn @@ -13,6 +14,7 @@ import torch.optim as optim import torch.utils.data as data_utils import torchvision import torchvision.datasets as datasets +import pickle from matplotlib import pyplot as plt from numpy import save, load @@ -198,7 +200,7 @@ def sample_sub_policy(policies, policy, num_sub_policies): """Sample policy, open and apply above transformations""" -def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet): +def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet, ds_name=None): # get number of policies and sub-policies num_policies = len(policies) @@ -226,6 +228,7 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl # create transformations using above info transform = torchvision.transforms.Compose( [torchvision.transforms.RandomAffine(degrees=(degrees,degrees), shear=(shear,shear), scale=(scale,scale)), + torchvision.transforms.CenterCrop(28), torchvision.transforms.ToTensor()]) # open data and apply these transformations @@ -244,14 +247,24 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl elif ds == "CIFAR100": train_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/fashionmnist/train', train=True, download=True, transform=transform) test_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/fashionmnist/test', train=False, download=True, transform=transform) + elif ds == 'Other': + dataset = datasets.ImageFolder('./MetaAugment/datasets/'+ ds_name, transform=transform) + len_train = int(0.8*len(dataset)) + train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len_train, len(dataset)-len_train]) + + print('train_dataset', len(train_dataset), 'test_dataset', len(test_dataset)) + # check sizes of images img_height = len(train_dataset[0][0][0]) img_width = len(train_dataset[0][0][0][0]) img_channels = len(train_dataset[0][0]) + # check output labels - if ds == "CIFAR10" or ds == "CIFAR100": + if ds == 'Other': + num_labels = len(dataset.class_to_idx) + elif ds == "CIFAR10" or ds == "CIFAR100": num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1) else: num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1).item() @@ -264,8 +277,11 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl model = LeNet(img_height, img_width, num_labels, img_channels) elif IsLeNet == "EasyNet": model = EasyNet(img_height, img_width, num_labels, img_channels) - else: + elif IsLeNet == 'SimpleNet': model = SimpleNet(img_height, img_width, num_labels, img_channels) + else: + model = pickle.load(open(f'datasets/childnetwork', "rb")) + sgd = optim.SGD(model.parameters(), lr=1e-1) cost = nn.CrossEntropyLoss() @@ -319,7 +335,7 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl best_q_value = max(q_values) best_q_values.append(best_q_value) - if (policy+1) % 10 == 0: + if (policy+1) % 1 == 0: print("Iteration: {},\tQ-Values: {}, Best Policy: {}".format(policy+1, list(np.around(np.array(q_values),2)), max(list(np.around(np.array(q_values),2))))) # update counts