Allow arbitrary objects in state_dicts (#62976)
authorJoel Schlosser <jbschlosser@fb.com>
Wed, 25 Aug 2021 02:00:33 +0000 (19:00 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 25 Aug 2021 02:06:14 +0000 (19:06 -0700)
commit544af391b5649c8c407fa36b36631a2307997a09
tree66365d64d459649ced0930fec49c6b3a260bc38c
parent58ef99bd5aaf94c2cf5744b938ba4774773eb98d
Allow arbitrary objects in state_dicts (#62976)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/62094

Introduces functionality for adding arbitrary objects to module state_dicts. To take advantage of this, the following functions can be defined on a module:
* `get_extra_state(self) -> dict` - Returns a dict defining any extra state this module wants to save
* `set_extra_state(self, state)` - Subsumes the given state within the module

In the details, a sub-dictionary is stored in the state_dict under the key `_extra_state` for each module that requires extra state.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/62976

Reviewed By: heitorschueroff

Differential Revision: D30518657

Pulled By: jbschlosser

fbshipit-source-id: 5fb35ab8e3d36f35e3e96dcd4498f8c917d1f386
test/test_nn.py
torch/jit/_script.py
torch/nn/modules/module.py