Add python mode (#63496)
authorRichard Zou <zou3519@gmail.com>
Tue, 31 Aug 2021 01:39:50 +0000 (18:39 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 31 Aug 2021 01:44:35 +0000 (18:44 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63496

This PR adds a (private) enable_python_mode context manager.
(see torch/utils/_python_dispatch.py).
enable_python_mode accepts the type of a __torch_dispatch__ object
as its argument. Whenever an operator gets called inside of the
context manager, it dispatches to the __torch_dispatch__ of
the passed-in type.

Example usage:
```
with enable_python_mode(LoggingTensor):
    z = torch.empty([])
    assert isinstance(z, LoggingTensor)
```

There are quite a few changes that were made to support this.

First, we added TorchDispatchTypeObject, a C++ struct that represents the
type of a `__torch_dispatch__` object (e.g. LoggingTensor).
It holds both the PyObject* representing the class and a PyInterpreter*
so we know which Python interpreter it came from.

Next, we updated the concrete_dispatch_fn in python_variable.cpp to accept
a `const std::shared_ptr<TorchDispatchTypeObject>&` argument. When this
is null, dispatching happens as usual. When it is non-null, we prepend
the TorchDispatchTypeObject's PyObject* to the overloaded args list so that
it is considered first for dispatch.

To get that to work, we changed how `handle_torch_dispatch_no_python_arg_parser`
works. The "overloaded args list" previously only consisted of Tensor PyObjects,
but now it can have types in addition to Tensors!
- We renamed `append_overloaded_arg` to `append_overloaded_arg`
- We added a new `append_overloaded_type` that appends a type to
overloaded_args
- We added special handling in `handle_torch_dispatch_no_python_arg_parser`
and `append_overloaded_arg` to handle types in addition to Tensors.

Then, there is PythonMode and PythonModeTLS.
- We reuse the DispatchKey::Python dispatch key as a mode key
- We use PythonMode::enter and PythonMode::exit to enable/disable
DispatchKey::Python and set the PythonModeTLS.
- PythonModeTLS stores a TorchDispatchTypeObject as metadata.
- PythonMode is in libtorch_python, and PythonModeTLS is in ATen.
This split is due to the libtorch_python library boundary (because we need
to save TLS in ATen/ThreadLocalState)
- We modify the PythonFallbackKernel to look up
the relevant TorchDispatchTypeObject (if Python Mode is active) and
dispatch using it.

There are two more miscellaneous changes:
- internal_new_from_data (torch/csrc/utils/tensor_new.cpp) gets an
exclude guard. enable_python_mode currently does not handle
torch.tensor and the exclude guard is to prevent a bug.

Future:
- This PR does not allow for the nesting of Python modes. In the future we
should be able to enable this with a more sane no_dispatch API and by changing
the TLS to a stack. For now I did not need this for CompositeImplicitAutograd testing.

Test Plan: - new tests

Reviewed By: malfet, albanD

Differential Revision: D30543236

Pulled By: zou3519

fbshipit-source-id: ef5444d96a5a957d1657b7e37dce80f9a497d452

19 files changed:
aten/src/ATen/PythonModeTLS.cpp [new file with mode: 0644]
aten/src/ATen/PythonModeTLS.h [new file with mode: 0644]
aten/src/ATen/ThreadLocalState.cpp
aten/src/ATen/ThreadLocalState.h
aten/src/ATen/core/PythonFallbackKernel.cpp
c10/core/TensorImpl.cpp
c10/core/TensorImpl.h
test/run_test.py
test/test_python_dispatch.py
tools/build_variables.bzl
torch/_C/__init__.pyi.in
torch/csrc/autograd/init.cpp
torch/csrc/autograd/python_mode.cpp [new file with mode: 0644]
torch/csrc/autograd/python_mode.h [new file with mode: 0644]
torch/csrc/autograd/python_variable.cpp
torch/csrc/utils/python_arg_parser.cpp
torch/csrc/utils/python_arg_parser.h
torch/csrc/utils/tensor_new.cpp
torch/utils/_python_dispatch.py [new file with mode: 0644]

diff --git a/aten/src/ATen/PythonModeTLS.cpp b/aten/src/ATen/PythonModeTLS.cpp
new file mode 100644 (file)
index 0000000..b53043c
--- /dev/null
@@ -0,0 +1,26 @@
+#include <ATen/PythonModeTLS.h>
+
+namespace at { namespace impl {
+
+thread_local std::shared_ptr<TorchDispatchTypeObject> pythonModeState;
+
+void PythonModeTLS::set_state(const std::shared_ptr<TorchDispatchTypeObject>& state) {
+  pythonModeState = state;
+  if (state) {
+    c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true);
+  } else {
+    PythonModeTLS::reset_state();
+  }
+}
+
+const std::shared_ptr<TorchDispatchTypeObject>& PythonModeTLS::get_state() {
+  return pythonModeState;
+}
+
+void PythonModeTLS::reset_state() {
+  pythonModeState.reset((TorchDispatchTypeObject*)nullptr);
+  c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
+}
+
+} // namespace impl
+} // namespace at
diff --git a/aten/src/ATen/PythonModeTLS.h b/aten/src/ATen/PythonModeTLS.h
new file mode 100644 (file)
index 0000000..be52b18
--- /dev/null
@@ -0,0 +1,17 @@
+#pragma once
+
+#include <c10/macros/Macros.h>
+#include <torch/library.h>
+#include <ATen/core/dispatch/Dispatcher.h>
+
+namespace at {
+namespace impl {
+
+struct TORCH_API PythonModeTLS {
+  static void set_state(const std::shared_ptr<TorchDispatchTypeObject>& state);
+  static const std::shared_ptr<TorchDispatchTypeObject>& get_state();
+  static void reset_state();
+};
+
+} // namespace impl
+} // namespace at
index 98c2519..19cfa89 100644 (file)
@@ -17,6 +17,7 @@ ThreadLocalState::ThreadLocalState()
   saved_tensors_default_hooks_ = SavedTensorDefaultHooks::get_hooks();
 
   bumped_record_all_functions_ = at::checkRecordAllFunctions();
+  python_mode_state_ = at::impl::PythonModeTLS::get_state();
 }
 
 void ThreadLocalState::set_grad_mode(bool enabled) {
@@ -30,6 +31,8 @@ void ThreadLocalState::setThreadLocalState(
   // restore the dispatch key set TLS at the same time.
   c10::AutogradState::set_tls_state(state.autograd_tls_);
 
+  at::impl::PythonModeTLS::set_state(state.python_mode_state_);
+
   at::set_record_function_tls_(state.rf_tls_);
 
   SavedTensorDefaultHooks::set_hooks(
index 4114691..c99ca61 100644 (file)
@@ -6,6 +6,7 @@
 #include <c10/util/ThreadLocalDebugInfo.h>
 
 #include <ATen/record_function.h>
+#include <ATen/PythonModeTLS.h>
 
 namespace at {
 
@@ -40,6 +41,8 @@ class TORCH_API ThreadLocalState {
   // TLS for AutogradModes
   AutogradState autograd_tls_;
 
+  std::shared_ptr<TorchDispatchTypeObject> python_mode_state_;
+
   // TLS for saved tensors default hooks
   std::pair<PyObject*, PyObject*> saved_tensors_default_hooks_;
 
index 276eabf..8e77d09 100644 (file)
@@ -1,9 +1,18 @@
 #include <torch/library.h>
 #include <ATen/core/dispatch/Dispatcher.h>
+#include <ATen/PythonModeTLS.h>
 
 namespace {
 
 void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
+  // If Python Mode is active, use its PyInterpreter for dispatch
+  const auto& maybe_python_mode_state = at::impl::PythonModeTLS::get_state();
+  if (maybe_python_mode_state) {
+    maybe_python_mode_state->pyinterpreter()->dispatch(op, stack, maybe_python_mode_state);
+    return;
+  }
+
+  // Otherwise, find a PyInterpreter on a Tensor
   const auto& schema = op.schema();
   const auto num_arguments = schema.arguments().size();
   // It is safe to dispatch on the very first Tensor with a pyobj_interpreter
@@ -15,7 +24,7 @@ void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
     if (ivalue.isTensor()) {
       auto* interpreter = ivalue.unsafeToTensorImpl()->pyobj_interpreter();
       if (interpreter) {
-        interpreter->dispatch(op, stack);
+        interpreter->dispatch(op, stack, nullptr);
         return;
       }
     } else if (ivalue.isTensorList()) {
@@ -24,7 +33,7 @@ void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
       for (const auto& nv : ivalue.toListRef()) {
         auto* interpreter = nv.unsafeToTensorImpl()->pyobj_interpreter();
         if (interpreter) {
-          interpreter->dispatch(op, stack);
+          interpreter->dispatch(op, stack, nullptr);
           return;
         }
       }
index de829c4..9a72659 100644 (file)
@@ -40,7 +40,8 @@ static c10::intrusive_ptr<TensorImpl> noop_detach_fn(
 static void noop_dispatch_fn(
     const PyInterpreter*,
     const c10::OperatorHandle& op,
-    torch::jit::Stack* stack) {
+    torch::jit::Stack* stack,
+    const std::shared_ptr<TorchDispatchTypeObject>& type) {
   TORCH_INTERNAL_ASSERT(
       0,
       "attempted to dispatch (__torch_dispatch__) an operator on Tensor with nontrivial PyObject after corresponding interpreter died");
@@ -608,6 +609,23 @@ void TensorImpl::copy_tensor_metadata(
   }
 }
 
+TorchDispatchTypeObject::TorchDispatchTypeObject(
+    PyObject* type_object,
+    c10::impl::PyInterpreter* pyinterpreter)
+    : data_(type_object), pyinterpreter_(pyinterpreter) {}
+
+TorchDispatchTypeObject::~TorchDispatchTypeObject() {
+  pyinterpreter_->decref(data_);
+}
+
+c10::impl::PyInterpreter* TorchDispatchTypeObject::pyinterpreter() const {
+  return pyinterpreter_;
+}
+
+PyObject* TorchDispatchTypeObject::ptr() const {
+  return data_;
+}
+
 namespace impl {
 
 namespace {
index 7051e36..d110a17 100644 (file)
@@ -161,6 +161,9 @@ struct C10_API AutogradMetaInterface {
   virtual ~AutogradMetaInterface();
 };
 
+// forward declared
+struct TorchDispatchTypeObject;
+
 namespace impl {
 
 // Unfortunately, the definition of AutogradMeta lives in a separate
@@ -255,7 +258,8 @@ struct C10_API PyInterpreter {
   using dispatch_sig = void(
       const PyInterpreter*,
       const c10::OperatorHandle&,
-      torch::jit::Stack* stack);
+      torch::jit::Stack* stack,
+      const std::shared_ptr<TorchDispatchTypeObject>& type);
 
   PyInterpreter(
       name_sig* name_fn,
@@ -299,8 +303,9 @@ struct C10_API PyInterpreter {
   // Invoke the Python boxed fallback dispatch to go back into Python
   __ubsan_ignore_function__ void dispatch(
       const c10::OperatorHandle& op,
-      torch::jit::Stack* stack) const {
-    return (*dispatch_fn_)(this, op, stack);
+      torch::jit::Stack* stack,
+      const std::shared_ptr<TorchDispatchTypeObject>& type) const {
+    return (*dispatch_fn_)(this, op, stack, type);
   }
 
   // Disarm this PyInterpreter, making all of its methods noops.
@@ -348,6 +353,30 @@ struct C10_API NamedTensorMetaInterface {
   };
 };
 
+// NOTE [What is TorchDispatchTypeObject?]
+// A TorchDispatchTypeObject represents the type of a Tensor subclass that has
+// a __torch_dispatch__ classmethod. Concretely, it holds the class as a
+// PyObject* and a PyInterpreter* that says which python interpreter the class
+// came from.
+//
+// See NOTE [dispatch_fn's type argument] for more details
+struct C10_API TorchDispatchTypeObject {
+  // Steals a reference to type_object
+  TorchDispatchTypeObject(
+      PyObject* type_object,
+      c10::impl::PyInterpreter* pyinterpreter);
+
+  // Releases the stolen reference to type_object
+  ~TorchDispatchTypeObject();
+
+  c10::impl::PyInterpreter* pyinterpreter() const;
+  PyObject* ptr() const;
+
+ private:
+  PyObject* data_;
+  c10::impl::PyInterpreter* pyinterpreter_;
+};
+
 // NOTE [ Version Counter Sharing ]
 //
 // Every Tensor has a version counter. Version counters are incremented whenever
index dd95e13..615aaf9 100755 (executable)
@@ -103,6 +103,7 @@ TESTS = [
     "test_optim",
     "test_functional_optim",
     "test_pytree",
+    "test_python_dispatch",
     "test_mobile_optimizer",
     "test_set_default_mobile_cpu_allocator",
     "test_xnnpack_integration",
index 0f5b6b9..e474f1f 100644 (file)
@@ -1,6 +1,7 @@
 import torch
 from torch.testing._internal.common_utils import TestCase, run_tests
 from torch.utils._pytree import tree_map
+from torch.utils._python_dispatch import enable_python_mode
 
 from typing import Iterator, List
 import logging
@@ -50,7 +51,10 @@ class LoggingTensor(torch.Tensor):
         def wrap(e):
             return LoggingTensor(e) if isinstance(e, torch.Tensor) else e
 
-        rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
+        # no_dispatch is only needed if you use enable_python_mode.
+        # It prevents infinite recursion.
+        with no_dispatch():
+            rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
         logging.getLogger("LoggingTensor").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs)
         return rs
 
@@ -335,6 +339,81 @@ $4 = torch._ops.aten.mul($3, tensor(2))
 $5 = torch._ops.aten.mul($4, $0)
 $6 = torch._ops.aten.add_($1, $5)''')
 
+    def test_enable_python_mode_error(self) -> None:
+        with self.assertRaisesRegex(ValueError, "__torch_dispatch__"):
+            with enable_python_mode(torch.Tensor):
+                pass
+        z = LoggingTensor(torch.empty([]))
+        with self.assertRaisesRegex(ValueError, "must be the type"):
+            with enable_python_mode(z):
+                pass
+
+    def test_enable_python_mode_basic(self) -> None:
+        with enable_python_mode(LoggingTensor):
+            z = torch.empty([])
+            self.assertTrue(isinstance(z, LoggingTensor))
+
+    def test_enable_python_mode_unrelated_tensors(self) -> None:
+        x = torch.randn([])
+        y = torch.randn([])
+        with enable_python_mode(LoggingTensor):
+            z = x + y
+            self.assertTrue(isinstance(z, LoggingTensor))
+
+    def test_enable_python_mode_subclass_priority(self) -> None:
+        class ErrorA(RuntimeError):
+            pass
+
+        class ErrorB(RuntimeError):
+            pass
+
+        class A(torch.Tensor):
+            @staticmethod
+            def __new__(cls, elem):
+                return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
+
+            @classmethod
+            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
+                raise ErrorA
+
+        class B(A):
+            @staticmethod
+            def __new__(cls, elem):
+                return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
+
+            @classmethod
+            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
+                raise ErrorB
+
+        a = A(torch.empty(1))
+        b = B(torch.empty(1))
+        with self.assertRaises(ErrorA):
+            a + a
+
+        # B has precedence over A due to the subclass relationship
+        with self.assertRaises(ErrorB):
+            with enable_python_mode(A):
+                b + b
+        with self.assertRaises(ErrorB):
+            with enable_python_mode(B):
+                a + a
+        with self.assertRaises(ErrorB):
+            with enable_python_mode(B):
+                a + b
+
+    def test_enable_python_mode_respects_no_dispatch(self) -> None:
+        with enable_python_mode(LoggingTensor):
+            z = torch.ones([2, 3])
+            self.assertTrue(isinstance(z, LoggingTensor))
+            with no_dispatch():
+                expected = torch.ones([2, 3])
+                self.assertEqual(z.elem, expected)
+
+    def test_nested_enable_python_mode(self) -> None:
+        with self.assertRaisesRegex(RuntimeError, "has already been set"):
+            with enable_python_mode(LoggingTensor):
+                with enable_python_mode(LoggingTensor):
+                    pass
 
 if __name__ == '__main__':
     run_tests()
index 34846b5..dd89981 100644 (file)
@@ -666,6 +666,7 @@ libtorch_python_core_sources = [
     "torch/csrc/autograd/init.cpp",
     "torch/csrc/autograd/python_anomaly_mode.cpp",
     "torch/csrc/autograd/python_saved_variable_hooks.cpp",
+    "torch/csrc/autograd/python_mode.cpp",
     "torch/csrc/autograd/python_cpp_function.cpp",
     "torch/csrc/autograd/python_engine.cpp",
     "torch/csrc/autograd/python_function.cpp",
@@ -793,6 +794,7 @@ aten_cpu_source_non_codegen_list = [
     "aten/src/ATen/ParallelNativeTBB.cpp",
     "aten/src/ATen/ParallelOpenMP.cpp",
     "aten/src/ATen/ParallelThreadPoolNative.cpp",
+    "aten/src/ATen/PythonModeTLS.cpp",
     "aten/src/ATen/ScalarOps.cpp",
     "aten/src/ATen/SequenceNumber.cpp",
     "aten/src/ATen/SparseTensorImpl.cpp",
index 3629150..c847e8d 100644 (file)
@@ -652,6 +652,8 @@ def __set_forward_AD_enabled(enabled: _bool) -> None: ...
 def __is_forward_AD_enabled() -> _bool: ...
 def _register_default_hooks(pack_hook: Callable, unpack_hook: Callable) -> None: ...
 def _reset_default_hooks() -> None: ...
+def _enter_python_mode(cls: Type) -> None: ...
+def _exit_python_mode() -> None: ...
 
 class _InferenceMode(object):
     def __init__(self, mode: _bool) -> None: ...
index 697ca87..860aaec 100644 (file)
@@ -14,6 +14,7 @@
 #include <torch/csrc/autograd/python_saved_variable_hooks.h>
 #include <torch/csrc/autograd/utils/wrap_outputs.h>
 #include <torch/csrc/autograd/utils/python_arg_parsing.h>
+#include <torch/csrc/autograd/python_mode.h>
 #include <torch/csrc/utils/pycfunction_helpers.h>
 #include <c10/core/ScalarType.h>
 
@@ -494,6 +495,20 @@ static PyObject * python_exit_dual_level(PyObject* _unused, PyObject* args, PyOb
   END_HANDLE_TH_ERRORS
 }
 
+static PyObject * enter_python_mode(PyObject* _unused, PyObject* arg) {
+  HANDLE_TH_ERRORS
+  PythonMode::enter(arg);
+  Py_RETURN_NONE;
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject * exit_python_mode(PyObject* _unused, PyObject* arg) {
+  HANDLE_TH_ERRORS
+  PythonMode::exit();
+  Py_RETURN_NONE;
+  END_HANDLE_TH_ERRORS
+}
+
 // autograd methods on torch._C
 static PyMethodDef methods[] = { // NOLINT
   {"_set_grad_enabled", set_grad_enabled, METH_O, nullptr},
@@ -514,6 +529,8 @@ static PyMethodDef methods[] = { // NOLINT
   {"is_anomaly_enabled", is_anomaly_mode_enabled, METH_NOARGS, nullptr},
   {"_enter_dual_level", python_enter_dual_level, METH_NOARGS, nullptr},
   {"_exit_dual_level", castPyCFunctionWithKeywords(python_exit_dual_level), METH_VARARGS | METH_KEYWORDS, nullptr},
+  {"_enter_python_mode", enter_python_mode, METH_O, nullptr},
+  {"_exit_python_mode", exit_python_mode, METH_NOARGS, nullptr},
   {nullptr, nullptr, 0, nullptr}
 };
 
diff --git a/torch/csrc/autograd/python_mode.cpp b/torch/csrc/autograd/python_mode.cpp
new file mode 100644 (file)
index 0000000..4358426
--- /dev/null
@@ -0,0 +1,27 @@
+#include <torch/csrc/autograd/python_mode.h>
+#include <torch/csrc/python_headers.h>
+#include <torch/csrc/autograd/python_variable.h>
+#include <ATen/PythonModeTLS.h>
+#include <c10/core/TensorImpl.h>
+
+namespace torch { namespace autograd {
+
+void PythonMode::enter(PyObject* type) {
+  if (at::impl::PythonModeTLS::get_state()) {
+    TORCH_CHECK(
+        false,
+        "python mode has already been set. We do not yet support nested python ",
+        "mode. Please file us an issue and reset it before setting it again.")
+  }
+  // TorchDispatchTypeObject steals a reference, See NOTE [What is TorchDispatchTypeObject?]
+  Py_INCREF(type);
+  auto state = std::make_shared<c10::TorchDispatchTypeObject>(type, getPyInterpreter());
+  at::impl::PythonModeTLS::set_state(state);
+}
+
+void PythonMode::exit() {
+  TORCH_INTERNAL_ASSERT(at::impl::PythonModeTLS::get_state(), "exiting Python Mode but it wasn't set!");
+  at::impl::PythonModeTLS::reset_state();
+}
+
+}}
diff --git a/torch/csrc/autograd/python_mode.h b/torch/csrc/autograd/python_mode.h
new file mode 100644 (file)
index 0000000..03da51c
--- /dev/null
@@ -0,0 +1,17 @@
+#pragma once
+
+#include <torch/csrc/python_headers.h>
+#include <c10/core/TensorImpl.h>
+
+namespace torch { namespace autograd {
+
+struct TORCH_API PythonMode {
+  // Enter python mode, causing all operators to dispatch to the type's __torch_dispatch__.
+  // `type` is the type of a Tensor subclass that has __torch_dispatch__.
+  static void enter(PyObject* type);
+
+  // Exit the current python mode.
+  static void exit();
+};
+
+}}
index 50d6eb9..abe9010 100644 (file)
@@ -32,6 +32,7 @@
 
 #include <torch/library.h>
 #include <torch/csrc/jit/python/pybind_utils.h>
+#include <torch/csrc/autograd/python_mode.h>
 
 
 #include <ATen/ATen.h>
@@ -64,7 +65,12 @@ void concrete_decref_fn(const c10::impl::PyInterpreter* self, PyObject* pyobj) {
     return;
 
   pybind11::gil_scoped_acquire gil;
-  if (Py_REFCNT(pyobj) > 1) {
+  // Two possibilities:
+  // 1. We are decref-ing a tensor. Then we must be careful about
+  // PyObject resurrection (this only applies to Tensors, see THPVariable_clear).
+  // 2. We are decref-ing some other Python object. We don't do
+  // PyObject resurrection on non-Tensors, so we just carry on as usual
+  if (THPVariable_Check(pyobj) && Py_REFCNT(pyobj) > 1) {
     // It's still alive!  This can happen if a weak ref resurrected
     // the PyObject without flipping ownership.  At this point it is
     // too late to rescue the object, so just stub out the PyObject
@@ -82,7 +88,11 @@ void concrete_decref_fn(const c10::impl::PyInterpreter* self, PyObject* pyobj) {
 };
 
 c10::intrusive_ptr<TensorImpl> concrete_detach_fn(const c10::impl::PyInterpreter*, const c10::TensorImpl* self);
-void concrete_dispatch_fn(const c10::impl::PyInterpreter*, const c10::OperatorHandle& op, torch::jit::Stack* stack);
+void concrete_dispatch_fn(
+    const c10::impl::PyInterpreter*,
+    const c10::OperatorHandle& op,
+    torch::jit::Stack* stack,
+    const std::shared_ptr<TorchDispatchTypeObject>& type);
 
 class PyInterpreterHolder {
  public:
@@ -1491,7 +1501,19 @@ bool isPythonTensor(const Tensor& tensor) {
   return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Python);
 }
 
-void concrete_dispatch_fn(const c10::impl::PyInterpreter*, const c10::OperatorHandle& op, torch::jit::Stack* stack) {
+// NOTE [dispatch_fn's type argument]
+// `type` is nullable and represents the PythonMode going on.
+// Right now we only support a single PythonMode, but in the future we could
+// change this to a stack of PythonModes.
+//
+// If `type` isn't null, then we consider the type for dispatch by prepending
+// it to the overloaded_args list. `handle_torch_funciton_no_python_arg_parser`
+// is responsible for doing overload resolution.
+void concrete_dispatch_fn(
+    const c10::impl::PyInterpreter*,
+    const c10::OperatorHandle& op,
+    torch::jit::Stack* stack,
+    const std::shared_ptr<TorchDispatchTypeObject>& type) {
   const auto& schema = op.schema();
   const auto num_returns = schema.returns().size();
 
@@ -1568,13 +1590,17 @@ void concrete_dispatch_fn(const c10::impl::PyInterpreter*, const c10::OperatorHa
   auto args = py::reinterpret_steal<py::object>(PyTuple_New(positional_default_start));
   py::dict kwargs;
 
+  if (type) {
+    append_overloaded_type(&overloaded_args, type->ptr());
+  }
+
   // Find overloaded tensors
   for (int64_t idx = 0; idx < arguments.size(); idx++) {
     const auto& ivalue = arguments[idx];
     if (ivalue.isTensor()) {
       const auto& tensor = ivalue.toTensor();
       if (isPythonTensor(tensor)) {
-        append_overloaded_arg(&overloaded_args, py::cast(tensor).ptr());
+        append_overloaded_tensor(&overloaded_args, py::cast(tensor).ptr());
       }
     } else if (ivalue.isList()) {
       const auto& list = ivalue.toListRef();
@@ -1583,7 +1609,7 @@ void concrete_dispatch_fn(const c10::impl::PyInterpreter*, const c10::OperatorHa
         if (nv.isTensor()) {
           const auto& tensor = nv.toTensor();
           if (isPythonTensor(tensor)) {
-            append_overloaded_arg(&overloaded_args, py::cast(tensor).ptr());
+            append_overloaded_tensor(&overloaded_args, py::cast(tensor).ptr());
           }
         }
       }
@@ -1633,7 +1659,7 @@ c10::intrusive_ptr<TensorImpl> concrete_detach_fn(const c10::impl::PyInterpreter
   Tensor self_t = Tensor(c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::unsafe_reclaim_from_nonowning(const_cast<c10::TensorImpl*>(self)));
   auto self_p = py::reinterpret_steal<py::object>(THPVariable_Wrap(self_t));
   TORCH_INTERNAL_ASSERT(isPythonTensor(self_t));
-  append_overloaded_arg(&overloaded_args, self_p.ptr());
+  append_overloaded_tensor(&overloaded_args, self_p.ptr());
   auto args = py::reinterpret_steal<py::object>(PyTuple_New(1));
   PyTuple_SET_ITEM(args.ptr(), 0, self_p.release().ptr());
 
index 6115dcd..3ee20c0 100644 (file)
@@ -200,12 +200,28 @@ auto handle_torch_function(PyObject* self, const std::string& func_name, PyObjec
   return ret.release().ptr();
 }
 
+// Note: [Overloaded args]
+// An overloaded arg may be one of the following:
+// - an instance of an object that has a __torch_function__ method
+// - an instance of an object that has a __torch_dispatch__ classmethod
+// - a class type that has a __torch_dispatch__ classmethod
+//
+// This function returns the type of the arg (if the arg is an instance),
+// otherwise, it returns the arg.
+static PyObject* get_type_of_overloaded_arg(PyObject* obj_or_type) {
+  if (PyType_Check(obj_or_type)) {
+    return obj_or_type;
+  }
+  return (PyObject*)Py_TYPE(obj_or_type);
+}
+
+// See Note: [Overloaded args] for what they hold
 auto handle_torch_function_no_python_arg_parser(const std::vector<py::handle> &overloaded_args, PyObject* args, PyObject* kwargs, const char* func_name, PyObject* torch_api_function, const char* module_name, const char* torch_function_name) -> PyObject* {
   // overloaded_args already all have unique types
   std::vector<py::object> overloaded_types;
   overloaded_types.reserve(overloaded_args.size());
   for (auto &arg : overloaded_args) {
-    overloaded_types.push_back(py::reinterpret_borrow<py::object>((PyObject *) Py_TYPE(arg.ptr())));
+    overloaded_types.push_back(py::reinterpret_borrow<py::object>(get_type_of_overloaded_arg(arg.ptr())));
   }
   py::tuple py_types = py::cast(overloaded_types);
   py::object ret;
@@ -231,7 +247,7 @@ auto handle_torch_function_no_python_arg_parser(const std::vector<py::handle> &o
     ss << "no implementation found for '" << module_name << "." << func_name
        << "' on types that implement " << torch_function_name << ": [";
     for (auto &arg : overloaded_args) {
-      ss << arg.ptr()->ob_type->tp_name;
+      ss << PyObject_Repr(get_type_of_overloaded_arg(arg.ptr()));
       if (!arg.is(overloaded_args.back())) {
         ss << ", ";
       }
@@ -328,10 +344,11 @@ auto handle_torch_function_indexing(PyObject* self, PyObject* index, PyObject* v
  *
  */
 
-void append_overloaded_arg(std::vector<py::handle>* overloaded_args, PyObject* obj) {
+static void append_overloaded_arg(std::vector<py::handle>* overloaded_args, PyObject* obj, bool obj_is_type) {
   bool class_not_seen_yet = true;
+  PyObject* obj_type = obj_is_type ? obj : (PyObject*)Py_TYPE(obj);
   for (auto &arg : *overloaded_args) {
-    if (Py_TYPE(obj) == Py_TYPE(arg.ptr())) {
+    if (obj_type == get_type_of_overloaded_arg(arg.ptr())) {
       // obj is the same type as another parameter we've seen in a prior
       // iteration of the loop over parameters so we already have an entry
       // with the proper __torch_function__ implementation to call, so skip
@@ -343,7 +360,7 @@ void append_overloaded_arg(std::vector<py::handle>* overloaded_args, PyObject* o
   if (class_not_seen_yet) {
     int arg_index = overloaded_args->size();
     for(const auto j : c10::irange(arg_index)) {
-      if (PyObject_IsInstance(obj, (PyObject*)(Py_TYPE((*overloaded_args)[j].ptr())))) {
+      if (PyObject_IsSubclass(obj_type, (PyObject*)(get_type_of_overloaded_arg((*overloaded_args)[j].ptr())))) {
         // obj is a subclass of another object we've seen already so its
         // __torch_function__ should be called first, therefore we
         // insert it into overloaded_args before the superclass
@@ -358,6 +375,14 @@ void append_overloaded_arg(std::vector<py::handle>* overloaded_args, PyObject* o
   }
 }
 
+void append_overloaded_tensor(std::vector<py::handle>* overloaded_args, PyObject* obj) {
+  append_overloaded_arg(overloaded_args, obj, /*obj_is_type*/false);
+}
+
+void append_overloaded_type(std::vector<py::handle>* overloaded_args, PyObject* obj) {
+  append_overloaded_arg(overloaded_args, obj, /*obj_is_type*/true);
+}
+
 bool is_tensor_and_append_overloaded(PyObject* obj, std::vector<py::handle>* overloaded_args) {
   if (THPVariable_CheckExact(obj)) {
     // torch.Tensor instances (not subclasses, except for Parameter)
@@ -366,7 +391,7 @@ bool is_tensor_and_append_overloaded(PyObject* obj, std::vector<py::handle>* ove
 
   if (check_has_torch_function(obj)) {
     // tensor subclasses and unrelated objects with __torch_function__
-    append_overloaded_arg(overloaded_args, obj);
+    append_overloaded_tensor(overloaded_args, obj);
     return true;
   } else if (THPVariable_Check(obj)) {
     // tensor subclasses without __torch_function__
@@ -905,7 +930,7 @@ bool FunctionSignature::parse(PyObject* self, PyObject* args, PyObject* kwargs,
 
   int i = 0;
   if (self != nullptr && check_has_torch_function(self)) {
-    append_overloaded_arg(&this->overloaded_args, self);
+    append_overloaded_tensor(&this->overloaded_args, self);
   }
   for (auto& param : params) {
     PyObject* obj = nullptr;
index d132185..6a05807 100644 (file)
@@ -818,6 +818,15 @@ bool is_tensor_list_and_append_overloaded(PyObject* obj, std::vector<py::handle>
  * 'overloaded_args': the vector to append the overloaded args
  * 'obj': the input tensor that is overloaded
  */
-void append_overloaded_arg(std::vector<py::handle>* overloaded_args, PyObject* obj);
+void append_overloaded_tensor(std::vector<py::handle>* overloaded_args, PyObject* obj);
+
+/* Given an argument that is definitely a type and is definitely overloaded,
+ * append it to the overloaded arguments list. Use this only with __torch_dispatch__,
+ * where we operate on classes that have a __torch_dispatch__ classmethod.
+ *
+ * 'overloaded_args': the vector to append the overloaded type
+ * 'obj': the input class that has a __torch_dispatch__ classmethod.
+ */
+void append_overloaded_type(std::vector<py::handle>* overloaded_args, PyObject* obj);
 
 } // namespace torch
index 17d7acc..25e9a59 100644 (file)
@@ -267,6 +267,7 @@ Tensor internal_new_from_data(
   {
     at::AutoDispatchBelowADInplaceOrView guard;  // TODO: remove
     at::tracer::impl::NoTracerDispatchMode tracer_guard;
+    c10::impl::ExcludeDispatchKeyGuard pythonmode_guard(c10::DispatchKey::Python);
     // functorch uses FuncTorchDynamicLayerBackMode as a mode key to wrap all
     // tensors returned from operators in special TensorWrapper tensor extension
     // The problem with this is that TensorWrapper does not have storage so
diff --git a/torch/utils/_python_dispatch.py b/torch/utils/_python_dispatch.py
new file mode 100644 (file)
index 0000000..a7cfae1
--- /dev/null
@@ -0,0 +1,34 @@
+import torch
+import contextlib
+from typing import Iterator
+
+# Context manager that causes all pytorch operators to dispatch to the passed-in
+# type's __torch_dispatch__ function.
+# operation that accepts no tensors but returns a tensor.
+#
+# enable_python_mode is affected by torch._C._DisableTorchDispatch.
+#
+# NB: Calling an operator inside __torch_dispatch__ does go through
+# __torch_dispatch__ again. Please use _DisableTorchDispatch inside
+# __torch_dispatch__ to prevent infinite recursion.
+#
+# TODO: Limitations and things about enable_python_mode we should fix before exposing it:
+# - it currently cannot be nested. This should be simple to implement; we need a
+#   stack of TorchDispatchTypeObjects and the next bullet point.
+# - We need a better user-facing api for torch._C._DisableTorchDispatch that
+#   is able to selectively disable __torch_dispatch__ of a particular class.
+# - It doesn't work with the tensor constructors (torch.tensor, torch.Tensor)
+# - Better name (see https://github.com/pytorch/pytorch/pull/63496#discussion_r694091694)
+@contextlib.contextmanager
+def enable_python_mode(cls) -> Iterator[None]:
+    if not hasattr(cls, '__torch_dispatch__'):
+        raise ValueError('The class passed to enable_python_mode '
+                         'must have a __torch_dispatch__ classmethod')
+    if not isinstance(cls, type) or not issubclass(cls, (torch.Tensor,)):
+        raise ValueError('The argument passed to enable_python_mode '
+                         'must be the type of a Tensor subclass')
+    torch._C._enter_python_mode(cls)
+    try:
+        yield
+    finally:
+        torch._C._exit_python_mode()