diff --git a/MetaAugment/autoaugment_learners/README.md b/MetaAugment/autoaugment_learners/README.md new file mode 100644 index 0000000000000000000000000000000000000000..4650f6d62c4536e4f192ecc2e177c090ecac6468 --- /dev/null +++ b/MetaAugment/autoaugment_learners/README.md @@ -0,0 +1,3 @@ +write `import MetaAugment.autoaugment_learners as aa` +and `aa_learner = aa.randomsearch_learner()` +to use \ No newline at end of file diff --git a/MetaAugment/autoaugment_learners/__init__.py b/MetaAugment/autoaugment_learners/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1c7b4de8c7c374fe4eee3502f7086dceb22a7d9b --- /dev/null +++ b/MetaAugment/autoaugment_learners/__init__.py @@ -0,0 +1 @@ +from .randomsearch_learner import * \ No newline at end of file diff --git a/MetaAugment/autoaugment_learners/aa_learner.py b/MetaAugment/autoaugment_learners/aa_learner.py new file mode 100644 index 0000000000000000000000000000000000000000..8ce33875507001a9f60261ea5ffb09efb6e2b31b --- /dev/null +++ b/MetaAugment/autoaugment_learners/aa_learner.py @@ -0,0 +1,29 @@ +# DUMMY PSEUDOCODE! +# +# this might become the superclass for all other autoaugment_learners +# This is sort of how our AA_Learner class should look like: + +class aa_learner: + def __init__(self, controller): + self.controller = controller + + def learn(self, train_dataset, test_dataset, child_network, res, toy_flag): + ''' + Does what is seen in Figure 1 in the AutoAugment paper. + + 'res' stands for resolution of the discretisation of the search space. It could be + a tuple, with first entry regarding probability, second regarding magnitude + ''' + good_policy_found = False + + while not good_policy_found: + policy = self.controller.pop_policy() + + train_loader, test_loader = create_toy(train_dataset, test_dataset, + batch_size=32, n_samples=0.005) + + reward = train_child_network(child_network, train_loader, test_loader, sgd, cost, epoch) + + self.controller.update(reward, policy) + + return good_policy \ No newline at end of file diff --git a/MetaAugment/autoaugment_learners/autoaugment.py b/MetaAugment/autoaugment_learners/autoaugment.py new file mode 100644 index 0000000000000000000000000000000000000000..dd7e04917b4f3c68a58705ef068ff0339464ba54 --- /dev/null +++ b/MetaAugment/autoaugment_learners/autoaugment.py @@ -0,0 +1,405 @@ +import math +import torch + +from enum import Enum +from torch import Tensor +from typing import List, Tuple, Optional, Dict + +from . import functional as F, InterpolationMode + +__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide"] + + +def _apply_op(img: Tensor, op_name: str, magnitude: float, + interpolation: InterpolationMode, fill: Optional[List[float]]): + if op_name == "ShearX": + img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[math.degrees(magnitude), 0.0], + interpolation=interpolation, fill=fill) + elif op_name == "ShearY": + img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[0.0, math.degrees(magnitude)], + interpolation=interpolation, fill=fill) + elif op_name == "TranslateX": + img = F.affine(img, angle=0.0, translate=[int(magnitude), 0], scale=1.0, + interpolation=interpolation, shear=[0.0, 0.0], fill=fill) + elif op_name == "TranslateY": + img = F.affine(img, angle=0.0, translate=[0, int(magnitude)], scale=1.0, + interpolation=interpolation, shear=[0.0, 0.0], fill=fill) + elif op_name == "Rotate": + img = F.rotate(img, magnitude, interpolation=interpolation, fill=fill) + elif op_name == "Brightness": + img = F.adjust_brightness(img, 1.0 + magnitude) + elif op_name == "Color": + img = F.adjust_saturation(img, 1.0 + magnitude) + elif op_name == "Contrast": + img = F.adjust_contrast(img, 1.0 + magnitude) + elif op_name == "Sharpness": + img = F.adjust_sharpness(img, 1.0 + magnitude) + elif op_name == "Posterize": + img = F.posterize(img, int(magnitude)) + elif op_name == "Solarize": + img = F.solarize(img, magnitude) + elif op_name == "AutoContrast": + img = F.autocontrast(img) + elif op_name == "Equalize": + img = F.equalize(img) + elif op_name == "Invert": + img = F.invert(img) + elif op_name == "Identity": + pass + else: + raise ValueError("The provided operator {} is not recognized.".format(op_name)) + return img + + +class AutoAugmentPolicy(Enum): + """AutoAugment policies learned on different datasets. + Available policies are IMAGENET, CIFAR10 and SVHN. + """ + IMAGENET = "imagenet" + CIFAR10 = "cifar10" + SVHN = "svhn" + + +# FIXME: Eliminate copy-pasted code for fill standardization and _augmentation_space() by moving stuff on a base class +class AutoAugment(torch.nn.Module): + r"""AutoAugment data augmentation method based on + `"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_. + If the image is torch Tensor, it should be of type torch.uint8, and it is expected + to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + + Args: + policy (AutoAugmentPolicy): Desired policy enum defined by + :class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``. + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. + fill (sequence or number, optional): Pixel fill value for the area outside the transformed + image. If given a number, the value is used for all bands respectively. + """ + + def __init__( + self, + policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[List[float]] = None + ) -> None: + super().__init__() + self.policy = policy + self.interpolation = interpolation + self.fill = fill + self.policies = self._get_policies(policy) + + def _get_policies( + self, + policy: AutoAugmentPolicy + ) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]: + if policy == AutoAugmentPolicy.IMAGENET: + return [ + (("Posterize", 0.4, 8), ("Rotate", 0.6, 9)), + (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), + (("Equalize", 0.8, None), ("Equalize", 0.6, None)), + (("Posterize", 0.6, 7), ("Posterize", 0.6, 6)), + (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), + (("Equalize", 0.4, None), ("Rotate", 0.8, 8)), + (("Solarize", 0.6, 3), ("Equalize", 0.6, None)), + (("Posterize", 0.8, 5), ("Equalize", 1.0, None)), + (("Rotate", 0.2, 3), ("Solarize", 0.6, 8)), + (("Equalize", 0.6, None), ("Posterize", 0.4, 6)), + (("Rotate", 0.8, 8), ("Color", 0.4, 0)), + (("Rotate", 0.4, 9), ("Equalize", 0.6, None)), + (("Equalize", 0.0, None), ("Equalize", 0.8, None)), + (("Invert", 0.6, None), ("Equalize", 1.0, None)), + (("Color", 0.6, 4), ("Contrast", 1.0, 8)), + (("Rotate", 0.8, 8), ("Color", 1.0, 2)), + (("Color", 0.8, 8), ("Solarize", 0.8, 7)), + (("Sharpness", 0.4, 7), ("Invert", 0.6, None)), + (("ShearX", 0.6, 5), ("Equalize", 1.0, None)), + (("Color", 0.4, 0), ("Equalize", 0.6, None)), + (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), + (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), + (("Invert", 0.6, None), ("Equalize", 1.0, None)), + (("Color", 0.6, 4), ("Contrast", 1.0, 8)), + (("Equalize", 0.8, None), ("Equalize", 0.6, None)), + ] + elif policy == AutoAugmentPolicy.CIFAR10: + return [ + (("Invert", 0.1, None), ("Contrast", 0.2, 6)), + (("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)), + (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)), + (("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)), + (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)), + (("ShearY", 0.2, 7), ("Posterize", 0.3, 7)), + (("Color", 0.4, 3), ("Brightness", 0.6, 7)), + (("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)), + (("Equalize", 0.6, None), ("Equalize", 0.5, None)), + (("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)), + (("Color", 0.7, 7), ("TranslateX", 0.5, 8)), + (("Equalize", 0.3, None), ("AutoContrast", 0.4, None)), + (("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)), + (("Brightness", 0.9, 6), ("Color", 0.2, 8)), + (("Solarize", 0.5, 2), ("Invert", 0.0, None)), + (("Equalize", 0.2, None), ("AutoContrast", 0.6, None)), + (("Equalize", 0.2, None), ("Equalize", 0.6, None)), + (("Color", 0.9, 9), ("Equalize", 0.6, None)), + (("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)), + (("Brightness", 0.1, 3), ("Color", 0.7, 0)), + (("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)), + (("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)), + (("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)), + (("Equalize", 0.8, None), ("Invert", 0.1, None)), + (("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)), + ] + elif policy == AutoAugmentPolicy.SVHN: + return [ + (("ShearX", 0.9, 4), ("Invert", 0.2, None)), + (("ShearY", 0.9, 8), ("Invert", 0.7, None)), + (("Equalize", 0.6, None), ("Solarize", 0.6, 6)), + (("Invert", 0.9, None), ("Equalize", 0.6, None)), + (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), + (("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)), + (("ShearY", 0.9, 8), ("Invert", 0.4, None)), + (("ShearY", 0.9, 5), ("Solarize", 0.2, 6)), + (("Invert", 0.9, None), ("AutoContrast", 0.8, None)), + (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), + (("ShearX", 0.9, 4), ("Solarize", 0.3, 3)), + (("ShearY", 0.8, 8), ("Invert", 0.7, None)), + (("Equalize", 0.9, None), ("TranslateY", 0.6, 6)), + (("Invert", 0.9, None), ("Equalize", 0.6, None)), + (("Contrast", 0.3, 3), ("Rotate", 0.8, 4)), + (("Invert", 0.8, None), ("TranslateY", 0.0, 2)), + (("ShearY", 0.7, 6), ("Solarize", 0.4, 8)), + (("Invert", 0.6, None), ("Rotate", 0.8, 4)), + (("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)), + (("ShearX", 0.1, 6), ("Invert", 0.6, None)), + (("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)), + (("ShearY", 0.8, 4), ("Invert", 0.8, None)), + (("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)), + (("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)), + (("ShearX", 0.7, 2), ("Invert", 0.1, None)), + ] + else: + raise ValueError("The provided policy {} is not recognized.".format(policy)) + + def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]: + return { + # op_name: (magnitudes, signed) + "ShearX": (torch.linspace(0.0, 0.3, num_bins), True), + "ShearY": (torch.linspace(0.0, 0.3, num_bins), True), + "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), + "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), + "Rotate": (torch.linspace(0.0, 30.0, num_bins), True), + "Brightness": (torch.linspace(0.0, 0.9, num_bins), True), + "Color": (torch.linspace(0.0, 0.9, num_bins), True), + "Contrast": (torch.linspace(0.0, 0.9, num_bins), True), + "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True), + "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False), + "Solarize": (torch.linspace(255.0, 0.0, num_bins), False), + "AutoContrast": (torch.tensor(0.0), False), + "Equalize": (torch.tensor(0.0), False), + "Invert": (torch.tensor(0.0), False), + } + + @staticmethod + def get_params(transform_num: int) -> Tuple[int, Tensor, Tensor]: + """Get parameters for autoaugment transformation + + Returns: + params required by the autoaugment transformation + """ + policy_id = int(torch.randint(transform_num, (1,)).item()) + probs = torch.rand((2,)) + signs = torch.randint(2, (2,)) + + return policy_id, probs, signs + + def forward(self, img: Tensor) -> Tensor: + """ + img (PIL Image or Tensor): Image to be transformed. + + Returns: + PIL Image or Tensor: AutoAugmented image. + """ + fill = self.fill + if isinstance(img, Tensor): + if isinstance(fill, (int, float)): + fill = [float(fill)] * F.get_image_num_channels(img) + elif fill is not None: + fill = [float(f) for f in fill] + + transform_id, probs, signs = self.get_params(len(self.policies)) + + for i, (op_name, p, magnitude_id) in enumerate(self.policies[transform_id]): + if probs[i] <= p: + op_meta = self._augmentation_space(10, F.get_image_size(img)) + magnitudes, signed = op_meta[op_name] + magnitude = float(magnitudes[magnitude_id].item()) if magnitude_id is not None else 0.0 + if signed and signs[i] == 0: + magnitude *= -1.0 + img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill) + + return img + + def __repr__(self) -> str: + return self.__class__.__name__ + '(policy={}, fill={})'.format(self.policy, self.fill) + + +class RandAugment(torch.nn.Module): + r"""RandAugment data augmentation method based on + `"RandAugment: Practical automated data augmentation with a reduced search space" + <https://arxiv.org/abs/1909.13719>`_. + If the image is torch Tensor, it should be of type torch.uint8, and it is expected + to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + + Args: + num_ops (int): Number of augmentation transformations to apply sequentially. + magnitude (int): Magnitude for all the transformations. + num_magnitude_bins (int): The number of different magnitude values. + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. + fill (sequence or number, optional): Pixel fill value for the area outside the transformed + image. If given a number, the value is used for all bands respectively. + """ + + def __init__(self, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: int = 31, + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[List[float]] = None) -> None: + super().__init__() + self.num_ops = num_ops + self.magnitude = magnitude + self.num_magnitude_bins = num_magnitude_bins + self.interpolation = interpolation + self.fill = fill + + def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]: + return { + # op_name: (magnitudes, signed) + "Identity": (torch.tensor(0.0), False), + "ShearX": (torch.linspace(0.0, 0.3, num_bins), True), + "ShearY": (torch.linspace(0.0, 0.3, num_bins), True), + "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), + "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), + "Rotate": (torch.linspace(0.0, 30.0, num_bins), True), + "Brightness": (torch.linspace(0.0, 0.9, num_bins), True), + "Color": (torch.linspace(0.0, 0.9, num_bins), True), + "Contrast": (torch.linspace(0.0, 0.9, num_bins), True), + "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True), + "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False), + "Solarize": (torch.linspace(255.0, 0.0, num_bins), False), + "AutoContrast": (torch.tensor(0.0), False), + "Equalize": (torch.tensor(0.0), False), + } + + def forward(self, img: Tensor) -> Tensor: + """ + img (PIL Image or Tensor): Image to be transformed. + + Returns: + PIL Image or Tensor: Transformed image. + """ + fill = self.fill + if isinstance(img, Tensor): + if isinstance(fill, (int, float)): + fill = [float(fill)] * F.get_image_num_channels(img) + elif fill is not None: + fill = [float(f) for f in fill] + + for _ in range(self.num_ops): + op_meta = self._augmentation_space(self.num_magnitude_bins, F.get_image_size(img)) + op_index = int(torch.randint(len(op_meta), (1,)).item()) + op_name = list(op_meta.keys())[op_index] + magnitudes, signed = op_meta[op_name] + magnitude = float(magnitudes[self.magnitude].item()) if magnitudes.ndim > 0 else 0.0 + if signed and torch.randint(2, (1,)): + magnitude *= -1.0 + img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill) + + return img + + def __repr__(self) -> str: + s = self.__class__.__name__ + '(' + s += 'num_ops={num_ops}' + s += ', magnitude={magnitude}' + s += ', num_magnitude_bins={num_magnitude_bins}' + s += ', interpolation={interpolation}' + s += ', fill={fill}' + s += ')' + return s.format(**self.__dict__) + + +class TrivialAugmentWide(torch.nn.Module): + r"""Dataset-independent data-augmentation with TrivialAugment Wide, as described in + `"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" <https://arxiv.org/abs/2103.10158>`. + If the image is torch Tensor, it should be of type torch.uint8, and it is expected + to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + + Args: + num_magnitude_bins (int): The number of different magnitude values. + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. + fill (sequence or number, optional): Pixel fill value for the area outside the transformed + image. If given a number, the value is used for all bands respectively. + """ + + def __init__(self, num_magnitude_bins: int = 31, interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[List[float]] = None) -> None: + super().__init__() + self.num_magnitude_bins = num_magnitude_bins + self.interpolation = interpolation + self.fill = fill + + def _augmentation_space(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]: + return { + # op_name: (magnitudes, signed) + "Identity": (torch.tensor(0.0), False), + "ShearX": (torch.linspace(0.0, 0.99, num_bins), True), + "ShearY": (torch.linspace(0.0, 0.99, num_bins), True), + "TranslateX": (torch.linspace(0.0, 32.0, num_bins), True), + "TranslateY": (torch.linspace(0.0, 32.0, num_bins), True), + "Rotate": (torch.linspace(0.0, 135.0, num_bins), True), + "Brightness": (torch.linspace(0.0, 0.99, num_bins), True), + "Color": (torch.linspace(0.0, 0.99, num_bins), True), + "Contrast": (torch.linspace(0.0, 0.99, num_bins), True), + "Sharpness": (torch.linspace(0.0, 0.99, num_bins), True), + "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)).round().int(), False), + "Solarize": (torch.linspace(255.0, 0.0, num_bins), False), + "AutoContrast": (torch.tensor(0.0), False), + "Equalize": (torch.tensor(0.0), False), + } + + def forward(self, img: Tensor) -> Tensor: + """ + img (PIL Image or Tensor): Image to be transformed. + + Returns: + PIL Image or Tensor: Transformed image. + """ + fill = self.fill + if isinstance(img, Tensor): + if isinstance(fill, (int, float)): + fill = [float(fill)] * F.get_image_num_channels(img) + elif fill is not None: + fill = [float(f) for f in fill] + + op_meta = self._augmentation_space(self.num_magnitude_bins) + op_index = int(torch.randint(len(op_meta), (1,)).item()) + op_name = list(op_meta.keys())[op_index] + magnitudes, signed = op_meta[op_name] + magnitude = float(magnitudes[torch.randint(len(magnitudes), (1,), dtype=torch.long)].item()) \ + if magnitudes.ndim > 0 else 0.0 + if signed and torch.randint(2, (1,)): + magnitude *= -1.0 + + return _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill) + + def __repr__(self) -> str: + s = self.__class__.__name__ + '(' + s += 'num_magnitude_bins={num_magnitude_bins}' + s += ', interpolation={interpolation}' + s += ', fill={fill}' + s += ')' + return s.format(**self.__dict__) diff --git a/MetaAugment/autoaugment_learners/randomsearch_learner.py b/MetaAugment/autoaugment_learners/randomsearch_learner.py new file mode 100644 index 0000000000000000000000000000000000000000..966ebeb06369f8e179faa91bedb16f5df61c7802 --- /dev/null +++ b/MetaAugment/autoaugment_learners/randomsearch_learner.py @@ -0,0 +1,14 @@ +import MetaAugment.autoaugment_learners as aa + + +class RandomSearch_Learner(aa): + def __init__(self): + super().__init__() + + + + + +def randomsearch_learner(): + model = RandomSearch_Learner() + return model \ No newline at end of file diff --git a/MetaAugment/child_networks/README.md b/MetaAugment/child_networks/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c8910de02fa20355e033853ead77084d535767b7 --- /dev/null +++ b/MetaAugment/child_networks/README.md @@ -0,0 +1,3 @@ +write `import MetaAugment.child_networks as child_networks` +and `child_network = child_networks.lenet()` +to use \ No newline at end of file diff --git a/MetaAugment/child_networks/__init__.py b/MetaAugment/child_networks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..88d93647b235ce75d9af9f7c77441134d3035f1e --- /dev/null +++ b/MetaAugment/child_networks/__init__.py @@ -0,0 +1,2 @@ +from .lenet import * +from .bad_lenet import * \ No newline at end of file diff --git a/MetaAugment/child_networks/bad_lenet.py b/MetaAugment/child_networks/bad_lenet.py new file mode 100644 index 0000000000000000000000000000000000000000..296192cb7746ae3de1aaf0e9954ad9c60dcaae78 --- /dev/null +++ b/MetaAugment/child_networks/bad_lenet.py @@ -0,0 +1,39 @@ +import torch.nn as nn + + +class Bad_LeNet(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1, 6, 5) + self.relu1 = nn.ReLU() + self.pool1 = nn.MaxPool2d(2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.relu2 = nn.ReLU() + self.pool2 = nn.MaxPool2d(2) + self.fc1 = nn.Linear(256, 120) + self.relu3 = nn.ReLU() + self.fc2 = nn.Linear(120, 84) + self.relu4 = nn.ReLU() + self.fc3 = nn.Linear(84, 10) + self.relu5 = nn.ReLU() + + def forward(self, x): + y = self.conv1(x) + y = self.relu1(y) + y = self.pool1(y) + y = self.conv2(y) + y = self.relu2(y) + y = self.pool2(y) + y = y.view(y.shape[0], -1) + y = self.fc1(y) + y = self.relu3(y) + y = self.fc2(y) + y = self.relu4(y) + y = self.fc3(y) + y = self.relu5(y) + return y + + +def bad_lenet(): + model = Bad_LeNet() + return model \ No newline at end of file diff --git a/MetaAugment/child_networks/lenet.py b/MetaAugment/child_networks/lenet.py new file mode 100644 index 0000000000000000000000000000000000000000..5546bfa76f3529f074f024dc1a8b81307d27eec0 --- /dev/null +++ b/MetaAugment/child_networks/lenet.py @@ -0,0 +1,39 @@ +import torch.nn as nn + + +class LeNet(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1, 6, 5) + self.relu1 = nn.ReLU() + self.pool1 = nn.MaxPool2d(2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.relu2 = nn.ReLU() + self.pool2 = nn.MaxPool2d(2) + self.fc1 = nn.Linear(256, 120) + self.relu3 = nn.ReLU() + self.fc2 = nn.Linear(120, 84) + self.relu4 = nn.ReLU() + self.fc3 = nn.Linear(84, 10) + self.relu5 = nn.ReLU() + + def forward(self, x): + y = self.conv1(x) + y = self.relu1(y) + y = self.pool1(y) + y = self.conv2(y) + y = self.relu2(y) + y = self.pool2(y) + y = y.view(y.shape[0], -1) + y = self.fc1(y) + y = self.relu3(y) + y = self.fc2(y) + y = self.relu4(y) + y = self.fc3(y) + y = self.relu5(y) + return y + + +def lenet(): + model = LeNet() + return model \ No newline at end of file diff --git a/MetaAugment/saved_models/mnist_0.9823.pkl b/MetaAugment/saved_models/mnist_0.9823.pkl new file mode 100644 index 0000000000000000000000000000000000000000..c762f2750da9807132221b447aef89c0211b36bd Binary files /dev/null and b/MetaAugment/saved_models/mnist_0.9823.pkl differ