Skip to content
Snippets Groups Projects
Commit 61295621 authored by Sun Jin Kim's avatar Sun Jin Kim
Browse files

minor changes

parent 2e64210a
No related branches found
No related tags found
No related merge requests found
......@@ -7,12 +7,11 @@ import pygad.torchga as torchga
import copy
import torch
from MetaAugment.controller_networks.evo_controller import evo_controller
from MetaAugment.autoaugment_learners.aa_learner import aa_learner, augmentation_space
import MetaAugment.child_networks as cn
from .aa_learner import aa_learner, augmentation_space
class evo_learner():
class evo_learner(aa_learner):
def __init__(self,
sp_num=1,
......
......@@ -12,7 +12,9 @@ import torchvision.datasets as datasets
import MetaAugment.autoaugment_learners as aal
import MetaAugment.controller_networks as cont_n
import MetaAugment.child_networks as cn
print('@@@ import successful')
from MetaAugment.main import create_toy
import pickle
def parse_users_learner_spec(
auto_aug_learner,
......@@ -30,6 +32,11 @@ def parse_users_learner_spec(
learning_rate,
max_epochs
):
"""
The website receives user inputs on what they want the aa_learner
to be. We take those hyperparameters and return an aa_learner
"""
if auto_aug_learner == 'UCB':
policies = aal.ucb_learner.generate_policies(num_policies, num_sub_policies)
q_values, best_q_values = aal.ucb_learner.run_UCB1(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment