self.assertEqual(torch.hub.get_dir(), dirname)
@retry(URLError, tries=3)
+ def test_hub_parse_repo_info(self):
+ # If the branch is specified we just parse the input and return
+ self.assertEqual(
+ torch.hub._parse_repo_info('a/b:c'),
+ ('a', 'b', 'c')
+ )
+ # For torchvision, the default branch is main
+ self.assertEqual(
+ torch.hub._parse_repo_info('pytorch/vision'),
+ ('pytorch', 'vision', 'main')
+ )
+ # For the torchhub_example repo, the default branch is still master
+ self.assertEqual(
+ torch.hub._parse_repo_info('ailzhang/torchhub_example'),
+ ('ailzhang', 'torchhub_example', 'master')
+ )
+
+ @retry(URLError, tries=3)
def test_load_state_dict_from_url_with_name(self):
with tempfile.TemporaryDirectory('hub_dir') as dirname:
torch.hub.set_dir(dirname)
import warnings
import zipfile
+from urllib.error import HTTPError
from urllib.request import urlopen, Request
from urllib.parse import urlparse # noqa: F401
# matches bfd8deac from resnet18-bfd8deac.pth
HASH_REGEX = re.compile(r'-([a-f0-9]*)\.')
-MASTER_BRANCH = 'master'
ENV_GITHUB_TOKEN = 'GITHUB_TOKEN'
ENV_TORCH_HOME = 'TORCH_HOME'
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
def _parse_repo_info(github):
- branch = MASTER_BRANCH
if ':' in github:
repo_info, branch = github.split(':')
else:
- repo_info = github
+ repo_info, branch = github, None
repo_owner, repo_name = repo_info.split('/')
+
+ if branch is None:
+ # The branch wasn't specified by the user, so we need to figure out the
+ # default branch: main or master. Our assumption is that if main exists
+ # then it's the default branch, otherwise it's master.
+ try:
+ with urlopen(f"https://github.com/{repo_owner}/{repo_name}/tree/main/"):
+ branch = 'main'
+ except HTTPError as e:
+ if e.code == 404:
+ branch = 'master'
+ else:
+ raise
return repo_owner, repo_name, branch
Args:
github (string): a string with format "repo_owner/repo_name[:tag_name]" with an optional
- tag/branch. The default branch is ``master`` if not specified.
- Example: 'pytorch/vision[:hub]'
+ tag/branch. If ``tag_name`` is not specified, the default branch is assumed to be ``main`` if
+ it exists, and otherwise ``master``.
+ Example: 'pytorch/vision:0.10'
force_reload (bool, optional): whether to discard the existing cache and force a fresh download.
Default is ``False``.
skip_validation (bool, optional): if ``False``, torchhub will check that the branch or commit
Args:
github (string): a string with format <repo_owner/repo_name[:tag_name]> with an optional
- tag/branch. The default branch is ``master`` if not specified.
- Example: 'pytorch/vision[:hub]'
+ tag/branch. If ``tag_name`` is not specified, the default branch is assumed to be ``main`` if
+ it exists, and otherwise ``master``.
+ Example: 'pytorch/vision:0.10'
model (string): a string of entrypoint name defined in repo's ``hubconf.py``
force_reload (bool, optional): whether to discard the existing cache and force a fresh download.
Default is ``False``.
If ``source`` is 'github', ``repo_or_dir`` is expected to be
of the form ``repo_owner/repo_name[:tag_name]`` with an optional
- tag/branch. The default branch is ``master`` if not specified.
+ tag/branch.
If ``source`` is 'local', ``repo_or_dir`` is expected to be a
path to a local directory.
Args:
- repo_or_dir (string): repo name (``repo_owner/repo_name[:tag_name]``),
- if ``source = 'github'``; or a path to a local directory, if
- ``source = 'local'``.
+ repo_or_dir (string): If ``source`` is 'github',
+ this should correspond to a github repo with format ``repo_owner/repo_name[:tag_name]`` with
+ an optional tag/branch, for example 'pytorch/vision:0.10'. If ``tag_name`` is not specified,
+ the default branch is assumed to be ``main`` if it exists, and otherwise ``master``.
+ If ``source`` is 'local' then it should be a path to a local directory.
model (string): the name of a callable (entrypoint) defined in the
repo/dir's ``hubconf.py``.
*args (optional): the corresponding args for callable ``model``.