diff --git a/MetaAugment/UCB1_JC_py.py b/MetaAugment/UCB1_JC_py.py index f829ab709f6446742a8a2771332714b9bf578f48..959d55230a8ad0bfa7c4973c8b0857d93c67213f 100644 --- a/MetaAugment/UCB1_JC_py.py +++ b/MetaAugment/UCB1_JC_py.py @@ -102,7 +102,7 @@ def sample_sub_policy(policies, policy, num_sub_policies): """Sample policy, open and apply above transformations""" -def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, early_stop_flag, average_validation, iterations, IsLeNet): +def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, early_stop_flag, average_validation, iterations, IsLeNet, ds_name=None): # get number of policies and sub-policies num_policies = len(policies) @@ -130,32 +130,40 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl # create transformations using above info transform = torchvision.transforms.Compose( [torchvision.transforms.RandomAffine(degrees=(degrees,degrees), shear=(shear,shear), scale=(scale,scale)), + torchvision.transforms.CenterCrop(28), # <--- need to remove after finishing testing torchvision.transforms.ToTensor()]) # open data and apply these transformations if ds == "MNIST": - train_dataset = datasets.MNIST(root='./MetaAugment/train', train=True, download=True, transform=transform) - test_dataset = datasets.MNIST(root='./MetaAugment/test', train=False, download=True, transform=transform) + train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=True, transform=transform) + test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=True, transform=transform) elif ds == "KMNIST": - train_dataset = datasets.KMNIST(root='./MetaAugment/train', train=True, download=True, transform=transform) - test_dataset = datasets.KMNIST(root='./MetaAugment/test', train=False, download=True, transform=transform) + train_dataset = datasets.KMNIST(root='./datasets/kmnist/train', train=True, download=True, transform=transform) + test_dataset = datasets.KMNIST(root='./datasets/kmnist/test', train=False, download=True, transform=transform) elif ds == "FashionMNIST": - train_dataset = datasets.FashionMNIST(root='./MetaAugment/train', train=True, download=True, transform=transform) - test_dataset = datasets.FashionMNIST(root='./MetaAugment/test', train=False, download=True, transform=transform) + train_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/train', train=True, download=True, transform=transform) + test_dataset = datasets.FashionMNIST(root='./datasets/fashionmnist/test', train=False, download=True, transform=transform) elif ds == "CIFAR10": - train_dataset = datasets.CIFAR10(root='./MetaAugment/train', train=True, download=True, transform=transform) - test_dataset = datasets.CIFAR10(root='./MetaAugment/test', train=False, download=True, transform=transform) + train_dataset = datasets.CIFAR10(root='./datasets/cifar10/train', train=True, download=True, transform=transform) + test_dataset = datasets.CIFAR10(root='./datasets/cifar10/test', train=False, download=True, transform=transform) elif ds == "CIFAR100": - train_dataset = datasets.CIFAR100(root='./MetaAugment/train', train=True, download=True, transform=transform) - test_dataset = datasets.CIFAR100(root='./MetaAugment/test', train=False, download=True, transform=transform) + train_dataset = datasets.CIFAR100(root='./datasets/cifar100/train', train=True, download=True, transform=transform) + test_dataset = datasets.CIFAR100(root='./datasets/cifar100/test', train=False, download=True, transform=transform) + elif ds == 'Other': + dataset = datasets.ImageFolder('./datasets/upload_dataset/'+ ds_name, transform=transform) + len_train = int(0.8*len(dataset)) + train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len_train, len(dataset)-len_train]) # check sizes of images img_height = len(train_dataset[0][0][0]) img_width = len(train_dataset[0][0][0][0]) img_channels = len(train_dataset[0][0]) + # check output labels - if ds == "CIFAR10" or ds == "CIFAR100": + if ds == 'Other': + num_labels = len(dataset.class_to_idx) + elif ds == "CIFAR10" or ds == "CIFAR100": num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1) else: num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1).item() @@ -164,70 +172,26 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, toy_size) # create model - device = 'cuda' if torch.cuda.is_available() else 'cpu' + if torch.cuda.is_available(): + device='cuda' + else: + device='cpu' + if IsLeNet == "LeNet": model = LeNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device) elif IsLeNet == "EasyNet": model = EasyNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device) - else: + elif IsLeNet == 'SimpleNet': model = SimpleNet(img_height, img_width, num_labels, img_channels).to(device) # added .to(device) - sgd = optim.SGD(model.parameters(), lr=1e-1) + else: + model = pickle.load(open(f'datasets/childnetwork', "rb")) + + sgd = optim.SGD(model.parameters(), lr=learning_rate) cost = nn.CrossEntropyLoss() - # set variables for best validation accuracy and early stop count - best_acc = 0 - early_stop_cnt = 0 - total_val = 0 - - # train model and check validation accuracy each epoch - for _epoch in range(max_epochs): - - # train model - model.train() - for idx, (train_x, train_label) in enumerate(train_loader): - train_x, train_label = train_x.to(device), train_label.to(device) # new code - label_np = np.zeros((train_label.shape[0], num_labels)) - sgd.zero_grad() - predict_y = model(train_x.float()) - loss = cost(predict_y, train_label.long()) - loss.backward() - sgd.step() - - # check validation accuracy on validation set - correct = 0 - _sum = 0 - model.eval() - for idx, (test_x, test_label) in enumerate(test_loader): - test_x, test_label = test_x.to(device), test_label.to(device) # new code - predict_y = model(test_x.float()).detach() - #predict_ys = np.argmax(predict_y, axis=-1) - predict_ys = torch.argmax(predict_y, axis=-1) # changed np to torch - #label_np = test_label.numpy() - _ = predict_ys == test_label - #correct += np.sum(_.numpy(), axis=-1) - correct += np.sum(_.cpu().numpy(), axis=-1) # added .cpu() - _sum += _.shape[0] - - acc = correct / _sum - - if average_validation[0] <= _epoch <= average_validation[1]: - total_val += acc - - # update best validation accuracy if it was higher, otherwise increase early stop count - if acc > best_acc : - best_acc = acc - early_stop_cnt = 0 - else: - early_stop_cnt += 1 - - # exit if validation gets worse over 10 runs and using early stopping - if early_stop_cnt >= early_stop_num and early_stop_flag: - break - - # exit if using fixed epoch length - if _epoch >= average_validation[1] and not early_stop_flag: - best_acc = total_val / (average_validation[1] - average_validation[0] + 1) - break + best_acc = train_child_network(model, train_loader, test_loader, sgd, + cost, max_epochs, early_stop_num, early_stop_flag, + average_validation, logging=False, print_every_epoch=False) # update q_values if policy < num_policies: @@ -238,7 +202,7 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl best_q_value = max(q_values) best_q_values.append(best_q_value) - if (policy+1) % 10 == 0: + if (policy+1) % 5 == 0: print("Iteration: {},\tQ-Values: {}, Best Policy: {}".format(policy+1, list(np.around(np.array(q_values),2)), max(list(np.around(np.array(q_values),2))))) # update counts @@ -250,6 +214,7 @@ def run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, earl for i in range(num_policies): q_plus_cnt[i] = q_values[i] + np.sqrt(2*np.log(total_count)/cnts[i]) + # yield q_values, best_q_values return q_values, best_q_values