Skip to content
Snippets Groups Projects
Commit 98e34909 authored by Sun Jin Kim's avatar Sun Jin Kim
Browse files

Add documentation and tutorial for how-to for aa_learner

parent 408964a7
No related branches found
No related tags found
No related merge requests found
Showing
with 244 additions and 29 deletions
File moved
File moved
File moved
# The parent class for all other autoaugment learners
import torch
import torch.nn as nn
......@@ -34,6 +34,9 @@ augmentation_space = [
class aa_learner:
"""
The parent class for all aa_learner's
"""
def __init__(self, sp_num=5, fun_num=14, p_bins=11, m_bins=10, discrete_p_m=False):
"""
Args:
......
......@@ -31,7 +31,11 @@ augmentation_space = [
]
class randomsearch_learner(aa_learner):
def __init__(self, sp_num=5, fun_num=14, p_bins=11, m_bins=10, discrete_p_m=False):
"""
Tests randomly sampled policies from the search space specified by the AutoAugment
paper. Acts as a baseline for other aa_learner's.
"""
def __init__(self, sp_num=5, fun_num=14, p_bins=11, m_bins=10, discrete_p_m=True):
super().__init__(sp_num, fun_num, p_bins, m_bins, discrete_p_m)
......
AutoAugment learners
--------------------
aa_learner: the parent class
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autosummary::
:toctree: generated
MetaAugment.autoaugment_learners.aa_learner
gru_learner: (almost the same as) the agent used in AutoAugment paper
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autosummary::
:toctree: generated
MetaAugment.autoaugment_learners.gru_learner
randomsearch_learner: the hard baseline to beat
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autosummary::
:toctree: generated
MetaAugment.autoaugment_learners.randomsearch_learner
\ No newline at end of file
MetaAugment.autoaugment\_learners.aa\_learner
=============================================
.. currentmodule:: MetaAugment.autoaugment_learners
.. autoclass:: aa_learner
.. automethod:: __init__
.. rubric:: Methods
.. autosummary::
~aa_learner.__init__
~aa_learner.demo_plot
~aa_learner.generate_new_policy
~aa_learner.learn
~aa_learner.test_autoaugment_policy
~aa_learner.translate_operation_tensor
\ No newline at end of file
MetaAugment.autoaugment\_learners.gru\_learner
==============================================
.. currentmodule:: MetaAugment.autoaugment_learners
.. autoclass:: gru_learner
.. automethod:: __init__
.. rubric:: Methods
.. autosummary::
~gru_learner.__init__
~gru_learner.demo_plot
~gru_learner.generate_new_policy
~gru_learner.learn
~gru_learner.test_autoaugment_policy
~gru_learner.translate_operation_tensor
\ No newline at end of file
MetaAugment.autoaugment\_learners.randomsearch\_learner
=======================================================
.. currentmodule:: MetaAugment.autoaugment_learners
.. autoclass:: randomsearch_learner
.. automethod:: __init__
.. rubric:: Methods
.. autosummary::
~randomsearch_learner.__init__
~randomsearch_learner.demo_plot
~randomsearch_learner.generate_new_continuous_operation
~randomsearch_learner.generate_new_discrete_operation
~randomsearch_learner.generate_new_policy
~randomsearch_learner.learn
~randomsearch_learner.test_autoaugment_policy
~randomsearch_learner.translate_operation_tensor
\ No newline at end of file
......@@ -20,9 +20,9 @@ sys.path.insert(0, os.path.abspath('../..'))
# -- Project information -----------------------------------------------------
project = 'metarl'
copyright = '2022, metarl_team'
author = 'metarl_team'
project = 'MetaAugment'
copyright = '2022, metaaug_team'
author = 'metaaug_team'
# The full version, including alpha/beta/rc tags
release = '0.0'
......@@ -42,7 +42,7 @@ extensions = [
]
# turn on sphinx.ext.autosummary
autosummary_generate = False
autosummary_generate = True
# turn on sphinx.ext.coverage
coverage_show_missing_items = True
......@@ -61,7 +61,7 @@ exclude_patterns = []
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'alabaster'
html_theme = 'furo'
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
......
Welcome to metarl's documentation!
MetaAugment Documentation
==================================
.. toctree::
:maxdepth: 3
:caption: Contents:
usage/installation
.. autoclass:: MetaAugment.autoaugment_learners.aa_learner.aa_learner
.. autoclass:: MetaAugment.autoaugment_learners.randomsearch_learner.randomsearch_learner
.. autoclass:: MetaAugment.autoaugment_learners.gru_learner.gru_learner
.. automodule:: MetaAugment.controller_networks
:members:
.. automodule:: MetaAugment.child_networks
:members:
.. toctree::
:maxdepth: 3
:caption: Table of Contents Tree:
usage/tutorial_for_team
MetaAugment_library/autoaugment_learners/aa_learners
..
I've commented this out for now
Indices and tables
==================
Indices and tables
==================
* :ref:`genindex`
* :ref:`modindex`
* :ref:`search`
* :ref:`genindex`
* :ref:`modindex`
* :ref:`search`
explain how to install MetaAugment here
\ No newline at end of file
How-to guides
---------------------------------
This is a page dedicated to demonstrating functionalities of :class:`aa_learner`.
It is a how-to guide. (Using the terminology of https://documentation.divio.com/structure/)
###################################################
Using an AutoAutgment learner to find a good policy
###################################################
This section can also be read as a ``.py`` file in ``./tutorials/how_use_aalearner.py``.
.. code-block::
import MetaAugment.autoaugment_learners as aal
import MetaAugment.child_networks as cn
import torchvision.datasets as datasets
import torchvision
Defining the problem setting:
.. code-block::
train_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/train',
train=True, download=True, transform=None)
test_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/test',
train=False, download=True, transform=torchvision.transforms.ToTensor())
child_network = cn.lenet
.. note::
It is important not to type
.. code-block::
child_network = cn.lenet()
We need the ``child_network`` variable to be a ``type``, not a ``nn.Module``
because the ``child_network`` will be called multiple times to initialize a
``nn.Module`` of its architecture multiple times: once every time we need to
train a different network to evaluate a different policy.
Using the random search learner to evaluate randomly generated policies:
.. code-block::
rs_agent = aal.randomsearch_learner()
rs_agent.learn(train_dataset, test_dataset, child_network, toy_flag=True)
Viewing the results:
``.history`` is a list containing all the policies tested and the respective
accuracies obtained when trained using them.
.. code-block::
print(rs_agent.history)
\ No newline at end of file
# You can run this in the main directory by typing:
# python -m tutorials.how_use_aalearner
import MetaAugment.autoaugment_learners as aal
import MetaAugment.child_networks as cn
import torchvision.datasets as datasets
import torchvision
# Defining our problem setting:
# In other words, specifying the dataset and the child network
train_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/train',
train=True, download=True, transform=None)
test_dataset = datasets.MNIST(root='./MetaAugment/datasets/mnist/test',
train=False, download=True, transform=torchvision.transforms.ToTensor())
child_network = cn.lenet
# NOTE: It is important not to type:
# child_network = cn.lenet()
# We need the ``child_network`` variable to be a ``type``, not a ``nn.Module``
# because the ``child_network`` will be called multiple times to initialize a
# ``nn.Module`` of its architecture multiple times: once every time we need to
# train a different network to evaluate a different policy.
# Using the random search learner to evaluate randomly generated policies
rs_agent = aal.randomsearch_learner()
rs_agent.learn(train_dataset, test_dataset, child_network, toy_flag=True)
# Viewing the results
# ``.history`` is a list containing all the policies tested and the respective
# accuracies obtained when trained using them
print(rs_agent.history)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment