Revert D30543236: Add python mode
authorRichard Zou <rzou@fb.com>
Tue, 31 Aug 2021 21:53:01 +0000 (14:53 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 31 Aug 2021 22:28:33 +0000 (15:28 -0700)
Test Plan: revert-hammer

Differential Revision:
D30543236 (https://github.com/pytorch/pytorch/commit/4bd03b02424d93b72f15e28c542ede13f88ea929)

Original commit changeset: ef5444d96a5a

fbshipit-source-id: b0042ac2c22765fa11d6d00bf751f6a4489eb6d8

19 files changed:
aten/src/ATen/PythonModeTLS.cpp [deleted file]
aten/src/ATen/PythonModeTLS.h [deleted file]
aten/src/ATen/ThreadLocalState.cpp
aten/src/ATen/ThreadLocalState.h
aten/src/ATen/core/PythonFallbackKernel.cpp
c10/core/TensorImpl.cpp
c10/core/TensorImpl.h
test/run_test.py
test/test_python_dispatch.py
tools/build_variables.bzl
torch/_C/__init__.pyi.in
torch/csrc/autograd/init.cpp
torch/csrc/autograd/python_mode.cpp [deleted file]
torch/csrc/autograd/python_mode.h [deleted file]
torch/csrc/autograd/python_variable.cpp
torch/csrc/utils/python_arg_parser.cpp
torch/csrc/utils/python_arg_parser.h
torch/csrc/utils/tensor_new.cpp
torch/utils/_python_dispatch.py [deleted file]

diff --git a/aten/src/ATen/PythonModeTLS.cpp b/aten/src/ATen/PythonModeTLS.cpp
deleted file mode 100644 (file)
index b53043c..0000000
+++ /dev/null
@@ -1,26 +0,0 @@
-#include <ATen/PythonModeTLS.h>
-
-namespace at { namespace impl {
-
-thread_local std::shared_ptr<TorchDispatchTypeObject> pythonModeState;
-
-void PythonModeTLS::set_state(const std::shared_ptr<TorchDispatchTypeObject>& state) {
-  pythonModeState = state;
-  if (state) {
-    c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true);
-  } else {
-    PythonModeTLS::reset_state();
-  }
-}
-
-const std::shared_ptr<TorchDispatchTypeObject>& PythonModeTLS::get_state() {
-  return pythonModeState;
-}
-
-void PythonModeTLS::reset_state() {
-  pythonModeState.reset((TorchDispatchTypeObject*)nullptr);
-  c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
-}
-
-} // namespace impl
-} // namespace at
diff --git a/aten/src/ATen/PythonModeTLS.h b/aten/src/ATen/PythonModeTLS.h
deleted file mode 100644 (file)
index be52b18..0000000
+++ /dev/null
@@ -1,17 +0,0 @@
-#pragma once
-
-#include <c10/macros/Macros.h>
-#include <torch/library.h>
-#include <ATen/core/dispatch/Dispatcher.h>
-
-namespace at {
-namespace impl {
-
-struct TORCH_API PythonModeTLS {
-  static void set_state(const std::shared_ptr<TorchDispatchTypeObject>& state);
-  static const std::shared_ptr<TorchDispatchTypeObject>& get_state();
-  static void reset_state();
-};
-
-} // namespace impl
-} // namespace at
index 19cfa89..98c2519 100644 (file)
@@ -17,7 +17,6 @@ ThreadLocalState::ThreadLocalState()
   saved_tensors_default_hooks_ = SavedTensorDefaultHooks::get_hooks();
 
   bumped_record_all_functions_ = at::checkRecordAllFunctions();
-  python_mode_state_ = at::impl::PythonModeTLS::get_state();
 }
 
 void ThreadLocalState::set_grad_mode(bool enabled) {
@@ -31,8 +30,6 @@ void ThreadLocalState::setThreadLocalState(
   // restore the dispatch key set TLS at the same time.
   c10::AutogradState::set_tls_state(state.autograd_tls_);
 
-  at::impl::PythonModeTLS::set_state(state.python_mode_state_);
-
   at::set_record_function_tls_(state.rf_tls_);
 
   SavedTensorDefaultHooks::set_hooks(
index c99ca61..4114691 100644 (file)
@@ -6,7 +6,6 @@
 #include <c10/util/ThreadLocalDebugInfo.h>
 
 #include <ATen/record_function.h>
-#include <ATen/PythonModeTLS.h>
 
 namespace at {
 
@@ -41,8 +40,6 @@ class TORCH_API ThreadLocalState {
   // TLS for AutogradModes
   AutogradState autograd_tls_;
 
-  std::shared_ptr<TorchDispatchTypeObject> python_mode_state_;
-
   // TLS for saved tensors default hooks
   std::pair<PyObject*, PyObject*> saved_tensors_default_hooks_;
 
index 8e77d09..276eabf 100644 (file)
@@ -1,18 +1,9 @@
 #include <torch/library.h>
 #include <ATen/core/dispatch/Dispatcher.h>
-#include <ATen/PythonModeTLS.h>
 
 namespace {
 
 void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
-  // If Python Mode is active, use its PyInterpreter for dispatch
-  const auto& maybe_python_mode_state = at::impl::PythonModeTLS::get_state();
-  if (maybe_python_mode_state) {
-    maybe_python_mode_state->pyinterpreter()->dispatch(op, stack, maybe_python_mode_state);
-    return;
-  }
-
-  // Otherwise, find a PyInterpreter on a Tensor
   const auto& schema = op.schema();
   const auto num_arguments = schema.arguments().size();
   // It is safe to dispatch on the very first Tensor with a pyobj_interpreter
@@ -24,7 +15,7 @@ void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
     if (ivalue.isTensor()) {
       auto* interpreter = ivalue.unsafeToTensorImpl()->pyobj_interpreter();
       if (interpreter) {
-        interpreter->dispatch(op, stack, nullptr);
+        interpreter->dispatch(op, stack);
         return;
       }
     } else if (ivalue.isTensorList()) {
@@ -33,7 +24,7 @@ void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
       for (const auto& nv : ivalue.toListRef()) {
         auto* interpreter = nv.unsafeToTensorImpl()->pyobj_interpreter();
         if (interpreter) {
-          interpreter->dispatch(op, stack, nullptr);
+          interpreter->dispatch(op, stack);
           return;
         }
       }
index 9a72659..de829c4 100644 (file)
@@ -40,8 +40,7 @@ static c10::intrusive_ptr<TensorImpl> noop_detach_fn(
 static void noop_dispatch_fn(
     const PyInterpreter*,
     const c10::OperatorHandle& op,
-    torch::jit::Stack* stack,
-    const std::shared_ptr<TorchDispatchTypeObject>& type) {
+    torch::jit::Stack* stack) {
   TORCH_INTERNAL_ASSERT(
       0,
       "attempted to dispatch (__torch_dispatch__) an operator on Tensor with nontrivial PyObject after corresponding interpreter died");
@@ -609,23 +608,6 @@ void TensorImpl::copy_tensor_metadata(
   }
 }
 
-TorchDispatchTypeObject::TorchDispatchTypeObject(
-    PyObject* type_object,
-    c10::impl::PyInterpreter* pyinterpreter)
-    : data_(type_object), pyinterpreter_(pyinterpreter) {}
-
-TorchDispatchTypeObject::~TorchDispatchTypeObject() {
-  pyinterpreter_->decref(data_);
-}
-
-c10::impl::PyInterpreter* TorchDispatchTypeObject::pyinterpreter() const {
-  return pyinterpreter_;
-}
-
-PyObject* TorchDispatchTypeObject::ptr() const {
-  return data_;
-}
-
 namespace impl {
 
 namespace {
index d110a17..7051e36 100644 (file)
@@ -161,9 +161,6 @@ struct C10_API AutogradMetaInterface {
   virtual ~AutogradMetaInterface();
 };
 
-// forward declared
-struct TorchDispatchTypeObject;
-
 namespace impl {
 
 // Unfortunately, the definition of AutogradMeta lives in a separate
@@ -258,8 +255,7 @@ struct C10_API PyInterpreter {
   using dispatch_sig = void(
       const PyInterpreter*,
       const c10::OperatorHandle&,
-      torch::jit::Stack* stack,
-      const std::shared_ptr<TorchDispatchTypeObject>& type);
+      torch::jit::Stack* stack);
 
   PyInterpreter(
       name_sig* name_fn,
@@ -303,9 +299,8 @@ struct C10_API PyInterpreter {
   // Invoke the Python boxed fallback dispatch to go back into Python
   __ubsan_ignore_function__ void dispatch(
       const c10::OperatorHandle& op,
-      torch::jit::Stack* stack,
-      const std::shared_ptr<TorchDispatchTypeObject>& type) const {
-    return (*dispatch_fn_)(this, op, stack, type);
+      torch::jit::Stack* stack) const {
+    return (*dispatch_fn_)(this, op, stack);
   }
 
   // Disarm this PyInterpreter, making all of its methods noops.
@@ -353,30 +348,6 @@ struct C10_API NamedTensorMetaInterface {
   };
 };
 
-// NOTE [What is TorchDispatchTypeObject?]
-// A TorchDispatchTypeObject represents the type of a Tensor subclass that has
-// a __torch_dispatch__ classmethod. Concretely, it holds the class as a
-// PyObject* and a PyInterpreter* that says which python interpreter the class
-// came from.
-//
-// See NOTE [dispatch_fn's type argument] for more details
-struct C10_API TorchDispatchTypeObject {
-  // Steals a reference to type_object
-  TorchDispatchTypeObject(
-      PyObject* type_object,
-      c10::impl::PyInterpreter* pyinterpreter);
-
-  // Releases the stolen reference to type_object
-  ~TorchDispatchTypeObject();
-
-  c10::impl::PyInterpreter* pyinterpreter() const;
-  PyObject* ptr() const;
-
- private:
-  PyObject* data_;
-  c10::impl::PyInterpreter* pyinterpreter_;
-};
-
 // NOTE [ Version Counter Sharing ]
 //
 // Every Tensor has a version counter. Version counters are incremented whenever
index d0871fa..55b2f38 100755 (executable)
@@ -104,7 +104,6 @@ TESTS = [
     "test_optim",
     "test_functional_optim",
     "test_pytree",
-    "test_python_dispatch",
     "test_mobile_optimizer",
     "test_set_default_mobile_cpu_allocator",
     "test_xnnpack_integration",
index e474f1f..0f5b6b9 100644 (file)
@@ -1,7 +1,6 @@
 import torch
 from torch.testing._internal.common_utils import TestCase, run_tests
 from torch.utils._pytree import tree_map
-from torch.utils._python_dispatch import enable_python_mode
 
 from typing import Iterator, List
 import logging
@@ -51,10 +50,7 @@ class LoggingTensor(torch.Tensor):
         def wrap(e):
             return LoggingTensor(e) if isinstance(e, torch.Tensor) else e
 
-        # no_dispatch is only needed if you use enable_python_mode.
-        # It prevents infinite recursion.
-        with no_dispatch():
-            rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
+        rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
         logging.getLogger("LoggingTensor").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs)
         return rs
 
@@ -339,81 +335,6 @@ $4 = torch._ops.aten.mul($3, tensor(2))
 $5 = torch._ops.aten.mul($4, $0)
 $6 = torch._ops.aten.add_($1, $5)''')
 
-    def test_enable_python_mode_error(self) -> None:
-        with self.assertRaisesRegex(ValueError, "__torch_dispatch__"):
-            with enable_python_mode(torch.Tensor):
-                pass
-        z = LoggingTensor(torch.empty([]))
-        with self.assertRaisesRegex(ValueError, "must be the type"):
-            with enable_python_mode(z):
-                pass
-
-    def test_enable_python_mode_basic(self) -> None:
-        with enable_python_mode(LoggingTensor):
-            z = torch.empty([])
-            self.assertTrue(isinstance(z, LoggingTensor))
-
-    def test_enable_python_mode_unrelated_tensors(self) -> None:
-        x = torch.randn([])
-        y = torch.randn([])
-        with enable_python_mode(LoggingTensor):
-            z = x + y
-            self.assertTrue(isinstance(z, LoggingTensor))
-
-    def test_enable_python_mode_subclass_priority(self) -> None:
-        class ErrorA(RuntimeError):
-            pass
-
-        class ErrorB(RuntimeError):
-            pass
-
-        class A(torch.Tensor):
-            @staticmethod
-            def __new__(cls, elem):
-                return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
-
-            @classmethod
-            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
-                raise ErrorA
-
-        class B(A):
-            @staticmethod
-            def __new__(cls, elem):
-                return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
-
-            @classmethod
-            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
-                raise ErrorB
-
-        a = A(torch.empty(1))
-        b = B(torch.empty(1))
-        with self.assertRaises(ErrorA):
-            a + a
-
-        # B has precedence over A due to the subclass relationship
-        with self.assertRaises(ErrorB):
-            with enable_python_mode(A):
-                b + b
-        with self.assertRaises(ErrorB):
-            with enable_python_mode(B):
-                a + a
-        with self.assertRaises(ErrorB):
-            with enable_python_mode(B):
-                a + b
-
-    def test_enable_python_mode_respects_no_dispatch(self) -> None:
-        with enable_python_mode(LoggingTensor):
-            z = torch.ones([2, 3])
-            self.assertTrue(isinstance(z, LoggingTensor))
-            with no_dispatch():
-                expected = torch.ones([2, 3])
-                self.assertEqual(z.elem, expected)
-
-    def test_nested_enable_python_mode(self) -> None:
-        with self.assertRaisesRegex(RuntimeError, "has already been set"):
-            with enable_python_mode(LoggingTensor):
-                with enable_python_mode(LoggingTensor):
-                    pass
 
 if __name__ == '__main__':
     run_tests()
index dd89981..34846b5 100644 (file)
@@ -666,7 +666,6 @@ libtorch_python_core_sources = [
     "torch/csrc/autograd/init.cpp",
     "torch/csrc/autograd/python_anomaly_mode.cpp",
     "torch/csrc/autograd/python_saved_variable_hooks.cpp",
-    "torch/csrc/autograd/python_mode.cpp",
     "torch/csrc/autograd/python_cpp_function.cpp",
     "torch/csrc/autograd/python_engine.cpp",
     "torch/csrc/autograd/python_function.cpp",
@@ -794,7 +793,6 @@ aten_cpu_source_non_codegen_list = [
     "aten/src/ATen/ParallelNativeTBB.cpp",
     "aten/src/ATen/ParallelOpenMP.cpp",
     "aten/src/ATen/ParallelThreadPoolNative.cpp",
-    "aten/src/ATen/PythonModeTLS.cpp",
     "aten/src/ATen/ScalarOps.cpp",
     "aten/src/ATen/SequenceNumber.cpp",
     "aten/src/ATen/SparseTensorImpl.cpp",
index 352edbe..01fdf9e 100644 (file)
@@ -652,8 +652,6 @@ def __set_forward_AD_enabled(enabled: _bool) -> None: ...
 def __is_forward_AD_enabled() -> _bool: ...
 def _register_default_hooks(pack_hook: Callable, unpack_hook: Callable) -> None: ...
 def _reset_default_hooks() -> None: ...
-def _enter_python_mode(cls: Type) -> None: ...
-def _exit_python_mode() -> None: ...
 
 class _InferenceMode(object):
     def __init__(self, mode: _bool) -> None: ...
index 860aaec..697ca87 100644 (file)
@@ -14,7 +14,6 @@
 #include <torch/csrc/autograd/python_saved_variable_hooks.h>
 #include <torch/csrc/autograd/utils/wrap_outputs.h>
 #include <torch/csrc/autograd/utils/python_arg_parsing.h>
-#include <torch/csrc/autograd/python_mode.h>
 #include <torch/csrc/utils/pycfunction_helpers.h>
 #include <c10/core/ScalarType.h>
 
@@ -495,20 +494,6 @@ static PyObject * python_exit_dual_level(PyObject* _unused, PyObject* args, PyOb
   END_HANDLE_TH_ERRORS
 }
 
-static PyObject * enter_python_mode(PyObject* _unused, PyObject* arg) {
-  HANDLE_TH_ERRORS
-  PythonMode::enter(arg);
-  Py_RETURN_NONE;
-  END_HANDLE_TH_ERRORS
-}
-
-static PyObject * exit_python_mode(PyObject* _unused, PyObject* arg) {
-  HANDLE_TH_ERRORS
-  PythonMode::exit();
-  Py_RETURN_NONE;
-  END_HANDLE_TH_ERRORS
-}
-
 // autograd methods on torch._C
 static PyMethodDef methods[] = { // NOLINT
   {"_set_grad_enabled", set_grad_enabled, METH_O, nullptr},
@@ -529,8 +514,6 @@ static PyMethodDef methods[] = { // NOLINT
   {"is_anomaly_enabled", is_anomaly_mode_enabled, METH_NOARGS, nullptr},
   {"_enter_dual_level", python_enter_dual_level, METH_NOARGS, nullptr},
   {"_exit_dual_level", castPyCFunctionWithKeywords(python_exit_dual_level), METH_VARARGS | METH_KEYWORDS, nullptr},
-  {"_enter_python_mode", enter_python_mode, METH_O, nullptr},
-  {"_exit_python_mode", exit_python_mode, METH_NOARGS, nullptr},
   {nullptr, nullptr, 0, nullptr}
 };
 
diff --git a/torch/csrc/autograd/python_mode.cpp b/torch/csrc/autograd/python_mode.cpp
deleted file mode 100644 (file)
index 4358426..0000000
+++ /dev/null
@@ -1,27 +0,0 @@
-#include <torch/csrc/autograd/python_mode.h>
-#include <torch/csrc/python_headers.h>
-#include <torch/csrc/autograd/python_variable.h>
-#include <ATen/PythonModeTLS.h>
-#include <c10/core/TensorImpl.h>
-
-namespace torch { namespace autograd {
-
-void PythonMode::enter(PyObject* type) {
-  if (at::impl::PythonModeTLS::get_state()) {
-    TORCH_CHECK(
-        false,
-        "python mode has already been set. We do not yet support nested python ",
-        "mode. Please file us an issue and reset it before setting it again.")
-  }
-  // TorchDispatchTypeObject steals a reference, See NOTE [What is TorchDispatchTypeObject?]
-  Py_INCREF(type);
-  auto state = std::make_shared<c10::TorchDispatchTypeObject>(type, getPyInterpreter());
-  at::impl::PythonModeTLS::set_state(state);
-}
-
-void PythonMode::exit() {
-  TORCH_INTERNAL_ASSERT(at::impl::PythonModeTLS::get_state(), "exiting Python Mode but it wasn't set!");
-  at::impl::PythonModeTLS::reset_state();
-}
-
-}}
diff --git a/torch/csrc/autograd/python_mode.h b/torch/csrc/autograd/python_mode.h
deleted file mode 100644 (file)
index 03da51c..0000000
+++ /dev/null
@@ -1,17 +0,0 @@
-#pragma once
-
-#include <torch/csrc/python_headers.h>
-#include <c10/core/TensorImpl.h>
-
-namespace torch { namespace autograd {
-
-struct TORCH_API PythonMode {
-  // Enter python mode, causing all operators to dispatch to the type's __torch_dispatch__.
-  // `type` is the type of a Tensor subclass that has __torch_dispatch__.
-  static void enter(PyObject* type);
-
-  // Exit the current python mode.
-  static void exit();
-};
-
-}}
index abe9010..50d6eb9 100644 (file)
@@ -32,7 +32,6 @@
 
 #include <torch/library.h>
 #include <torch/csrc/jit/python/pybind_utils.h>
-#include <torch/csrc/autograd/python_mode.h>
 
 
 #include <ATen/ATen.h>
@@ -65,12 +64,7 @@ void concrete_decref_fn(const c10::impl::PyInterpreter* self, PyObject* pyobj) {
     return;
 
   pybind11::gil_scoped_acquire gil;
-  // Two possibilities:
-  // 1. We are decref-ing a tensor. Then we must be careful about
-  // PyObject resurrection (this only applies to Tensors, see THPVariable_clear).
-  // 2. We are decref-ing some other Python object. We don't do
-  // PyObject resurrection on non-Tensors, so we just carry on as usual
-  if (THPVariable_Check(pyobj) && Py_REFCNT(pyobj) > 1) {
+  if (Py_REFCNT(pyobj) > 1) {
     // It's still alive!  This can happen if a weak ref resurrected
     // the PyObject without flipping ownership.  At this point it is
     // too late to rescue the object, so just stub out the PyObject
@@ -88,11 +82,7 @@ void concrete_decref_fn(const c10::impl::PyInterpreter* self, PyObject* pyobj) {
 };
 
 c10::intrusive_ptr<TensorImpl> concrete_detach_fn(const c10::impl::PyInterpreter*, const c10::TensorImpl* self);
-void concrete_dispatch_fn(
-    const c10::impl::PyInterpreter*,
-    const c10::OperatorHandle& op,
-    torch::jit::Stack* stack,
-    const std::shared_ptr<TorchDispatchTypeObject>& type);
+void concrete_dispatch_fn(const c10::impl::PyInterpreter*, const c10::OperatorHandle& op, torch::jit::Stack* stack);
 
 class PyInterpreterHolder {
  public:
@@ -1501,19 +1491,7 @@ bool isPythonTensor(const Tensor& tensor) {
   return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Python);
 }
 
-// NOTE [dispatch_fn's type argument]
-// `type` is nullable and represents the PythonMode going on.
-// Right now we only support a single PythonMode, but in the future we could
-// change this to a stack of PythonModes.
-//
-// If `type` isn't null, then we consider the type for dispatch by prepending
-// it to the overloaded_args list. `handle_torch_funciton_no_python_arg_parser`
-// is responsible for doing overload resolution.
-void concrete_dispatch_fn(
-    const c10::impl::PyInterpreter*,
-    const c10::OperatorHandle& op,
-    torch::jit::Stack* stack,
-    const std::shared_ptr<TorchDispatchTypeObject>& type) {
+void concrete_dispatch_fn(const c10::impl::PyInterpreter*, const c10::OperatorHandle& op, torch::jit::Stack* stack) {
   const auto& schema = op.schema();
   const auto num_returns = schema.returns().size();
 
@@ -1590,17 +1568,13 @@ void concrete_dispatch_fn(
   auto args = py::reinterpret_steal<py::object>(PyTuple_New(positional_default_start));
   py::dict kwargs;
 
-  if (type) {
-    append_overloaded_type(&overloaded_args, type->ptr());
-  }
-
   // Find overloaded tensors
   for (int64_t idx = 0; idx < arguments.size(); idx++) {
     const auto& ivalue = arguments[idx];
     if (ivalue.isTensor()) {
       const auto& tensor = ivalue.toTensor();
       if (isPythonTensor(tensor)) {
-        append_overloaded_tensor(&overloaded_args, py::cast(tensor).ptr());
+        append_overloaded_arg(&overloaded_args, py::cast(tensor).ptr());
       }
     } else if (ivalue.isList()) {
       const auto& list = ivalue.toListRef();
@@ -1609,7 +1583,7 @@ void concrete_dispatch_fn(
         if (nv.isTensor()) {
           const auto& tensor = nv.toTensor();
           if (isPythonTensor(tensor)) {
-            append_overloaded_tensor(&overloaded_args, py::cast(tensor).ptr());
+            append_overloaded_arg(&overloaded_args, py::cast(tensor).ptr());
           }
         }
       }
@@ -1659,7 +1633,7 @@ c10::intrusive_ptr<TensorImpl> concrete_detach_fn(const c10::impl::PyInterpreter
   Tensor self_t = Tensor(c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::unsafe_reclaim_from_nonowning(const_cast<c10::TensorImpl*>(self)));
   auto self_p = py::reinterpret_steal<py::object>(THPVariable_Wrap(self_t));
   TORCH_INTERNAL_ASSERT(isPythonTensor(self_t));
-  append_overloaded_tensor(&overloaded_args, self_p.ptr());
+  append_overloaded_arg(&overloaded_args, self_p.ptr());
   auto args = py::reinterpret_steal<py::object>(PyTuple_New(1));
   PyTuple_SET_ITEM(args.ptr(), 0, self_p.release().ptr());
 
index 3ee20c0..6115dcd 100644 (file)
@@ -200,28 +200,12 @@ auto handle_torch_function(PyObject* self, const std::string& func_name, PyObjec
   return ret.release().ptr();
 }
 
-// Note: [Overloaded args]
-// An overloaded arg may be one of the following:
-// - an instance of an object that has a __torch_function__ method
-// - an instance of an object that has a __torch_dispatch__ classmethod
-// - a class type that has a __torch_dispatch__ classmethod
-//
-// This function returns the type of the arg (if the arg is an instance),
-// otherwise, it returns the arg.
-static PyObject* get_type_of_overloaded_arg(PyObject* obj_or_type) {
-  if (PyType_Check(obj_or_type)) {
-    return obj_or_type;
-  }
-  return (PyObject*)Py_TYPE(obj_or_type);
-}
-
-// See Note: [Overloaded args] for what they hold
 auto handle_torch_function_no_python_arg_parser(const std::vector<py::handle> &overloaded_args, PyObject* args, PyObject* kwargs, const char* func_name, PyObject* torch_api_function, const char* module_name, const char* torch_function_name) -> PyObject* {
   // overloaded_args already all have unique types
   std::vector<py::object> overloaded_types;
   overloaded_types.reserve(overloaded_args.size());
   for (auto &arg : overloaded_args) {
-    overloaded_types.push_back(py::reinterpret_borrow<py::object>(get_type_of_overloaded_arg(arg.ptr())));
+    overloaded_types.push_back(py::reinterpret_borrow<py::object>((PyObject *) Py_TYPE(arg.ptr())));
   }
   py::tuple py_types = py::cast(overloaded_types);
   py::object ret;
@@ -247,7 +231,7 @@ auto handle_torch_function_no_python_arg_parser(const std::vector<py::handle> &o
     ss << "no implementation found for '" << module_name << "." << func_name
        << "' on types that implement " << torch_function_name << ": [";
     for (auto &arg : overloaded_args) {
-      ss << PyObject_Repr(get_type_of_overloaded_arg(arg.ptr()));
+      ss << arg.ptr()->ob_type->tp_name;
       if (!arg.is(overloaded_args.back())) {
         ss << ", ";
       }
@@ -344,11 +328,10 @@ auto handle_torch_function_indexing(PyObject* self, PyObject* index, PyObject* v
  *
  */
 
-static void append_overloaded_arg(std::vector<py::handle>* overloaded_args, PyObject* obj, bool obj_is_type) {
+void append_overloaded_arg(std::vector<py::handle>* overloaded_args, PyObject* obj) {
   bool class_not_seen_yet = true;
-  PyObject* obj_type = obj_is_type ? obj : (PyObject*)Py_TYPE(obj);
   for (auto &arg : *overloaded_args) {
-    if (obj_type == get_type_of_overloaded_arg(arg.ptr())) {
+    if (Py_TYPE(obj) == Py_TYPE(arg.ptr())) {
       // obj is the same type as another parameter we've seen in a prior
       // iteration of the loop over parameters so we already have an entry
       // with the proper __torch_function__ implementation to call, so skip
@@ -360,7 +343,7 @@ static void append_overloaded_arg(std::vector<py::handle>* overloaded_args, PyOb
   if (class_not_seen_yet) {
     int arg_index = overloaded_args->size();
     for(const auto j : c10::irange(arg_index)) {
-      if (PyObject_IsSubclass(obj_type, (PyObject*)(get_type_of_overloaded_arg((*overloaded_args)[j].ptr())))) {
+      if (PyObject_IsInstance(obj, (PyObject*)(Py_TYPE((*overloaded_args)[j].ptr())))) {
         // obj is a subclass of another object we've seen already so its
         // __torch_function__ should be called first, therefore we
         // insert it into overloaded_args before the superclass
@@ -375,14 +358,6 @@ static void append_overloaded_arg(std::vector<py::handle>* overloaded_args, PyOb
   }
 }
 
-void append_overloaded_tensor(std::vector<py::handle>* overloaded_args, PyObject* obj) {
-  append_overloaded_arg(overloaded_args, obj, /*obj_is_type*/false);
-}
-
-void append_overloaded_type(std::vector<py::handle>* overloaded_args, PyObject* obj) {
-  append_overloaded_arg(overloaded_args, obj, /*obj_is_type*/true);
-}
-
 bool is_tensor_and_append_overloaded(PyObject* obj, std::vector<py::handle>* overloaded_args) {
   if (THPVariable_CheckExact(obj)) {
     // torch.Tensor instances (not subclasses, except for Parameter)
@@ -391,7 +366,7 @@ bool is_tensor_and_append_overloaded(PyObject* obj, std::vector<py::handle>* ove
 
   if (check_has_torch_function(obj)) {
     // tensor subclasses and unrelated objects with __torch_function__
-    append_overloaded_tensor(overloaded_args, obj);
+    append_overloaded_arg(overloaded_args, obj);
     return true;
   } else if (THPVariable_Check(obj)) {
     // tensor subclasses without __torch_function__
@@ -930,7 +905,7 @@ bool FunctionSignature::parse(PyObject* self, PyObject* args, PyObject* kwargs,
 
   int i = 0;
   if (self != nullptr && check_has_torch_function(self)) {
-    append_overloaded_tensor(&this->overloaded_args, self);
+    append_overloaded_arg(&this->overloaded_args, self);
   }
   for (auto& param : params) {
     PyObject* obj = nullptr;
index 6a05807..d132185 100644 (file)
@@ -818,15 +818,6 @@ bool is_tensor_list_and_append_overloaded(PyObject* obj, std::vector<py::handle>
  * 'overloaded_args': the vector to append the overloaded args
  * 'obj': the input tensor that is overloaded
  */
-void append_overloaded_tensor(std::vector<py::handle>* overloaded_args, PyObject* obj);
-
-/* Given an argument that is definitely a type and is definitely overloaded,
- * append it to the overloaded arguments list. Use this only with __torch_dispatch__,
- * where we operate on classes that have a __torch_dispatch__ classmethod.
- *
- * 'overloaded_args': the vector to append the overloaded type
- * 'obj': the input class that has a __torch_dispatch__ classmethod.
- */
-void append_overloaded_type(std::vector<py::handle>* overloaded_args, PyObject* obj);
+void append_overloaded_arg(std::vector<py::handle>* overloaded_args, PyObject* obj);
 
 } // namespace torch
index 25e9a59..17d7acc 100644 (file)
@@ -267,7 +267,6 @@ Tensor internal_new_from_data(
   {
     at::AutoDispatchBelowADInplaceOrView guard;  // TODO: remove
     at::tracer::impl::NoTracerDispatchMode tracer_guard;
-    c10::impl::ExcludeDispatchKeyGuard pythonmode_guard(c10::DispatchKey::Python);
     // functorch uses FuncTorchDynamicLayerBackMode as a mode key to wrap all
     // tensors returned from operators in special TensorWrapper tensor extension
     // The problem with this is that TensorWrapper does not have storage so
diff --git a/torch/utils/_python_dispatch.py b/torch/utils/_python_dispatch.py
deleted file mode 100644 (file)
index a7cfae1..0000000
+++ /dev/null
@@ -1,34 +0,0 @@
-import torch
-import contextlib
-from typing import Iterator
-
-# Context manager that causes all pytorch operators to dispatch to the passed-in
-# type's __torch_dispatch__ function.
-# operation that accepts no tensors but returns a tensor.
-#
-# enable_python_mode is affected by torch._C._DisableTorchDispatch.
-#
-# NB: Calling an operator inside __torch_dispatch__ does go through
-# __torch_dispatch__ again. Please use _DisableTorchDispatch inside
-# __torch_dispatch__ to prevent infinite recursion.
-#
-# TODO: Limitations and things about enable_python_mode we should fix before exposing it:
-# - it currently cannot be nested. This should be simple to implement; we need a
-#   stack of TorchDispatchTypeObjects and the next bullet point.
-# - We need a better user-facing api for torch._C._DisableTorchDispatch that
-#   is able to selectively disable __torch_dispatch__ of a particular class.
-# - It doesn't work with the tensor constructors (torch.tensor, torch.Tensor)
-# - Better name (see https://github.com/pytorch/pytorch/pull/63496#discussion_r694091694)
-@contextlib.contextmanager
-def enable_python_mode(cls) -> Iterator[None]:
-    if not hasattr(cls, '__torch_dispatch__'):
-        raise ValueError('The class passed to enable_python_mode '
-                         'must have a __torch_dispatch__ classmethod')
-    if not isinstance(cls, type) or not issubclass(cls, (torch.Tensor,)):
-        raise ValueError('The argument passed to enable_python_mode '
-                         'must be the type of a Tensor subclass')
-    torch._C._enter_python_mode(cls)
-    try:
-        yield
-    finally:
-        torch._C._exit_python_mode()