Skip to content
Snippets Groups Projects
Commit 7935e672 authored by Sun Jin Kim's avatar Sun Jin Kim
Browse files

add Mia's ac learner in-progress code

parent f5a10e7e
No related branches found
No related tags found
No related merge requests found
...@@ -14,7 +14,6 @@ from pprint import pprint ...@@ -14,7 +14,6 @@ from pprint import pprint
# We will use this augmentation_space temporarily. Later on we will need to # We will use this augmentation_space temporarily. Later on we will need to
# make sure we are able to add other image functions if the users want. # make sure we are able to add other image functions if the users want.
num_bins = 10
augmentation_space = [ augmentation_space = [
# (function_name, do_we_need_to_specify_magnitude) # (function_name, do_we_need_to_specify_magnitude)
("ShearX", True), ("ShearX", True),
...@@ -34,8 +33,6 @@ augmentation_space = [ ...@@ -34,8 +33,6 @@ augmentation_space = [
] ]
# TODO: Right now the aa_learner is identical to randomsearch_learner. Change
# this so that it can act as a superclass to all other augment learners
class aa_learner: class aa_learner:
def __init__(self, sp_num=5, fun_num=14, p_bins=11, m_bins=10, discrete_p_m=False): def __init__(self, sp_num=5, fun_num=14, p_bins=11, m_bins=10, discrete_p_m=False):
''' '''
...@@ -135,6 +132,7 @@ class aa_learner: ...@@ -135,6 +132,7 @@ class aa_learner:
until a certain condition (either specified by the user or pre-specified) is met until a certain condition (either specified by the user or pre-specified) is met
''' '''
# This is dummy code
# test out 15 random policies # test out 15 random policies
for _ in range(15): for _ in range(15):
policy = self.generate_new_policy() policy = self.generate_new_policy()
......
# %%
import numpy as np
import matplotlib.pyplot as plt
from itertools import count
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.distributions import Categorical
from torch.utils.data import TensorDataset, DataLoader
from collections import namedtuple, deque
import math
import random
from MetaAugment.main import *
batch_size = 128
test_dataset = torchvision.datasets.MNIST(root='test_dataset/', train=False, download=True, transform=torchvision.transforms.ToTensor())
train_dataset = torchvision.datasets.MNIST(root='test_dataset/', train=True, download=True, transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
print('test_loader', len(test_loader))
print('train_loader',len(train_loader))
def create_toy(train_dataset, test_dataset, batch_size, n_samples):
# shuffle and take first n_samples %age of training dataset
shuffled_train_dataset = torch.utils.data.Subset(train_dataset, torch.randperm(len(train_dataset)).tolist())
indices_train = torch.arange(int(n_samples*len(train_dataset)))
reduced_train_dataset = torch.utils.data.Subset(shuffled_train_dataset, indices_train)
# shuffle and take first n_samples %age of test dataset
shuffled_test_dataset = torch.utils.data.Subset(test_dataset, torch.randperm(len(test_dataset)).tolist())
indices_test = torch.arange(int(n_samples*len(test_dataset)))
reduced_test_dataset = torch.utils.data.Subset(shuffled_test_dataset, indices_test)
# push into DataLoader
train_loader = torch.utils.data.DataLoader(reduced_train_dataset, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(reduced_test_dataset, batch_size=batch_size)
return train_loader, test_loader
# train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, 10)
class LeNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.relu2 = nn.ReLU()
self.pool2 = nn.MaxPool2d(2)
self.fc1 = nn.Linear(256, 120)
self.relu3 = nn.ReLU()
self.fc2 = nn.Linear(120, 84)
self.relu4 = nn.ReLU()
self.fc3 = nn.Linear(84, 10)
self.relu5 = nn.ReLU()
def forward(self, x):
y = self.conv1(x)
y = self.relu1(y)
y = self.pool1(y)
y = self.conv2(y)
y = self.relu2(y)
y = self.pool2(y)
y = y.view(y.shape[0], -1)
y = self.fc1(y)
y = self.relu3(y)
y = self.fc2(y)
y = self.relu4(y)
y = self.fc3(y)
y = self.relu5(y)
return y
# %% [markdown]
# ## collect reward
# %%
def collect_reward(train_loader, test_loader, max_epochs=100, early_stop_num=10):
child_network = LeNet()
sgd = optim.SGD(child_network.parameters(), lr=1e-1)
cost = nn.CrossEntropyLoss()
best_acc=0
early_stop_cnt = 0
# train child_network and check validation accuracy each epoch
print('max_epochs', max_epochs)
for _epoch in range(max_epochs):
print('_epoch', _epoch)
# train child_network
child_network.train()
for t, (train_x, train_label) in enumerate(train_loader):
label_np = np.zeros((train_label.shape[0], 10))
sgd.zero_grad()
predict_y = child_network(train_x.float())
loss = cost(predict_y, train_label.long())
loss.backward()
sgd.step()
# check validation accuracy on validation set
correct = 0
_sum = 0
child_network.eval()
for idx, (test_x, test_label) in enumerate(test_loader):
predict_y = child_network(test_x.float()).detach()
predict_ys = np.argmax(predict_y, axis=-1)
label_np = test_label.numpy()
_ = predict_ys == test_label
correct += np.sum(_.numpy(), axis=-1)
_sum += _.shape[0]
# update best validation accuracy if it was higher, otherwise increase early stop count
acc = correct / _sum
if acc > best_acc :
best_acc = acc
early_stop_cnt = 0
else:
early_stop_cnt += 1
# exit if validation gets worse over 10 runs
if early_stop_cnt >= early_stop_num:
break
# if _epoch%30 == 0:
# print('child_network accuracy: ', best_acc)
return best_acc
# %%
for t, (train_x, train_label) in enumerate(test_loader):
print(train_x.shape)
print(train_label)
break
len(test_loader)
# %%
collect_reward(train_loader, test_loader)
# %% [markdown]
# ## Policy network
# %%
class Policy(nn.Module):
"""
implements both actor and critic in one model
"""
def __init__(self):
super(Policy, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5 , stride=2)
self.conv2 = nn.Conv2d(6, 12, 5, stride=2)
self.maxpool = nn.MaxPool2d(4)
# actor's layer
self.action_head = nn.Linear(12, 2)
# critic's layer
self.value_head = nn.Linear(12, 1)
# action & reward buffer
self.saved_actions = []
self.rewards = []
def forward(self, x):
"""
forward of both actor and critic
"""
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = self.maxpool(x)
x = x.view(x.size(0), -1)
# print('x', x.shape)
# actor: choses action to take from state s_t
# by returning probability of each action
# print('self.action_head(x)', self.action_head(x).shape)
action_prob = F.softmax(self.action_head(x), dim=-1)
# print('action_prob', action_prob.shape)
# critic: evaluates being in the state s_t
state_values = self.value_head(x)
# return values for both actor and critic as a tuple of 2 values:
# 1. a list with the probability of each action over the action space
# 2. the value from state s_t
return action_prob, state_values
# %%
test_dataset = torchvision.datasets.MNIST(root='test_dataset/', train=False, download=True, transform=torchvision.transforms.ToTensor())
train_dataset = torchvision.datasets.MNIST(root='test_dataset/', train=True, download=True, transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
policy_model = Policy()
# for t, (x, y) in enumerate(train_loader):
# # print(x.shape)
# policy_model(x)
# %% [markdown]
# ## select action
# %%
SavedAction = namedtuple('SavedAction', ['log_prob', 'value'])
def select_action(train_loader, policy_model):
probs_list = []
value_list = []
for t, (x, y) in enumerate(train_loader):
probs_i, state_value_i = policy_model(x)
probs_list += [probs_i]
value_list += [state_value_i]
probs = torch.mean(torch.cat(probs_list), dim=0)
state_value = torch.mean(torch.cat(value_list))
# print('probs_i', probs_i)
# print('probs', probs)
# create a categorical distribution over the list of probabilities of actions
m = Categorical(probs)
# print('m', m)
# and sample an action using the distribution
action = m.sample()
# print('action', action)
# save to action buffer
policy_model.saved_actions.append(SavedAction(m.log_prob(action), state_value))
# the action to take (left or right)
return action.item()
# %%
torch.tensor([1, 2, 3])
# %% [markdown]
# ## take action
# %%
def take_action(action_idx):
# Define actions (data augmentation policy) --- can be improved
action_list = [
torchvision.transforms.Compose([torchvision.transforms.RandomVerticalFlip(),
torchvision.transforms.ToTensor()]),
torchvision.transforms.Compose([torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor()]),
torchvision.transforms.Compose([torchvision.transforms.RandomGrayscale(),
torchvision.transforms.ToTensor()]),
torchvision.transforms.Compose([torchvision.transforms.RandomAffine(30),
torchvision.transforms.ToTensor()])]
# transform
transform = action_list[action_idx]
test_dataset = torchvision.datasets.MNIST(root='test_dataset/', train=False, download=True, transform=transform)
train_dataset = torchvision.datasets.MNIST(root='test_dataset/', train=True, download=True, transform=transform)
train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size, n_samples=0.0002)
return train_loader, test_loader
# %% [markdown]
# ## finish episode
# %%
policy_model = Policy()
optimizer = optim.Adam(policy_model.parameters(), lr=3e-2)
eps = np.finfo(np.float32).eps.item()
gamma = 0.9
def finish_episode():
"""
Training code. Calculates actor and critic loss and performs backprop.
"""
R = 0
saved_actions = policy_model.saved_actions
policy_losses = [] # list to save actor (policy) loss
value_losses = [] # list to save critic (value) loss
returns = [] # list to save the true values
# calculate the true value using rewards returned from the environment
for r in policy_model.rewards[::-1]:
# calculate the discounted value
R = r + gamma * R
returns.insert(0, R)
returns = torch.tensor(returns)
returns = (returns - returns.mean()) / (returns.std() + eps)
for (log_prob, value), R in zip(saved_actions, returns):
advantage = R - value.item()
# calculate actor (policy) loss
policy_losses.append(-log_prob * advantage)
# calculate critic (value) loss using L1 smooth loss
value_losses.append(F.smooth_l1_loss(value, torch.tensor([R])))
# reset gradients
optimizer.zero_grad()
# sum up all the values of policy_losses and value_losses
loss = torch.stack(policy_losses).sum() + torch.stack(value_losses).sum()
# perform backprop
loss.backward()
optimizer.step()
# reset rewards and action buffer
del policy_model.rewards[:]
del policy_model.saved_actions[:]
# %% [markdown]
# ## run
# %%
running_reward = 10
episodes_num = 100
policy_model = Policy()
for i_episode in range(episodes_num) :
# initiate a new state
train_dataset = torchvision.datasets.MNIST(root='test_dataset/', train=True, download=True, transform=torchvision.transforms.ToTensor())
# train_dataset = torchvision.datasets.MNIST(root='test_dataset/', train=True, download=True, transform=torchvision.transforms.ToTensor())
train_loader_state = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# select action from policy
action_idx = select_action(train_loader, policy_model)
print('>>> action_idx', action_idx)
# take the action -> apply data augmentation
train_loader, test_loader = take_action(action_idx)
reward = collect_reward(train_loader, test_loader)
print('>>> reward', reward)
# if args.render:
# env.render()
policy_model.rewards.append(reward)
# perform backprop
finish_episode()
# # log result
if i_episode % 10 == 0:
print('Episode {}\tLast reward (val accuracy): {:.2f}'.format(i_episode, reward))
# %%
This diff is collapsed.
...@@ -15,7 +15,6 @@ from pprint import pprint ...@@ -15,7 +15,6 @@ from pprint import pprint
# We will use this augmentation_space temporarily. Later on we will need to # We will use this augmentation_space temporarily. Later on we will need to
# make sure we are able to add other image functions if the users want. # make sure we are able to add other image functions if the users want.
num_bins = 10
augmentation_space = [ augmentation_space = [
# (function_name, do_we_need_to_specify_magnitude) # (function_name, do_we_need_to_specify_magnitude)
("ShearX", True), ("ShearX", True),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment