Misleading documentation for module._load_from_state_dict (#17618)
authorKai Zhang <kaizh@fb.com>
Tue, 12 Mar 2019 23:52:38 +0000 (16:52 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 12 Mar 2019 23:57:39 +0000 (16:57 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17618

Base on the code, we only add key to `missing_keys` and `unexpected_keys` if `$strict` is `True`. The documentation is confusing.

This diff also fix one FLAKE8 warning.

Reviewed By: ailzhang

Differential Revision: D14280593

fbshipit-source-id: d368f5596bdf74ff62ee4d28d79120f5af91e0a3

torch/nn/modules/module.py

index 1da58ce..afefd5f 100644 (file)
@@ -671,9 +671,9 @@ class Module(object):
             strict (bool): whether to strictly enforce that the keys in
                 :attr:`state_dict` with :attr:`prefix` match the names of
                 parameters and buffers in this module
-            missing_keys (list of str): if ``strict=False``, add missing keys to
+            missing_keys (list of str): if ``strict=True``, add missing keys to
                 this list
-            unexpected_keys (list of str): if ``strict=False``, add unexpected
+            unexpected_keys (list of str): if ``strict=True``, add unexpected
                 keys to this list
             error_msgs (list of str): error messages should be added to this
                 list, and will be reported together in
@@ -715,7 +715,7 @@ class Module(object):
                 missing_keys.append(key)
 
         if strict:
-            for key, input_param in state_dict.items():
+            for key in state_dict.keys():
                 if key.startswith(prefix):
                     input_name = key[len(prefix):]
                     input_name = input_name.split('.', 1)[0]  # get the name of param/buffer/child