Add python mode (#63496)
authorRichard Zou <zou3519@gmail.com>
Tue, 31 Aug 2021 01:39:50 +0000 (18:39 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 31 Aug 2021 01:44:35 +0000 (18:44 -0700)
commit4bd03b02424d93b72f15e28c542ede13f88ea929
treed6379eb1b5fab0c415c49074ed76b175d464d98d
parentebc0aacf83a0446ed798a96059c05da815c73d3d
Add python mode (#63496)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63496

This PR adds a (private) enable_python_mode context manager.
(see torch/utils/_python_dispatch.py).
enable_python_mode accepts the type of a __torch_dispatch__ object
as its argument. Whenever an operator gets called inside of the
context manager, it dispatches to the __torch_dispatch__ of
the passed-in type.

Example usage:
```
with enable_python_mode(LoggingTensor):
    z = torch.empty([])
    assert isinstance(z, LoggingTensor)
```

There are quite a few changes that were made to support this.

First, we added TorchDispatchTypeObject, a C++ struct that represents the
type of a `__torch_dispatch__` object (e.g. LoggingTensor).
It holds both the PyObject* representing the class and a PyInterpreter*
so we know which Python interpreter it came from.

Next, we updated the concrete_dispatch_fn in python_variable.cpp to accept
a `const std::shared_ptr<TorchDispatchTypeObject>&` argument. When this
is null, dispatching happens as usual. When it is non-null, we prepend
the TorchDispatchTypeObject's PyObject* to the overloaded args list so that
it is considered first for dispatch.

To get that to work, we changed how `handle_torch_dispatch_no_python_arg_parser`
works. The "overloaded args list" previously only consisted of Tensor PyObjects,
but now it can have types in addition to Tensors!
- We renamed `append_overloaded_arg` to `append_overloaded_arg`
- We added a new `append_overloaded_type` that appends a type to
overloaded_args
- We added special handling in `handle_torch_dispatch_no_python_arg_parser`
and `append_overloaded_arg` to handle types in addition to Tensors.

Then, there is PythonMode and PythonModeTLS.
- We reuse the DispatchKey::Python dispatch key as a mode key
- We use PythonMode::enter and PythonMode::exit to enable/disable
DispatchKey::Python and set the PythonModeTLS.
- PythonModeTLS stores a TorchDispatchTypeObject as metadata.
- PythonMode is in libtorch_python, and PythonModeTLS is in ATen.
This split is due to the libtorch_python library boundary (because we need
to save TLS in ATen/ThreadLocalState)
- We modify the PythonFallbackKernel to look up
the relevant TorchDispatchTypeObject (if Python Mode is active) and
dispatch using it.

There are two more miscellaneous changes:
- internal_new_from_data (torch/csrc/utils/tensor_new.cpp) gets an
exclude guard. enable_python_mode currently does not handle
torch.tensor and the exclude guard is to prevent a bug.

Future:
- This PR does not allow for the nesting of Python modes. In the future we
should be able to enable this with a more sane no_dispatch API and by changing
the TLS to a stack. For now I did not need this for CompositeImplicitAutograd testing.

Test Plan: - new tests

Reviewed By: malfet, albanD

Differential Revision: D30543236

Pulled By: zou3519

fbshipit-source-id: ef5444d96a5a957d1657b7e37dce80f9a497d452
19 files changed:
aten/src/ATen/PythonModeTLS.cpp [new file with mode: 0644]
aten/src/ATen/PythonModeTLS.h [new file with mode: 0644]
aten/src/ATen/ThreadLocalState.cpp
aten/src/ATen/ThreadLocalState.h
aten/src/ATen/core/PythonFallbackKernel.cpp
c10/core/TensorImpl.cpp
c10/core/TensorImpl.h
test/run_test.py
test/test_python_dispatch.py
tools/build_variables.bzl
torch/_C/__init__.pyi.in
torch/csrc/autograd/init.cpp
torch/csrc/autograd/python_mode.cpp [new file with mode: 0644]
torch/csrc/autograd/python_mode.h [new file with mode: 0644]
torch/csrc/autograd/python_variable.cpp
torch/csrc/utils/python_arg_parser.cpp
torch/csrc/utils/python_arg_parser.h
torch/csrc/utils/tensor_new.cpp
torch/utils/_python_dispatch.py [new file with mode: 0644]