diff --git a/autoaug/autoaugment_learners/AaLearner.py b/autoaug/autoaugment_learners/AaLearner.py index 638b92237e463f035792c3157b08a9100d9fa02c..c700bc5ec6d71a4bdb76364f540d1d26a4726dd8 100644 --- a/autoaug/autoaugment_learners/AaLearner.py +++ b/autoaug/autoaugment_learners/AaLearner.py @@ -291,6 +291,7 @@ class AaLearner: Does the loop which is seen in Figure 1 in the AutoAugment paper which is: + 1. <generate a random policy> 2. <see how good that policy is> 3. <save how good the policy is in a list/dictionary and @@ -310,15 +311,11 @@ class AaLearner: Args: train_dataset (torchvision.dataset.vision.VisionDataset): - test_dataset (torchvision.dataset.vision.VisionDataset): - child_network_architecture (Union[function, nn.Module]): - This can be both, for example, - ``LeNet`` - - and + test_dataset (torchvision.dataset.vision.VisionDataset): - ``LeNet()`` + child_network_architecture (Union[function, nn.Module]): + This can be both, for example, ``LeNet`` or ``LeNet()`` iterations (int): how many different policies do you want to test