diff --git a/backend/MetaAugment/Evo_learner.py b/backend/MetaAugment/Evo_learner.py index 2f97fbd0f0d4a6c70896bbbce9383e11b2cd5ad0..71170f437b7674a15e8941d0a15c6990595ab55d 100644 --- a/backend/MetaAugment/Evo_learner.py +++ b/backend/MetaAugment/Evo_learner.py @@ -25,6 +25,7 @@ from torch import Tensor + class Learner(nn.Module): def __init__(self, fun_num=14, p_bins=11, m_bins=10, sub_num_pol=5): self.fun_num = fun_num @@ -46,6 +47,7 @@ class Learner(nn.Module): self.fc3 = nn.Linear(84, self.sub_num_pol * 2 * (self.fun_num + self.p_bins + self.m_bins)) def forward(self, x): + x = x[:, 0:1, :, :] y = self.conv1(x) y = self.relu1(y) y = self.pool1(y) @@ -61,12 +63,29 @@ class Learner(nn.Module): return y +class LeNet(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(784, 2048) + self.relu1 = nn.ReLU() + self.fc2 = nn.Linear(2048, 10) + self.relu2 = nn.ReLU() + + def forward(self, x): + x = x.reshape((-1, 784)) + y = self.fc1(x) + y = self.relu1(y) + y = self.fc2(y) + y = self.relu2(y) + return y + class Evolutionary_learner(): def __init__(self, network, num_solutions = 10, num_generations = 5, num_parents_mating = 5, batch_size=32, child_network = None, p_bins = 11, mag_bins = 10, sub_num_pol=5, fun_num = 14, exclude_method=[], augmentation_space = None, ds=None, ds_name=None): - self.auto_aug_agent = Learner(fun_num=fun_num, p_bins=p_bins, m_bins=mag_bins, sub_num_pol=sub_num_pol) - self.torch_ga = torchga.TorchGA(model=network, num_solutions=num_solutions) + self.auto_aug_agent = network + self.torch_ga = torchga.TorchGA(model=self.auto_aug_agent, num_solutions=num_solutions) + self.num_generations = num_generations self.num_parents_mating = num_parents_mating self.initial_population = self.torch_ga.population_weights @@ -295,8 +314,8 @@ class Evolutionary_learner(): for idx, (test_x, label_x) in enumerate(self.train_loader): full_policy = self.get_policy_cov(test_x) - fit_val = ((test_autoaugment_policy(full_policy, self.train_dataset, self.test_dataset)[0])/ - + test_autoaugment_policy(full_policy, self.train_dataset, self.test_dataset)[0]) / 2 + fit_val = ((test_autoaugment_policy(full_policy, self.train_dataset, self.test_dataset, self.child_network)[0])/ + + test_autoaugment_policy(full_policy, self.train_dataset, self.test_dataset, self.child_network)[0]) / 2 return fit_val @@ -382,8 +401,12 @@ def train_child_network(child_network, train_loader, test_loader, sgd, correct = 0 _sum = 0 child_network.eval() + print("here0") with torch.no_grad(): + print("here1") + print("len test_loader: ", len(test_loader)) for idx, (test_x, test_label) in enumerate(test_loader): + print("here2") test_x = test_x.to(device=device, dtype=test_x.dtype) test_label = test_label.to(device=device, dtype=test_label.dtype) @@ -394,6 +417,7 @@ def train_child_network(child_network, train_loader, test_loader, sgd, correct += torch.sum(_, axis=-1) _sum += _.shape[0] + print("SUM: ", _sum) acc = correct / _sum @@ -836,3 +860,4 @@ class TrivialAugmentWide(torch.nn.Module): + diff --git a/backend/auto_augmentation/progress.py b/backend/auto_augmentation/progress.py index c33d45ca275dc305bbeca58624819d92401608d3..a83b2bb77020f71226d363e8a4a3f3176a5c660b 100644 --- a/backend/auto_augmentation/progress.py +++ b/backend/auto_augmentation/progress.py @@ -81,7 +81,9 @@ def response(): policies = UCB1_JC.generate_policies(num_policies, num_sub_policies) q_values, best_q_values = UCB1_JC.run_UCB1(policies, batch_size, learning_rate, ds, toy_size, max_epochs, early_stop_num, iterations, IsLeNet, ds_name) elif auto_aug_leanrer == 'Evolutionary Learner': - learner = Evo.Evolutionary_learner(network=Evo.Learner(), fun_num=num_funcs, p_bins=1, mag_bins=1, sub_num_pol=1, ds = ds, ds_name=ds_name, exclude_method=exclude_method) + network = Evo.Learner(fun_num=num_funcs, p_bins=1, m_bins=1, sub_num_pol=1) + child_network = Evo.LeNet() + learner = Evo.Evolutionary_learner(network=network, fun_num=num_funcs, p_bins=1, mag_bins=1, sub_num_pol=1, ds = ds, ds_name=ds_name, exclude_method=exclude_method, child_network=child_network) learner.run_instance() elif auto_aug_leanrer == 'Random Searcher': pass @@ -131,4 +133,4 @@ def response(): # return redirect(url_for('uploaded_file', filename=filename)) # return ''' -# ''' \ No newline at end of file +# '''