Add return type hint and improve the docstring of consume_prefix_in_state_dict_if_pre...
authorYi Wang <wayi@fb.com>
Tue, 17 Aug 2021 18:28:43 +0000 (11:28 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 17 Aug 2021 18:30:27 +0000 (11:30 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63388

Context: https://discuss.pytorch.org/t/how-to-use-the-helper-function-consume-prefix-in-state-dict-if-present/129505/3

Make it clear that this method strips the prefix in place rather than returns a new value.

Additional reformatting is also applied.
ghstack-source-id: 135973393

Test Plan: waitforbuildbot

Reviewed By: rohan-varma

Differential Revision: D30360931

fbshipit-source-id: 1a0c7967a4c86f729e3c810686c21dec43d1dd7a

torch/nn/modules/utils.py

index 01e92aa..b164d78 100644 (file)
@@ -1,4 +1,3 @@
-
 import collections
 from itertools import repeat
 from typing import List, Dict, Any
@@ -9,8 +8,10 @@ def _ntuple(n):
         if isinstance(x, collections.abc.Iterable):
             return tuple(x)
         return tuple(repeat(x, n))
+
     return parse
 
+
 _single = _ntuple(1)
 _pair = _ntuple(2)
 _triple = _ntuple(3)
@@ -30,12 +31,18 @@ def _list_with_default(out_size: List[int], defaults: List[int]) -> List[int]:
     if isinstance(out_size, int):
         return out_size
     if len(defaults) <= len(out_size):
-        raise ValueError('Input dimension should be at least {}'.format(len(out_size) + 1))
-    return [v if v is not None else d for v, d in zip(out_size, defaults[-len(out_size):])]
+        raise ValueError(
+            "Input dimension should be at least {}".format(len(out_size) + 1)
+        )
+    return [
+        v if v is not None else d for v, d in zip(out_size, defaults[-len(out_size) :])
+    ]
 
 
-def consume_prefix_in_state_dict_if_present(state_dict: Dict[str, Any], prefix: str):
-    r"""Strip the prefix in state_dict, if any.
+def consume_prefix_in_state_dict_if_present(
+    state_dict: Dict[str, Any], prefix: str
+) -> None:
+    r"""Strip the prefix in state_dict in place, if any.
 
     ..note::
         Given a `state_dict` from a DP/DDP model, a local model can load it by applying