From 0457a85d459479881ad07e84a8e9f53bf82bb48d Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Tue, 31 Aug 2021 14:53:01 -0700 Subject: [PATCH] Revert D30543236: Add python mode Test Plan: revert-hammer Differential Revision: D30543236 (https://github.com/pytorch/pytorch/commit/4bd03b02424d93b72f15e28c542ede13f88ea929) Original commit changeset: ef5444d96a5a fbshipit-source-id: b0042ac2c22765fa11d6d00bf751f6a4489eb6d8 --- aten/src/ATen/PythonModeTLS.cpp | 26 --------- aten/src/ATen/PythonModeTLS.h | 17 ------ aten/src/ATen/ThreadLocalState.cpp | 3 -- aten/src/ATen/ThreadLocalState.h | 3 -- aten/src/ATen/core/PythonFallbackKernel.cpp | 13 +---- c10/core/TensorImpl.cpp | 20 +------ c10/core/TensorImpl.h | 35 ++----------- test/run_test.py | 1 - test/test_python_dispatch.py | 81 +---------------------------- tools/build_variables.bzl | 2 - torch/_C/__init__.pyi.in | 2 - torch/csrc/autograd/init.cpp | 17 ------ torch/csrc/autograd/python_mode.cpp | 27 ---------- torch/csrc/autograd/python_mode.h | 17 ------ torch/csrc/autograd/python_variable.cpp | 38 +++----------- torch/csrc/utils/python_arg_parser.cpp | 39 +++----------- torch/csrc/utils/python_arg_parser.h | 11 +--- torch/csrc/utils/tensor_new.cpp | 1 - torch/utils/_python_dispatch.py | 34 ------------ 19 files changed, 21 insertions(+), 366 deletions(-) delete mode 100644 aten/src/ATen/PythonModeTLS.cpp delete mode 100644 aten/src/ATen/PythonModeTLS.h delete mode 100644 torch/csrc/autograd/python_mode.cpp delete mode 100644 torch/csrc/autograd/python_mode.h delete mode 100644 torch/utils/_python_dispatch.py diff --git a/aten/src/ATen/PythonModeTLS.cpp b/aten/src/ATen/PythonModeTLS.cpp deleted file mode 100644 index b53043c..0000000 --- a/aten/src/ATen/PythonModeTLS.cpp +++ /dev/null @@ -1,26 +0,0 @@ -#include - -namespace at { namespace impl { - -thread_local std::shared_ptr pythonModeState; - -void PythonModeTLS::set_state(const std::shared_ptr& state) { - pythonModeState = state; - if (state) { - c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true); - } else { - PythonModeTLS::reset_state(); - } -} - -const std::shared_ptr& PythonModeTLS::get_state() { - return pythonModeState; -} - -void PythonModeTLS::reset_state() { - pythonModeState.reset((TorchDispatchTypeObject*)nullptr); - c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false); -} - -} // namespace impl -} // namespace at diff --git a/aten/src/ATen/PythonModeTLS.h b/aten/src/ATen/PythonModeTLS.h deleted file mode 100644 index be52b18..0000000 --- a/aten/src/ATen/PythonModeTLS.h +++ /dev/null @@ -1,17 +0,0 @@ -#pragma once - -#include -#include -#include - -namespace at { -namespace impl { - -struct TORCH_API PythonModeTLS { - static void set_state(const std::shared_ptr& state); - static const std::shared_ptr& get_state(); - static void reset_state(); -}; - -} // namespace impl -} // namespace at diff --git a/aten/src/ATen/ThreadLocalState.cpp b/aten/src/ATen/ThreadLocalState.cpp index 19cfa89..98c2519 100644 --- a/aten/src/ATen/ThreadLocalState.cpp +++ b/aten/src/ATen/ThreadLocalState.cpp @@ -17,7 +17,6 @@ ThreadLocalState::ThreadLocalState() saved_tensors_default_hooks_ = SavedTensorDefaultHooks::get_hooks(); bumped_record_all_functions_ = at::checkRecordAllFunctions(); - python_mode_state_ = at::impl::PythonModeTLS::get_state(); } void ThreadLocalState::set_grad_mode(bool enabled) { @@ -31,8 +30,6 @@ void ThreadLocalState::setThreadLocalState( // restore the dispatch key set TLS at the same time. c10::AutogradState::set_tls_state(state.autograd_tls_); - at::impl::PythonModeTLS::set_state(state.python_mode_state_); - at::set_record_function_tls_(state.rf_tls_); SavedTensorDefaultHooks::set_hooks( diff --git a/aten/src/ATen/ThreadLocalState.h b/aten/src/ATen/ThreadLocalState.h index c99ca61..4114691 100644 --- a/aten/src/ATen/ThreadLocalState.h +++ b/aten/src/ATen/ThreadLocalState.h @@ -6,7 +6,6 @@ #include #include -#include namespace at { @@ -41,8 +40,6 @@ class TORCH_API ThreadLocalState { // TLS for AutogradModes AutogradState autograd_tls_; - std::shared_ptr python_mode_state_; - // TLS for saved tensors default hooks std::pair saved_tensors_default_hooks_; diff --git a/aten/src/ATen/core/PythonFallbackKernel.cpp b/aten/src/ATen/core/PythonFallbackKernel.cpp index 8e77d09..276eabf 100644 --- a/aten/src/ATen/core/PythonFallbackKernel.cpp +++ b/aten/src/ATen/core/PythonFallbackKernel.cpp @@ -1,18 +1,9 @@ #include #include -#include namespace { void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { - // If Python Mode is active, use its PyInterpreter for dispatch - const auto& maybe_python_mode_state = at::impl::PythonModeTLS::get_state(); - if (maybe_python_mode_state) { - maybe_python_mode_state->pyinterpreter()->dispatch(op, stack, maybe_python_mode_state); - return; - } - - // Otherwise, find a PyInterpreter on a Tensor const auto& schema = op.schema(); const auto num_arguments = schema.arguments().size(); // It is safe to dispatch on the very first Tensor with a pyobj_interpreter @@ -24,7 +15,7 @@ void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { if (ivalue.isTensor()) { auto* interpreter = ivalue.unsafeToTensorImpl()->pyobj_interpreter(); if (interpreter) { - interpreter->dispatch(op, stack, nullptr); + interpreter->dispatch(op, stack); return; } } else if (ivalue.isTensorList()) { @@ -33,7 +24,7 @@ void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { for (const auto& nv : ivalue.toListRef()) { auto* interpreter = nv.unsafeToTensorImpl()->pyobj_interpreter(); if (interpreter) { - interpreter->dispatch(op, stack, nullptr); + interpreter->dispatch(op, stack); return; } } diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp index 9a72659..de829c4 100644 --- a/c10/core/TensorImpl.cpp +++ b/c10/core/TensorImpl.cpp @@ -40,8 +40,7 @@ static c10::intrusive_ptr noop_detach_fn( static void noop_dispatch_fn( const PyInterpreter*, const c10::OperatorHandle& op, - torch::jit::Stack* stack, - const std::shared_ptr& type) { + torch::jit::Stack* stack) { TORCH_INTERNAL_ASSERT( 0, "attempted to dispatch (__torch_dispatch__) an operator on Tensor with nontrivial PyObject after corresponding interpreter died"); @@ -609,23 +608,6 @@ void TensorImpl::copy_tensor_metadata( } } -TorchDispatchTypeObject::TorchDispatchTypeObject( - PyObject* type_object, - c10::impl::PyInterpreter* pyinterpreter) - : data_(type_object), pyinterpreter_(pyinterpreter) {} - -TorchDispatchTypeObject::~TorchDispatchTypeObject() { - pyinterpreter_->decref(data_); -} - -c10::impl::PyInterpreter* TorchDispatchTypeObject::pyinterpreter() const { - return pyinterpreter_; -} - -PyObject* TorchDispatchTypeObject::ptr() const { - return data_; -} - namespace impl { namespace { diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index d110a17..7051e36 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -161,9 +161,6 @@ struct C10_API AutogradMetaInterface { virtual ~AutogradMetaInterface(); }; -// forward declared -struct TorchDispatchTypeObject; - namespace impl { // Unfortunately, the definition of AutogradMeta lives in a separate @@ -258,8 +255,7 @@ struct C10_API PyInterpreter { using dispatch_sig = void( const PyInterpreter*, const c10::OperatorHandle&, - torch::jit::Stack* stack, - const std::shared_ptr& type); + torch::jit::Stack* stack); PyInterpreter( name_sig* name_fn, @@ -303,9 +299,8 @@ struct C10_API PyInterpreter { // Invoke the Python boxed fallback dispatch to go back into Python __ubsan_ignore_function__ void dispatch( const c10::OperatorHandle& op, - torch::jit::Stack* stack, - const std::shared_ptr& type) const { - return (*dispatch_fn_)(this, op, stack, type); + torch::jit::Stack* stack) const { + return (*dispatch_fn_)(this, op, stack); } // Disarm this PyInterpreter, making all of its methods noops. @@ -353,30 +348,6 @@ struct C10_API NamedTensorMetaInterface { }; }; -// NOTE [What is TorchDispatchTypeObject?] -// A TorchDispatchTypeObject represents the type of a Tensor subclass that has -// a __torch_dispatch__ classmethod. Concretely, it holds the class as a -// PyObject* and a PyInterpreter* that says which python interpreter the class -// came from. -// -// See NOTE [dispatch_fn's type argument] for more details -struct C10_API TorchDispatchTypeObject { - // Steals a reference to type_object - TorchDispatchTypeObject( - PyObject* type_object, - c10::impl::PyInterpreter* pyinterpreter); - - // Releases the stolen reference to type_object - ~TorchDispatchTypeObject(); - - c10::impl::PyInterpreter* pyinterpreter() const; - PyObject* ptr() const; - - private: - PyObject* data_; - c10::impl::PyInterpreter* pyinterpreter_; -}; - // NOTE [ Version Counter Sharing ] // // Every Tensor has a version counter. Version counters are incremented whenever diff --git a/test/run_test.py b/test/run_test.py index d0871fa..55b2f38 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -104,7 +104,6 @@ TESTS = [ "test_optim", "test_functional_optim", "test_pytree", - "test_python_dispatch", "test_mobile_optimizer", "test_set_default_mobile_cpu_allocator", "test_xnnpack_integration", diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py index e474f1f..0f5b6b9 100644 --- a/test/test_python_dispatch.py +++ b/test/test_python_dispatch.py @@ -1,7 +1,6 @@ import torch from torch.testing._internal.common_utils import TestCase, run_tests from torch.utils._pytree import tree_map -from torch.utils._python_dispatch import enable_python_mode from typing import Iterator, List import logging @@ -51,10 +50,7 @@ class LoggingTensor(torch.Tensor): def wrap(e): return LoggingTensor(e) if isinstance(e, torch.Tensor) else e - # no_dispatch is only needed if you use enable_python_mode. - # It prevents infinite recursion. - with no_dispatch(): - rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))) + rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))) logging.getLogger("LoggingTensor").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs) return rs @@ -339,81 +335,6 @@ $4 = torch._ops.aten.mul($3, tensor(2)) $5 = torch._ops.aten.mul($4, $0) $6 = torch._ops.aten.add_($1, $5)''') - def test_enable_python_mode_error(self) -> None: - with self.assertRaisesRegex(ValueError, "__torch_dispatch__"): - with enable_python_mode(torch.Tensor): - pass - z = LoggingTensor(torch.empty([])) - with self.assertRaisesRegex(ValueError, "must be the type"): - with enable_python_mode(z): - pass - - def test_enable_python_mode_basic(self) -> None: - with enable_python_mode(LoggingTensor): - z = torch.empty([]) - self.assertTrue(isinstance(z, LoggingTensor)) - - def test_enable_python_mode_unrelated_tensors(self) -> None: - x = torch.randn([]) - y = torch.randn([]) - with enable_python_mode(LoggingTensor): - z = x + y - self.assertTrue(isinstance(z, LoggingTensor)) - - def test_enable_python_mode_subclass_priority(self) -> None: - class ErrorA(RuntimeError): - pass - - class ErrorB(RuntimeError): - pass - - class A(torch.Tensor): - @staticmethod - def __new__(cls, elem): - return torch.Tensor._make_subclass(cls, elem, elem.requires_grad) - - @classmethod - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - raise ErrorA - - class B(A): - @staticmethod - def __new__(cls, elem): - return torch.Tensor._make_subclass(cls, elem, elem.requires_grad) - - @classmethod - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - raise ErrorB - - a = A(torch.empty(1)) - b = B(torch.empty(1)) - with self.assertRaises(ErrorA): - a + a - - # B has precedence over A due to the subclass relationship - with self.assertRaises(ErrorB): - with enable_python_mode(A): - b + b - with self.assertRaises(ErrorB): - with enable_python_mode(B): - a + a - with self.assertRaises(ErrorB): - with enable_python_mode(B): - a + b - - def test_enable_python_mode_respects_no_dispatch(self) -> None: - with enable_python_mode(LoggingTensor): - z = torch.ones([2, 3]) - self.assertTrue(isinstance(z, LoggingTensor)) - with no_dispatch(): - expected = torch.ones([2, 3]) - self.assertEqual(z.elem, expected) - - def test_nested_enable_python_mode(self) -> None: - with self.assertRaisesRegex(RuntimeError, "has already been set"): - with enable_python_mode(LoggingTensor): - with enable_python_mode(LoggingTensor): - pass if __name__ == '__main__': run_tests() diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index dd89981..34846b5 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -666,7 +666,6 @@ libtorch_python_core_sources = [ "torch/csrc/autograd/init.cpp", "torch/csrc/autograd/python_anomaly_mode.cpp", "torch/csrc/autograd/python_saved_variable_hooks.cpp", - "torch/csrc/autograd/python_mode.cpp", "torch/csrc/autograd/python_cpp_function.cpp", "torch/csrc/autograd/python_engine.cpp", "torch/csrc/autograd/python_function.cpp", @@ -794,7 +793,6 @@ aten_cpu_source_non_codegen_list = [ "aten/src/ATen/ParallelNativeTBB.cpp", "aten/src/ATen/ParallelOpenMP.cpp", "aten/src/ATen/ParallelThreadPoolNative.cpp", - "aten/src/ATen/PythonModeTLS.cpp", "aten/src/ATen/ScalarOps.cpp", "aten/src/ATen/SequenceNumber.cpp", "aten/src/ATen/SparseTensorImpl.cpp", diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 352edbe..01fdf9e 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -652,8 +652,6 @@ def __set_forward_AD_enabled(enabled: _bool) -> None: ... def __is_forward_AD_enabled() -> _bool: ... def _register_default_hooks(pack_hook: Callable, unpack_hook: Callable) -> None: ... def _reset_default_hooks() -> None: ... -def _enter_python_mode(cls: Type) -> None: ... -def _exit_python_mode() -> None: ... class _InferenceMode(object): def __init__(self, mode: _bool) -> None: ... diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index 860aaec..697ca87 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -14,7 +14,6 @@ #include #include #include -#include #include #include @@ -495,20 +494,6 @@ static PyObject * python_exit_dual_level(PyObject* _unused, PyObject* args, PyOb END_HANDLE_TH_ERRORS } -static PyObject * enter_python_mode(PyObject* _unused, PyObject* arg) { - HANDLE_TH_ERRORS - PythonMode::enter(arg); - Py_RETURN_NONE; - END_HANDLE_TH_ERRORS -} - -static PyObject * exit_python_mode(PyObject* _unused, PyObject* arg) { - HANDLE_TH_ERRORS - PythonMode::exit(); - Py_RETURN_NONE; - END_HANDLE_TH_ERRORS -} - // autograd methods on torch._C static PyMethodDef methods[] = { // NOLINT {"_set_grad_enabled", set_grad_enabled, METH_O, nullptr}, @@ -529,8 +514,6 @@ static PyMethodDef methods[] = { // NOLINT {"is_anomaly_enabled", is_anomaly_mode_enabled, METH_NOARGS, nullptr}, {"_enter_dual_level", python_enter_dual_level, METH_NOARGS, nullptr}, {"_exit_dual_level", castPyCFunctionWithKeywords(python_exit_dual_level), METH_VARARGS | METH_KEYWORDS, nullptr}, - {"_enter_python_mode", enter_python_mode, METH_O, nullptr}, - {"_exit_python_mode", exit_python_mode, METH_NOARGS, nullptr}, {nullptr, nullptr, 0, nullptr} }; diff --git a/torch/csrc/autograd/python_mode.cpp b/torch/csrc/autograd/python_mode.cpp deleted file mode 100644 index 4358426..0000000 --- a/torch/csrc/autograd/python_mode.cpp +++ /dev/null @@ -1,27 +0,0 @@ -#include -#include -#include -#include -#include - -namespace torch { namespace autograd { - -void PythonMode::enter(PyObject* type) { - if (at::impl::PythonModeTLS::get_state()) { - TORCH_CHECK( - false, - "python mode has already been set. We do not yet support nested python ", - "mode. Please file us an issue and reset it before setting it again.") - } - // TorchDispatchTypeObject steals a reference, See NOTE [What is TorchDispatchTypeObject?] - Py_INCREF(type); - auto state = std::make_shared(type, getPyInterpreter()); - at::impl::PythonModeTLS::set_state(state); -} - -void PythonMode::exit() { - TORCH_INTERNAL_ASSERT(at::impl::PythonModeTLS::get_state(), "exiting Python Mode but it wasn't set!"); - at::impl::PythonModeTLS::reset_state(); -} - -}} diff --git a/torch/csrc/autograd/python_mode.h b/torch/csrc/autograd/python_mode.h deleted file mode 100644 index 03da51c..0000000 --- a/torch/csrc/autograd/python_mode.h +++ /dev/null @@ -1,17 +0,0 @@ -#pragma once - -#include -#include - -namespace torch { namespace autograd { - -struct TORCH_API PythonMode { - // Enter python mode, causing all operators to dispatch to the type's __torch_dispatch__. - // `type` is the type of a Tensor subclass that has __torch_dispatch__. - static void enter(PyObject* type); - - // Exit the current python mode. - static void exit(); -}; - -}} diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index abe9010..50d6eb9 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -32,7 +32,6 @@ #include #include -#include #include @@ -65,12 +64,7 @@ void concrete_decref_fn(const c10::impl::PyInterpreter* self, PyObject* pyobj) { return; pybind11::gil_scoped_acquire gil; - // Two possibilities: - // 1. We are decref-ing a tensor. Then we must be careful about - // PyObject resurrection (this only applies to Tensors, see THPVariable_clear). - // 2. We are decref-ing some other Python object. We don't do - // PyObject resurrection on non-Tensors, so we just carry on as usual - if (THPVariable_Check(pyobj) && Py_REFCNT(pyobj) > 1) { + if (Py_REFCNT(pyobj) > 1) { // It's still alive! This can happen if a weak ref resurrected // the PyObject without flipping ownership. At this point it is // too late to rescue the object, so just stub out the PyObject @@ -88,11 +82,7 @@ void concrete_decref_fn(const c10::impl::PyInterpreter* self, PyObject* pyobj) { }; c10::intrusive_ptr concrete_detach_fn(const c10::impl::PyInterpreter*, const c10::TensorImpl* self); -void concrete_dispatch_fn( - const c10::impl::PyInterpreter*, - const c10::OperatorHandle& op, - torch::jit::Stack* stack, - const std::shared_ptr& type); +void concrete_dispatch_fn(const c10::impl::PyInterpreter*, const c10::OperatorHandle& op, torch::jit::Stack* stack); class PyInterpreterHolder { public: @@ -1501,19 +1491,7 @@ bool isPythonTensor(const Tensor& tensor) { return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Python); } -// NOTE [dispatch_fn's type argument] -// `type` is nullable and represents the PythonMode going on. -// Right now we only support a single PythonMode, but in the future we could -// change this to a stack of PythonModes. -// -// If `type` isn't null, then we consider the type for dispatch by prepending -// it to the overloaded_args list. `handle_torch_funciton_no_python_arg_parser` -// is responsible for doing overload resolution. -void concrete_dispatch_fn( - const c10::impl::PyInterpreter*, - const c10::OperatorHandle& op, - torch::jit::Stack* stack, - const std::shared_ptr& type) { +void concrete_dispatch_fn(const c10::impl::PyInterpreter*, const c10::OperatorHandle& op, torch::jit::Stack* stack) { const auto& schema = op.schema(); const auto num_returns = schema.returns().size(); @@ -1590,17 +1568,13 @@ void concrete_dispatch_fn( auto args = py::reinterpret_steal(PyTuple_New(positional_default_start)); py::dict kwargs; - if (type) { - append_overloaded_type(&overloaded_args, type->ptr()); - } - // Find overloaded tensors for (int64_t idx = 0; idx < arguments.size(); idx++) { const auto& ivalue = arguments[idx]; if (ivalue.isTensor()) { const auto& tensor = ivalue.toTensor(); if (isPythonTensor(tensor)) { - append_overloaded_tensor(&overloaded_args, py::cast(tensor).ptr()); + append_overloaded_arg(&overloaded_args, py::cast(tensor).ptr()); } } else if (ivalue.isList()) { const auto& list = ivalue.toListRef(); @@ -1609,7 +1583,7 @@ void concrete_dispatch_fn( if (nv.isTensor()) { const auto& tensor = nv.toTensor(); if (isPythonTensor(tensor)) { - append_overloaded_tensor(&overloaded_args, py::cast(tensor).ptr()); + append_overloaded_arg(&overloaded_args, py::cast(tensor).ptr()); } } } @@ -1659,7 +1633,7 @@ c10::intrusive_ptr concrete_detach_fn(const c10::impl::PyInterpreter Tensor self_t = Tensor(c10::intrusive_ptr::unsafe_reclaim_from_nonowning(const_cast(self))); auto self_p = py::reinterpret_steal(THPVariable_Wrap(self_t)); TORCH_INTERNAL_ASSERT(isPythonTensor(self_t)); - append_overloaded_tensor(&overloaded_args, self_p.ptr()); + append_overloaded_arg(&overloaded_args, self_p.ptr()); auto args = py::reinterpret_steal(PyTuple_New(1)); PyTuple_SET_ITEM(args.ptr(), 0, self_p.release().ptr()); diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index 3ee20c0..6115dcd 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -200,28 +200,12 @@ auto handle_torch_function(PyObject* self, const std::string& func_name, PyObjec return ret.release().ptr(); } -// Note: [Overloaded args] -// An overloaded arg may be one of the following: -// - an instance of an object that has a __torch_function__ method -// - an instance of an object that has a __torch_dispatch__ classmethod -// - a class type that has a __torch_dispatch__ classmethod -// -// This function returns the type of the arg (if the arg is an instance), -// otherwise, it returns the arg. -static PyObject* get_type_of_overloaded_arg(PyObject* obj_or_type) { - if (PyType_Check(obj_or_type)) { - return obj_or_type; - } - return (PyObject*)Py_TYPE(obj_or_type); -} - -// See Note: [Overloaded args] for what they hold auto handle_torch_function_no_python_arg_parser(const std::vector &overloaded_args, PyObject* args, PyObject* kwargs, const char* func_name, PyObject* torch_api_function, const char* module_name, const char* torch_function_name) -> PyObject* { // overloaded_args already all have unique types std::vector overloaded_types; overloaded_types.reserve(overloaded_args.size()); for (auto &arg : overloaded_args) { - overloaded_types.push_back(py::reinterpret_borrow(get_type_of_overloaded_arg(arg.ptr()))); + overloaded_types.push_back(py::reinterpret_borrow((PyObject *) Py_TYPE(arg.ptr()))); } py::tuple py_types = py::cast(overloaded_types); py::object ret; @@ -247,7 +231,7 @@ auto handle_torch_function_no_python_arg_parser(const std::vector &o ss << "no implementation found for '" << module_name << "." << func_name << "' on types that implement " << torch_function_name << ": ["; for (auto &arg : overloaded_args) { - ss << PyObject_Repr(get_type_of_overloaded_arg(arg.ptr())); + ss << arg.ptr()->ob_type->tp_name; if (!arg.is(overloaded_args.back())) { ss << ", "; } @@ -344,11 +328,10 @@ auto handle_torch_function_indexing(PyObject* self, PyObject* index, PyObject* v * */ -static void append_overloaded_arg(std::vector* overloaded_args, PyObject* obj, bool obj_is_type) { +void append_overloaded_arg(std::vector* overloaded_args, PyObject* obj) { bool class_not_seen_yet = true; - PyObject* obj_type = obj_is_type ? obj : (PyObject*)Py_TYPE(obj); for (auto &arg : *overloaded_args) { - if (obj_type == get_type_of_overloaded_arg(arg.ptr())) { + if (Py_TYPE(obj) == Py_TYPE(arg.ptr())) { // obj is the same type as another parameter we've seen in a prior // iteration of the loop over parameters so we already have an entry // with the proper __torch_function__ implementation to call, so skip @@ -360,7 +343,7 @@ static void append_overloaded_arg(std::vector* overloaded_args, PyOb if (class_not_seen_yet) { int arg_index = overloaded_args->size(); for(const auto j : c10::irange(arg_index)) { - if (PyObject_IsSubclass(obj_type, (PyObject*)(get_type_of_overloaded_arg((*overloaded_args)[j].ptr())))) { + if (PyObject_IsInstance(obj, (PyObject*)(Py_TYPE((*overloaded_args)[j].ptr())))) { // obj is a subclass of another object we've seen already so its // __torch_function__ should be called first, therefore we // insert it into overloaded_args before the superclass @@ -375,14 +358,6 @@ static void append_overloaded_arg(std::vector* overloaded_args, PyOb } } -void append_overloaded_tensor(std::vector* overloaded_args, PyObject* obj) { - append_overloaded_arg(overloaded_args, obj, /*obj_is_type*/false); -} - -void append_overloaded_type(std::vector* overloaded_args, PyObject* obj) { - append_overloaded_arg(overloaded_args, obj, /*obj_is_type*/true); -} - bool is_tensor_and_append_overloaded(PyObject* obj, std::vector* overloaded_args) { if (THPVariable_CheckExact(obj)) { // torch.Tensor instances (not subclasses, except for Parameter) @@ -391,7 +366,7 @@ bool is_tensor_and_append_overloaded(PyObject* obj, std::vector* ove if (check_has_torch_function(obj)) { // tensor subclasses and unrelated objects with __torch_function__ - append_overloaded_tensor(overloaded_args, obj); + append_overloaded_arg(overloaded_args, obj); return true; } else if (THPVariable_Check(obj)) { // tensor subclasses without __torch_function__ @@ -930,7 +905,7 @@ bool FunctionSignature::parse(PyObject* self, PyObject* args, PyObject* kwargs, int i = 0; if (self != nullptr && check_has_torch_function(self)) { - append_overloaded_tensor(&this->overloaded_args, self); + append_overloaded_arg(&this->overloaded_args, self); } for (auto& param : params) { PyObject* obj = nullptr; diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h index 6a05807..d132185 100644 --- a/torch/csrc/utils/python_arg_parser.h +++ b/torch/csrc/utils/python_arg_parser.h @@ -818,15 +818,6 @@ bool is_tensor_list_and_append_overloaded(PyObject* obj, std::vector * 'overloaded_args': the vector to append the overloaded args * 'obj': the input tensor that is overloaded */ -void append_overloaded_tensor(std::vector* overloaded_args, PyObject* obj); - -/* Given an argument that is definitely a type and is definitely overloaded, - * append it to the overloaded arguments list. Use this only with __torch_dispatch__, - * where we operate on classes that have a __torch_dispatch__ classmethod. - * - * 'overloaded_args': the vector to append the overloaded type - * 'obj': the input class that has a __torch_dispatch__ classmethod. - */ -void append_overloaded_type(std::vector* overloaded_args, PyObject* obj); +void append_overloaded_arg(std::vector* overloaded_args, PyObject* obj); } // namespace torch diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp index 25e9a59..17d7acc 100644 --- a/torch/csrc/utils/tensor_new.cpp +++ b/torch/csrc/utils/tensor_new.cpp @@ -267,7 +267,6 @@ Tensor internal_new_from_data( { at::AutoDispatchBelowADInplaceOrView guard; // TODO: remove at::tracer::impl::NoTracerDispatchMode tracer_guard; - c10::impl::ExcludeDispatchKeyGuard pythonmode_guard(c10::DispatchKey::Python); // functorch uses FuncTorchDynamicLayerBackMode as a mode key to wrap all // tensors returned from operators in special TensorWrapper tensor extension // The problem with this is that TensorWrapper does not have storage so diff --git a/torch/utils/_python_dispatch.py b/torch/utils/_python_dispatch.py deleted file mode 100644 index a7cfae1..0000000 --- a/torch/utils/_python_dispatch.py +++ /dev/null @@ -1,34 +0,0 @@ -import torch -import contextlib -from typing import Iterator - -# Context manager that causes all pytorch operators to dispatch to the passed-in -# type's __torch_dispatch__ function. -# operation that accepts no tensors but returns a tensor. -# -# enable_python_mode is affected by torch._C._DisableTorchDispatch. -# -# NB: Calling an operator inside __torch_dispatch__ does go through -# __torch_dispatch__ again. Please use _DisableTorchDispatch inside -# __torch_dispatch__ to prevent infinite recursion. -# -# TODO: Limitations and things about enable_python_mode we should fix before exposing it: -# - it currently cannot be nested. This should be simple to implement; we need a -# stack of TorchDispatchTypeObjects and the next bullet point. -# - We need a better user-facing api for torch._C._DisableTorchDispatch that -# is able to selectively disable __torch_dispatch__ of a particular class. -# - It doesn't work with the tensor constructors (torch.tensor, torch.Tensor) -# - Better name (see https://github.com/pytorch/pytorch/pull/63496#discussion_r694091694) -@contextlib.contextmanager -def enable_python_mode(cls) -> Iterator[None]: - if not hasattr(cls, '__torch_dispatch__'): - raise ValueError('The class passed to enable_python_mode ' - 'must have a __torch_dispatch__ classmethod') - if not isinstance(cls, type) or not issubclass(cls, (torch.Tensor,)): - raise ValueError('The argument passed to enable_python_mode ' - 'must be the type of a Tensor subclass') - torch._C._enter_python_mode(cls) - try: - yield - finally: - torch._C._exit_python_mode() -- 2.7.4