Skip to content
Snippets Groups Projects
Commit 21fc9bce authored by Sun Jin Kim's avatar Sun Jin Kim
Browse files

Demo code on main.py

parent 7f8fc18d
No related branches found
No related tags found
No related merge requests found
......@@ -96,11 +96,11 @@ def train_model(transform_idx, p):
# create toy dataset from above uploaded data
train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, 0.01)
train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(reduced_test_dataset, batch_size=batch_size)
# train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=batch_size)
# test_loader = torch.utils.data.DataLoader(reduced_test_dataset, batch_size=batch_size)
print("Size of training dataset:\t", len(reduced_train_dataset))
print("Size of testing dataset:\t", len(reduced_test_dataset), "\n")
# print("Size of training dataset:\t", len(reduced_train_dataset))
# print("Size of testing dataset:\t", len(reduced_test_dataset), "\n")
child_network = child_networks.lenet()
sgd = optim.SGD(child_network.parameters(), lr=1e-1)
......
No preview for this file type
......@@ -77,28 +77,24 @@ def train_child_network(child_network, train_loader, test_loader, sgd,
return best_acc
# This is sort of how our AA_Learner class should look like:
class AA_Learner:
def __init__(self, controller):
self.controller = controller
def learn(self, train_dataset, test_dataset, child_network, toy_flag):
'''
Deos what is seen in Figure 1 in the AutoAugment paper.
if __name__=='__main__':
import MetaAugment.child_networks as cn
'res' stands for resolution of the discretisation of the search space. It could be
a tuple, with first entry regarding probability, second regarding magnitude
'''
good_policy_found = False
batch_size = 32
n_samples = 0.005
while not good_policy_found:
policy = self.controller.pop_policy()
train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False,
transform=torchvision.transforms.ToTensor())
test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=False,
transform=torchvision.transforms.ToTensor())
train_loader, test_loader = create_toy(train_dataset, test_dataset,
batch_size=32, n_samples=0.005)
# create toy dataset from above uploaded data
train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, 0.01)
reward = train_child_network(child_network, train_loader, test_loader, sgd, cost, epoch)
child_network = cn.lenet()
sgd = optim.SGD(child_network.parameters(), lr=1e-1)
cost = nn.CrossEntropyLoss()
epoch = 20
self.controller.update(reward, policy)
return good_policy
\ No newline at end of file
best_acc = train_child_network(child_network, train_loader, test_loader, sgd, cost, max_epochs=100)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment