Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63496
This PR adds a (private) enable_python_mode context manager.
(see torch/utils/_python_dispatch.py).
enable_python_mode accepts the type of a __torch_dispatch__ object
as its argument. Whenever an operator gets called inside of the
context manager, it dispatches to the __torch_dispatch__ of
the passed-in type.
Example usage:
```
with enable_python_mode(LoggingTensor):
z = torch.empty([])
assert isinstance(z, LoggingTensor)
```
There are quite a few changes that were made to support this.
First, we added TorchDispatchTypeObject, a C++ struct that represents the
type of a `__torch_dispatch__` object (e.g. LoggingTensor).
It holds both the PyObject* representing the class and a PyInterpreter*
so we know which Python interpreter it came from.
Next, we updated the concrete_dispatch_fn in python_variable.cpp to accept
a `const std::shared_ptr<TorchDispatchTypeObject>&` argument. When this
is null, dispatching happens as usual. When it is non-null, we prepend
the TorchDispatchTypeObject's PyObject* to the overloaded args list so that
it is considered first for dispatch.
To get that to work, we changed how `handle_torch_dispatch_no_python_arg_parser`
works. The "overloaded args list" previously only consisted of Tensor PyObjects,
but now it can have types in addition to Tensors!
- We renamed `append_overloaded_arg` to `append_overloaded_arg`
- We added a new `append_overloaded_type` that appends a type to
overloaded_args
- We added special handling in `handle_torch_dispatch_no_python_arg_parser`
and `append_overloaded_arg` to handle types in addition to Tensors.
Then, there is PythonMode and PythonModeTLS.
- We reuse the DispatchKey::Python dispatch key as a mode key
- We use PythonMode::enter and PythonMode::exit to enable/disable
DispatchKey::Python and set the PythonModeTLS.
- PythonModeTLS stores a TorchDispatchTypeObject as metadata.
- PythonMode is in libtorch_python, and PythonModeTLS is in ATen.
This split is due to the libtorch_python library boundary (because we need
to save TLS in ATen/ThreadLocalState)
- We modify the PythonFallbackKernel to look up
the relevant TorchDispatchTypeObject (if Python Mode is active) and
dispatch using it.
There are two more miscellaneous changes:
- internal_new_from_data (torch/csrc/utils/tensor_new.cpp) gets an
exclude guard. enable_python_mode currently does not handle
torch.tensor and the exclude guard is to prevent a bug.
Future:
- This PR does not allow for the nesting of Python modes. In the future we
should be able to enable this with a more sane no_dispatch API and by changing
the TLS to a stack. For now I did not need this for CompositeImplicitAutograd testing.
Test Plan: - new tests
Reviewed By: malfet, albanD
Differential Revision:
D30543236
Pulled By: zou3519
fbshipit-source-id:
ef5444d96a5a957d1657b7e37dce80f9a497d452
--- /dev/null
+#include <ATen/PythonModeTLS.h>
+
+namespace at { namespace impl {
+
+thread_local std::shared_ptr<TorchDispatchTypeObject> pythonModeState;
+
+void PythonModeTLS::set_state(const std::shared_ptr<TorchDispatchTypeObject>& state) {
+ pythonModeState = state;
+ if (state) {
+ c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true);
+ } else {
+ PythonModeTLS::reset_state();
+ }
+}
+
+const std::shared_ptr<TorchDispatchTypeObject>& PythonModeTLS::get_state() {
+ return pythonModeState;
+}
+
+void PythonModeTLS::reset_state() {
+ pythonModeState.reset((TorchDispatchTypeObject*)nullptr);
+ c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
+}
+
+} // namespace impl
+} // namespace at
--- /dev/null
+#pragma once
+
+#include <c10/macros/Macros.h>
+#include <torch/library.h>
+#include <ATen/core/dispatch/Dispatcher.h>
+
+namespace at {
+namespace impl {
+
+struct TORCH_API PythonModeTLS {
+ static void set_state(const std::shared_ptr<TorchDispatchTypeObject>& state);
+ static const std::shared_ptr<TorchDispatchTypeObject>& get_state();
+ static void reset_state();
+};
+
+} // namespace impl
+} // namespace at
saved_tensors_default_hooks_ = SavedTensorDefaultHooks::get_hooks();
bumped_record_all_functions_ = at::checkRecordAllFunctions();
+ python_mode_state_ = at::impl::PythonModeTLS::get_state();
}
void ThreadLocalState::set_grad_mode(bool enabled) {
// restore the dispatch key set TLS at the same time.
c10::AutogradState::set_tls_state(state.autograd_tls_);
+ at::impl::PythonModeTLS::set_state(state.python_mode_state_);
+
at::set_record_function_tls_(state.rf_tls_);
SavedTensorDefaultHooks::set_hooks(
#include <c10/util/ThreadLocalDebugInfo.h>
#include <ATen/record_function.h>
+#include <ATen/PythonModeTLS.h>
namespace at {
// TLS for AutogradModes
AutogradState autograd_tls_;
+ std::shared_ptr<TorchDispatchTypeObject> python_mode_state_;
+
// TLS for saved tensors default hooks
std::pair<PyObject*, PyObject*> saved_tensors_default_hooks_;
#include <torch/library.h>
#include <ATen/core/dispatch/Dispatcher.h>
+#include <ATen/PythonModeTLS.h>
namespace {
void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
+ // If Python Mode is active, use its PyInterpreter for dispatch
+ const auto& maybe_python_mode_state = at::impl::PythonModeTLS::get_state();
+ if (maybe_python_mode_state) {
+ maybe_python_mode_state->pyinterpreter()->dispatch(op, stack, maybe_python_mode_state);
+ return;
+ }
+
+ // Otherwise, find a PyInterpreter on a Tensor
const auto& schema = op.schema();
const auto num_arguments = schema.arguments().size();
// It is safe to dispatch on the very first Tensor with a pyobj_interpreter
if (ivalue.isTensor()) {
auto* interpreter = ivalue.unsafeToTensorImpl()->pyobj_interpreter();
if (interpreter) {
- interpreter->dispatch(op, stack);
+ interpreter->dispatch(op, stack, nullptr);
return;
}
} else if (ivalue.isTensorList()) {
for (const auto& nv : ivalue.toListRef()) {
auto* interpreter = nv.unsafeToTensorImpl()->pyobj_interpreter();
if (interpreter) {
- interpreter->dispatch(op, stack);
+ interpreter->dispatch(op, stack, nullptr);
return;
}
}
static void noop_dispatch_fn(
const PyInterpreter*,
const c10::OperatorHandle& op,
- torch::jit::Stack* stack) {
+ torch::jit::Stack* stack,
+ const std::shared_ptr<TorchDispatchTypeObject>& type) {
TORCH_INTERNAL_ASSERT(
0,
"attempted to dispatch (__torch_dispatch__) an operator on Tensor with nontrivial PyObject after corresponding interpreter died");
}
}
+TorchDispatchTypeObject::TorchDispatchTypeObject(
+ PyObject* type_object,
+ c10::impl::PyInterpreter* pyinterpreter)
+ : data_(type_object), pyinterpreter_(pyinterpreter) {}
+
+TorchDispatchTypeObject::~TorchDispatchTypeObject() {
+ pyinterpreter_->decref(data_);
+}
+
+c10::impl::PyInterpreter* TorchDispatchTypeObject::pyinterpreter() const {
+ return pyinterpreter_;
+}
+
+PyObject* TorchDispatchTypeObject::ptr() const {
+ return data_;
+}
+
namespace impl {
namespace {
virtual ~AutogradMetaInterface();
};
+// forward declared
+struct TorchDispatchTypeObject;
+
namespace impl {
// Unfortunately, the definition of AutogradMeta lives in a separate
using dispatch_sig = void(
const PyInterpreter*,
const c10::OperatorHandle&,
- torch::jit::Stack* stack);
+ torch::jit::Stack* stack,
+ const std::shared_ptr<TorchDispatchTypeObject>& type);
PyInterpreter(
name_sig* name_fn,
// Invoke the Python boxed fallback dispatch to go back into Python
__ubsan_ignore_function__ void dispatch(
const c10::OperatorHandle& op,
- torch::jit::Stack* stack) const {
- return (*dispatch_fn_)(this, op, stack);
+ torch::jit::Stack* stack,
+ const std::shared_ptr<TorchDispatchTypeObject>& type) const {
+ return (*dispatch_fn_)(this, op, stack, type);
}
// Disarm this PyInterpreter, making all of its methods noops.
};
};
+// NOTE [What is TorchDispatchTypeObject?]
+// A TorchDispatchTypeObject represents the type of a Tensor subclass that has
+// a __torch_dispatch__ classmethod. Concretely, it holds the class as a
+// PyObject* and a PyInterpreter* that says which python interpreter the class
+// came from.
+//
+// See NOTE [dispatch_fn's type argument] for more details
+struct C10_API TorchDispatchTypeObject {
+ // Steals a reference to type_object
+ TorchDispatchTypeObject(
+ PyObject* type_object,
+ c10::impl::PyInterpreter* pyinterpreter);
+
+ // Releases the stolen reference to type_object
+ ~TorchDispatchTypeObject();
+
+ c10::impl::PyInterpreter* pyinterpreter() const;
+ PyObject* ptr() const;
+
+ private:
+ PyObject* data_;
+ c10::impl::PyInterpreter* pyinterpreter_;
+};
+
// NOTE [ Version Counter Sharing ]
//
// Every Tensor has a version counter. Version counters are incremented whenever
"test_optim",
"test_functional_optim",
"test_pytree",
+ "test_python_dispatch",
"test_mobile_optimizer",
"test_set_default_mobile_cpu_allocator",
"test_xnnpack_integration",
import torch
from torch.testing._internal.common_utils import TestCase, run_tests
from torch.utils._pytree import tree_map
+from torch.utils._python_dispatch import enable_python_mode
from typing import Iterator, List
import logging
def wrap(e):
return LoggingTensor(e) if isinstance(e, torch.Tensor) else e
- rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
+ # no_dispatch is only needed if you use enable_python_mode.
+ # It prevents infinite recursion.
+ with no_dispatch():
+ rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
logging.getLogger("LoggingTensor").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs)
return rs
$5 = torch._ops.aten.mul($4, $0)
$6 = torch._ops.aten.add_($1, $5)''')
+ def test_enable_python_mode_error(self) -> None:
+ with self.assertRaisesRegex(ValueError, "__torch_dispatch__"):
+ with enable_python_mode(torch.Tensor):
+ pass
+ z = LoggingTensor(torch.empty([]))
+ with self.assertRaisesRegex(ValueError, "must be the type"):
+ with enable_python_mode(z):
+ pass
+
+ def test_enable_python_mode_basic(self) -> None:
+ with enable_python_mode(LoggingTensor):
+ z = torch.empty([])
+ self.assertTrue(isinstance(z, LoggingTensor))
+
+ def test_enable_python_mode_unrelated_tensors(self) -> None:
+ x = torch.randn([])
+ y = torch.randn([])
+ with enable_python_mode(LoggingTensor):
+ z = x + y
+ self.assertTrue(isinstance(z, LoggingTensor))
+
+ def test_enable_python_mode_subclass_priority(self) -> None:
+ class ErrorA(RuntimeError):
+ pass
+
+ class ErrorB(RuntimeError):
+ pass
+
+ class A(torch.Tensor):
+ @staticmethod
+ def __new__(cls, elem):
+ return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
+
+ @classmethod
+ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
+ raise ErrorA
+
+ class B(A):
+ @staticmethod
+ def __new__(cls, elem):
+ return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
+
+ @classmethod
+ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
+ raise ErrorB
+
+ a = A(torch.empty(1))
+ b = B(torch.empty(1))
+ with self.assertRaises(ErrorA):
+ a + a
+
+ # B has precedence over A due to the subclass relationship
+ with self.assertRaises(ErrorB):
+ with enable_python_mode(A):
+ b + b
+ with self.assertRaises(ErrorB):
+ with enable_python_mode(B):
+ a + a
+ with self.assertRaises(ErrorB):
+ with enable_python_mode(B):
+ a + b
+
+ def test_enable_python_mode_respects_no_dispatch(self) -> None:
+ with enable_python_mode(LoggingTensor):
+ z = torch.ones([2, 3])
+ self.assertTrue(isinstance(z, LoggingTensor))
+ with no_dispatch():
+ expected = torch.ones([2, 3])
+ self.assertEqual(z.elem, expected)
+
+ def test_nested_enable_python_mode(self) -> None:
+ with self.assertRaisesRegex(RuntimeError, "has already been set"):
+ with enable_python_mode(LoggingTensor):
+ with enable_python_mode(LoggingTensor):
+ pass
if __name__ == '__main__':
run_tests()
"torch/csrc/autograd/init.cpp",
"torch/csrc/autograd/python_anomaly_mode.cpp",
"torch/csrc/autograd/python_saved_variable_hooks.cpp",
+ "torch/csrc/autograd/python_mode.cpp",
"torch/csrc/autograd/python_cpp_function.cpp",
"torch/csrc/autograd/python_engine.cpp",
"torch/csrc/autograd/python_function.cpp",
"aten/src/ATen/ParallelNativeTBB.cpp",
"aten/src/ATen/ParallelOpenMP.cpp",
"aten/src/ATen/ParallelThreadPoolNative.cpp",
+ "aten/src/ATen/PythonModeTLS.cpp",
"aten/src/ATen/ScalarOps.cpp",
"aten/src/ATen/SequenceNumber.cpp",
"aten/src/ATen/SparseTensorImpl.cpp",
def __is_forward_AD_enabled() -> _bool: ...
def _register_default_hooks(pack_hook: Callable, unpack_hook: Callable) -> None: ...
def _reset_default_hooks() -> None: ...
+def _enter_python_mode(cls: Type) -> None: ...
+def _exit_python_mode() -> None: ...
class _InferenceMode(object):
def __init__(self, mode: _bool) -> None: ...
#include <torch/csrc/autograd/python_saved_variable_hooks.h>
#include <torch/csrc/autograd/utils/wrap_outputs.h>
#include <torch/csrc/autograd/utils/python_arg_parsing.h>
+#include <torch/csrc/autograd/python_mode.h>
#include <torch/csrc/utils/pycfunction_helpers.h>
#include <c10/core/ScalarType.h>
END_HANDLE_TH_ERRORS
}
+static PyObject * enter_python_mode(PyObject* _unused, PyObject* arg) {
+ HANDLE_TH_ERRORS
+ PythonMode::enter(arg);
+ Py_RETURN_NONE;
+ END_HANDLE_TH_ERRORS
+}
+
+static PyObject * exit_python_mode(PyObject* _unused, PyObject* arg) {
+ HANDLE_TH_ERRORS
+ PythonMode::exit();
+ Py_RETURN_NONE;
+ END_HANDLE_TH_ERRORS
+}
+
// autograd methods on torch._C
static PyMethodDef methods[] = { // NOLINT
{"_set_grad_enabled", set_grad_enabled, METH_O, nullptr},
{"is_anomaly_enabled", is_anomaly_mode_enabled, METH_NOARGS, nullptr},
{"_enter_dual_level", python_enter_dual_level, METH_NOARGS, nullptr},
{"_exit_dual_level", castPyCFunctionWithKeywords(python_exit_dual_level), METH_VARARGS | METH_KEYWORDS, nullptr},
+ {"_enter_python_mode", enter_python_mode, METH_O, nullptr},
+ {"_exit_python_mode", exit_python_mode, METH_NOARGS, nullptr},
{nullptr, nullptr, 0, nullptr}
};
--- /dev/null
+#include <torch/csrc/autograd/python_mode.h>
+#include <torch/csrc/python_headers.h>
+#include <torch/csrc/autograd/python_variable.h>
+#include <ATen/PythonModeTLS.h>
+#include <c10/core/TensorImpl.h>
+
+namespace torch { namespace autograd {
+
+void PythonMode::enter(PyObject* type) {
+ if (at::impl::PythonModeTLS::get_state()) {
+ TORCH_CHECK(
+ false,
+ "python mode has already been set. We do not yet support nested python ",
+ "mode. Please file us an issue and reset it before setting it again.")
+ }
+ // TorchDispatchTypeObject steals a reference, See NOTE [What is TorchDispatchTypeObject?]
+ Py_INCREF(type);
+ auto state = std::make_shared<c10::TorchDispatchTypeObject>(type, getPyInterpreter());
+ at::impl::PythonModeTLS::set_state(state);
+}
+
+void PythonMode::exit() {
+ TORCH_INTERNAL_ASSERT(at::impl::PythonModeTLS::get_state(), "exiting Python Mode but it wasn't set!");
+ at::impl::PythonModeTLS::reset_state();
+}
+
+}}
--- /dev/null
+#pragma once
+
+#include <torch/csrc/python_headers.h>
+#include <c10/core/TensorImpl.h>
+
+namespace torch { namespace autograd {
+
+struct TORCH_API PythonMode {
+ // Enter python mode, causing all operators to dispatch to the type's __torch_dispatch__.
+ // `type` is the type of a Tensor subclass that has __torch_dispatch__.
+ static void enter(PyObject* type);
+
+ // Exit the current python mode.
+ static void exit();
+};
+
+}}
#include <torch/library.h>
#include <torch/csrc/jit/python/pybind_utils.h>
+#include <torch/csrc/autograd/python_mode.h>
#include <ATen/ATen.h>
return;
pybind11::gil_scoped_acquire gil;
- if (Py_REFCNT(pyobj) > 1) {
+ // Two possibilities:
+ // 1. We are decref-ing a tensor. Then we must be careful about
+ // PyObject resurrection (this only applies to Tensors, see THPVariable_clear).
+ // 2. We are decref-ing some other Python object. We don't do
+ // PyObject resurrection on non-Tensors, so we just carry on as usual
+ if (THPVariable_Check(pyobj) && Py_REFCNT(pyobj) > 1) {
// It's still alive! This can happen if a weak ref resurrected
// the PyObject without flipping ownership. At this point it is
// too late to rescue the object, so just stub out the PyObject
};
c10::intrusive_ptr<TensorImpl> concrete_detach_fn(const c10::impl::PyInterpreter*, const c10::TensorImpl* self);
-void concrete_dispatch_fn(const c10::impl::PyInterpreter*, const c10::OperatorHandle& op, torch::jit::Stack* stack);
+void concrete_dispatch_fn(
+ const c10::impl::PyInterpreter*,
+ const c10::OperatorHandle& op,
+ torch::jit::Stack* stack,
+ const std::shared_ptr<TorchDispatchTypeObject>& type);
class PyInterpreterHolder {
public:
return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Python);
}
-void concrete_dispatch_fn(const c10::impl::PyInterpreter*, const c10::OperatorHandle& op, torch::jit::Stack* stack) {
+// NOTE [dispatch_fn's type argument]
+// `type` is nullable and represents the PythonMode going on.
+// Right now we only support a single PythonMode, but in the future we could
+// change this to a stack of PythonModes.
+//
+// If `type` isn't null, then we consider the type for dispatch by prepending
+// it to the overloaded_args list. `handle_torch_funciton_no_python_arg_parser`
+// is responsible for doing overload resolution.
+void concrete_dispatch_fn(
+ const c10::impl::PyInterpreter*,
+ const c10::OperatorHandle& op,
+ torch::jit::Stack* stack,
+ const std::shared_ptr<TorchDispatchTypeObject>& type) {
const auto& schema = op.schema();
const auto num_returns = schema.returns().size();
auto args = py::reinterpret_steal<py::object>(PyTuple_New(positional_default_start));
py::dict kwargs;
+ if (type) {
+ append_overloaded_type(&overloaded_args, type->ptr());
+ }
+
// Find overloaded tensors
for (int64_t idx = 0; idx < arguments.size(); idx++) {
const auto& ivalue = arguments[idx];
if (ivalue.isTensor()) {
const auto& tensor = ivalue.toTensor();
if (isPythonTensor(tensor)) {
- append_overloaded_arg(&overloaded_args, py::cast(tensor).ptr());
+ append_overloaded_tensor(&overloaded_args, py::cast(tensor).ptr());
}
} else if (ivalue.isList()) {
const auto& list = ivalue.toListRef();
if (nv.isTensor()) {
const auto& tensor = nv.toTensor();
if (isPythonTensor(tensor)) {
- append_overloaded_arg(&overloaded_args, py::cast(tensor).ptr());
+ append_overloaded_tensor(&overloaded_args, py::cast(tensor).ptr());
}
}
}
Tensor self_t = Tensor(c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>::unsafe_reclaim_from_nonowning(const_cast<c10::TensorImpl*>(self)));
auto self_p = py::reinterpret_steal<py::object>(THPVariable_Wrap(self_t));
TORCH_INTERNAL_ASSERT(isPythonTensor(self_t));
- append_overloaded_arg(&overloaded_args, self_p.ptr());
+ append_overloaded_tensor(&overloaded_args, self_p.ptr());
auto args = py::reinterpret_steal<py::object>(PyTuple_New(1));
PyTuple_SET_ITEM(args.ptr(), 0, self_p.release().ptr());
return ret.release().ptr();
}
+// Note: [Overloaded args]
+// An overloaded arg may be one of the following:
+// - an instance of an object that has a __torch_function__ method
+// - an instance of an object that has a __torch_dispatch__ classmethod
+// - a class type that has a __torch_dispatch__ classmethod
+//
+// This function returns the type of the arg (if the arg is an instance),
+// otherwise, it returns the arg.
+static PyObject* get_type_of_overloaded_arg(PyObject* obj_or_type) {
+ if (PyType_Check(obj_or_type)) {
+ return obj_or_type;
+ }
+ return (PyObject*)Py_TYPE(obj_or_type);
+}
+
+// See Note: [Overloaded args] for what they hold
auto handle_torch_function_no_python_arg_parser(const std::vector<py::handle> &overloaded_args, PyObject* args, PyObject* kwargs, const char* func_name, PyObject* torch_api_function, const char* module_name, const char* torch_function_name) -> PyObject* {
// overloaded_args already all have unique types
std::vector<py::object> overloaded_types;
overloaded_types.reserve(overloaded_args.size());
for (auto &arg : overloaded_args) {
- overloaded_types.push_back(py::reinterpret_borrow<py::object>((PyObject *) Py_TYPE(arg.ptr())));
+ overloaded_types.push_back(py::reinterpret_borrow<py::object>(get_type_of_overloaded_arg(arg.ptr())));
}
py::tuple py_types = py::cast(overloaded_types);
py::object ret;
ss << "no implementation found for '" << module_name << "." << func_name
<< "' on types that implement " << torch_function_name << ": [";
for (auto &arg : overloaded_args) {
- ss << arg.ptr()->ob_type->tp_name;
+ ss << PyObject_Repr(get_type_of_overloaded_arg(arg.ptr()));
if (!arg.is(overloaded_args.back())) {
ss << ", ";
}
*
*/
-void append_overloaded_arg(std::vector<py::handle>* overloaded_args, PyObject* obj) {
+static void append_overloaded_arg(std::vector<py::handle>* overloaded_args, PyObject* obj, bool obj_is_type) {
bool class_not_seen_yet = true;
+ PyObject* obj_type = obj_is_type ? obj : (PyObject*)Py_TYPE(obj);
for (auto &arg : *overloaded_args) {
- if (Py_TYPE(obj) == Py_TYPE(arg.ptr())) {
+ if (obj_type == get_type_of_overloaded_arg(arg.ptr())) {
// obj is the same type as another parameter we've seen in a prior
// iteration of the loop over parameters so we already have an entry
// with the proper __torch_function__ implementation to call, so skip
if (class_not_seen_yet) {
int arg_index = overloaded_args->size();
for(const auto j : c10::irange(arg_index)) {
- if (PyObject_IsInstance(obj, (PyObject*)(Py_TYPE((*overloaded_args)[j].ptr())))) {
+ if (PyObject_IsSubclass(obj_type, (PyObject*)(get_type_of_overloaded_arg((*overloaded_args)[j].ptr())))) {
// obj is a subclass of another object we've seen already so its
// __torch_function__ should be called first, therefore we
// insert it into overloaded_args before the superclass
}
}
+void append_overloaded_tensor(std::vector<py::handle>* overloaded_args, PyObject* obj) {
+ append_overloaded_arg(overloaded_args, obj, /*obj_is_type*/false);
+}
+
+void append_overloaded_type(std::vector<py::handle>* overloaded_args, PyObject* obj) {
+ append_overloaded_arg(overloaded_args, obj, /*obj_is_type*/true);
+}
+
bool is_tensor_and_append_overloaded(PyObject* obj, std::vector<py::handle>* overloaded_args) {
if (THPVariable_CheckExact(obj)) {
// torch.Tensor instances (not subclasses, except for Parameter)
if (check_has_torch_function(obj)) {
// tensor subclasses and unrelated objects with __torch_function__
- append_overloaded_arg(overloaded_args, obj);
+ append_overloaded_tensor(overloaded_args, obj);
return true;
} else if (THPVariable_Check(obj)) {
// tensor subclasses without __torch_function__
int i = 0;
if (self != nullptr && check_has_torch_function(self)) {
- append_overloaded_arg(&this->overloaded_args, self);
+ append_overloaded_tensor(&this->overloaded_args, self);
}
for (auto& param : params) {
PyObject* obj = nullptr;
* 'overloaded_args': the vector to append the overloaded args
* 'obj': the input tensor that is overloaded
*/
-void append_overloaded_arg(std::vector<py::handle>* overloaded_args, PyObject* obj);
+void append_overloaded_tensor(std::vector<py::handle>* overloaded_args, PyObject* obj);
+
+/* Given an argument that is definitely a type and is definitely overloaded,
+ * append it to the overloaded arguments list. Use this only with __torch_dispatch__,
+ * where we operate on classes that have a __torch_dispatch__ classmethod.
+ *
+ * 'overloaded_args': the vector to append the overloaded type
+ * 'obj': the input class that has a __torch_dispatch__ classmethod.
+ */
+void append_overloaded_type(std::vector<py::handle>* overloaded_args, PyObject* obj);
} // namespace torch
{
at::AutoDispatchBelowADInplaceOrView guard; // TODO: remove
at::tracer::impl::NoTracerDispatchMode tracer_guard;
+ c10::impl::ExcludeDispatchKeyGuard pythonmode_guard(c10::DispatchKey::Python);
// functorch uses FuncTorchDynamicLayerBackMode as a mode key to wrap all
// tensors returned from operators in special TensorWrapper tensor extension
// The problem with this is that TensorWrapper does not have storage so
--- /dev/null
+import torch
+import contextlib
+from typing import Iterator
+
+# Context manager that causes all pytorch operators to dispatch to the passed-in
+# type's __torch_dispatch__ function.
+# operation that accepts no tensors but returns a tensor.
+#
+# enable_python_mode is affected by torch._C._DisableTorchDispatch.
+#
+# NB: Calling an operator inside __torch_dispatch__ does go through
+# __torch_dispatch__ again. Please use _DisableTorchDispatch inside
+# __torch_dispatch__ to prevent infinite recursion.
+#
+# TODO: Limitations and things about enable_python_mode we should fix before exposing it:
+# - it currently cannot be nested. This should be simple to implement; we need a
+# stack of TorchDispatchTypeObjects and the next bullet point.
+# - We need a better user-facing api for torch._C._DisableTorchDispatch that
+# is able to selectively disable __torch_dispatch__ of a particular class.
+# - It doesn't work with the tensor constructors (torch.tensor, torch.Tensor)
+# - Better name (see https://github.com/pytorch/pytorch/pull/63496#discussion_r694091694)
+@contextlib.contextmanager
+def enable_python_mode(cls) -> Iterator[None]:
+ if not hasattr(cls, '__torch_dispatch__'):
+ raise ValueError('The class passed to enable_python_mode '
+ 'must have a __torch_dispatch__ classmethod')
+ if not isinstance(cls, type) or not issubclass(cls, (torch.Tensor,)):
+ raise ValueError('The argument passed to enable_python_mode '
+ 'must be the type of a Tensor subclass')
+ torch._C._enter_python_mode(cls)
+ try:
+ yield
+ finally:
+ torch._C._exit_python_mode()