Torchhub: More robust assumption regarding main or master branch (#64364)
authorNicolas Hug <nicolashug@fb.com>
Mon, 20 Sep 2021 17:27:12 +0000 (10:27 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 20 Sep 2021 17:36:13 +0000 (10:36 -0700)
Summary:
Closes https://github.com/pytorch/pytorch/issues/63753

This PR changes the assumption regarding the default branch of a repo to the following:

> If main exist then use main,otherwise use master

This will make torchhub more robust w.r.t. to the ongoing changes where repo use `main` instead of `master` as the development / default branch.

cc nairbv NicolasHug

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64364

Reviewed By: saketh-are

Differential Revision: D30731551

Pulled By: NicolasHug

fbshipit-source-id: 7232a30e956dcccca21933a29de5eddd711aa99b

test/test_utils.py
torch/hub.py

index 34f4406..b1b7cc0 100644 (file)
@@ -689,6 +689,24 @@ class TestHub(TestCase):
             self.assertEqual(torch.hub.get_dir(), dirname)
 
     @retry(URLError, tries=3)
+    def test_hub_parse_repo_info(self):
+        # If the branch is specified we just parse the input and return
+        self.assertEqual(
+            torch.hub._parse_repo_info('a/b:c'),
+            ('a', 'b', 'c')
+        )
+        # For torchvision, the default branch is main
+        self.assertEqual(
+            torch.hub._parse_repo_info('pytorch/vision'),
+            ('pytorch', 'vision', 'main')
+        )
+        # For the torchhub_example repo, the default branch is still master
+        self.assertEqual(
+            torch.hub._parse_repo_info('ailzhang/torchhub_example'),
+            ('ailzhang', 'torchhub_example', 'master')
+        )
+
+    @retry(URLError, tries=3)
     def test_load_state_dict_from_url_with_name(self):
         with tempfile.TemporaryDirectory('hub_dir') as dirname:
             torch.hub.set_dir(dirname)
index 82287d8..bfc24dc 100644 (file)
@@ -10,6 +10,7 @@ import torch
 import warnings
 import zipfile
 
+from urllib.error import HTTPError
 from urllib.request import urlopen, Request
 from urllib.parse import urlparse  # noqa: F401
 
@@ -55,7 +56,6 @@ except ImportError:
 # matches bfd8deac from resnet18-bfd8deac.pth
 HASH_REGEX = re.compile(r'-([a-f0-9]*)\.')
 
-MASTER_BRANCH = 'master'
 ENV_GITHUB_TOKEN = 'GITHUB_TOKEN'
 ENV_TORCH_HOME = 'TORCH_HOME'
 ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
@@ -105,12 +105,24 @@ def _get_torch_home():
 
 
 def _parse_repo_info(github):
-    branch = MASTER_BRANCH
     if ':' in github:
         repo_info, branch = github.split(':')
     else:
-        repo_info = github
+        repo_info, branch = github, None
     repo_owner, repo_name = repo_info.split('/')
+
+    if branch is None:
+        # The branch wasn't specified by the user, so we need to figure out the
+        # default branch: main or master. Our assumption is that if main exists
+        # then it's the default branch, otherwise it's master.
+        try:
+            with urlopen(f"https://github.com/{repo_owner}/{repo_name}/tree/main/"):
+                branch = 'main'
+        except HTTPError as e:
+            if e.code == 404:
+                branch = 'master'
+            else:
+                raise
     return repo_owner, repo_name, branch
 
 
@@ -261,8 +273,9 @@ def list(github, force_reload=False, skip_validation=False):
 
     Args:
         github (string): a string with format "repo_owner/repo_name[:tag_name]" with an optional
-            tag/branch. The default branch is ``master`` if not specified.
-            Example: 'pytorch/vision[:hub]'
+            tag/branch. If ``tag_name`` is not specified, the default branch is assumed to be ``main`` if
+            it exists, and otherwise ``master``.
+            Example: 'pytorch/vision:0.10'
         force_reload (bool, optional): whether to discard the existing cache and force a fresh download.
             Default is ``False``.
         skip_validation (bool, optional): if ``False``, torchhub will check that the branch or commit
@@ -296,8 +309,9 @@ def help(github, model, force_reload=False, skip_validation=False):
 
     Args:
         github (string): a string with format <repo_owner/repo_name[:tag_name]> with an optional
-            tag/branch. The default branch is ``master`` if not specified.
-            Example: 'pytorch/vision[:hub]'
+            tag/branch. If ``tag_name`` is not specified, the default branch is assumed to be ``main`` if
+            it exists, and otherwise ``master``.
+            Example: 'pytorch/vision:0.10'
         model (string): a string of entrypoint name defined in repo's ``hubconf.py``
         force_reload (bool, optional): whether to discard the existing cache and force a fresh download.
             Default is ``False``.
@@ -332,15 +346,17 @@ def load(repo_or_dir, model, *args, source='github', force_reload=False, verbose
 
     If ``source`` is 'github', ``repo_or_dir`` is expected to be
     of the form ``repo_owner/repo_name[:tag_name]`` with an optional
-    tag/branch. The default branch is ``master`` if not specified.
+    tag/branch.
 
     If ``source`` is 'local', ``repo_or_dir`` is expected to be a
     path to a local directory.
 
     Args:
-        repo_or_dir (string): repo name (``repo_owner/repo_name[:tag_name]``),
-            if ``source = 'github'``; or a path to a local directory, if
-            ``source = 'local'``.
+        repo_or_dir (string): If ``source`` is 'github',
+            this should correspond to a github repo with format ``repo_owner/repo_name[:tag_name]`` with
+            an optional tag/branch, for example 'pytorch/vision:0.10'. If ``tag_name`` is not specified,
+            the default branch is assumed to be ``main`` if it exists, and otherwise ``master``.
+            If ``source`` is 'local'  then it should be a path to a local directory.
         model (string): the name of a callable (entrypoint) defined in the
             repo/dir's ``hubconf.py``.
         *args (optional): the corresponding args for callable ``model``.