-
import collections
from itertools import repeat
from typing import List, Dict, Any
if isinstance(x, collections.abc.Iterable):
return tuple(x)
return tuple(repeat(x, n))
+
return parse
+
_single = _ntuple(1)
_pair = _ntuple(2)
_triple = _ntuple(3)
if isinstance(out_size, int):
return out_size
if len(defaults) <= len(out_size):
- raise ValueError('Input dimension should be at least {}'.format(len(out_size) + 1))
- return [v if v is not None else d for v, d in zip(out_size, defaults[-len(out_size):])]
+ raise ValueError(
+ "Input dimension should be at least {}".format(len(out_size) + 1)
+ )
+ return [
+ v if v is not None else d for v, d in zip(out_size, defaults[-len(out_size) :])
+ ]
-def consume_prefix_in_state_dict_if_present(state_dict: Dict[str, Any], prefix: str):
- r"""Strip the prefix in state_dict, if any.
+def consume_prefix_in_state_dict_if_present(
+ state_dict: Dict[str, Any], prefix: str
+) -> None:
+ r"""Strip the prefix in state_dict in place, if any.
..note::
Given a `state_dict` from a DP/DDP model, a local model can load it by applying