From: Yi Wang Date: Tue, 17 Aug 2021 18:28:43 +0000 (-0700) Subject: Add return type hint and improve the docstring of consume_prefix_in_state_dict_if_pre... X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~956 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=e000dfcf976454fdadfdc556248976e6e560d155;p=platform%2Fupstream%2Fpytorch.git Add return type hint and improve the docstring of consume_prefix_in_state_dict_if_present method (#63388) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63388 Context: https://discuss.pytorch.org/t/how-to-use-the-helper-function-consume-prefix-in-state-dict-if-present/129505/3 Make it clear that this method strips the prefix in place rather than returns a new value. Additional reformatting is also applied. ghstack-source-id: 135973393 Test Plan: waitforbuildbot Reviewed By: rohan-varma Differential Revision: D30360931 fbshipit-source-id: 1a0c7967a4c86f729e3c810686c21dec43d1dd7a --- diff --git a/torch/nn/modules/utils.py b/torch/nn/modules/utils.py index 01e92aa..b164d78 100644 --- a/torch/nn/modules/utils.py +++ b/torch/nn/modules/utils.py @@ -1,4 +1,3 @@ - import collections from itertools import repeat from typing import List, Dict, Any @@ -9,8 +8,10 @@ def _ntuple(n): if isinstance(x, collections.abc.Iterable): return tuple(x) return tuple(repeat(x, n)) + return parse + _single = _ntuple(1) _pair = _ntuple(2) _triple = _ntuple(3) @@ -30,12 +31,18 @@ def _list_with_default(out_size: List[int], defaults: List[int]) -> List[int]: if isinstance(out_size, int): return out_size if len(defaults) <= len(out_size): - raise ValueError('Input dimension should be at least {}'.format(len(out_size) + 1)) - return [v if v is not None else d for v, d in zip(out_size, defaults[-len(out_size):])] + raise ValueError( + "Input dimension should be at least {}".format(len(out_size) + 1) + ) + return [ + v if v is not None else d for v, d in zip(out_size, defaults[-len(out_size) :]) + ] -def consume_prefix_in_state_dict_if_present(state_dict: Dict[str, Any], prefix: str): - r"""Strip the prefix in state_dict, if any. +def consume_prefix_in_state_dict_if_present( + state_dict: Dict[str, Any], prefix: str +) -> None: + r"""Strip the prefix in state_dict in place, if any. ..note:: Given a `state_dict` from a DP/DDP model, a local model can load it by applying