Skip to content
Snippets Groups Projects
Commit 8598a70a authored by Mia Wang's avatar Mia Wang
Browse files

connects to training page; change learner from checkbox to radio

parent 98be9246
No related branches found
No related tags found
No related merge requests found
0.19934545454545455,0.19519090909090908,0.19935454545454545,0.19381818181818183,0.18769999999999998,0.19858181818181822,0.19459090909090906,0.18030000000000002,0.17654545454545453,0.2042909090909091
\ No newline at end of file
0.7018545454545454,0.6530636363636364,0.6565090909090909,0.7029727272727273,0.6615000000000001,0.6610181818181818,0.6333545454545454,0.6617909090909091,0.6584636363636364,0.6933909090909091
\ No newline at end of file
......@@ -46,125 +46,96 @@ def get_form_data():
# form_data = request.files['ds_upload']
# print('@@@ form_data', form_data)
# form_data = request.form.get('test')
# print('@@@ this is form data', request.get_data())
form_data = request.form
print('@@@ this is form data', form_data)
# required input
# ds = form_data['select_dataset'] # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100)
# IsLeNet = form_data["select_network"] # using LeNet or EasyNet or SimpleNet ->> default
# auto_aug_learner = form_data["select_learner"] # augmentation methods to be excluded
# print('@@@ required user input:', 'ds', ds, 'IsLeNet:', IsLeNet, 'auto_aug_leanrer:',auto_aug_learner)
# # advanced input
# if 'batch_size' in form_data.keys():
# batch_size = form_data['batch_size'] # size of batch the inner NN is trained with
# else:
# batch_size = 1 # this is for demonstration purposes
# if 'learning_rate' in form_data.keys():
# learning_rate = form_data['learning_rate'] # fix learning rate
# else:
# learning_rate = 10-1
# if 'toy_size' in form_data.keys():
# toy_size = form_data['toy_size'] # total propeortion of training and test set we use
# else:
# toy_size = 1 # this is for demonstration purposes
# if 'iterations' in form_data.keys():
# iterations = form_data['iterations'] # total iterations, should be more than the number of policies
# else:
# iterations = 10
# exclude_method = form_data['select_action']
# num_funcs = 14 - len(exclude_method)
# print('@@@ advanced search: batch_size:', batch_size, 'learning_rate:', learning_rate, 'toy_size:', toy_size, 'iterations:', iterations, 'exclude_method', exclude_method, 'num_funcs', num_funcs)
ds = form_data['select_dataset'] # pick dataset (MNIST, KMNIST, FashionMNIST, CIFAR10, CIFAR100)
IsLeNet = form_data["select_network"] # using LeNet or EasyNet or SimpleNet ->> default
auto_aug_learner = form_data["select_learner"] # augmentation methods to be excluded
print('@@@ required user input:', 'ds', ds, 'IsLeNet:', IsLeNet, 'auto_aug_leanrer:',auto_aug_learner)
# advanced input
if form_data['batch_size'] != 'undefined':
batch_size = form_data['batch_size'] # size of batch the inner NN is trained with
else:
batch_size = 1 # this is for demonstration purposes
if form_data['learning_rate'] != 'undefined':
learning_rate = form_data['learning_rate'] # fix learning rate
else:
learning_rate = 10-1
if form_data['toy_size'] != 'undefined':
toy_size = form_data['toy_size'] # total propeortion of training and test set we use
else:
toy_size = 1 # this is for demonstration purposes
if form_data['iterations'] != 'undefined':
iterations = form_data['iterations'] # total iterations, should be more than the number of policies
else:
iterations = 10
exclude_method = form_data['select_action']
print('@@@ advanced search: batch_size:', batch_size, 'learning_rate:', learning_rate, 'toy_size:', toy_size, 'iterations:', iterations, 'exclude_method', exclude_method)
# # default values
# max_epochs = 10 # max number of epochs that is run if early stopping is not hit
# early_stop_num = 10 # max number of worse validation scores before early stopping is triggered
# num_policies = 5 # fix number of policies
# num_sub_policies = 5 # fix number of sub-policies in a policy
# default values
max_epochs = 10 # max number of epochs that is run if early stopping is not hit
early_stop_num = 10 # max number of worse validation scores before early stopping is triggered
num_policies = 5 # fix number of policies
num_sub_policies = 5 # fix number of sub-policies in a policy
# # if user upload datasets and networks, save them in the database
# if ds == 'Other':
# ds_folder = request.files['ds_upload']
# print('!!!ds_folder', ds_folder)
# ds_name_zip = ds_folder.filename
# ds_name = ds_name_zip.split('.')[0]
# ds_folder.save('./datasets/'+ ds_name_zip)
# with zipfile.ZipFile('./datasets/'+ ds_name_zip, 'r') as zip_ref:
# zip_ref.extractall('./datasets/upload_dataset/')
# if not current_app.debug:
# os.remove(f'./datasets/{ds_name_zip}')
# else:
# ds_name = None
# # test if uploaded dataset meets the criteria
# for (dirpath, dirnames, filenames) in os.walk(f'./datasets/upload_dataset/{ds_name}/'):
# for dirname in dirnames:
# if dirname[0:6] != 'class_':
# return None # neet to change render to a 'failed dataset webpage'
# # save the user uploaded network
# if IsLeNet == 'Other':
# childnetwork = request.files['network_upload']
# childnetwork.save('./child_networks/'+childnetwork.filename)
# network_name = childnetwork.filename
# if user upload datasets and networks, save them in the database
if ds == 'Other':
ds_folder = request.files['ds_upload']
print('!!!ds_folder', ds_folder)
ds_name_zip = ds_folder.filename
ds_name = ds_name_zip.split('.')[0]
ds_folder.save('./datasets/'+ ds_name_zip)
with zipfile.ZipFile('./datasets/'+ ds_name_zip, 'r') as zip_ref:
zip_ref.extractall('./datasets/upload_dataset/')
if not current_app.debug:
os.remove(f'./datasets/{ds_name_zip}')
else:
ds_name_zip = None
ds_name = None
# test if uploaded dataset meets the criteria
for (dirpath, dirnames, filenames) in os.walk(f'./datasets/upload_dataset/{ds_name}/'):
for dirname in dirnames:
if dirname[0:6] != 'class_':
return None # neet to change render to a 'failed dataset webpage'
# save the user uploaded network
if IsLeNet == 'Other':
childnetwork = request.files['network_upload']
childnetwork.save('./child_networks/'+childnetwork.filename)
network_name = childnetwork.filename
else:
network_name = None
# # generate random policies at start
# current_app.config['AAL'] = auto_aug_learner
# current_app.config['NP'] = num_policies
# current_app.config['NSP'] = num_sub_policies
# current_app.config['BS'] = batch_size
# current_app.config['LR'] = learning_rate
# current_app.config['TS'] = toy_size
# current_app.config['ME'] = max_epochs
# current_app.config['ESN'] = early_stop_num
# current_app.config['IT'] = iterations
# current_app.config['ISLENET'] = IsLeNet
# current_app.config['DSN'] = ds_name
# current_app.config['ds'] = ds
print("@@@ user input has all stored in the app")
# print("@@@ user input has all stored in the app")
data = {'ds': ds, 'ds_name': ds_name_zip, 'IsLeNet': IsLeNet, 'network_name': network_name,
'auto_aug_learner':auto_aug_learner, 'batch_size': batch_size, 'learning_rate': learning_rate,
'toy_size':toy_size, 'iterations':iterations, 'exclude_method': exclude_method, }
# data = {'ds': ds, 'ds_name': ds_name, 'IsLeNet': IsLeNet, 'ds_folder.filename': ds_name,
# 'auto_aug_learner':auto_aug_learner, 'batch_size': batch_size, 'learning_rate': learning_rate,
# 'toy_size':toy_size, 'iterations':iterations, }
current_app.config['data'] = data
# print('@@@ all data sent', data)
return {'data': 'show training data'}
print('@@@ all data sent', current_app.config['data'])
return {'data': 'all stored'}
@app.route('/confirm', methods=['POST', 'GET'])
def confirm():
print('inside confirm')
# aa learner
auto_aug_learner = current_app.config.get('AAL')
# search space & problem setting
ds = current_app.config.get('ds')
ds_name = current_app.config.get('DSN')
exclude_method = current_app.config.get('exc_meth')
num_policies = current_app.config.get('NP')
num_sub_policies = current_app.config.get('NSP')
num_funcs = current_app.config.get('NUMFUN')
toy_size = current_app.config.get('TS')
# child network
IsLeNet = current_app.config.get('ISLENET')
# ========================================================================
@app.route('/confirm', methods=['POST', 'GET'])
def confirm():
print('inside confirm page')
data = current_app.config['data']
return data
# child network training hyperparameters
batch_size = current_app.config.get('BS')
early_stop_num = current_app.config.get('ESN')
iterations = current_app.config.get('IT')
learning_rate = current_app.config.get('LR')
max_epochs = current_app.config.get('ME')
data = {'ds': ds, 'ds_name': ds_name, 'IsLeNet': IsLeNet, 'ds_folder.filename': ds_name,
'auto_aug_learner':auto_aug_learner, 'batch_size': batch_size, 'learning_rate': learning_rate,
'toy_size':toy_size, 'iterations':iterations, }
return {'batch_size': '12'}
# ========================================================================
@app.route('/training', methods=['POST', 'GET'])
......@@ -192,24 +163,31 @@ def training():
learning_rate = current_app.config.get('LR')
max_epochs = current_app.config.get('ME')
# default values
max_epochs = 10 # max number of epochs that is run if early stopping is not hit
early_stop_num = 10 # max number of worse validation scores before early stopping is triggered
num_policies = 5 # fix number of policies
num_sub_policies = 5 # fix number of sub-policies in a policy
data = current_app.config.get('data')
if auto_aug_learner == 'UCB':
if data.auto_aug_learner == 'UCB':
policies = UCB1_JC.generate_policies(num_policies, num_sub_policies)
q_values, best_q_values = UCB1_JC.run_UCB1(
policies,
batch_size,
learning_rate,
ds,
toy_size,
data.batch_size,
data.learning_rate,
data.ds,
data.toy_size,
max_epochs,
early_stop_num,
iterations,
IsLeNet,
ds_name
data.iterations,
data.IsLeNet,
data.ds_name
)
best_q_values = np.array(best_q_values)
elif auto_aug_learner == 'Evolutionary Learner':
elif data.auto_aug_learner == 'Evolutionary Learner':
network = cn.evo_controller.evo_controller(fun_num=num_funcs, p_bins=1, m_bins=1, sub_num_pol=1)
child_network = aal.evo.LeNet()
......@@ -226,12 +204,12 @@ def training():
)
learner.run_instance()
elif auto_aug_learner == 'Random Searcher':
elif data.auto_aug_learner == 'Random Searcher':
pass
elif auto_aug_learner == 'Genetic Learner':
elif data.auto_aug_learner == 'Genetic Learner':
pass
return {'status': 'training'}
return {'status': 'training done!'}
......
......@@ -37,7 +37,8 @@ function App() {
<Routes>
<Route exact path="/" element={<Home/>}/>
<Route exact path="/confirm" element={<Confirm/>}/>
{/* <Route exact path="/Progress" element={<Training/>}/> */}
<Route exact path="/progress" element={<Progress/>}/>
<Route exact path="/result" element={<Result/>}/>
</Routes>
</BrowserRouter>
</div>
......
import React, { useState, useEffect } from "react";
import { Grid, List, ListItem, Avatar, ListItemAvatar, ListItemText, Card, CardContent, Typography, Button, TextField } from '@mui/material';
import { Grid, ListItem, ListItemAvatar, ListItemText, Card, CardContent, Typography, Button } from '@mui/material';
import CheckCircleOutlineRoundedIcon from '@mui/icons-material/CheckCircleOutlineRounded';
import TuneRoundedIcon from '@mui/icons-material/TuneRounded';
import {useNavigate, Route} from "react-router-dom";
export default function Confirm() {
const [batchSize, setBatchSize] = useState(0)
// // const [myData, setMyData] = useState([{}])
const [myData, setMyData] = useState([])
const [dataset, setDataset] = useState()
const [network, setNetwork] = useState()
useEffect(() => {
const res = fetch('/confirm').then(
response => response.json()
).then(data => setBatchSize(data.batch_size));
console.log("batchsize", batchSize)
// setBatchSize(res.batch_size)
// .then(data => {console.log('training', data);
// })
).then(data => {setMyData(data);
if (data.ds == 'Other'){setDataset(data.ds_name)} else {setDataset(data.ds)};
if (data.IsLeNet == 'Other'){setNetwork(data.network_name)} else {setNetwork(data.IsLeNet)};
});
}, []);
let navigate = useNavigate();
const onSubmit = async () => {
navigate('/progress', {replace:true});
};
return (
<div className="App" style={{padding:"60px"}}>
......@@ -39,7 +41,7 @@ export default function Confirm() {
<ListItemAvatar>
<TuneRoundedIcon color="primary" fontSize='large'/>
</ListItemAvatar>
<ListItemText primary="Batch size" secondary={batchSize} />
<ListItemText primary="Batch size" secondary={myData.batch_size} />
</ListItem>
</Grid>
<Grid xs={12} sm={6} item >
......@@ -47,7 +49,7 @@ export default function Confirm() {
<ListItemAvatar>
<CheckCircleOutlineRoundedIcon color="primary" fontSize='large'/>
</ListItemAvatar>
<ListItemText primary="Dataset selection" secondary="[Dataset]" />
<ListItemText primary="Dataset selection" secondary={dataset} />
</ListItem>
</Grid>
<Grid xs={12} sm={6} item>
......@@ -55,7 +57,7 @@ export default function Confirm() {
<ListItemAvatar>
<TuneRoundedIcon color="primary" fontSize='large'/>
</ListItemAvatar>
<ListItemText primary="Learning rate" secondary="[Learning rate]" />
<ListItemText primary="Learning rate" secondary={myData.learning_rate} />
</ListItem>
</Grid>
<Grid xs={12} sm={6} item>
......@@ -63,7 +65,7 @@ export default function Confirm() {
<ListItemAvatar>
<CheckCircleOutlineRoundedIcon color="primary" fontSize='large'/>
</ListItemAvatar>
<ListItemText primary="Network selection" secondary="[Network selection]" />
<ListItemText primary="Network selection" secondary={network} />
</ListItem>
</Grid>
<Grid xs={12} sm={6} item>
......@@ -71,7 +73,7 @@ export default function Confirm() {
<ListItemAvatar>
<TuneRoundedIcon color="primary" fontSize='large'/>
</ListItemAvatar>
<ListItemText primary="Dataset Proportion" secondary="[Dataset Proportion]" />
<ListItemText primary="Dataset Proportion" secondary={myData.toy_size} />
</ListItem>
</Grid>
<Grid xs={12} sm={6} item>
......@@ -79,7 +81,7 @@ export default function Confirm() {
<ListItemAvatar>
<CheckCircleOutlineRoundedIcon color="primary" fontSize='large'/>
</ListItemAvatar>
<ListItemText primary="Auto-augment learner selection" secondary="[Auto-augment learner selection]" />
<ListItemText primary="Auto-augment learner selection" secondary={myData.auto_aug_learner} />
</ListItem>
</Grid>
<Grid xs={12} sm={6} item>
......@@ -87,7 +89,7 @@ export default function Confirm() {
<ListItemAvatar>
<TuneRoundedIcon color="primary" fontSize='large'/>
</ListItemAvatar>
<ListItemText primary="Iterations" secondary="[Iterations]" />
<ListItemText primary="Iterations" secondary={myData.iterations} />
</ListItem>
</Grid>
</Grid>
......@@ -98,6 +100,7 @@ export default function Confirm() {
variant="contained"
color='success'
size='large'
onClick={onSubmit}
>
Confirm
</Button>
......
......@@ -6,16 +6,11 @@ import SendIcon from '@mui/icons-material/Send';
import { CardActions, Collapse, IconButton } from "@mui/material";
import ExpandMoreIcon from '@mui/icons-material/ExpandMore';
import { styled } from '@mui/material/styles';
// import {
// BrowserRouter as Router,
// Switch,
// Route,
// Redirect,
// } from "react-router-dom";
// import Confirm from './pages/Confirm'
import {useNavigate, Route} from "react-router-dom";
const ExpandMore = styled((props) => {
const { expand, ...other } = props;
return <IconButton {...other} />;
......@@ -47,8 +42,16 @@ export default function Home() {
formData.append("ds_upload", data.ds_upload[0]);
formData.append("network_upload", data.network_upload[0]);
formData.append("test", 'see');
formData.append("batch_size", data.batch_size)
formData.append("toy_size", data.toy_size)
formData.append("iterations", data.iterations)
formData.append("learning_rate", data.learning_rate)
formData.append("select_action", data.select_action)
formData.append("select_dataset", data.select_dataset)
formData.append("select_learner", data.select_learner)
formData.append("select_network", data.select_network)
console.log('>>> this is the user input in formData')
for (var key of formData.entries()) {
console.log(key[0] + ', ' + key[1])}
......@@ -57,7 +60,6 @@ export default function Home() {
method: 'POST',
body: formData
}).then((response) => response.json());
console.log('check if it is here')
navigate('/confirm', {replace:true});
//
......@@ -83,21 +85,6 @@ export default function Home() {
// console.log('errors', errors);
// console.log('handleSubmit', handleSubmit)
// handling learner selection
const handleLearnerSelect = (value) => {
const isPresent = selectLearner.indexOf(value);
if (isPresent !== -1) {
const remaining = selectLearner.filter((item) => item !== value);
setSelectLearner(remaining);
} else {
setSelectLearner((prevItems) => [...prevItems, value]);
}
};
useEffect(() => {
setValue('select_learner', selectLearner);
}, [selectLearner]);
// handling action selection
const handleActionSelect = (value) => {
......@@ -250,28 +237,25 @@ export default function Home() {
<FormLabel id="select_learner" align="left" variant="h6">
Please select the auto-augment learners you'd like to use (multiple learners can be selected)
</FormLabel>
<div>
{['UCB learner', 'Evolutionary learner', 'Random Searcher', 'GRU Learner'].map((option) => {
return (
<FormControlLabel
control={
<Controller
name='select_learner'
render={({}) => {
return (
<Checkbox
checked={selectLearner.includes(option)}
onChange={() => handleLearnerSelect(option)}/> );
}}
control={control}
rules={{ required: true }}
/>}
label={option}
key={option}
/>
);
})}
</div>
<Controller
name='select_learner'
control={control}
rules={{ required: true }}
render={({field: { onChange, value }}) => (
<RadioGroup
row
aria-labelledby="select_learner"
name="select_learner"
align="centre"
value={value ?? ""}
onChange={onChange}
>
<FormControlLabel value="UCB learner" control={<Radio />} label="UCB learner" />
<FormControlLabel value="Evolutionary learner" control={<Radio />} label="Evolutionary learner" />
<FormControlLabel value="Random Searcher" control={<Radio />} label="Random Searcher" />
<FormControlLabel value="GRU Learner" control={<Radio />} label="GRU Learner" />
</RadioGroup> )}
/>
{errors.select_learner && errors.select_learner.type === "required" &&
<Alert severity="error">
<AlertTitle>This field is required</AlertTitle>
......@@ -314,16 +298,16 @@ export default function Home() {
</Typography>
<Grid container spacing={1} style={{maxWidth:800, padding:"10px 10px"}}>
<Grid xs={12} sm={6} item>
<TextField type="number" {...register("batch_size", {valueAsNumber: true})} name="batch_size" placeholder="Batch Size" label="Batch Size" variant="outlined" fullWidth />
<TextField type="number" InputProps={{ inputProps: { min: 0} }} {...register("batch_size")} name="batch_size" placeholder="Batch Size" label="Batch Size" variant="outlined" fullWidth />
</Grid>
<Grid xs={12} sm={6} item>
<TextField type="number" {...register("learning_rate", {valueAsNumber: true})} name="learning_rate" placeholder="Learning Rate" label="Learning Rate" variant="outlined" fullWidth />
<TextField type="number" inputProps={{step: "0.000000001",min: 0}} {...register("learning_rate")} name="learning_rate" placeholder="Learning Rate" label="Learning Rate" variant="outlined" fullWidth />
</Grid>
<Grid xs={12} sm={6} item>
<TextField type="number" {...register("iterations", {valueAsNumber: true})} name="iterations" placeholder="Number of Iterations" label="Iterations" variant="outlined" fullWidth />
<TextField type="number" InputProps={{ inputProps: { min: 0} }} {...register("iterations")} name="iterations" placeholder="Number of Iterations" label="Iterations" variant="outlined" fullWidth />
</Grid>
<Grid xs={12} sm={6} item>
<TextField type="number" {...register("toy_size", {valueAsNumber: true})} name="toy_size" placeholder="Dataset Proportion" label="Dataset Proportion" variant="outlined" fullWidth />
<TextField type="number" inputProps={{step: "0.01", min: 0}} {...register("toy_size")} name="toy_size" placeholder="Dataset Proportion" label="Dataset Proportion" variant="outlined" fullWidth />
</Grid>
<FormLabel variant="h8" align='centre'>
* Dataset Proportion defines the percentage of original dataset our auto-augment learner will use to find the
......
......@@ -2,9 +2,20 @@ import React, { useState, useEffect } from "react";
import { Grid, LinearProgress, Card, CardContent, Typography, Button, TextField } from '@mui/material';
import CheckCircleOutlineRoundedIcon from '@mui/icons-material/CheckCircleOutlineRounded';
import TuneRoundedIcon from '@mui/icons-material/TuneRounded';
import {useNavigate, Route} from "react-router-dom";
export default function Training() {
useEffect(() => {
const res = fetch('/training').then(
response => response.json()
).then(data => console.log(data))
}, []);
return (
<div className="App" style={{padding:"60px"}}>
<Typography gutterBottom variant="h3" align="center" >
......@@ -14,7 +25,7 @@ export default function Training() {
<CardContent>
<Grid style={{padding:"50px"}}>
<Typography gutterBottom variant="subtitle1" align="center" >
Our auto-augment agents are working hard to generate your data augmentation policy ...
Our auto-augment learners are working hard to generate your data augmentation policy ...
</Typography>
<Grid style={{padding:"60px"}}>
<LinearProgress color="primary"/>
......
import React, { useState, useEffect } from "react";
import { Grid, List, ListItem, Avatar, ListItemAvatar, ListItemText, Card, CardContent, Typography, Button, CardMedia } from '@mui/material';
import output from './pytest.png'
import {useNavigate, Route} from "react-router-dom";
export default function Result() {
return (
<div className="App" style={{padding:"60px"}}>
<Typography gutterBottom variant="h3" align="center" >
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment