Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
M
MetaRL
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Wang, Mia
MetaRL
Commits
3eadcc84
Commit
3eadcc84
authored
2 years ago
by
Sun Jin Kim
Browse files
Options
Downloads
Plain Diff
Merge branch 'master' of gitlab.doc.ic.ac.uk:yw21218/metarl
parents
98e34909
8e89ad3a
No related branches found
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
MetaAugment/UCB1_JC.py
+361
-0
361 additions, 0 deletions
MetaAugment/UCB1_JC.py
with
361 additions
and
0 deletions
MetaAugment/UCB1_JC.py
0 → 100644
+
361
−
0
View file @
3eadcc84
#!/usr/bin/env python
# coding: utf-8
# In[1]:
import
numpy
as
np
import
torch
torch
.
manual_seed
(
0
)
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.optim
as
optim
import
torch.utils.data
as
data_utils
import
torchvision
import
torchvision.datasets
as
datasets
from
matplotlib
import
pyplot
as
plt
from
numpy
import
save
,
load
from
tqdm
import
trange
# In[2]:
"""
Define internal NN module that trains on the dataset
"""
class
LeNet
(
nn
.
Module
):
def
__init__
(
self
,
img_height
,
img_width
,
num_labels
,
img_channels
):
super
().
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
img_channels
,
6
,
5
)
self
.
relu1
=
nn
.
ReLU
()
self
.
pool1
=
nn
.
MaxPool2d
(
2
)
self
.
conv2
=
nn
.
Conv2d
(
6
,
16
,
5
)
self
.
relu2
=
nn
.
ReLU
()
self
.
pool2
=
nn
.
MaxPool2d
(
2
)
self
.
fc1
=
nn
.
Linear
(
int
((((
img_height
-
4
)
/
2
-
4
)
/
2
)
*
(((
img_width
-
4
)
/
2
-
4
)
/
2
)
*
16
),
120
)
self
.
relu3
=
nn
.
ReLU
()
self
.
fc2
=
nn
.
Linear
(
120
,
84
)
self
.
relu4
=
nn
.
ReLU
()
self
.
fc3
=
nn
.
Linear
(
84
,
num_labels
)
self
.
relu5
=
nn
.
ReLU
()
def
forward
(
self
,
x
):
y
=
self
.
conv1
(
x
)
y
=
self
.
relu1
(
y
)
y
=
self
.
pool1
(
y
)
y
=
self
.
conv2
(
y
)
y
=
self
.
relu2
(
y
)
y
=
self
.
pool2
(
y
)
y
=
y
.
view
(
y
.
shape
[
0
],
-
1
)
y
=
self
.
fc1
(
y
)
y
=
self
.
relu3
(
y
)
y
=
self
.
fc2
(
y
)
y
=
self
.
relu4
(
y
)
y
=
self
.
fc3
(
y
)
y
=
self
.
relu5
(
y
)
return
y
# In[3]:
"""
Define internal NN module that trains on the dataset
"""
class
EasyNet
(
nn
.
Module
):
def
__init__
(
self
,
img_height
,
img_width
,
num_labels
,
img_channels
):
super
().
__init__
()
self
.
fc1
=
nn
.
Linear
(
img_height
*
img_width
*
img_channels
,
2048
)
self
.
relu1
=
nn
.
ReLU
()
self
.
fc2
=
nn
.
Linear
(
2048
,
num_labels
)
self
.
relu2
=
nn
.
ReLU
()
def
forward
(
self
,
x
):
y
=
x
.
view
(
x
.
shape
[
0
],
-
1
)
y
=
self
.
fc1
(
y
)
y
=
self
.
relu1
(
y
)
y
=
self
.
fc2
(
y
)
y
=
self
.
relu2
(
y
)
return
y
# In[4]:
"""
Define internal NN module that trains on the dataset
"""
class
SimpleNet
(
nn
.
Module
):
def
__init__
(
self
,
img_height
,
img_width
,
num_labels
,
img_channels
):
super
().
__init__
()
self
.
fc1
=
nn
.
Linear
(
img_height
*
img_width
*
img_channels
,
num_labels
)
self
.
relu1
=
nn
.
ReLU
()
def
forward
(
self
,
x
):
y
=
x
.
view
(
x
.
shape
[
0
],
-
1
)
y
=
self
.
fc1
(
y
)
y
=
self
.
relu1
(
y
)
return
y
# In[5]:
"""
Make toy dataset
"""
def
create_toy
(
train_dataset
,
test_dataset
,
batch_size
,
n_samples
):
# shuffle and take first n_samples %age of training dataset
shuffle_order_train
=
np
.
random
.
RandomState
(
seed
=
100
).
permutation
(
len
(
train_dataset
))
shuffled_train_dataset
=
torch
.
utils
.
data
.
Subset
(
train_dataset
,
shuffle_order_train
)
indices_train
=
torch
.
arange
(
int
(
n_samples
*
len
(
train_dataset
)))
reduced_train_dataset
=
data_utils
.
Subset
(
shuffled_train_dataset
,
indices_train
)
# shuffle and take first n_samples %age of test dataset
shuffle_order_test
=
np
.
random
.
RandomState
(
seed
=
1000
).
permutation
(
len
(
test_dataset
))
shuffled_test_dataset
=
torch
.
utils
.
data
.
Subset
(
test_dataset
,
shuffle_order_test
)
indices_test
=
torch
.
arange
(
int
(
n_samples
*
len
(
test_dataset
)))
reduced_test_dataset
=
data_utils
.
Subset
(
shuffled_test_dataset
,
indices_test
)
# push into DataLoader
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
reduced_train_dataset
,
batch_size
=
batch_size
)
test_loader
=
torch
.
utils
.
data
.
DataLoader
(
reduced_test_dataset
,
batch_size
=
batch_size
)
return
train_loader
,
test_loader
# In[6]:
"""
Randomly generate 10 policies
"""
"""
Each policy has 5 sub-policies
"""
"""
For each sub-policy, pick 2 transformations, 2 probabilities and 2 magnitudes
"""
def
generate_policies
(
num_policies
,
num_sub_policies
):
policies
=
np
.
zeros
([
num_policies
,
num_sub_policies
,
6
])
# Policies array will be 10x5x6
for
policy
in
range
(
num_policies
):
for
sub_policy
in
range
(
num_sub_policies
):
# pick two sub_policy transformations (0=rotate, 1=shear, 2=scale)
policies
[
policy
,
sub_policy
,
0
]
=
np
.
random
.
randint
(
0
,
3
)
policies
[
policy
,
sub_policy
,
1
]
=
np
.
random
.
randint
(
0
,
3
)
while
policies
[
policy
,
sub_policy
,
0
]
==
policies
[
policy
,
sub_policy
,
1
]:
policies
[
policy
,
sub_policy
,
1
]
=
np
.
random
.
randint
(
0
,
3
)
# pick probabilities
policies
[
policy
,
sub_policy
,
2
]
=
np
.
random
.
randint
(
0
,
11
)
/
10
policies
[
policy
,
sub_policy
,
3
]
=
np
.
random
.
randint
(
0
,
11
)
/
10
# pick magnitudes
for
transformation
in
range
(
2
):
if
policies
[
policy
,
sub_policy
,
transformation
]
<=
1
:
policies
[
policy
,
sub_policy
,
transformation
+
4
]
=
np
.
random
.
randint
(
-
4
,
5
)
*
5
elif
policies
[
policy
,
sub_policy
,
transformation
]
==
2
:
policies
[
policy
,
sub_policy
,
transformation
+
4
]
=
np
.
random
.
randint
(
5
,
15
)
/
10
return
policies
# In[7]:
"""
Pick policy and sub-policy
"""
"""
Each row of data should have a different sub-policy but for now, this will do
"""
def
sample_sub_policy
(
policies
,
policy
,
num_sub_policies
):
sub_policy
=
np
.
random
.
randint
(
0
,
num_sub_policies
)
degrees
=
0
shear
=
0
scale
=
1
# check for rotations
if
policies
[
policy
,
sub_policy
][
0
]
==
0
:
if
np
.
random
.
uniform
()
<
policies
[
policy
,
sub_policy
][
2
]:
degrees
=
policies
[
policy
,
sub_policy
][
4
]
elif
policies
[
policy
,
sub_policy
][
1
]
==
0
:
if
np
.
random
.
uniform
()
<
policies
[
policy
,
sub_policy
][
3
]:
degrees
=
policies
[
policy
,
sub_policy
][
5
]
# check for shears
if
policies
[
policy
,
sub_policy
][
0
]
==
1
:
if
np
.
random
.
uniform
()
<
policies
[
policy
,
sub_policy
][
2
]:
shear
=
policies
[
policy
,
sub_policy
][
4
]
elif
policies
[
policy
,
sub_policy
][
1
]
==
1
:
if
np
.
random
.
uniform
()
<
policies
[
policy
,
sub_policy
][
3
]:
shear
=
policies
[
policy
,
sub_policy
][
5
]
# check for scales
if
policies
[
policy
,
sub_policy
][
0
]
==
2
:
if
np
.
random
.
uniform
()
<
policies
[
policy
,
sub_policy
][
2
]:
scale
=
policies
[
policy
,
sub_policy
][
4
]
elif
policies
[
policy
,
sub_policy
][
1
]
==
2
:
if
np
.
random
.
uniform
()
<
policies
[
policy
,
sub_policy
][
3
]:
scale
=
policies
[
policy
,
sub_policy
][
5
]
return
degrees
,
shear
,
scale
# In[8]:
"""
Sample policy, open and apply above transformations
"""
def
run_UCB1
(
policies
,
batch_size
,
learning_rate
,
ds
,
toy_size
,
max_epochs
,
early_stop_num
,
iterations
,
IsLeNet
):
# get number of policies and sub-policies
num_policies
=
len
(
policies
)
num_sub_policies
=
len
(
policies
[
0
])
#Initialize vector weights, counts and regret
q_values
=
[
0
]
*
num_policies
cnts
=
[
0
]
*
num_policies
q_plus_cnt
=
[
0
]
*
num_policies
total_count
=
0
best_q_values
=
[]
for
policy
in
trange
(
iterations
):
# get the action to try (either initially in order or using best q_plus_cnt value)
if
policy
>=
num_policies
:
this_policy
=
np
.
argmax
(
q_plus_cnt
)
else
:
this_policy
=
policy
# get info of transformation for this sub-policy
degrees
,
shear
,
scale
=
sample_sub_policy
(
policies
,
this_policy
,
num_sub_policies
)
# create transformations using above info
transform
=
torchvision
.
transforms
.
Compose
(
[
torchvision
.
transforms
.
RandomAffine
(
degrees
=
(
degrees
,
degrees
),
shear
=
(
shear
,
shear
),
scale
=
(
scale
,
scale
)),
torchvision
.
transforms
.
ToTensor
()])
# open data and apply these transformations
if
ds
==
"
MNIST
"
:
train_dataset
=
datasets
.
MNIST
(
root
=
'
./MetaAugment/train
'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
test_dataset
=
datasets
.
MNIST
(
root
=
'
./MetaAugment/test
'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
elif
ds
==
"
KMNIST
"
:
train_dataset
=
datasets
.
KMNIST
(
root
=
'
./MetaAugment/train
'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
test_dataset
=
datasets
.
KMNIST
(
root
=
'
./MetaAugment/test
'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
elif
ds
==
"
FashionMNIST
"
:
train_dataset
=
datasets
.
FashionMNIST
(
root
=
'
./MetaAugment/train
'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
test_dataset
=
datasets
.
FashionMNIST
(
root
=
'
./MetaAugment/test
'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
elif
ds
==
"
CIFAR10
"
:
train_dataset
=
datasets
.
CIFAR10
(
root
=
'
./MetaAugment/train
'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
test_dataset
=
datasets
.
CIFAR10
(
root
=
'
./MetaAugment/test
'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
elif
ds
==
"
CIFAR100
"
:
train_dataset
=
datasets
.
CIFAR100
(
root
=
'
./MetaAugment/train
'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
test_dataset
=
datasets
.
CIFAR100
(
root
=
'
./MetaAugment/test
'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
# check sizes of images
img_height
=
len
(
train_dataset
[
0
][
0
][
0
])
img_width
=
len
(
train_dataset
[
0
][
0
][
0
][
0
])
img_channels
=
len
(
train_dataset
[
0
][
0
])
# check output labels
if
ds
==
"
CIFAR10
"
or
ds
==
"
CIFAR100
"
:
num_labels
=
(
max
(
train_dataset
.
targets
)
-
min
(
train_dataset
.
targets
)
+
1
)
else
:
num_labels
=
(
max
(
train_dataset
.
targets
)
-
min
(
train_dataset
.
targets
)
+
1
).
item
()
# create toy dataset from above uploaded data
train_loader
,
test_loader
=
create_toy
(
train_dataset
,
test_dataset
,
batch_size
,
toy_size
)
# create model
if
IsLeNet
==
"
LeNet
"
:
model
=
LeNet
(
img_height
,
img_width
,
num_labels
,
img_channels
)
elif
IsLeNet
==
"
EasyNet
"
:
model
=
EasyNet
(
img_height
,
img_width
,
num_labels
,
img_channels
)
else
:
model
=
SimpleNet
(
img_height
,
img_width
,
num_labels
,
img_channels
)
sgd
=
optim
.
SGD
(
model
.
parameters
(),
lr
=
1e-1
)
cost
=
nn
.
CrossEntropyLoss
()
# set variables for best validation accuracy and early stop count
best_acc
=
0
early_stop_cnt
=
0
# train model and check validation accuracy each epoch
for
_epoch
in
range
(
max_epochs
):
# train model
model
.
train
()
for
idx
,
(
train_x
,
train_label
)
in
enumerate
(
train_loader
):
label_np
=
np
.
zeros
((
train_label
.
shape
[
0
],
num_labels
))
sgd
.
zero_grad
()
predict_y
=
model
(
train_x
.
float
())
loss
=
cost
(
predict_y
,
train_label
.
long
())
loss
.
backward
()
sgd
.
step
()
# check validation accuracy on validation set
correct
=
0
_sum
=
0
model
.
eval
()
for
idx
,
(
test_x
,
test_label
)
in
enumerate
(
test_loader
):
predict_y
=
model
(
test_x
.
float
()).
detach
()
predict_ys
=
np
.
argmax
(
predict_y
,
axis
=-
1
)
label_np
=
test_label
.
numpy
()
_
=
predict_ys
==
test_label
correct
+=
np
.
sum
(
_
.
numpy
(),
axis
=-
1
)
_sum
+=
_
.
shape
[
0
]
# update best validation accuracy if it was higher, otherwise increase early stop count
acc
=
correct
/
_sum
if
acc
>
best_acc
:
best_acc
=
acc
early_stop_cnt
=
0
else
:
early_stop_cnt
+=
1
# exit if validation gets worse over 10 runs
if
early_stop_cnt
>=
early_stop_num
:
break
# update q_values
if
policy
<
num_policies
:
q_values
[
this_policy
]
+=
best_acc
else
:
q_values
[
this_policy
]
=
(
q_values
[
this_policy
]
*
cnts
[
this_policy
]
+
best_acc
)
/
(
cnts
[
this_policy
]
+
1
)
best_q_value
=
max
(
q_values
)
best_q_values
.
append
(
best_q_value
)
if
(
policy
+
1
)
%
10
==
0
:
print
(
"
Iteration: {},
\t
Q-Values: {}, Best Policy: {}
"
.
format
(
policy
+
1
,
list
(
np
.
around
(
np
.
array
(
q_values
),
2
)),
max
(
list
(
np
.
around
(
np
.
array
(
q_values
),
2
)))))
# update counts
cnts
[
this_policy
]
+=
1
total_count
+=
1
# update q_plus_cnt values every turn after the initial sweep through
if
policy
>=
num_policies
-
1
:
for
i
in
range
(
num_policies
):
q_plus_cnt
[
i
]
=
q_values
[
i
]
+
np
.
sqrt
(
2
*
np
.
log
(
total_count
)
/
cnts
[
i
])
return
q_values
,
best_q_values
# # In[9]:
# batch_size = 32 # size of batch the inner NN is trained with
# learning_rate = 1e-1 # fix learning rate
# ds = "MNIST" # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100)
# toy_size = 0.02 # total propeortion of training and test set we use
# max_epochs = 100 # max number of epochs that is run if early stopping is not hit
# early_stop_num = 10 # max number of worse validation scores before early stopping is triggered
# num_policies = 5 # fix number of policies
# num_sub_policies = 5 # fix number of sub-policies in a policy
# iterations = 100 # total iterations, should be more than the number of policies
# IsLeNet = "SimpleNet" # using LeNet or EasyNet or SimpleNet
# # generate random policies at start
# policies = generate_policies(num_policies, num_sub_policies)
# q_values, best_q_values = run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet)
# plt.plot(best_q_values)
# best_q_values = np.array(best_q_values)
# save('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), best_q_values)
# #best_q_values = load('best_q_values_{}_{}percent_{}.npy'.format(IsLeNet, int(toy_size*100), ds), allow_pickle=True)
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment