Further improvements of nn.container docs
authorTongzhou Wang <ssnl@users.noreply.github.com>
Mon, 11 Mar 2019 01:26:20 +0000 (18:26 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 11 Mar 2019 01:30:39 +0000 (18:30 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17731

Differential Revision: D14401894

Pulled By: soumith

fbshipit-source-id: cebb25859f78589cc4f4f8afb1e84c97f82b6962

torch/nn/modules/container.py

index e8f870c..09c6b22 100644 (file)
@@ -101,8 +101,9 @@ class Sequential(Module):
 class ModuleList(Module):
     r"""Holds submodules in a list.
 
-    ModuleList can be indexed like a regular Python list, but modules it
-    contains are properly registered, and will be visible by all Module methods.
+    :class:`~torch.nn.ModuleList` can be indexed like a regular Python list, but
+    modules it contains are properly registered, and will be visible by all
+    :class:`~torch.nn.Module` methods.
 
     Arguments:
         modules (iterable, optional): an iterable of modules to add
@@ -219,13 +220,12 @@ class ModuleDict(Module):
       or another :class:`~torch.nn.ModuleDict` (the argument to :meth:`~torch.nn.ModuleDict.update`).
 
     Note that :meth:`~torch.nn.ModuleDict.update` with other unordered mapping
-    types (e.g., Python's plain ``dict``) doesn't not preserve order of the
+    types (e.g., Python's plain ``dict``) does not preserve the order of the
     merged mapping.
 
-
     Arguments:
         modules (iterable, optional): a mapping (dictionary) of (string: module)
-            or an iterable of key/value pairs of type (string, module)
+            or an iterable of key-value pairs of type (string, module)
 
     Example::
 
@@ -301,12 +301,16 @@ class ModuleDict(Module):
         return self._modules.values()
 
     def update(self, modules):
-        r"""Update the ModuleDict with the key/value pairs from a mapping or
-        an iterable, overwriting existing keys.
+        r"""Update the :class:`~torch.nn.ModuleDict` with the key-value pairs from a
+        mapping or an iterable, overwriting existing keys.
+
+        .. note::
+            If :attr:`modules` is an ``OrderedDict``, a :class:`~torch.nn.ModuleDict`, or
+            an iterable of key-value pairs, the order of new elements in it is preserved.
 
         Arguments:
-            modules (iterable): a mapping (dictionary) of (string: :class:`~torch.nn.Module``) or
-                an iterable of key/value pairs of type (string, :class:`~torch.nn.Module``)
+            modules (iterable): a mapping (dictionary) from string to :class:`~torch.nn.Module`,
+                or an iterable of key-value pairs of type (string, :class:`~torch.nn.Module`)
         """
         if not isinstance(modules, container_abcs.Iterable):
             raise TypeError("ModuleDict.update should be called with an "
@@ -336,8 +340,9 @@ class ModuleDict(Module):
 class ParameterList(Module):
     r"""Holds parameters in a list.
 
-    ParameterList can be indexed like a regular Python list, but parameters it
-    contains are properly registered, and will be visible by all Module methods.
+    :class:`~torch.nn.ParameterList` can be indexed like a regular Python
+    list, but parameters it contains are properly registered, and will be
+    visible by all :class:`~torch.nn.Module` methods.
 
     Arguments:
         parameters (iterable, optional): an iterable of :class:`~torch.nn.Parameter` to add
@@ -436,9 +441,21 @@ class ParameterDict(Module):
     ParameterDict can be indexed like a regular Python dictionary, but parameters it
     contains are properly registered, and will be visible by all Module methods.
 
+    :class:`~torch.nn.ParameterDict` is an **ordered** dictionary that respects
+
+    * the order of insertion, and
+
+    * in :meth:`~torch.nn.ParameterDict.update`, the order of the merged ``OrderedDict``
+      or another :class:`~torch.nn.ParameterDict` (the argument to
+      :meth:`~torch.nn.ParameterDict.update`).
+
+    Note that :meth:`~torch.nn.ParameterDict.update` with other unordered mapping
+    types (e.g., Python's plain ``dict``) does not preserve the order of the
+    merged mapping.
+
     Arguments:
         parameters (iterable, optional): a mapping (dictionary) of
-            (string : :class:`~torch.nn.Parameter`) or an iterable of key,value pairs
+            (string : :class:`~torch.nn.Parameter`) or an iterable of key-value pairs
             of type (string, :class:`~torch.nn.Parameter`)
 
     Example::
@@ -510,13 +527,17 @@ class ParameterDict(Module):
         return self._parameters.values()
 
     def update(self, parameters):
-        r"""Update the ParameterDict with the key/value pairs from a mapping or
-        an iterable, overwriting existing keys.
+        r"""Update the :class:`~torch.nn.ParameterDict` with the key-value pairs from a
+        mapping or an iterable, overwriting existing keys.
+
+        .. note::
+            If :attr:`parameters` is an ``OrderedDict``, a :class:`~torch.nn.ParameterDict`, or
+            an iterable of key-value pairs, the order of new elements in it is preserved.
 
         Arguments:
-            parameters (iterable): a mapping (dictionary) of
-                (string : :class:`~torch.nn.Parameter`) or an iterable of
-                key/value pairs of type (string, :class:`~torch.nn.Parameter`)
+            parameters (iterable): a mapping (dictionary) from string to
+                :class:`~torch.nn.Parameter`, or an iterable of
+                key-value pairs of type (string, :class:`~torch.nn.Parameter`)
         """
         if not isinstance(parameters, container_abcs.Iterable):
             raise TypeError("ParametersDict.update should be called with an "
@@ -524,7 +545,7 @@ class ParameterDict(Module):
                             type(parameters).__name__)
 
         if isinstance(parameters, container_abcs.Mapping):
-            if isinstance(parameters, OrderedDict):
+            if isinstance(parameters, (OrderedDict, ParameterDict)):
                 for key, parameter in parameters.items():
                     self[key] = parameter
             else: