Newer
Older
"""
CONTAINS THE FUNTIONS THAT THE WEBAPP CAN USE TO INTERACT WITH
THE LIBRARY
"""
import numpy as np
import torch
import torchvision
import torchvision.datasets as datasets
# # import agents and its functions
import MetaAugment.autoaugment_learners as aal
import MetaAugment.controller_networks as cont_n
import MetaAugment.child_networks as cn
ds,
ds_name,
exclude_method,
num_funcs,
num_policies,
num_sub_policies,
toy_size,
IsLeNet,
batch_size,
early_stop_num,
iterations,
learning_rate,
max_epochs
):
"""
The website receives user inputs on what they want the aa_learner
to be. We take those hyperparameters and return an aa_learner
"""
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
if auto_aug_learner == 'UCB':
policies = aal.ucb_learner.generate_policies(num_policies, num_sub_policies)
q_values, best_q_values = aal.ucb_learner.run_UCB1(
policies,
batch_size,
learning_rate,
ds,
toy_size,
max_epochs,
early_stop_num,
iterations,
IsLeNet,
ds_name
)
best_q_values = np.array(best_q_values)
elif auto_aug_learner == 'Evolutionary Learner':
network = cont_n.evo_controller(fun_num=num_funcs, p_bins=1, m_bins=1, sub_num_pol=1)
child_network = cn.LeNet()
learner = aal.evo_learner(
network=network,
fun_num=num_funcs,
p_bins=1,
mag_bins=1,
sub_num_pol=1,
ds = ds,
ds_name=ds_name,
exclude_method=exclude_method,
child_network=child_network
)
learner.run_instance()
elif auto_aug_learner == 'Random Searcher':
# As opposed to when ucb==True, `ds` and `IsLenet` are processed outside of the agent
# This system makes more sense for the user who is not using the webapp and is instead
# using the library within their code
download = True
if ds == "MNIST":
train_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/train', train=True, download=download)
test_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/test', train=False,
download=download, transform=torchvision.transforms.ToTensor())
elif ds == "KMNIST":
train_dataset = datasets.KMNIST(root='./MetaAugment/datasets/kmnist/train', train=True, download=download)
test_dataset = datasets.KMNIST(root='./MetaAugment/datasets/kmnist/test', train=False,
download=download, transform=torchvision.transforms.ToTensor())
elif ds == "FashionMNIST":
train_dataset = datasets.FashionMNIST(root='./MetaAugment/datasets/fashionmnist/train', train=True, download=download)
test_dataset = datasets.FashionMNIST(root='./MetaAugment/datasets/fashionmnist/test', train=False,
download=download, transform=torchvision.transforms.ToTensor())
elif ds == "CIFAR10":
train_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/cifar10/train', train=True, download=download)
test_dataset = datasets.CIFAR10(root='./MetaAugment/datasets/cifar10/test', train=False,
download=download, transform=torchvision.transforms.ToTensor())
elif ds == "CIFAR100":
train_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/cifar100/train', train=True, download=download)
test_dataset = datasets.CIFAR100(root='./MetaAugment/datasets/cifar100/test', train=False,
download=download, transform=torchvision.transforms.ToTensor())
elif ds == 'Other':
dataset = datasets.ImageFolder('./MetaAugment/datasets/'+ ds_name)
len_train = int(0.8*len(dataset))
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len_train, len(dataset)-len_train])
# check sizes of images
img_height = len(train_dataset[0][0][0])
img_width = len(train_dataset[0][0][0][0])
img_channels = len(train_dataset[0][0])
# check output labels
if ds == 'Other':
num_labels = len(dataset.class_to_idx)
elif ds == "CIFAR10" or ds == "CIFAR100":
num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1)
else:
num_labels = (max(train_dataset.targets) - min(train_dataset.targets) + 1).item()
# create toy dataset from above uploaded data
train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, toy_size)
# create model
if IsLeNet == "LeNet":
model = cn.LeNet(img_height, img_width, num_labels, img_channels)
elif IsLeNet == "EasyNet":
model = cn.EasyNet(img_height, img_width, num_labels, img_channels)
elif IsLeNet == 'SimpleNet':
model = cn.SimpleNet(img_height, img_width, num_labels, img_channels)
else:
model = pickle.load(open(f'datasets/childnetwork', "rb"))
# use an aa_learner. in this case, a rs learner
agent = aal.randomsearch_learner(batch_size=batch_size,
toy_flag=True,
learning_rate=learning_rate,
toy_size=toy_size,
max_epochs=max_epochs,
early_stop_num=early_stop_num,
)
agent.learn(train_dataset,
test_dataset,
child_network_architecture=model,
iterations=iterations)
elif auto_aug_learner == 'Genetic Learner':
pass