Skip to content
Snippets Groups Projects
autoaugment.py 23 KiB
Newer Older
  • Learn to ignore specific revisions
  • # We can use the functions in here to easily apply many different image transformations
    # to a dataset, in the same format that is seen in the AutoAugment paper (A policy which
    # consists of N subpolicies, 2 operations per subpolicy). The actual way to use the code
    # is quite simple. See the demo code in 'if __main__' for demonstration.
    
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
    import math
    import torch
    
    from enum import Enum
    from torch import Tensor
    from typing import List, Tuple, Optional, Dict
    
    
    from torchvision.transforms import functional as F, InterpolationMode
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
    
    __all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide"]
    
    
    def _apply_op(img: Tensor, op_name: str, magnitude: float,
                  interpolation: InterpolationMode, fill: Optional[List[float]]):
        if op_name == "ShearX":
            img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[math.degrees(magnitude), 0.0],
                           interpolation=interpolation, fill=fill)
        elif op_name == "ShearY":
            img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[0.0, math.degrees(magnitude)],
                           interpolation=interpolation, fill=fill)
        elif op_name == "TranslateX":
            img = F.affine(img, angle=0.0, translate=[int(magnitude), 0], scale=1.0,
                           interpolation=interpolation, shear=[0.0, 0.0], fill=fill)
        elif op_name == "TranslateY":
            img = F.affine(img, angle=0.0, translate=[0, int(magnitude)], scale=1.0,
                           interpolation=interpolation, shear=[0.0, 0.0], fill=fill)
        elif op_name == "Rotate":
            img = F.rotate(img, magnitude, interpolation=interpolation, fill=fill)
        elif op_name == "Brightness":
            img = F.adjust_brightness(img, 1.0 + magnitude)
        elif op_name == "Color":
            img = F.adjust_saturation(img, 1.0 + magnitude)
        elif op_name == "Contrast":
            img = F.adjust_contrast(img, 1.0 + magnitude)
        elif op_name == "Sharpness":
            img = F.adjust_sharpness(img, 1.0 + magnitude)
        elif op_name == "Posterize":
            img = F.posterize(img, int(magnitude))
        elif op_name == "Solarize":
            img = F.solarize(img, magnitude)
        elif op_name == "AutoContrast":
            img = F.autocontrast(img)
        elif op_name == "Equalize":
            img = F.equalize(img)
        elif op_name == "Invert":
            img = F.invert(img)
        elif op_name == "Identity":
            pass
        else:
            raise ValueError("The provided operator {} is not recognized.".format(op_name))
        return img
    
    
    class AutoAugmentPolicy(Enum):
        """AutoAugment policies learned on different datasets.
        Available policies are IMAGENET, CIFAR10 and SVHN.
        """
        IMAGENET = "imagenet"
        CIFAR10 = "cifar10"
        SVHN = "svhn"
    
    
    # FIXME: Eliminate copy-pasted code for fill standardization and _augmentation_space() by moving stuff on a base class
    class AutoAugment(torch.nn.Module):
        r"""AutoAugment data augmentation method based on
        `"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
        If the image is torch Tensor, it should be of type torch.uint8, and it is expected
        to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
        If img is PIL Image, it is expected to be in mode "L" or "RGB".
    
        Args:
            policy (AutoAugmentPolicy): Desired policy enum defined by
                :class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``.
            interpolation (InterpolationMode): Desired interpolation enum defined by
                :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
                If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
            fill (sequence or number, optional): Pixel fill value for the area outside the transformed
                image. If given a number, the value is used for all bands respectively.
        """
    
        def __init__(
            self,
            policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
            interpolation: InterpolationMode = InterpolationMode.NEAREST,
            fill: Optional[List[float]] = None
        ) -> None:
            super().__init__()
            self.policy = policy
            self.interpolation = interpolation
            self.fill = fill
    
            self.subpolicies = self._get_subpolicies(policy)
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
    
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
            self,
            policy: AutoAugmentPolicy
        ) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]:
            if policy == AutoAugmentPolicy.IMAGENET:
                return [
                    (("Posterize", 0.4, 8), ("Rotate", 0.6, 9)),
                    (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
                    (("Equalize", 0.8, None), ("Equalize", 0.6, None)),
                    (("Posterize", 0.6, 7), ("Posterize", 0.6, 6)),
                    (("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
                    (("Equalize", 0.4, None), ("Rotate", 0.8, 8)),
                    (("Solarize", 0.6, 3), ("Equalize", 0.6, None)),
                    (("Posterize", 0.8, 5), ("Equalize", 1.0, None)),
                    (("Rotate", 0.2, 3), ("Solarize", 0.6, 8)),
                    (("Equalize", 0.6, None), ("Posterize", 0.4, 6)),
                    (("Rotate", 0.8, 8), ("Color", 0.4, 0)),
                    (("Rotate", 0.4, 9), ("Equalize", 0.6, None)),
                    (("Equalize", 0.0, None), ("Equalize", 0.8, None)),
                    (("Invert", 0.6, None), ("Equalize", 1.0, None)),
                    (("Color", 0.6, 4), ("Contrast", 1.0, 8)),
                    (("Rotate", 0.8, 8), ("Color", 1.0, 2)),
                    (("Color", 0.8, 8), ("Solarize", 0.8, 7)),
                    (("Sharpness", 0.4, 7), ("Invert", 0.6, None)),
                    (("ShearX", 0.6, 5), ("Equalize", 1.0, None)),
                    (("Color", 0.4, 0), ("Equalize", 0.6, None)),
                    (("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
                    (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
                    (("Invert", 0.6, None), ("Equalize", 1.0, None)),
                    (("Color", 0.6, 4), ("Contrast", 1.0, 8)),
                    (("Equalize", 0.8, None), ("Equalize", 0.6, None)),
                ]
            elif policy == AutoAugmentPolicy.CIFAR10:
                return [
                    (("Invert", 0.1, None), ("Contrast", 0.2, 6)),
                    (("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)),
                    (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
                    (("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)),
                    (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)),
                    (("ShearY", 0.2, 7), ("Posterize", 0.3, 7)),
                    (("Color", 0.4, 3), ("Brightness", 0.6, 7)),
                    (("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)),
                    (("Equalize", 0.6, None), ("Equalize", 0.5, None)),
                    (("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)),
                    (("Color", 0.7, 7), ("TranslateX", 0.5, 8)),
                    (("Equalize", 0.3, None), ("AutoContrast", 0.4, None)),
                    (("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)),
                    (("Brightness", 0.9, 6), ("Color", 0.2, 8)),
                    (("Solarize", 0.5, 2), ("Invert", 0.0, None)),
                    (("Equalize", 0.2, None), ("AutoContrast", 0.6, None)),
                    (("Equalize", 0.2, None), ("Equalize", 0.6, None)),
                    (("Color", 0.9, 9), ("Equalize", 0.6, None)),
                    (("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)),
                    (("Brightness", 0.1, 3), ("Color", 0.7, 0)),
                    (("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)),
                    (("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)),
                    (("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)),
                    (("Equalize", 0.8, None), ("Invert", 0.1, None)),
                    (("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)),
                ]
            elif policy == AutoAugmentPolicy.SVHN:
                return [
                    (("ShearX", 0.9, 4), ("Invert", 0.2, None)),
                    (("ShearY", 0.9, 8), ("Invert", 0.7, None)),
                    (("Equalize", 0.6, None), ("Solarize", 0.6, 6)),
                    (("Invert", 0.9, None), ("Equalize", 0.6, None)),
                    (("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
                    (("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)),
                    (("ShearY", 0.9, 8), ("Invert", 0.4, None)),
                    (("ShearY", 0.9, 5), ("Solarize", 0.2, 6)),
                    (("Invert", 0.9, None), ("AutoContrast", 0.8, None)),
                    (("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
                    (("ShearX", 0.9, 4), ("Solarize", 0.3, 3)),
                    (("ShearY", 0.8, 8), ("Invert", 0.7, None)),
                    (("Equalize", 0.9, None), ("TranslateY", 0.6, 6)),
                    (("Invert", 0.9, None), ("Equalize", 0.6, None)),
                    (("Contrast", 0.3, 3), ("Rotate", 0.8, 4)),
                    (("Invert", 0.8, None), ("TranslateY", 0.0, 2)),
                    (("ShearY", 0.7, 6), ("Solarize", 0.4, 8)),
                    (("Invert", 0.6, None), ("Rotate", 0.8, 4)),
                    (("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)),
                    (("ShearX", 0.1, 6), ("Invert", 0.6, None)),
                    (("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)),
                    (("ShearY", 0.8, 4), ("Invert", 0.8, None)),
                    (("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)),
                    (("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)),
                    (("ShearX", 0.7, 2), ("Invert", 0.1, None)),
                ]
            else:
                raise ValueError("The provided policy {} is not recognized.".format(policy))
    
        def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]:
            return {
                # op_name: (magnitudes, signed)
                "ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
                "ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
                "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
                "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
                "Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
                "Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
                "Color": (torch.linspace(0.0, 0.9, num_bins), True),
                "Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
                "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
                "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
                "Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
                "AutoContrast": (torch.tensor(0.0), False),
                "Equalize": (torch.tensor(0.0), False),
                "Invert": (torch.tensor(0.0), False),
            }
    
        @staticmethod
        def get_params(transform_num: int) -> Tuple[int, Tensor, Tensor]:
            """Get parameters for autoaugment transformation
    
            Returns:
                params required by the autoaugment transformation
            """
            policy_id = int(torch.randint(transform_num, (1,)).item())
            probs = torch.rand((2,))
            signs = torch.randint(2, (2,))
    
            return policy_id, probs, signs
    
        def forward(self, img: Tensor) -> Tensor:
            """
                img (PIL Image or Tensor): Image to be transformed.
    
            Returns:
                PIL Image or Tensor: AutoAugmented image.
            """
            fill = self.fill
            if isinstance(img, Tensor):
                if isinstance(fill, (int, float)):
                    fill = [float(fill)] * F.get_image_num_channels(img)
                elif fill is not None:
                    fill = [float(f) for f in fill]
    
    
            transform_id, probs, signs = self.get_params(len(self.subpolicies))
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
    
    
            for i, (op_name, p, magnitude_id) in enumerate(self.subpolicies[transform_id]):
    
    Sun Jin Kim's avatar
    Sun Jin Kim committed
                if probs[i] <= p:
                    op_meta = self._augmentation_space(10, F.get_image_size(img))
                    magnitudes, signed = op_meta[op_name]
                    magnitude = float(magnitudes[magnitude_id].item()) if magnitude_id is not None else 0.0
                    if signed and signs[i] == 0:
                        magnitude *= -1.0
                    img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
    
            return img
    
        def __repr__(self) -> str:
            return self.__class__.__name__ + '(policy={}, fill={})'.format(self.policy, self.fill)
    
    
    class RandAugment(torch.nn.Module):
        r"""RandAugment data augmentation method based on
        `"RandAugment: Practical automated data augmentation with a reduced search space"
        <https://arxiv.org/abs/1909.13719>`_.
        If the image is torch Tensor, it should be of type torch.uint8, and it is expected
        to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
        If img is PIL Image, it is expected to be in mode "L" or "RGB".
    
        Args:
            num_ops (int): Number of augmentation transformations to apply sequentially.
            magnitude (int): Magnitude for all the transformations.
            num_magnitude_bins (int): The number of different magnitude values.
            interpolation (InterpolationMode): Desired interpolation enum defined by
                :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
                If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
            fill (sequence or number, optional): Pixel fill value for the area outside the transformed
                image. If given a number, the value is used for all bands respectively.
            """
    
        def __init__(self, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: int = 31,
                     interpolation: InterpolationMode = InterpolationMode.NEAREST,
                     fill: Optional[List[float]] = None) -> None:
            super().__init__()
            self.num_ops = num_ops
            self.magnitude = magnitude
            self.num_magnitude_bins = num_magnitude_bins
            self.interpolation = interpolation
            self.fill = fill
    
        def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]:
            return {
                # op_name: (magnitudes, signed)
                "Identity": (torch.tensor(0.0), False),
                "ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
                "ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
                "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
                "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
                "Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
                "Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
                "Color": (torch.linspace(0.0, 0.9, num_bins), True),
                "Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
                "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
                "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
                "Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
                "AutoContrast": (torch.tensor(0.0), False),
                "Equalize": (torch.tensor(0.0), False),
            }
    
        def forward(self, img: Tensor) -> Tensor:
            """
                img (PIL Image or Tensor): Image to be transformed.
    
            Returns:
                PIL Image or Tensor: Transformed image.
            """
            fill = self.fill
            if isinstance(img, Tensor):
                if isinstance(fill, (int, float)):
                    fill = [float(fill)] * F.get_image_num_channels(img)
                elif fill is not None:
                    fill = [float(f) for f in fill]
    
            for _ in range(self.num_ops):
                op_meta = self._augmentation_space(self.num_magnitude_bins, F.get_image_size(img))
                op_index = int(torch.randint(len(op_meta), (1,)).item())
                op_name = list(op_meta.keys())[op_index]
                magnitudes, signed = op_meta[op_name]
                magnitude = float(magnitudes[self.magnitude].item()) if magnitudes.ndim > 0 else 0.0
                if signed and torch.randint(2, (1,)):
                    magnitude *= -1.0
                img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
    
            return img
    
        def __repr__(self) -> str:
            s = self.__class__.__name__ + '('
            s += 'num_ops={num_ops}'
            s += ', magnitude={magnitude}'
            s += ', num_magnitude_bins={num_magnitude_bins}'
            s += ', interpolation={interpolation}'
            s += ', fill={fill}'
            s += ')'
            return s.format(**self.__dict__)
    
    
    class TrivialAugmentWide(torch.nn.Module):
        r"""Dataset-independent data-augmentation with TrivialAugment Wide, as described in
        `"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" <https://arxiv.org/abs/2103.10158>`.
        If the image is torch Tensor, it should be of type torch.uint8, and it is expected
        to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
        If img is PIL Image, it is expected to be in mode "L" or "RGB".
    
        Args:
            num_magnitude_bins (int): The number of different magnitude values.
            interpolation (InterpolationMode): Desired interpolation enum defined by
                :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
                If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
            fill (sequence or number, optional): Pixel fill value for the area outside the transformed
                image. If given a number, the value is used for all bands respectively.
            """
    
        def __init__(self, num_magnitude_bins: int = 31, interpolation: InterpolationMode = InterpolationMode.NEAREST,
                     fill: Optional[List[float]] = None) -> None:
            super().__init__()
            self.num_magnitude_bins = num_magnitude_bins
            self.interpolation = interpolation
            self.fill = fill
    
        def _augmentation_space(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]:
            return {
                # op_name: (magnitudes, signed)
                "Identity": (torch.tensor(0.0), False),
                "ShearX": (torch.linspace(0.0, 0.99, num_bins), True),
                "ShearY": (torch.linspace(0.0, 0.99, num_bins), True),
                "TranslateX": (torch.linspace(0.0, 32.0, num_bins), True),
                "TranslateY": (torch.linspace(0.0, 32.0, num_bins), True),
                "Rotate": (torch.linspace(0.0, 135.0, num_bins), True),
                "Brightness": (torch.linspace(0.0, 0.99, num_bins), True),
                "Color": (torch.linspace(0.0, 0.99, num_bins), True),
                "Contrast": (torch.linspace(0.0, 0.99, num_bins), True),
                "Sharpness": (torch.linspace(0.0, 0.99, num_bins), True),
                "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)).round().int(), False),
                "Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
                "AutoContrast": (torch.tensor(0.0), False),
                "Equalize": (torch.tensor(0.0), False),
            }
    
        def forward(self, img: Tensor) -> Tensor:
            """
                img (PIL Image or Tensor): Image to be transformed.
    
            Returns:
                PIL Image or Tensor: Transformed image.
            """
            fill = self.fill
            if isinstance(img, Tensor):
                if isinstance(fill, (int, float)):
                    fill = [float(fill)] * F.get_image_num_channels(img)
                elif fill is not None:
                    fill = [float(f) for f in fill]
    
            op_meta = self._augmentation_space(self.num_magnitude_bins)
            op_index = int(torch.randint(len(op_meta), (1,)).item())
            op_name = list(op_meta.keys())[op_index]
            magnitudes, signed = op_meta[op_name]
            magnitude = float(magnitudes[torch.randint(len(magnitudes), (1,), dtype=torch.long)].item()) \
                if magnitudes.ndim > 0 else 0.0
            if signed and torch.randint(2, (1,)):
                magnitude *= -1.0
    
            return _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
    
        def __repr__(self) -> str:
            s = self.__class__.__name__ + '('
            s += 'num_magnitude_bins={num_magnitude_bins}'
            s += ', interpolation={interpolation}'
            s += ', fill={fill}'
            s += ')'
            return s.format(**self.__dict__)
    
        import matplotlib.pyplot as plt
    
        from MetaAugment.main import *
        import MetaAugment.child_networks as cn
        import torchvision.transforms as transforms
    
    
        # If you get rid of this nextimport, the whole thing doesn't work... By the way this import also 
        # exists on the top of this document.
        # I think this is because "import torchvision.transforms as transforms" overrides the import at 
        # the top of this file and does some funny stuff... Anyways we need to call this import again to get
        # rid of the bug.
    
        from torchvision.transforms import functional as F, InterpolationMode
    
                (("Invert", 0.8, None), ("Contrast", 0.2, 6)),
                (("Rotate", 0.7, 2), ("Invert", 0.8, None)),
                (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
    
                (("AutoContrast", 0.5, None), ("Equalize", 0.9, None))
                ]
    
    
        # The one that i hand crafted. You'll see that this one usually reaches a much
        # higher poerformance
    
                (("ShearY", 0.8, 4), ("Rotate", 0.5, 6)),
                (("TranslateY", 0.7, 4), ("TranslateX", 0.8, 6)),
                (("Rotate", 0.5, 3), ("ShearY", 0.8, 5)),
                (("ShearX", 0.5, 6), ("TranslateY", 0.7, 3)),
                (("Rotate", 0.5, 3), ("TranslateX", 0.5, 5))
    
    
        train_dataset = datasets.MNIST(root='./datasets/mnist/train', train=True, download=False, 
                                    transform=None)
        test_dataset = datasets.MNIST(root='./datasets/mnist/test', train=False, download=False,
                                    transform=torchvision.transforms.ToTensor())
    
    
    
        def test_autoaugment_policy(subpolicies, train_dataset, test_dataset):
    
            aa_transform = AutoAugment()
    
            aa_transform.subpolicies = subpolicies
    
            train_transform = transforms.Compose([
    
            train_dataset.transform = train_transform
    
    
            # create toy dataset from above uploaded data
    
            train_loader, test_loader = create_toy(train_dataset, test_dataset, batch_size=32, n_samples=0.1)
    
    
            child_network = cn.lenet()
            sgd = optim.SGD(child_network.parameters(), lr=1e-1)
    
            cost = nn.CrossEntropyLoss()
    
            best_acc, acc_log = train_child_network(child_network, train_loader, test_loader,
                                                        sgd, cost, max_epochs=100, logging=True)
    
            return best_acc, acc_log
    
    
    
        _, acc_log1 = test_autoaugment_policy(subpolicies1, train_dataset, test_dataset)
        _, acc_log2 = test_autoaugment_policy(subpolicies2, train_dataset, test_dataset)
    
    
        plt.plot(acc_log1, label='subpolicies1')
        plt.plot(acc_log2, label='subpolicies2')
        plt.xlabel('epochs')
        plt.ylabel('accuracy')
        plt.legend()
        plt.show()