diff --git a/MetaAugment/UCB1_JC_py.py b/MetaAugment/UCB1_JC_py.py index f49889521253e052334500180e8f46b0b48f9d40..f829ab709f6446742a8a2771332714b9bf578f48 100644 --- a/MetaAugment/UCB1_JC_py.py +++ b/MetaAugment/UCB1_JC_py.py @@ -20,8 +20,8 @@ from matplotlib import pyplot as plt from numpy import save, load from tqdm import trange -from MetaAugment.child_networks import * -from MetaAugment.main import create_toy, train_child_network +from .child_networks import * +from .main import create_toy, train_child_network # In[6]: @@ -102,7 +102,7 @@ def sample_sub_policy(policies, policy, num_sub_policies): """Sample policy, open and apply above transformations""" -def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, early_stop_flag, average_validation, iterations, IsLeNet, ds_name=None): +def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, early_stop_flag, average_validation, iterations, IsLeNet): # get number of policies and sub-policies num_policies = len(policies) @@ -130,40 +130,32 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl # create transformations using above info transform = torchvision.transforms.Compose( [torchvision.transforms.RandomAffine(degrees=(degrees,degrees), shear=(shear,shear), scale=(scale,scale)), - torchvision.transforms.CenterCrop(28), # <--- need to remove after finishing testing torchvision.transforms.ToTensor()]) # open data and apply these transformations if ds == "MNIST": - train_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/train', train=True, download=True, transform=transform) - test_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/test', train=False, download=True, transform=transform) + train_dataset = datasets.MNIST(root='./MetaAugment/train', train=True, download=True, transform=transform) + test_dataset = datasets.MNIST(root='./MetaAugment/test', train=False, download=True, transform=transform) elif ds == "KMNIST": - train_dataset = datasets.KMNIST(root='./MetaAugment/datasets/kmnist/train', train=True, download=True, transform=transform) - test_dataset = datasets.KMNIST(root='./MetaAugment/datasets/kmnist/test', train=False, download=True, transform=transform) + train_dataset = datasets.KMNIST(root='./MetaAugment/train', train=True, download=True, transform=transform) + test_dataset = datasets.KMNIST(root='./MetaAugment/test', train=False, download=True, transform=transform) elif ds == "FashionMNIST": - train_dataset = datasets.FashionMNIST(root='./MetaAugment/datasets/fashionmnist/train', train=True, download=True, transform=transform) - test_dataset = datasets.FashionMNIST(root='./MetaAugment/datasets/fashionmnist/test', train=False, download=True, transform=transform) + train_dataset = datasets.FashionMNIST(root='./MetaAugment/train', train=True, download=True, transform=transform) + test_dataset = datasets.FashionMNIST(root='./MetaAugment/test', train=False, download=True, transform=transform) elif ds == "CIFAR10": - train_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/cifar10/train', train=True, download=True, transform=transform) - test_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/cifar10/test', train=False, download=True, transform=transform) + train_dataset = datasets.CIFAR10(root='./MetaAugment/train', train=True, download=True, transform=transform) + test_dataset = datasets.CIFAR10(root='./MetaAugment/test', train=False, download=True, transform=transform) elif ds == "CIFAR100": - train_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/cifar100/train', train=True, download=True, transform=transform) - test_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/cifar100/test', train=False, download=True, transform=transform) - elif ds == 'Other': - dataset = datasets.ImageFolder('./MetaAugment/datasets/upload_dataset/'+ ds_name, transform=transform) - len_train = int(0.8*len(dataset)) - train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len_train, len(dataset)-len_train]) + train_dataset = datasets.CIFAR100(root='./MetaAugment/train', train=True, download=True, transform=transform) + test_dataset = datasets.CIFAR100(root='./MetaAugment/test', train=False, download=True, transform=transform) # check sizes of images img_height = len(train_dataset[0][0][0]) img_width = len(train_dataset[0][0][0][0]) img_channels = len(train_dataset[0][0]) - # check output labels - if ds == 'Other': - num_labels = len(dataset.class_to_idx) - elif ds == "CIFAR10" or ds == "CIFAR100": + if ds == "CIFAR10" or ds == "CIFAR100": num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1) else: num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1).item() @@ -172,22 +164,70 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, toy_size) # create model - device = 'cuda' if torch.cuda.is_available() else 'cpu' + device = 'cuda' if torch.cuda.is_available() else 'cpu' if IsLeNet == "LeNet": model = LeNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device) elif IsLeNet == "EasyNet": model = EasyNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device) - elif IsLeNet == 'SimpleNet': - model = SimpleNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device) else: - model = pickle.load(open(f'datasets/childnetwork', "rb")) - + model = SimpleNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device) sgd = optim.SGD(model.parameters(), lr=1e-1) cost = nn.CrossEntropyLoss() - best_acc = train_child_network(model, train_loader, test_loader, sgd, - cost, max_epochs, early_stop_num, early_stop_flag, - average_validation, logging=False, print_every_epoch=False) + # set variables for best validation accuracy and early stop count + best_acc = 0 + early_stop_cnt = 0 + total_val = 0 + + # train model and check validation accuracy each epoch + for _epoch in range(max_epochs): + + # train model + model.train() + for idx, (train_x, train_label) in enumerate(train_loader): + train_x, train_label = train_x.to(device), train_label.to(device) # new code + label_np = np.zeros((train_label.shape[0], num_labels)) + sgd.zero_grad() + predict_y = model(train_x.float()) + loss = cost(predict_y, train_label.long()) + loss.backward() + sgd.step() + + # check validation accuracy on validation set + correct = 0 + _sum = 0 + model.eval() + for idx, (test_x, test_label) in enumerate(test_loader): + test_x, test_label = test_x.to(device), test_label.to(device) # new code + predict_y = model(test_x.float()).detach() + #predict_ys = np.argmax(predict_y, axis=-1) + predict_ys = torch.argmax(predict_y, axis=-1) # changed np to torch + #label_np = test_label.numpy() + _ = predict_ys == test_label + #correct += np.sum(_.numpy(), axis=-1) + correct += np.sum(_.cpu().numpy(), axis=-1) # added .cpu() + _sum += _.shape[0] + + acc = correct / _sum + + if average_validation[0] <= _epoch <= average_validation[1]: + total_val += acc + + # update best validation accuracy if it was higher, otherwise increase early stop count + if acc > best_acc : + best_acc = acc + early_stop_cnt = 0 + else: + early_stop_cnt += 1 + + # exit if validation gets worse over 10 runs and using early stopping + if early_stop_cnt >= early_stop_num and early_stop_flag: + break + + # exit if using fixed epoch length + if _epoch >= average_validation[1] and not early_stop_flag: + best_acc = total_val / (average_validation[1] - average_validation[0] + 1) + break # update q_values if policy < num_policies: @@ -198,7 +238,7 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl best_q_value = max(q_values) best_q_values.append(best_q_value) - if (policy+1) % 5 == 0: + if (policy+1) % 10 == 0: print("Iteration: {},\tQ-Values: {}, Best Policy: {}".format(policy+1, list(np.around(np.array(q_values),2)), max(list(np.around(np.array(q_values),2))))) # update counts @@ -210,7 +250,6 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl for i in range(num_policies): q_plus_cnt[i] = q_values[i] + np.sqrt(2*np.log(total_count)/cnts[i]) - # yield q_values, best_q_values return q_values, best_q_values diff --git a/MetaAugment/__init__.py b/MetaAugment/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391