From 5734e9677564743fc4000cfb955fb42046689be9 Mon Sep 17 00:00:00 2001 From: Ailing Zhang Date: Fri, 7 Dec 2018 14:56:56 -0800 Subject: [PATCH] Improve hub documentation (#14862) Summary: Added a few examples and explains to how publish/load models. Pull Request resolved: https://github.com/pytorch/pytorch/pull/14862 Differential Revision: D13384790 Pulled By: ailzhang fbshipit-source-id: 008166e84e59dcb62c0be38a87982579524fb20e --- docs/source/hub.rst | 95 +++++++++++++++++++++++++++++++++++++++++++++++++++++ torch/hub.py | 12 ++++--- 2 files changed, 103 insertions(+), 4 deletions(-) diff --git a/docs/source/hub.rst b/docs/source/hub.rst index 2966d0d..fd252b0 100644 --- a/docs/source/hub.rst +++ b/docs/source/hub.rst @@ -1,6 +1,101 @@ torch.hub =================================== +Pytorch Hub is a pre-trained model repository designed to facilitate research reproducibility. + +Publishing models +----------------- + +Pytorch Hub supports publishing pre-trained models(model definitions and pre-trained weights) +to a github repository by adding a simple ``hubconf.py`` file; + +``hubconf.py`` can have multiple entrypoints. Each entrypoint is defined as a python function with +the following signature. + +:: + + def entrypoint_name(pretrained=False, *args, **kwargs): + ... + +How to implement an entrypoint? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Here is a code snipet from pytorch/vision repository, which specifies an entrypoint +for ``resnet18`` model. You can see a full script in +`pytorch/vision repo `_ + +:: + + dependencies = ['torch', 'math'] + + def resnet18(pretrained=False, *args, **kwargs): + """ + Resnet18 model + pretrained (bool): a recommended kwargs for all entrypoints + args & kwargs are arguments for the function + """ + ######## Call the model in the repo ############### + from torchvision.models.resnet import resnet18 as _resnet18 + model = _resnet18(*args, **kwargs) + ######## End of call ############################## + # The following logic is REQUIRED + if pretrained: + # For weights saved in local repo + # model.load_state_dict() + + # For weights saved elsewhere + checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' + model.load_state_dict(model_zoo.load_url(checkpoint, progress=False)) + return model + +- ``dependencies`` variable is a **list** of package names required to to run the model. +- Pretrained weights can either be stored local in the github repo, or loadable by + ``model_zoo.load()``. +- ``pretrained`` controls whether to load the pre-trained weights provided by repo owners. +- ``args`` and ``kwargs`` are passed along to the real callable function. +- Docstring of the function works as a help message, explaining what does the model do and what + are the allowed arguments. +- Entrypoint function should **ALWAYS** return a model(nn.module). + +Important Notice +^^^^^^^^^^^^^^^^ + +- The published models should be at least in a branch/tag. It can't be a random commit. + +Loading models from Hub +----------------------- + +Users can load the pre-trained models using ``torch.hub.load()`` API. + .. automodule:: torch.hub .. autofunction:: load + +Here's an example loading ``resnet18`` entrypoint from ``pytorch/vision`` repo. + +:: + + hub_model = hub.load( + 'pytorch/vision:master', # repo_owner/repo_name:branch + 'resnet18', # entrypoint + 1234, # args for callable [not applicable to resnet] + pretrained=True) # kwargs for callable + +Where are my downloaded model & weights saved? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The locations are used in the order of + +- hub_dir: user specified path. It can be set in the following ways: + - Setting the environment variable ``TORCH_HUB_DIR`` + - Calling ``hub.set_dir()`` +- ``~/.torch/hub`` + .. autofunction:: set_dir + +Caching logic +^^^^^^^^^^^^^ + +By default, we don't clean up files after loading it. Hub uses the cache by default if it already exists in ``hub_dir``. + +Users can force a reload by calling ``hub.load(..., force_reload=True)``. This will delete +the existing github folder and downloaded weights, reinitialize a fresh download. This is useful +when updates are published to the same branch, users can keep up with the latest release. diff --git a/torch/hub.py b/torch/hub.py index 2ddfd1c..10e3db3 100644 --- a/torch/hub.py +++ b/torch/hub.py @@ -73,9 +73,13 @@ def _load_attr_from_module(module_name, func_name): def set_dir(d): r""" - Optionally set hub_dir to a local dir to save the intermediate model & checkpoint files. - If this argument is not set, env variable `TORCH_HUB_DIR` will be searched first, - `~/.torch/hub` will be created and used as fallback. + Optionally set hub_dir to a local dir to save downloaded models & weights. + + If this argument is not set, env variable `TORCH_HUB_DIR` will be searched first, + `~/.torch/hub` will be created and used as fallback. + + Args: + d: path to a local folder to save downloaded models & weights. """ global hub_dir hub_dir = d @@ -89,7 +93,7 @@ def load(github, model, force_reload=False, *args, **kwargs): github: Required, a string with format "repo_owner/repo_name[:tag_name]" with an optional tag/branch. The default branch is `master` if not specified. Example: 'pytorch/vision[:hub]' - model: Required, a string of callable name defined in repo's hubconf.py + model: Required, a string of entrypoint name defined in repo's hubconf.py force_reload: Optional, whether to discard the existing cache and force a fresh download. Default is `False`. *args: Optional, the corresponding args for callable `model`. -- 2.7.4