+++ /dev/null
-#include <ATen/PythonModeTLS.h>
-
-namespace at { namespace impl {
-
-thread_local std::shared_ptr<TorchDispatchTypeObject> pythonModeState;
-
-void PythonModeTLS::set_state(const std::shared_ptr<TorchDispatchTypeObject>& state) {
- pythonModeState = state;
- if (state) {
- c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true);
- } else {
- PythonModeTLS::reset_state();
- }
-}
-
-const std::shared_ptr<TorchDispatchTypeObject>& PythonModeTLS::get_state() {
- return pythonModeState;
-}
-
-void PythonModeTLS::reset_state() {
- pythonModeState.reset((TorchDispatchTypeObject*)nullptr);
- c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
-}
-
-} // namespace impl
-} // namespace at
+++ /dev/null
-#pragma once
-
-#include <c10/macros/Macros.h>
-#include <torch/library.h>
-#include <ATen/core/dispatch/Dispatcher.h>
-
-namespace at {
-namespace impl {
-
-struct TORCH_API PythonModeTLS {
- static void set_state(const std::shared_ptr<TorchDispatchTypeObject>& state);
- static const std::shared_ptr<TorchDispatchTypeObject>& get_state();
- static void reset_state();
-};
-
-} // namespace impl
-} // namespace at
saved_tensors_default_hooks_ = SavedTensorDefaultHooks::get_hooks();
bumped_record_all_functions_ = at::checkRecordAllFunctions();
- python_mode_state_ = at::impl::PythonModeTLS::get_state();
}
void ThreadLocalState::set_grad_mode(bool enabled) {
// restore the dispatch key set TLS at the same time.
c10::AutogradState::set_tls_state(state.autograd_tls_);
- at::impl::PythonModeTLS::set_state(state.python_mode_state_);
-
at::set_record_function_tls_(state.rf_tls_);
SavedTensorDefaultHooks::set_hooks(
#include <c10/util/ThreadLocalDebugInfo.h>
#include <ATen/record_function.h>
-#include <ATen/PythonModeTLS.h>
namespace at {
// TLS for AutogradModes
AutogradState autograd_tls_;
- std::shared_ptr<TorchDispatchTypeObject> python_mode_state_;
-
// TLS for saved tensors default hooks
std::pair<PyObject*, PyObject*> saved_tensors_default_hooks_;
#include <torch/library.h>
#include <ATen/core/dispatch/Dispatcher.h>
-#include <ATen/PythonModeTLS.h>
namespace {
void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
- // If Python Mode is active, use its PyInterpreter for dispatch
- const auto& maybe_python_mode_state = at::impl::PythonModeTLS::get_state();
- if (maybe_python_mode_state) {
- maybe_python_mode_state->pyinterpreter()->dispatch(op, stack, maybe_python_mode_state);
- return;
- }
-
- // Otherwise, find a PyInterpreter on a Tensor
const auto& schema = op.schema();
const auto num_arguments = schema.arguments().size();
// It is safe to dispatch on the very first Tensor with a pyobj_interpreter
if (ivalue.isTensor()) {
auto* interpreter = ivalue.unsafeToTensorImpl()->pyobj_interpreter();
if (interpreter) {
- interpreter->dispatch(op, stack, nullptr);
+ interpreter->dispatch(op, stack);
return;
}
} else if (ivalue.isTensorList()) {
for (const auto& nv : ivalue.toListRef()) {
auto* interpreter = nv.unsafeToTensorImpl()->pyobj_interpreter();
if (interpreter) {
- interpreter->dispatch(op, stack, nullptr);
+ interpreter->dispatch(op, stack);
return;
}
}
static void noop_dispatch_fn(
const PyInterpreter*,
const c10::OperatorHandle& op,
- torch::jit::Stack* stack,
- const std::shared_ptr<TorchDispatchTypeObject>& type) {
+ torch::jit::Stack* stack) {
TORCH_INTERNAL_ASSERT(
0,
"attempted to dispatch (__torch_dispatch__) an operator on Tensor with nontrivial PyObject after corresponding interpreter died");
}
}
-TorchDispatchTypeObject::TorchDispatchTypeObject(
- PyObject* type_object,
- c10::impl::PyInterpreter* pyinterpreter)
- : data_(type_object), pyinterpreter_(pyinterpreter) {}
-
-TorchDispatchTypeObject::~TorchDispatchTypeObject() {
- pyinterpreter_->decref(data_);
-}
-
-c10::impl::PyInterpreter* TorchDispatchTypeObject::pyinterpreter() const {
- return pyinterpreter_;
-}
-
-PyObject* TorchDispatchTypeObject::ptr() const {
- return data_;
-}
-
namespace impl {
namespace {
virtual ~AutogradMetaInterface();
};
-// forward declared
-struct TorchDispatchTypeObject;
-
namespace impl {
// Unfortunately, the definition of AutogradMeta lives in a separate
using dispatch_sig = void(
const PyInterpreter*,
const c10::OperatorHandle&,
- torch::jit::Stack* stack,
- const std::shared_ptr<TorchDispatchTypeObject>& type);
+ torch::jit::Stack* stack);
PyInterpreter(
name_sig* name_fn,
// Invoke the Python boxed fallback dispatch to go back into Python
__ubsan_ignore_function__ void dispatch(
const c10::OperatorHandle& op,
- torch::jit::Stack* stack,
- const std::shared_ptr<TorchDispatchTypeObject>& type) const {
- return (*dispatch_fn_)(this, op, stack, type);
+ torch::jit::Stack* stack) const {
+ return (*dispatch_fn_)(this, op, stack);
}
// Disarm this PyInterpreter, making all of its methods noops.
};
};
-// NOTE [What is TorchDispatchTypeObject?]
-// A TorchDispatchTypeObject represents the type of a Tensor subclass that has
-// a __torch_dispatch__ classmethod. Concretely, it holds the class as a
-// PyObject* and a PyInterpreter* that says which python interpreter the class
-// came from.
-//
-// See NOTE [dispatch_fn's type argument] for more details
-struct C10_API TorchDispatchTypeObject {
- // Steals a reference to type_object
- TorchDispatchTypeObject(
- PyObject* type_object,
- c10::impl::PyInterpreter* pyinterpreter);
-
- // Releases the stolen reference to type_object
- ~TorchDispatchTypeObject();
-
- c10::impl::PyInterpreter* pyinterpreter() const;
- PyObject* ptr() const;
-
- private:
- PyObject* data_;
- c10::impl::PyInterpreter* pyinterpreter_;
-};
-
// NOTE [ Version Counter Sharing ]
//
// Every Tensor has a version counter. Version counters are incremented whenever
"test_optim",
"test_functional_optim",
"test_pytree",
- "test_python_dispatch",
"test_mobile_optimizer",
"test_set_default_mobile_cpu_allocator",
"test_xnnpack_integration",
import torch
from torch.testing._internal.common_utils import TestCase, run_tests
from torch.utils._pytree import tree_map
-from torch.utils._python_dispatch import enable_python_mode
from typing import Iterator, List
import logging
def wrap(e):
return LoggingTensor(e) if isinstance(e, torch.Tensor) else e
- # no_dispatch is only needed if you use enable_python_mode.
- # It prevents infinite recursion.
- with no_dispatch():
- rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
+ rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
logging.getLogger("LoggingTensor").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs)
return rs
$5 = torch._ops.aten.mul($4, $0)
$6 = torch._ops.aten.add_($1, $5)''')
- def test_enable_python_mode_error(self) -> None:
- with self.assertRaisesRegex(ValueError, "__torch_dispatch__"):
- with enable_python_mode(torch.Tensor):
- pass
- z = LoggingTensor(torch.empty([]))
- with self.assertRaisesRegex(ValueError, "must be the type"):
- with enable_python_mode(z):
- pass
-
- def test_enable_python_mode_basic(self) -> None:
- with enable_python_mode(LoggingTensor):
- z = torch.empty([])
- self.assertTrue(isinstance(z, LoggingTensor))
-
- def test_enable_python_mode_unrelated_tensors(self) -> None:
- x = torch.randn([])
- y = torch.randn([])
- with enable_python_mode(LoggingTensor):
- z = x + y
- self.assertTrue(isinstance(z, LoggingTensor))
-
- def test_enable_python_mode_subclass_priority(self) -> None:
- class ErrorA(RuntimeError):
- pass
-
- class ErrorB(RuntimeError):
- pass
-
- class A(torch.Tensor):
- @staticmethod
- def __new__(cls, elem):
- return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
-
- @classmethod
- def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
- raise ErrorA
-
- class B(A):
- @staticmethod
- def __new__(cls, elem):
- return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
-
- @classmethod
- def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
- raise ErrorB
-
- a = A(torch.empty(1))
- b = B(torch.empty(1))
- with self.assertRaises(ErrorA):
- a + a
-
- # B has precedence over A due to the subclass relationship
- with self.assertRaises(ErrorB):
- with enable_python_mode(A):
- b + b
- with self.assertRaises(ErrorB):
- with enable_python_mode(B):
- a + a
- with self.assertRaises(ErrorB):
- with enable_python_mode(B):
- a + b
-
- def test_enable_python_mode_respects_no_dispatch(self) -> None:
- with enable_python_mode(LoggingTensor):
- z = torch.ones([2, 3])
- self.assertTrue(isinstance(z, LoggingTensor))
- with no_dispatch():
- expected = torch.ones([2, 3])
- self.assertEqual(z.elem, expected)
-
- def test_nested_enable_python_mode(self) -> None:
- with self.assertRaisesRegex(RuntimeError, "has already been set"):
- with enable_python_mode(LoggingTensor):
- with enable_python_mode(LoggingTensor):
- pass
if __name__ == '__main__':
run_tests()
"torch/csrc/autograd/init.cpp",
"torch/csrc/autograd/python_anomaly_mode.cpp",
"torch/csrc/autograd/python_saved_variable_hooks.cpp",
- "torch/csrc/autograd/python_mode.cpp",
"torch/csrc/autograd/python_cpp_function.cpp",
"torch/csrc/autograd/python_engine.cpp",
"torch/csrc/autograd/python_function.cpp",
"aten/src/ATen/ParallelNativeTBB.cpp",
"aten/src/ATen/ParallelOpenMP.cpp",
"aten/src/ATen/ParallelThreadPoolNative.cpp",
- "aten/src/ATen/PythonModeTLS.cpp",
"aten/src/ATen/ScalarOps.cpp",
"aten/src/ATen/SequenceNumber.cpp",
"aten/src/ATen/SparseTensorImpl.cpp",
def __is_forward_AD_enabled() -> _bool: ...
def _register_default_hooks(pack_hook: Callable, unpack_hook: Callable) -> None: ...
def _reset_default_hooks() -> None: ...
-def _enter_python_mode(cls: Type) -> None: ...
-def _exit_python_mode() -> None: ...
class _InferenceMode(object):
def __init__(self, mode: _bool) -> None: ...
#include <torch/csrc/autograd/python_saved_variable_hooks.h>
#include <torch/csrc/autograd/utils/wrap_outputs.h>
#include <torch/csrc/autograd/utils/python_arg_parsing.h>
-#include <torch/csrc/autograd/python_mode.h>
#include <torch/csrc/utils/pycfunction_helpers.h>
#include <c10/core/ScalarType.h>
END_HANDLE_TH_ERRORS
}
-static PyObject * enter_python_mode(PyObject* _unused, PyObject* arg) {
- HANDLE_TH_ERRORS
- PythonMode::enter(arg);
- Py_RETURN_NONE;
- END_HANDLE_TH_ERRORS
-}
-
-static PyObject * exit_python_mode(PyObject* _unused, PyObject* arg) {
- HANDLE_TH_ERRORS
- PythonMode::exit();
- Py_RETURN_NONE;
- END_HANDLE_TH_ERRORS
-}
-
// autograd methods on torch._C
static PyMethodDef methods[] = { // NOLINT
{"_set_grad_enabled", set_grad_enabled, METH_O, nullptr},
{"is_anomaly_enabled", is_anomaly_mode_enabled, METH_NOARGS, nullptr},
{"_enter_dual_level", python_enter_dual_level, METH_NOARGS, nullptr},
{"_exit_dual_level", castPyCFunctionWithKeywords(python_exit_dual_level), METH_VARARGS | METH_KEYWORDS, nullptr},
- {"_enter_python_mode", enter_python_mode, METH_O, nullptr},
- {"_exit_python_mode", exit_python_mode, METH_NOARGS, nullptr},
{nullptr, nullptr, 0, nullptr}
};
+++ /dev/null
-#include <torch/csrc/autograd/python_mode.h>
-#include <torch/csrc/python_headers.h>
-#include <torch/csrc/autograd/python_variable.h>
-#include <ATen/PythonModeTLS.h>
-#include <c10/core/TensorImpl.h>
-
-namespace torch { namespace autograd {
-
-void PythonMode::enter(PyObject* type) {
- if (at::impl::PythonModeTLS::get_state()) {
- TORCH_CHECK(
- false,
- "python mode has already been set. We do not yet support nested python ",
- "mode. Please file us an issue and reset it before setting it again.")
- }
- // TorchDispatchTypeObject steals a reference, See NOTE [What is TorchDispatchTypeObject?]
- Py_INCREF(type);
- auto state = std::make_shared<c10::TorchDispatchTypeObject>(type, getPyInterpreter());
- at::impl::PythonModeTLS::set_state(state);
-}
-
-void PythonMode::exit() {
- TORCH_INTERNAL_ASSERT(at::impl::PythonModeTLS::get_state(), "exiting Python Mode but it wasn't set!");
- at::impl::PythonModeTLS::reset_state();
-}
-
-}}
+++ /dev/null
-#pragma once
-
-#include <torch/csrc/python_headers.h>
-#include <c10/core/TensorImpl.h>
-
-namespace torch { namespace autograd {
-
-struct TORCH_API PythonMode {
- // Enter python mode, causing all operators to dispatch to the type's __torch_dispatch__.
- // `type` is the type of a Tensor subclass that has __torch_dispatch__.
- static void enter(PyObject* type);
-
- // Exit the current python mode.
- static void exit();
-};
-
-}}
#include <torch/library.h>
#include <torch/csrc/jit/python/pybind_utils.h>
-#include <torch/csrc/autograd/python_mode.h>
#include <ATen/ATen.h>
return;
pybind11::gil_scoped_acquire gil;
- // Two possibilities:
- // 1. We are decref-ing a tensor. Then we must be careful about
- // PyObject resurrection (this only applies to Tensors, see THPVariable_clear).
- // 2. We are decref-ing some other Python object. We don't do
- // PyObject resurrection on non-Tensors, so we just carry on as usual
- if (THPVariable_Check(pyobj) && Py_REFCNT(pyobj) > 1) {
+ if (Py_REFCNT(pyobj) > 1) {
// It's still alive! This can happen if a weak ref resurrected
// the PyObject without flipping ownership. At this point it is
// too late to rescue the object, so just stub out the PyObject
};
c10::intrusive_ptr<TensorImpl> concrete_detach_fn(const c10::impl::PyInterpreter*, const c10::TensorImpl* self);
-void concrete_dispatch_fn(
- const c10::impl::PyInterpreter*,
- const c10::OperatorHandle& op,
- torch::jit::Stack* stack,
- const std::shared_ptr<TorchDispatchTypeObject>& type);
+void concrete_dispatch_fn(const c10::impl::PyInterpreter*, const c10::OperatorHandle& op, torch::jit::Stack* stack);
class PyInterpreterHolder {
public:
return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Python);
}
-// NOTE [dispatch_fn's type argument]
-// `type` is nullable and represents the PythonMode going on.
-// Right now we only support a single PythonMode, but in the future we could
-// change this to a stack of PythonModes.
-//
-// If `type` isn't null, then we consider the type for dispatch by prepending
-// it to the overloaded_args list. `handle_torch_funciton_no_python_arg_parser`
-// is responsible for doing overload resolution.
-void concrete_dispatch_fn(
- const c10::impl::PyInterpreter*,
- const c10::OperatorHandle& op,
- torch::jit::Stack* stack,
- const std::shared_ptr<TorchDispatchTypeObject>& type) {
+void concrete_dispatch_fn(const c10::impl::PyInterpreter*, const c10::OperatorHandle& op, torch::jit::Stack* stack) {
const auto& schema = op.schema();
const auto num_returns = schema.returns().size();
auto args = py::reinterpret_steal<py::object>(PyTuple_New(positional_default_start));
py::dict kwargs;
- if (type) {
- append_overloaded_type(&overloaded_args, type->ptr());
- }
-
// Find overloaded tensors
for (int64_t idx = 0; idx < arguments.size(); idx++) {
const auto& ivalue = arguments[idx];
if (ivalue.isTensor()) {
const auto& tensor = ivalue.toTensor();
if (isPythonTensor(tensor)) {
- append_overloaded_tensor(&overloaded_args, py::cast(tensor).ptr());
+ append_overloaded_arg(&overloaded_args, py::cast(tensor).ptr());
}
} else if (ivalue.isList()) {
const auto& list = ivalue.toListRef();
if (nv.isTensor()) {
const auto& tensor = nv.toTensor();
if (isPythonTensor(tensor)) {
- append_overloaded_tensor(&overloaded_args, py::cast(tensor).ptr());
+ append_overloaded_arg(&overloaded_args, py::cast(tensor).ptr());
}
}
}
Tensor self_t = Tensor(c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::unsafe_reclaim_from_nonowning(const_cast<c10::TensorImpl*>(self)));
auto self_p = py::reinterpret_steal<py::object>(THPVariable_Wrap(self_t));
TORCH_INTERNAL_ASSERT(isPythonTensor(self_t));
- append_overloaded_tensor(&overloaded_args, self_p.ptr());
+ append_overloaded_arg(&overloaded_args, self_p.ptr());
auto args = py::reinterpret_steal<py::object>(PyTuple_New(1));
PyTuple_SET_ITEM(args.ptr(), 0, self_p.release().ptr());
return ret.release().ptr();
}
-// Note: [Overloaded args]
-// An overloaded arg may be one of the following:
-// - an instance of an object that has a __torch_function__ method
-// - an instance of an object that has a __torch_dispatch__ classmethod
-// - a class type that has a __torch_dispatch__ classmethod
-//
-// This function returns the type of the arg (if the arg is an instance),
-// otherwise, it returns the arg.
-static PyObject* get_type_of_overloaded_arg(PyObject* obj_or_type) {
- if (PyType_Check(obj_or_type)) {
- return obj_or_type;
- }
- return (PyObject*)Py_TYPE(obj_or_type);
-}
-
-// See Note: [Overloaded args] for what they hold
auto handle_torch_function_no_python_arg_parser(const std::vector<py::handle> &overloaded_args, PyObject* args, PyObject* kwargs, const char* func_name, PyObject* torch_api_function, const char* module_name, const char* torch_function_name) -> PyObject* {
// overloaded_args already all have unique types
std::vector<py::object> overloaded_types;
overloaded_types.reserve(overloaded_args.size());
for (auto &arg : overloaded_args) {
- overloaded_types.push_back(py::reinterpret_borrow<py::object>(get_type_of_overloaded_arg(arg.ptr())));
+ overloaded_types.push_back(py::reinterpret_borrow<py::object>((PyObject *) Py_TYPE(arg.ptr())));
}
py::tuple py_types = py::cast(overloaded_types);
py::object ret;
ss << "no implementation found for '" << module_name << "." << func_name
<< "' on types that implement " << torch_function_name << ": [";
for (auto &arg : overloaded_args) {
- ss << PyObject_Repr(get_type_of_overloaded_arg(arg.ptr()));
+ ss << arg.ptr()->ob_type->tp_name;
if (!arg.is(overloaded_args.back())) {
ss << ", ";
}
*
*/
-static void append_overloaded_arg(std::vector<py::handle>* overloaded_args, PyObject* obj, bool obj_is_type) {
+void append_overloaded_arg(std::vector<py::handle>* overloaded_args, PyObject* obj) {
bool class_not_seen_yet = true;
- PyObject* obj_type = obj_is_type ? obj : (PyObject*)Py_TYPE(obj);
for (auto &arg : *overloaded_args) {
- if (obj_type == get_type_of_overloaded_arg(arg.ptr())) {
+ if (Py_TYPE(obj) == Py_TYPE(arg.ptr())) {
// obj is the same type as another parameter we've seen in a prior
// iteration of the loop over parameters so we already have an entry
// with the proper __torch_function__ implementation to call, so skip
if (class_not_seen_yet) {
int arg_index = overloaded_args->size();
for(const auto j : c10::irange(arg_index)) {
- if (PyObject_IsSubclass(obj_type, (PyObject*)(get_type_of_overloaded_arg((*overloaded_args)[j].ptr())))) {
+ if (PyObject_IsInstance(obj, (PyObject*)(Py_TYPE((*overloaded_args)[j].ptr())))) {
// obj is a subclass of another object we've seen already so its
// __torch_function__ should be called first, therefore we
// insert it into overloaded_args before the superclass
}
}
-void append_overloaded_tensor(std::vector<py::handle>* overloaded_args, PyObject* obj) {
- append_overloaded_arg(overloaded_args, obj, /*obj_is_type*/false);
-}
-
-void append_overloaded_type(std::vector<py::handle>* overloaded_args, PyObject* obj) {
- append_overloaded_arg(overloaded_args, obj, /*obj_is_type*/true);
-}
-
bool is_tensor_and_append_overloaded(PyObject* obj, std::vector<py::handle>* overloaded_args) {
if (THPVariable_CheckExact(obj)) {
// torch.Tensor instances (not subclasses, except for Parameter)
if (check_has_torch_function(obj)) {
// tensor subclasses and unrelated objects with __torch_function__
- append_overloaded_tensor(overloaded_args, obj);
+ append_overloaded_arg(overloaded_args, obj);
return true;
} else if (THPVariable_Check(obj)) {
// tensor subclasses without __torch_function__
int i = 0;
if (self != nullptr && check_has_torch_function(self)) {
- append_overloaded_tensor(&this->overloaded_args, self);
+ append_overloaded_arg(&this->overloaded_args, self);
}
for (auto& param : params) {
PyObject* obj = nullptr;
* 'overloaded_args': the vector to append the overloaded args
* 'obj': the input tensor that is overloaded
*/
-void append_overloaded_tensor(std::vector<py::handle>* overloaded_args, PyObject* obj);
-
-/* Given an argument that is definitely a type and is definitely overloaded,
- * append it to the overloaded arguments list. Use this only with __torch_dispatch__,
- * where we operate on classes that have a __torch_dispatch__ classmethod.
- *
- * 'overloaded_args': the vector to append the overloaded type
- * 'obj': the input class that has a __torch_dispatch__ classmethod.
- */
-void append_overloaded_type(std::vector<py::handle>* overloaded_args, PyObject* obj);
+void append_overloaded_arg(std::vector<py::handle>* overloaded_args, PyObject* obj);
} // namespace torch
{
at::AutoDispatchBelowADInplaceOrView guard; // TODO: remove
at::tracer::impl::NoTracerDispatchMode tracer_guard;
- c10::impl::ExcludeDispatchKeyGuard pythonmode_guard(c10::DispatchKey::Python);
// functorch uses FuncTorchDynamicLayerBackMode as a mode key to wrap all
// tensors returned from operators in special TensorWrapper tensor extension
// The problem with this is that TensorWrapper does not have storage so
+++ /dev/null
-import torch
-import contextlib
-from typing import Iterator
-
-# Context manager that causes all pytorch operators to dispatch to the passed-in
-# type's __torch_dispatch__ function.
-# operation that accepts no tensors but returns a tensor.
-#
-# enable_python_mode is affected by torch._C._DisableTorchDispatch.
-#
-# NB: Calling an operator inside __torch_dispatch__ does go through
-# __torch_dispatch__ again. Please use _DisableTorchDispatch inside
-# __torch_dispatch__ to prevent infinite recursion.
-#
-# TODO: Limitations and things about enable_python_mode we should fix before exposing it:
-# - it currently cannot be nested. This should be simple to implement; we need a
-# stack of TorchDispatchTypeObjects and the next bullet point.
-# - We need a better user-facing api for torch._C._DisableTorchDispatch that
-# is able to selectively disable __torch_dispatch__ of a particular class.
-# - It doesn't work with the tensor constructors (torch.tensor, torch.Tensor)
-# - Better name (see https://github.com/pytorch/pytorch/pull/63496#discussion_r694091694)
-@contextlib.contextmanager
-def enable_python_mode(cls) -> Iterator[None]:
- if not hasattr(cls, '__torch_dispatch__'):
- raise ValueError('The class passed to enable_python_mode '
- 'must have a __torch_dispatch__ classmethod')
- if not isinstance(cls, type) or not issubclass(cls, (torch.Tensor,)):
- raise ValueError('The argument passed to enable_python_mode '
- 'must be the type of a Tensor subclass')
- torch._C._enter_python_mode(cls)
- try:
- yield
- finally:
- torch._C._exit_python_mode()