Improve hub documentation (#14862)
authorAiling Zhang <ailzhang@fb.com>
Fri, 7 Dec 2018 22:56:56 +0000 (14:56 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 7 Dec 2018 22:59:01 +0000 (14:59 -0800)
Summary:
Added a few examples and explains to how publish/load models.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14862

Differential Revision: D13384790

Pulled By: ailzhang

fbshipit-source-id: 008166e84e59dcb62c0be38a87982579524fb20e

docs/source/hub.rst
torch/hub.py

index 2966d0d..fd252b0 100644 (file)
@@ -1,6 +1,101 @@
 torch.hub
 ===================================
+Pytorch Hub is a pre-trained model repository designed to facilitate research reproducibility.
+
+Publishing models
+-----------------
+
+Pytorch Hub supports publishing pre-trained models(model definitions and pre-trained weights)
+to a github repository by adding a simple ``hubconf.py`` file;
+
+``hubconf.py`` can have multiple entrypoints. Each entrypoint is defined as a python function with
+the following signature.
+
+::
+
+    def entrypoint_name(pretrained=False, *args, **kwargs):
+        ...
+
+How to implement an entrypoint?
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+Here is a code snipet from pytorch/vision repository, which specifies an entrypoint
+for ``resnet18`` model. You can see a full script in
+`pytorch/vision repo <https://github.com/pytorch/vision/blob/master/hubconf.py>`_
+
+::
+
+    dependencies = ['torch', 'math']
+
+    def resnet18(pretrained=False, *args, **kwargs):
+        """
+        Resnet18 model
+        pretrained (bool): a recommended kwargs for all entrypoints
+        args & kwargs are arguments for the function
+        """
+        ######## Call the model in the repo ###############
+        from torchvision.models.resnet import resnet18 as _resnet18
+        model = _resnet18(*args, **kwargs)
+        ######## End of call ##############################
+        # The following logic is REQUIRED
+        if pretrained:
+            # For weights saved in local repo
+                       # model.load_state_dict(<path_to_saved_file>)
+
+                       # For weights saved elsewhere
+                       checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
+            model.load_state_dict(model_zoo.load_url(checkpoint, progress=False))
+        return model
+
+- ``dependencies`` variable is a **list** of package names required to to run the model.
+- Pretrained weights can either be stored local in the github repo, or loadable by
+  ``model_zoo.load()``.
+- ``pretrained`` controls whether to load the pre-trained weights provided by repo owners.
+- ``args`` and ``kwargs`` are passed along to the real callable function.
+- Docstring of the function works as a help message, explaining what does the model do and what
+  are the allowed arguments.
+- Entrypoint function should **ALWAYS** return a model(nn.module).
+
+Important Notice
+^^^^^^^^^^^^^^^^
+
+- The published models should be at least in a branch/tag. It can't be a random commit.
+
+Loading models from Hub
+-----------------------
+
+Users can load the pre-trained models using ``torch.hub.load()`` API.
+
 
 .. automodule:: torch.hub
 .. autofunction:: load
+
+Here's an example loading ``resnet18`` entrypoint from ``pytorch/vision`` repo.
+
+::
+
+    hub_model = hub.load(
+        'pytorch/vision:master', # repo_owner/repo_name:branch
+        'resnet18', # entrypoint
+        1234, # args for callable [not applicable to resnet]
+        pretrained=True) # kwargs for callable
+
+Where are my downloaded model & weights saved?
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+The locations are used in the order of
+
+- hub_dir: user specified path. It can be set in the following ways:
+  - Setting the environment variable ``TORCH_HUB_DIR``
+  - Calling ``hub.set_dir(<PATH_TO_HUB_DIR>)``
+- ``~/.torch/hub``
+
 .. autofunction:: set_dir
+
+Caching logic
+^^^^^^^^^^^^^
+
+By default, we don't clean up files after loading it. Hub uses the cache by default if it already exists in ``hub_dir``.
+
+Users can force a reload by calling ``hub.load(..., force_reload=True)``. This will delete
+the existing github folder and downloaded weights, reinitialize a fresh download. This is useful
+when updates are published to the same branch, users can keep up with the latest release.
index 2ddfd1c..10e3db3 100644 (file)
@@ -73,9 +73,13 @@ def _load_attr_from_module(module_name, func_name):
 
 def set_dir(d):
     r"""
-    Optionally set hub_dir to a local dir to save the intermediate model & checkpoint files.
-        If this argument is not set, env variable `TORCH_HUB_DIR` will be searched first,
-        `~/.torch/hub` will be created and used as fallback.
+    Optionally set hub_dir to a local dir to save downloaded models & weights.
+
+    If this argument is not set, env variable `TORCH_HUB_DIR` will be searched first,
+    `~/.torch/hub` will be created and used as fallback.
+
+    Args:
+        d: path to a local folder to save downloaded models & weights.
     """
     global hub_dir
     hub_dir = d
@@ -89,7 +93,7 @@ def load(github, model, force_reload=False, *args, **kwargs):
         github: Required, a string with format "repo_owner/repo_name[:tag_name]" with an optional
             tag/branch. The default branch is `master` if not specified.
             Example: 'pytorch/vision[:hub]'
-        model: Required, a string of callable name defined in repo's hubconf.py
+        model: Required, a string of entrypoint name defined in repo's hubconf.py
         force_reload: Optional, whether to discard the existing cache and force a fresh download.
             Default is `False`.
         *args: Optional, the corresponding args for callable `model`.