Skip to content
Snippets Groups Projects
Commit b6beb405 authored by Mia Wang's avatar Mia Wang
Browse files

connected with learners

parent ddf4ca0a
No related branches found
No related tags found
No related merge requests found
Showing with 155 additions and 82 deletions
No preview for this file type
......@@ -462,3 +462,22 @@ class AaLearner:
megapol += pol[0]
return megapol
def get_n_best_policies(self, number_policies=5):
"""
returns the n best policies
Args:
number_policies (int): Number of (sub)policies to return
Returns:
list of best n policies
"""
number_policies = max(number_policies, len(self.history))
inter_pol = sorted(self.history, key=lambda x: x[1], reverse = True)[:number_policies]
return inter_pol[n]
......@@ -153,8 +153,7 @@ class UcbLearner(RsLearner):
train_dataset,
test_dataset,
child_network_architecture,
iterations=15,
print_every_epoch=False):
iterations=15,):
"""continue the UCB algorithm for ``iterations`` number of turns
"""
......@@ -173,7 +172,6 @@ class UcbLearner(RsLearner):
train_dataset,
test_dataset,
logging=False,
print_every_epoch=print_every_epoch
)
# update q_values (average accuracy)
self.avg_accs[this_policy_idx] = acc
......@@ -188,7 +186,6 @@ class UcbLearner(RsLearner):
train_dataset,
test_dataset,
logging=False,
print_every_epoch=print_every_epoch
)
# update q_values (average accuracy)
self.avg_accs[this_policy_idx] = (self.avg_accs[this_policy_idx]*self.cnts[this_policy_idx] + acc) / (self.cnts[this_policy_idx] + 1)
......@@ -246,6 +243,30 @@ class UcbLearner(RsLearner):
return megapol
def get_n_best_policies(self, number_policies=5):
"""
returns the n best policies
Args:
number_policies (int): Number of (sub)policies to return
Returns:
list of best n policies
"""
temp_avg_accs = [x if x is not None else 0 for x in self.avg_accs]
temp_history = list(zip(self.policies, temp_avg_accs))
number_policies = max(number_policies, len(temp_history))
inter_pol = sorted(temp_history, key=lambda x: x[1], reverse = True)[:number_policies]
return inter_pol[n]
......
......@@ -6,7 +6,7 @@ import torch
torch.manual_seed(0)
import temp_util.wapp_util as wapp_util
import react_backend.wapp_util as wapp_util
bp = Blueprint("progress", __name__)
......
......@@ -5,7 +5,7 @@ import torch
torch.manual_seed(0)
import temp_util.wapp_util as wapp_util
import react_backend.wapp_util as wapp_util
bp = Blueprint("training", __name__)
......
This diff is collapsed.
from dataclasses import dataclass
from flask import Flask, request, current_app, send_file, send_from_directory, redirect, url_for, session
from flask_cors import CORS, cross_origin
import os
import zipfile
import torch
from numpy import save, load
# import temp_util.wapp_util as wapp_util
import time
from numpy import int0, save, load
from react_backend.wapp_util import parse_users_learner_spec
import pprint
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import os
import sys
sys.path.insert(0, os.path.abspath('..'))
......@@ -39,22 +39,22 @@ def get_form_data():
print('@@@ required user input:', 'ds', ds, 'IsLeNet:', IsLeNet, 'auto_aug_leanrer:',auto_aug_learner)
# advanced input
if form_data['batch_size'] != 'undefined':
batch_size = form_data['batch_size'] # size of batch the inner NN is trained with
if form_data['batch_size'] not in ['undefined', ""]:
batch_size = int(form_data['batch_size'] ) # size of batch the inner NN is trained with
else:
batch_size = 1 # this is for demonstration purposes
if form_data['learning_rate'] != 'undefined':
learning_rate = form_data['learning_rate'] # fix learning rate
batch_size = 16 # this is for demonstration purposes
if form_data['learning_rate'] not in ['undefined', ""]:
learning_rate = float(form_data['learning_rate']) # fix learning rate
else:
learning_rate = 10-1
if form_data['toy_size'] != 'undefined':
toy_size = form_data['toy_size'] # total propeortion of training and test set we use
learning_rate = 1e-2
if form_data['toy_size'] not in ['undefined', ""]:
toy_size = float(form_data['toy_size']) # total propeortion of training and test set we use
else:
toy_size = 1 # this is for demonstration purposes
if form_data['iterations'] != 'undefined':
iterations = form_data['iterations'] # total iterations, should be more than the number of policies
toy_size = 0.01 # this is for demonstration purposes
if form_data['iterations'] not in ['undefined', ""]:
iterations = int(form_data['iterations']) # total iterations, should be more than the number of policies
else:
iterations = 10
iterations = 2
exclude_method = form_data['select_action']
print('@@@ advanced search: batch_size:', batch_size, 'learning_rate:', learning_rate, 'toy_size:', toy_size, 'iterations:', iterations, 'exclude_method', exclude_method)
......@@ -156,15 +156,54 @@ def training():
num_sub_policies = 5 # fix number of sub-policies in a policy
data = current_app.config.get('data')
# fake training
print('pretend it is training')
time.sleep(1)
print('epoch: 1')
time.sleep(1)
print('epoch: 2')
time.sleep(1)
print('epoch: 3')
print('it has finished training')
# parse the settings given by the user to obtain tools we need
train_dataset, test_dataset, child_archi, agent = parse_users_learner_spec(
max_epochs=max_epochs,
early_stop_num=early_stop_num,
num_policies=num_policies,
num_sub_policies=num_sub_policies,
**data
)
# train the autoaugment learner for number of `iterations`
agent.learn(
train_dataset=train_dataset,
test_dataset=test_dataset,
child_network_architecture=child_archi,
iterations=data['iterations']
)
print('the history of all the policies the agent has tested:')
pprint.pprint(agent.history)
# get acc graph and best acc graph
acc_list = [acc for (policy,acc) in agent.history]
best_acc_list = []
best_til_now = 0
for acc in acc_list:
if acc>best_til_now:
best_til_now=acc
best_acc_list.append(best_til_now)
# plot both here
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
ax.plot(acc_list)
ax.plot(best_acc_list)
ax.set_xlabel('Number of Iterations')
ax.set_ylabel('Accuracy')
ax.set_title('Auto-augmentation Learner Performance Curve')
with open("./react_frontend/src/pages/output.png", 'wb') as f:
fig.savefig(f)
print("best policies:")
best_policy = agent.get_mega_policy(number_policies=4)
print(best_policy)
with open("./react_backend/policy.txt", 'w') as f:
# save the best_policy in pretty_print string format
f.write(pprint.pformat(best_policy, indent=4))
print('')
return {'status': 'Training is done!'}
......
......@@ -36,12 +36,15 @@ def parse_ds_cn_arch(ds, ds_name, IsLeNet):
len_train = int(0.8*len(dataset))
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len_train, len(dataset)-len_train])
# check sizes of images
img_height = len(train_dataset[0][0][0])
img_width = len(train_dataset[0][0][0][0])
img_channels = len(train_dataset[0][0])
my_image = train_dataset[0][0]
my_transform = torchvision.transforms.ToTensor()
my_image = my_transform(my_image)
# check sizes of images
img_height = len(my_image[0])
img_width = len(my_image[0][0])
img_channels = len(my_image)
# check output labels
if ds == 'Other':
num_labels = len(dataset.class_to_idx)
......
0.7018545454545454,0.6530636363636364,0.6565090909090909,0.7029727272727273,0.6615000000000001,0.6610181818181818,0.6333545454545454,0.6617909090909091,0.6584636363636364,0.6933909090909091
\ No newline at end of file
[ (('Brightness', 0.4, 4), ('Color', 0.0, 8)),
(('Posterize', 0.1, 0), ('Rotate', 0.3, 3)),
(('Equalize', 0.8, None), ('TranslateX', 0.7, 4)),
(('Equalize', 0.1, None), ('Posterize', 0.4, 1)),
(('AutoContrast', 0.3, None), ('Solarize', 0.2, 3)),
(('Sharpness', 1.0, 9), ('ShearY', 0.9, 9)),
(('Posterize', 0.6, 0), ('Color', 1.0, 0)),
(('TranslateX', 0.9, 0), ('Solarize', 0.1, 2)),
(('Rotate', 1.0, 9), ('Equalize', 0.3, None)),
(('ShearX', 0.3, 5), ('TranslateX', 0.0, 7))]
\ No newline at end of file
......@@ -26,7 +26,6 @@ def parse_users_learner_spec(
auto_aug_learner,
# search space settings
exclude_method,
num_funcs,
num_policies,
num_sub_policies,
# child network settings
......@@ -35,7 +34,9 @@ def parse_users_learner_spec(
early_stop_num,
iterations,
learning_rate,
max_epochs
max_epochs,
# dummy variable which does nothing
network_name,
):
train_dataset, test_dataset, child_archi = parse_ds_cn_arch(
ds,
......@@ -111,8 +112,4 @@ def parse_users_learner_spec(
early_stop_num=early_stop_num,
)
agent.learn(train_dataset,
test_dataset,
child_network_architecture=child_archi,
iterations=iterations)
\ No newline at end of file
return train_dataset, test_dataset, child_archi, agent
\ No newline at end of file
......@@ -8,17 +8,21 @@ export default function Confirm() {
const [myData, setMyData] = useState([])
const [dataset, setDataset] = useState()
const [network, setNetwork] = useState()
const [yes, setYes] = useState()
// console.log('already in confirm react')
console.log('already in confirm react')
useEffect(() => {
const res = fetch('/home').then(
response => response.json()
).then(data => {setMyData(data);
if (data.ds == 'Other'){setDataset(data.ds_name)} else {setDataset(data.ds)};
if (data.IsLeNet == 'Other'){setNetwork(data.network_name)} else {setNetwork(data.IsLeNet)};
setYes('hey');
console.log('setYes', yes);
});
}, []);
let navigate = useNavigate();
const onSubmit = async () => {
navigate('/progress', {replace:true});
......@@ -27,7 +31,7 @@ export default function Confirm() {
return (
<div className="App" style={{padding:"60px"}}>
<Typography gutterBottom variant="h3" align="center" >
Data Auto-Augmentation
Data Auto-Augmentation {yes}
</Typography>
<Grid>
<Card style={{ maxWidth: 900, padding: "10px 5px", margin: "0 auto" }}>
......
......@@ -27,7 +27,6 @@ const ExpandMore = styled((props) => {
export default function Home() {
const [selectAction, setSelectAction] = useState([]);
const [validation, setValidation] = useState([]);
// for form submission
const {register, control, handleSubmit, setValue, watch, formState: { errors, dirtyFields}} = useForm();
......@@ -66,31 +65,8 @@ export default function Home() {
response => response.json()
).then(data => {
if ('error' in data){navigate('/error', data)} else {navigate('/confirm', {replace:true})}
});
//
///////// testing
// .then((response)=> {
// responseClone = response.clone(); // 2
// return response.json();
// })
// .then(function (data) {
// console.log('data from flask', data)
// }, function (rejectionReason) { // 3
// console.log('Error parsing JSON from response:', rejectionReason, responseClone); // 4
// responseClone.text() // 5
// .then(function (bodyText) {
// console.log('Received the following instead of valid JSON:', bodyText); // 6
// });
// });
});
};
// body: JSON.stringify(data)
// console.log('errors', errors);
// console.log('handleSubmit', handleSubmit)
// handling action selection
const handleActionSelect = (value) => {
......@@ -328,10 +304,10 @@ export default function Home() {
<TextField type="number" inputProps={{step: "0.000000001",min: 0}} {...register("learning_rate")} name="learning_rate" placeholder="Learning Rate" label="Learning Rate" variant="outlined" fullWidth />
</Grid>
<Grid xs={12} sm={6} item>
<TextField type="number" InputProps={{step:"1", inputProps: { min: 0, max: 1} }} {...register("iterations")} name="iterations" placeholder="Number of Iterations" label="Iterations" variant="outlined" fullWidth />
<TextField type="number" InputProps={{step:"1", inputProps: { min: 0} }} {...register("iterations")} name="iterations" placeholder="Number of Iterations" label="Iterations" variant="outlined" fullWidth />
</Grid>
<Grid xs={12} sm={6} item>
<TextField type="number" inputProps={{step: "0.01", min: 0, max: 1}} {...register("toy_size")} name="toy_size" placeholder="Dataset Proportion" label="Dataset Proportion" variant="outlined" fullWidth />
<TextField type="number" inputProps={{step: "0.0001", min: 0, max: 1}} {...register("toy_size")} name="toy_size" placeholder="Dataset Proportion" label="Dataset Proportion" variant="outlined" fullWidth />
</Grid>
<FormLabel variant="h8" align='centre'>
* Dataset Proportion defines the percentage of original dataset our auto-augment learner will use to find the
......
import React, { useState, useEffect } from "react";
import { Grid, List, ListItem, Avatar, ListItemAvatar, ListItemText, Card, CardContent, Typography, Button, CardMedia } from '@mui/material';
import output from './pytest.png'
import output from './output.png'
import {useNavigate, Route} from "react-router-dom";
import axios from 'axios'
import fileDownload from 'js-file-download'
......@@ -28,10 +28,10 @@ export default function Result() {
Here are the results from our auto-augment agent:
</Typography>
<Grid style={{padding:"30px"}} container spacing={4} alignItems="center">
<Grid xs={7} item>
<Grid xs={6} item>
<img src={output} alt='output' />
</Grid>
<Grid xs={5} item>
<Grid xs={6} item>
<Typography>
write something here to explain the meaning of the graph to the user
</Typography>
......@@ -41,6 +41,7 @@ export default function Result() {
<Typography gutterBottom variant='h6' align='center'>
You can download the augentation policy here
</Typography>
<Button
type="submit"
variant="contained"
......@@ -50,9 +51,13 @@ export default function Result() {
>
Download
</Button>
<Typography>
Please follow our documentation to apply this policy to your dataset.
</Typography>
<Grid style={{padding:'10px'}}>
<Typography>
Please follow our documentation to apply this policy to your dataset.
</Typography>
</Grid>
</CardContent>
</Card>
......
react_frontend/src/pages/output.png

29.4 KiB

react_frontend/src/pages/pytest.png

7.93 KiB

0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment