Skip to content
Snippets Groups Projects
Commit b764dcf6 authored by Mia Wang's avatar Mia Wang
Browse files

link to confirm page && request data from flask

parents 53cedb63 e454494b
No related branches found
No related tags found
No related merge requests found
Showing
with 78 additions and 44 deletions
# MetaRL # MetaRL
Documentation:
https://metaaugment.readthedocs.io/en/latest/
\ No newline at end of file
import MetaAugment.child_networks as cn
from pprint import pprint
import torchvision.datasets as datasets
import torchvision
from MetaAugment.autoaugment_learners.aa_learner import aa_learner
import pickle
train_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/train',
train=True, download=True, transform=None)
test_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/test',
train=False, download=True, transform=torchvision.transforms.ToTensor())
child_network = cn.bad_lenet
aalearner = aa_learner(discrete_p_m=True)
# this policy is same as identity function, because probabaility and magnitude are both zero
null_policy = [(("Contrast", 0.0, 0.0), ("Contrast", 0.0, 0.0))]
with open('bad_lenet_baseline.txt', 'w') as file:
file.write('')
for _ in range(100):
acc = aalearner.test_autoaugment_policy(null_policy, child_network(), train_dataset, test_dataset,
toy_flag=True, logging=False)
with open('bad_lenet_baseline.txt', 'a') as file:
file.write(str(acc))
file.write('\n')
pprint(aalearner.history)
\ No newline at end of file
...@@ -68,7 +68,4 @@ This section has moved here: [https://facebook.github.io/create-react-app/docs/d ...@@ -68,7 +68,4 @@ This section has moved here: [https://facebook.github.io/create-react-app/docs/d
### `npm run build` fails to minify ### `npm run build` fails to minify
This section has moved here: [https://facebook.github.io/create-react-app/docs/troubleshooting#npm-run-build-fails-to-minify](https://facebook.github.io/create-react-app/docs/troubleshooting#npm-run-build-fails-to-minify) This section has moved here: [https://facebook.github.io/create-react-app/docs/troubleshooting#npm-run-build-fails-to-minify](https://facebook.github.io/create-react-app/docs/troubleshooting#npm-run-build-fails-to-minify)
\ No newline at end of file
Documentation:
https://metaaugment.readthedocs.io/en/latest/
...@@ -179,13 +179,36 @@ def training(): ...@@ -179,13 +179,36 @@ def training():
if auto_aug_learner == 'UCB': if auto_aug_learner == 'UCB':
policies = UCB1_JC.generate_policies(num_policies, num_sub_policies) policies = UCB1_JC.generate_policies(num_policies, num_sub_policies)
q_values, best_q_values = UCB1_JC.run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet, ds_name) q_values, best_q_values = UCB1_JC.run_UCB1(
policies,
batch_size,
learning_rate,
ds,
toy_size,
max_epochs,
early_stop_num,
iterations,
IsLeNet,
ds_name
)
best_q_values = np.array(best_q_values) best_q_values = np.array(best_q_values)
elif auto_aug_learner == 'Evolutionary Learner': elif auto_aug_learner == 'Evolutionary Learner':
network = cn.evo_controller.evo_controller(fun_num=num_funcs, p_bins=1, m_bins=1, sub_num_pol=1) network = cn.evo_controller.evo_controller(fun_num=num_funcs, p_bins=1, m_bins=1, sub_num_pol=1)
child_network = aal.evo.LeNet() child_network = aal.evo.LeNet()
learner = aal.evo.evo_learner(network=network, fun_num=num_funcs, p_bins=1, mag_bins=1, sub_num_pol=1, ds = ds, ds_name=ds_name, exclude_method=exclude_method, child_network=child_network) learner = aal.evo.evo_learner(
network=network,
fun_num=num_funcs,
p_bins=1,
mag_bins=1,
sub_num_pol=1,
ds = ds,
ds_name=ds_name,
exclude_method=exclude_method,
child_network=child_network
)
learner.run_instance() learner.run_instance()
elif auto_aug_learner == 'Random Searcher': elif auto_aug_learner == 'Random Searcher':
pass pass
......
[(('ShearY', 0.2, 5), ('Rotate', 0.6, 6)),
(('TranslateX', 0.8, 3), ('Posterize', 0.1, 3)),
(('TranslateY', 0.0, 8), ('Equalize', 0.7, None)),
(('Equalize', 0.3, None), ('Contrast', 0.2, 0)),
(('ShearX', 0.4, 5), ('Contrast', 0.2, 8)),
(('TranslateX', 0.9, 3), ('Solarize', 0.4, 5)),
(('Color', 0.2, 4), ('Solarize', 0.6, 8)),
(('ShearX', 0.1, 8), ('Equalize', 0.4, None)),
(('Posterize', 0.7, 5), ('Solarize', 1.0, 4))][0.6056999564170837, 0.6329999566078186, 0.6171000003814697, 0.62909996509552]original small policys accuracies: [0.6236000061035156, 0.6187999844551086, 0.617900013923645]
\ No newline at end of file
No preview for this file type
[(('Color', 0.9, 3), ('Contrast', 0.8, 3)),
(('Sharpness', 0.9, 0), ('Solarize', 0.3, 7)),
(('Color', 0.0, 6), ('Solarize', 0.4, 3)),
(('Brightness', 0.1, 3), ('Brightness', 0.5, 9)),
(('Solarize', 0.9, 6), ('Rotate', 0.6, 1)),
(('Contrast', 0.7, 3), ('Posterize', 0.9, 4)),
(('Solarize', 0.6, 2), ('Contrast', 0.5, 6)),
(('TranslateX', 0.0, 4), ('AutoContrast', 0.3, None)),
(('Equalize', 0.0, None), ('Brightness', 0.8, 1))][0.7490999698638916, 0.8359999656677246, 0.8394999504089355]original small policys accuracies: [0.8380999565124512, 0.8376999497413635, 0.8376999497413635]
\ No newline at end of file
No preview for this file type
[(('ShearX', 1.0, 0), ('Color', 0.3, 2)),
(('AutoContrast', 0.0, None), ('Brightness', 0.7, 2)),
(('Invert', 0.1, None), ('Contrast', 0.1, 6)),
(('Solarize', 0.4, 2), ('Contrast', 0.9, 2)),
(('Equalize', 0.0, None), ('Contrast', 0.0, 2)),
(('Rotate', 0.4, 0), ('Posterize', 0.5, 9)),
(('Posterize', 0.7, 3), ('Invert', 0.1, None)),
(('Solarize', 0.6, 1), ('Contrast', 0.0, 0)),
(('Color', 0.2, 6), ('Posterize', 0.4, 7))][0.6222999691963196, 0.6868000030517578, 0.8374999761581421, 0.8370999693870544, 0.6934999823570251]original small policys accuracies: [0.8431999683380127, 0.8393999934196472, 0.8377999663352966]
\ No newline at end of file
...@@ -53,5 +53,6 @@ rerun_best_policy( ...@@ -53,5 +53,6 @@ rerun_best_policy(
train_dataset=train_dataset, train_dataset=train_dataset,
test_dataset=test_dataset, test_dataset=test_dataset,
child_network_architecture=child_network_architecture, child_network_architecture=child_network_architecture,
config=config,
repeat_num=5 repeat_num=5
) )
\ No newline at end of file
...@@ -52,5 +52,6 @@ rerun_best_policy( ...@@ -52,5 +52,6 @@ rerun_best_policy(
train_dataset=train_dataset, train_dataset=train_dataset,
test_dataset=test_dataset, test_dataset=test_dataset,
child_network_architecture=child_network_architecture, child_network_architecture=child_network_architecture,
config=config,
repeat_num=5 repeat_num=5
) )
\ No newline at end of file
...@@ -40,6 +40,7 @@ run_benchmark( ...@@ -40,6 +40,7 @@ run_benchmark(
child_network_architecture=child_network_architecture, child_network_architecture=child_network_architecture,
agent_arch=aal.gru_learner, agent_arch=aal.gru_learner,
config=config, config=config,
total_iter=144
) )
rerun_best_policy( rerun_best_policy(
...@@ -48,5 +49,6 @@ rerun_best_policy( ...@@ -48,5 +49,6 @@ rerun_best_policy(
train_dataset=train_dataset, train_dataset=train_dataset,
test_dataset=test_dataset, test_dataset=test_dataset,
child_network_architecture=child_network_architecture, child_network_architecture=child_network_architecture,
config=config,
repeat_num=5 repeat_num=5
) )
\ No newline at end of file
...@@ -48,5 +48,6 @@ rerun_best_policy( ...@@ -48,5 +48,6 @@ rerun_best_policy(
train_dataset=train_dataset, train_dataset=train_dataset,
test_dataset=test_dataset, test_dataset=test_dataset,
child_network_architecture=child_network_architecture, child_network_architecture=child_network_architecture,
config=config,
repeat_num=5 repeat_num=5
) )
\ No newline at end of file
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
import MetaAugment.child_networks as cn import MetaAugment.child_networks as cn
import MetaAugment.autoaugment_learners as aal import MetaAugment.autoaugment_learners as aal
from pprint import pprint import pprint
""" """
testing gru_learner and randomsearch_learner on testing gru_learner and randomsearch_learner on
...@@ -75,16 +75,21 @@ def get_mega_policy(history, n): ...@@ -75,16 +75,21 @@ def get_mega_policy(history, n):
assert len(history) >= n assert len(history) >= n
# agent.history is a list of (policy(list), val_accuracy(float)) tuples # agent.history is a list of (policy(list), val_accuracy(float)) tuples
sorted_history = sorted(history, key=lambda x:x[1]) # sort wrt acc sorted_history = sorted(history, key=lambda x:x[1], reverse=True) # sort wrt acc
best_history = sorted_history[:n] best_history = sorted_history[:n]
megapolicy = [] megapolicy = []
# we also want to keep track of how good the best policies were
# maybe if we add them all up, they'll become worse! Hopefully better tho
orig_accs = []
for policy,acc in best_history: for policy,acc in best_history:
for subpolicy in policy: for subpolicy in policy:
megapolicy.append(subpolicy) megapolicy.append(subpolicy)
orig_accs.append(acc)
return megapolicy return megapolicy, orig_accs
def rerun_best_policy( def rerun_best_policy(
...@@ -93,25 +98,30 @@ def rerun_best_policy( ...@@ -93,25 +98,30 @@ def rerun_best_policy(
train_dataset, train_dataset,
test_dataset, test_dataset,
child_network_architecture, child_network_architecture,
config,
repeat_num repeat_num
): ):
with open(agent_pickle, 'rb') as f: with open(agent_pickle, 'rb') as f:
agent = torch.load(f, map_location=device) agent = torch.load(f)
megapol = get_mega_policy(agent.history) megapol, orig_accs = get_mega_policy(agent.history,3)
print('mega policy to be tested:') print('mega policy to be tested:')
pprint(megapol) pprint.pprint(megapol)
print(orig_accs)
accs=[] accs=[]
for _ in range(repeat_num): for _ in range(repeat_num):
print(f'{_}/{repeat_num}') print(f'{_}/{repeat_num}')
temp_agent = aal.aa_learner(**config)
accs.append( accs.append(
agent.test_autoaugment_policy(megapol, temp_agent.test_autoaugment_policy(megapol,
child_network_architecture, child_network_architecture,
train_dataset, train_dataset,
test_dataset, test_dataset,
logging=False) logging=False)
) )
with open(accs_txt, 'w') as f: with open(accs_txt, 'w') as f:
f.write(pprint.pformat(megapol))
f.write(str(accs)) f.write(str(accs))
f.write(f'original small policys accuracies: {orig_accs}')
File moved
File moved
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment