diff --git a/MetaAugment/CP2_Max.py b/MetaAugment/CP2_Max.py index e928b7de2bd152aaa05b56c8fbb8321a8681edaa..1b236f940c12e7e09024bfc15eec4ca987a92297 100644 --- a/MetaAugment/CP2_Max.py +++ b/MetaAugment/CP2_Max.py @@ -22,6 +22,24 @@ np.random.seed(0) random.seed(0) +# augmentation_space = [ +# # (function_name, do_we_need_to_specify_magnitude) +# ("ShearX", True), +# ("ShearY", True), +# ("TranslateX", True), +# ("TranslateY", True), +# ("Rotate", True), +# ("Brightness", True), +# ("Color", True), +# ("Contrast", True), +# ("Sharpness", True), +# ("Posterize", True), +# ("Solarize", True), +# ("AutoContrast", False), +# ("Equalize", False), +# ("Invert", False), +# ] + class Learner(nn.Module): def __init__(self, num_transforms = 3): super().__init__() @@ -38,6 +56,7 @@ class Learner(nn.Module): self.fc3 = nn.Linear(84, 13) # self.sig = nn.Sigmoid() + def forward(self, x): y = self.conv1(x) y = self.relu1(y) @@ -60,7 +79,6 @@ class Learner(nn.Module): p_ret = 0.1 * torch.argmax(y[:, 3:].mean(dim = 0)) return (idx_ret, p_ret) - # return (torch.argmax(y[0:3]), y[torch.argmax(y[3:])]) class LeNet(nn.Module): def __init__(self): @@ -253,7 +271,7 @@ class Evolutionary_learner(): self.num_parents_mating = num_parents_mating self.initial_population = self.torch_ga.population_weights self.train_loader = train_loader - self.backup_model = sec_model + self.sec_model = sec_model assert num_solutions > num_parents_mating, 'Number of solutions must be larger than the number of parents mating!' @@ -269,7 +287,7 @@ class Evolutionary_learner(): return solution, solution_fitness, solution_idx def new_model(self): - copy_model = copy.deepcopy(self.backup_model) + copy_model = copy.deepcopy(self.sec_model) return copy_model