#include <ATen/core/LegacyTypeDispatch.h>
#include <ATen/detail/CUDAHooksInterface.h>
#include <ATen/detail/HIPHooksInterface.h>
+#include <ATen/detail/ORTHooksInterface.h>
#include <c10/util/Exception.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/core/QEngine.h>
static bool hasMLC() {
return c10::impl::hasDeviceGuardImpl(at::DeviceType::MLC);
}
+ static bool hasORT() {
+ return c10::impl::hasDeviceGuardImpl(at::DeviceType::ORT);
+ }
// defined in header so that getNonVariableType has ability to inline
// call_once check. getNonVariableType is called fairly frequently
THCState* lazyInitCUDA() {
return globalContext().hasMLC();
}
+static inline bool hasORT() {
+ return globalContext().hasORT();
+}
+
// Despite its name, this function returns the number of *CUDA* GPUs.
static inline size_t getNumGPUs() {
// WARNING: DO NOT ADD LOGIC TO HANDLE OTHER DEVICE TYPES TO THIS
ss << detail::getCUDAHooks().showConfig();
}
+ if (hasORT()) {
+ ss << detail::getORTHooks().showConfig();
+ }
+
ss << " - Build settings: ";
for (const auto& pair : caffe2::GetBuildOptions()) {
if (!pair.second.empty()) {
_(aten, is_contiguous) \
_(aten, is_cuda) \
_(aten, is_mlc) \
+_(aten, is_ort) \
_(aten, is_distributed) \
_(aten, is_floating_point) \
_(aten, is_inference) \
* You’re writing a new operator that isn’t supposed to be part of the public PyTorch API.
* You’re writing a new operator but don’t want to change the core pytorch code base, say you’re developing a shared library with operators.
* You’re writing a C++ extension for PyTorch or you’re using inline c++ in your .py model files.
-* You’re writing a backend library like XLA or MSNPU that adds new kernels to all operators defined in `native_functions.yaml`.
+* You’re writing a backend library like XLA or ORT that adds new kernels to all operators defined in `native_functions.yaml`.
For these use cases, the custom operator API is the better solution.
### What is the price for using the custom operator API instead of `native_functions.yaml`?
-If you’re just using the custom operator API to add new kernels for existing operators (e.g. the XLA/MSNPU example above), then you’re fine and don’t pay any price. If, however, you define a new operator purely using the custom op API, i.e. your operator never shows up in `native_functions.yaml`, then you need to be aware of a few caveats.
+If you’re just using the custom operator API to add new kernels for existing operators (e.g. the XLA/ORT example above), then you’re fine and don’t pay any price. If, however, you define a new operator purely using the custom op API, i.e. your operator never shows up in `native_functions.yaml`, then you need to be aware of a few caveats.
* It will not get a C++ API generated. There will not be `Tensor::your_op()` methods or `at::your_op()` functions to call your operator.
* The API for calling the operator from Python looks a little bit different. It needs to be called through `torch.ops.your_op()` instead of `torch._C`.
--- /dev/null
+#include <ATen/detail/ORTHooksInterface.h>
+
+#include <c10/util/Exception.h>
+
+#include <cstddef>
+#include <memory>
+#include <mutex>
+
+namespace at {
+namespace detail {
+
+// See getCUDAHooks for some more commentary
+const ORTHooksInterface& getORTHooks() {
+ static std::unique_ptr<ORTHooksInterface> ort_hooks;
+ static std::once_flag once;
+ std::call_once(once, [] {
+ ort_hooks = ORTHooksRegistry()->Create("ORTHooks", {});
+ if (!ort_hooks) {
+ ort_hooks =
+ // NOLINTNEXTLINE(modernize-make-unique)
+ std::unique_ptr<ORTHooksInterface>(new ORTHooksInterface());
+ }
+ });
+ return *ort_hooks;
+}
+} // namespace detail
+
+// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
+C10_DEFINE_REGISTRY(ORTHooksRegistry, ORTHooksInterface, ORTHooksArgs)
+
+} // namespace at
--- /dev/null
+#pragma once
+
+#include <c10/util/Exception.h>
+#include <c10/util/Registry.h>
+
+constexpr const char* ORT_HELP =
+ " You need to 'import torch_ort' to use the 'ort' device in PyTorch. "
+ "The 'torch_ort' module is provided by the ONNX Runtime itself "
+ "(https://onnxruntime.ai).";
+
+// NB: Class must live in `at` due to limitations of Registry.h.
+namespace at {
+
+struct TORCH_API ORTHooksInterface {
+ // This should never actually be implemented, but it is used to
+ // squelch -Werror=non-virtual-dtor
+ virtual ~ORTHooksInterface() {}
+
+ virtual std::string showConfig() const {
+ TORCH_CHECK(false, "Cannot query detailed ORT version information.", ORT_HELP);
+ }
+};
+
+// NB: dummy argument to suppress "ISO C++11 requires at least one argument
+// for the "..." in a variadic macro"
+struct TORCH_API ORTHooksArgs {};
+
+C10_DECLARE_REGISTRY(ORTHooksRegistry, ORTHooksInterface, ORTHooksArgs);
+#define REGISTER_ORT_HOOKS(clsname) \
+ C10_REGISTER_CLASS(ORTHooksRegistry, clsname, clsname)
+
+namespace detail {
+TORCH_API const ORTHooksInterface& getORTHooks();
+} // namespace detail
+
+} // namespace at
return impl_->is_mlc();
}
+ /// Returns if a `Tensor` is ort tensor.
+ bool is_ort() const {
+ // NB: this is not a native function to avoid dispatching overhead.
+ return impl_->is_ort();
+ }
+
/// Returns if a `Tensor` is vulkan tensor.
bool is_vulkan() const {
// NB: this is not a native function to avoid dispatching overhead.
#include <torch/csrc/jit/runtime/operator.h>
+// NB. These tests use the ORT dispatch key to test backend dispatching
+// machinery, but these tests are not specific to ORT at all. The ORT
+// backend is fully out-of-tree, so it's safe to use this key for
+// in-tree tests.
+
using namespace at;
static int test_int;
Storage(
Storage::use_byte_size_t(),
0,
- at::DataPtr(nullptr, Device(DeviceType::MSNPU, 1)),
+ at::DataPtr(nullptr, Device(DeviceType::ORT, 1)),
nullptr,
false),
- DispatchKey::MSNPU,
+ DispatchKey::ORT,
caffe2::TypeMeta::Make<float>());
return Tensor(std::move(tensor_impl));
}
Tensor add_override(const Tensor & a, const Tensor & b , const Scalar& c) {
- auto out = empty({5, 5}, at::kMSNPU); // Don't return self as-is
+ auto out = empty({5, 5}, at::kORT); // Don't return self as-is
test_int = 2;
return out;
}
return empty_override(size, dtype, layout, device, pin_memory, c10::nullopt);
}
-TORCH_LIBRARY_IMPL(aten, MSNPU, m) {
+TORCH_LIBRARY_IMPL(aten, ORT, m) {
m.impl("aten::empty.memory_format", empty_override);
m.impl("aten::empty_strided", empty_strided_override);
m.impl("aten::add.Tensor", add_override);
}
TEST(BackendExtensionTest, TestRegisterOp) {
- Tensor a = empty({5, 5}, at::kMSNPU);
- ASSERT_EQ(a.device().type(), at::kMSNPU);
+ Tensor a = empty({5, 5}, at::kORT);
+ ASSERT_EQ(a.device().type(), at::kORT);
ASSERT_EQ(a.device().index(), 1);
ASSERT_EQ(a.dtype(), caffe2::TypeMeta::Make<float>());
ASSERT_EQ(test_int, 1);
- Tensor b = empty_like(a, at::kMSNPU);
- ASSERT_EQ(b.device().type(), at::kMSNPU);
+ Tensor b = empty_like(a, at::kORT);
+ ASSERT_EQ(b.device().type(), at::kORT);
ASSERT_EQ(b.device().index(), 1);
ASSERT_EQ(b.dtype(), caffe2::TypeMeta::Make<float>());
add(a, b);
ASSERT_EQ(test_int, 2);
- // Ensure that non-MSNPU operator still works
+ // Ensure that non-ORT operator still works
Tensor d = empty({5, 5}, at::kCPU);
ASSERT_EQ(d.device().type(), at::kCPU);
}
SparseHIP,
SparseVE,
SparseXPU,
- MSNPU,
+ ORT,
XLA,
Vulkan,
Metal,
return Backend::VE;
} else if (t == DispatchKey::FPGA) {
return Backend::FPGA;
- } else if (t == DispatchKey::MSNPU) {
- return Backend::MSNPU;
+ } else if (t == DispatchKey::ORT) {
+ return Backend::ORT;
} else if (t == DispatchKey::XLA || t == DispatchKey::AutogradXLA) {
return Backend::XLA;
} else if (t == DispatchKey::Lazy || t == DispatchKey::AutogradLazy) {
return DispatchKey::VE;
case Backend::FPGA:
return DispatchKey::FPGA;
- case Backend::MSNPU:
- return DispatchKey::MSNPU;
+ case Backend::ORT:
+ return DispatchKey::ORT;
case Backend::XLA:
return DispatchKey::XLA;
case Backend::Lazy:
return DeviceType::VE;
case Backend::FPGA:
return DeviceType::FPGA;
- case Backend::MSNPU:
- return DeviceType::MSNPU;
+ case Backend::ORT:
+ return DeviceType::ORT;
case Backend::XLA:
return DeviceType::XLA;
case Backend::Lazy:
return "FPGA";
case Backend::XPU:
return "XPU";
- case Backend::MSNPU:
- return "MSNPU";
+ case Backend::ORT:
+ return "ORT";
case Backend::XLA:
return "XLA";
case Backend::Lazy:
{"hip", DeviceType::HIP},
{"ve", DeviceType::VE},
{"fpga", DeviceType::FPGA},
- {"msnpu", DeviceType::MSNPU},
+ {"ort", DeviceType::ORT},
{"xla", DeviceType::XLA},
{"lazy", DeviceType::Lazy},
{"vulkan", DeviceType::Vulkan},
}
TORCH_CHECK(
false,
- "Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, ve, msnpu, mlc, xla, lazy, vulkan, meta, hpu device type at start of device string: ",
+ "Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, ve, ort, mlc, xla, lazy, vulkan, meta, hpu device type at start of device string: ",
device_string);
}
enum DeviceStringParsingState { START, INDEX_START, INDEX_REST, ERROR };
return lower_case ? "ve" : "VE";
case DeviceType::FPGA:
return lower_case ? "fpga" : "FPGA";
- case DeviceType::MSNPU:
- return lower_case ? "msnpu" : "MSNPU";
+ case DeviceType::ORT:
+ return lower_case ? "ort" : "ORT";
case DeviceType::XLA:
return lower_case ? "xla" : "XLA";
case DeviceType::Lazy:
case DeviceType::HIP:
case DeviceType::VE:
case DeviceType::FPGA:
- case DeviceType::MSNPU:
+ case DeviceType::ORT:
case DeviceType::XLA:
case DeviceType::Lazy:
case DeviceType::MLC:
IDEEP = 5, // IDEEP.
HIP = 6, // AMD HIP
FPGA = 7, // FPGA
- MSNPU = 8, // MSNPU
+ ORT = 8, // ONNX Runtime / Microsoft
XLA = 9, // XLA / TPU
Vulkan = 10, // Vulkan
Metal = 11, // Metal
constexpr DeviceType kCUDA = DeviceType::CUDA;
constexpr DeviceType kHIP = DeviceType::HIP;
constexpr DeviceType kFPGA = DeviceType::FPGA;
-constexpr DeviceType kMSNPU = DeviceType::MSNPU;
+constexpr DeviceType kORT = DeviceType::ORT;
constexpr DeviceType kXLA = DeviceType::XLA;
constexpr DeviceType kMLC = DeviceType::MLC;
constexpr DeviceType kMeta = DeviceType::Meta;
return "FPGA";
case DispatchKey::XPU:
return "XPU";
- case DispatchKey::MSNPU:
- return "MSNPU";
+ case DispatchKey::ORT:
+ return "ORT";
case DispatchKey::XLA:
return "XLA";
case DispatchKey::Lazy:
// CUDA]
FPGA, // Xilinx support lives out of tree at
// https://gitlab.com/pytorch-complex/vitis_kernels
- MSNPU, // unused externally, but tested at
- // test/cpp_extensions/msnpu_extension.cpp
+
+ // ONNX Runtime, lives out of tree at https://github.com/pytorch/ort and
+ // https://github.com/microsoft/onnxruntime, and is also used to test general
+ // backend/extension machinery in the core. cf:
+ // - test/cpp_extensions/ort_extension.cpp
+ // - test/test_torch.py
+ // - aten/src/ATen/test/extension_backend_test.cpp
+ ORT,
+
XLA, // lives out of tree at https://github.com/pytorch/xla
MLC, // lives out of tree at https://github.com/pytorch/MLCompute
Vulkan,
// Here are reserved backends for user-defined backends, see Note [Private use
// DispatchKey]
- // To see some example about how to use this, check out MSNPU
+ // To see some example about how to use this, check out ORT
PrivateUse1,
PrivateUse2,
PrivateUse3,
DispatchKey::PrivateUse3,
DispatchKey::MLC,
DispatchKey::HPU,
+ DispatchKey::ORT,
DispatchKey::Meta,
});
{DispatchKey::HIP,
DispatchKey::VE,
DispatchKey::FPGA,
- DispatchKey::MSNPU,
+ DispatchKey::ORT,
DispatchKey::Vulkan,
DispatchKey::Metal,
DispatchKey::QuantizedCPU,
return key_set_.has(DispatchKey::MLC);
}
+ bool is_ort() const {
+ return key_set_.has(DispatchKey::ORT);
+ }
+
// TODO: remove this once we don't automatically enabled Autograd dispatch
// keys
// in TensorImpl constructor.
return DispatchKey::VE;
case DeviceType::FPGA:
return DispatchKey::FPGA;
- case DeviceType::MSNPU:
- return DispatchKey::MSNPU;
+ case DeviceType::ORT:
+ return DispatchKey::ORT;
case DeviceType::XLA:
return DispatchKey::XLA;
case DeviceType::Lazy:
case DispatchKey::HPU:
case DispatchKey::AutogradHPU:
return DeviceType::HPU;
-
- // stuff that isn't real
- case DispatchKey::MSNPU:
- return DeviceType::MSNPU;
+ case DispatchKey::ORT:
+ return DeviceType::ORT;
default:
TORCH_CHECK(
false,
PROTO_IDEEP = 5; // IDEEP.
PROTO_HIP = 6; // AMD HIP
PROTO_FPGA = 7; // FPGA
- PROTO_MSNPU = 8; // MSNPU
+ PROTO_ORT = 8; // ONNX Runtime
PROTO_XLA = 9; // XLA / TPU
PROTO_MLC = 10; // ML Compute
// Change the following number if you add more devices in the code.
PROTO_IDEEP = DeviceTypeProto.V(5)
PROTO_HIP = DeviceTypeProto.V(6)
PROTO_FPGA = DeviceTypeProto.V(7)
- PROTO_MSNPU = DeviceTypeProto.V(8)
+ PROTO_ORT = DeviceTypeProto.V(8)
PROTO_XLA = DeviceTypeProto.V(9)
PROTO_MLC = DeviceTypeProto.V(10)
PROTO_COMPILE_TIME_MAX_DEVICE_TYPES = DeviceTypeProto.V(11)
PROTO_IDEEP = DeviceTypeProto.V(5)
PROTO_HIP = DeviceTypeProto.V(6)
PROTO_FPGA = DeviceTypeProto.V(7)
-PROTO_MSNPU = DeviceTypeProto.V(8)
+PROTO_ORT = DeviceTypeProto.V(8)
PROTO_XLA = DeviceTypeProto.V(9)
PROTO_MLC = DeviceTypeProto.V(10)
PROTO_COMPILE_TIME_MAX_DEVICE_TYPES = DeviceTypeProto.V(11)
Storage(
Storage::use_byte_size_t(),
0,
- at::DataPtr(nullptr, Device(DeviceType::MSNPU, 0)),
+ at::DataPtr(nullptr, Device(DeviceType::ORT, 0)),
nullptr,
false),
- DispatchKey::MSNPU,
+ DispatchKey::ORT,
dtype);
// This is a hack to workaround the shape checks in _convolution.
tensor_impl->set_sizes_contiguous(size);
get_tensor(input.dtype(), {}));
}
-TORCH_LIBRARY_IMPL(aten, MSNPU, m) {
+TORCH_LIBRARY_IMPL(aten, ORT, m) {
m.impl("empty.memory_format", empty_override);
m.impl("add.out", add_out_override);
m.impl("convolution_overrideable", fake_convolution);
// TODO: Extend this to exercise multi-device setting. In that case,
// we need to add a thread local variable to track the current device.
-struct MSNPUGuardImpl final : public c10::impl::DeviceGuardImplInterface {
- static constexpr DeviceType static_type = DeviceType::MSNPU;
- MSNPUGuardImpl() {}
- MSNPUGuardImpl(DeviceType t) {
- AT_ASSERT(t == DeviceType::MSNPU);
+struct ORTGuardImpl final : public c10::impl::DeviceGuardImplInterface {
+ static constexpr DeviceType static_type = DeviceType::ORT;
+ ORTGuardImpl() {}
+ ORTGuardImpl(DeviceType t) {
+ AT_ASSERT(t == DeviceType::ORT);
}
DeviceType type() const override {
- return DeviceType::MSNPU;
+ return DeviceType::ORT;
}
Device exchangeDevice(Device d) const override {
- AT_ASSERT(d.type() == DeviceType::MSNPU);
+ AT_ASSERT(d.type() == DeviceType::ORT);
AT_ASSERT(d.index() == 0);
return d;
}
Device getDevice() const override {
- return Device(DeviceType::MSNPU, 0);
+ return Device(DeviceType::ORT, 0);
}
void setDevice(Device d) const override {
- AT_ASSERT(d.type() == DeviceType::MSNPU);
+ AT_ASSERT(d.type() == DeviceType::ORT);
AT_ASSERT(d.index() == 0);
}
void uncheckedSetDevice(Device d) const noexcept override {
}
Stream getStream(Device d) const noexcept override {
- return Stream(Stream::DEFAULT, Device(DeviceType::MSNPU, 0));
+ return Stream(Stream::DEFAULT, Device(DeviceType::ORT, 0));
}
Stream exchangeStream(Stream s) const noexcept override {
- return Stream(Stream::DEFAULT, Device(DeviceType::MSNPU, 0));
+ return Stream(Stream::DEFAULT, Device(DeviceType::ORT, 0));
}
DeviceIndex deviceCount() const noexcept override {
return 1;
const Stream& stream,
const DeviceIndex device_index,
const EventFlag flag) const override {
- TORCH_CHECK(false, "MSNPU backend doesn't support events.");
+ TORCH_CHECK(false, "ORT backend doesn't support events.");
}
void block(
void* event,
const Stream& stream) const override {
- TORCH_CHECK(false, "MSNPU backend doesn't support events.");
+ TORCH_CHECK(false, "ORT backend doesn't support events.");
}
bool queryEvent(void* event) const override {
- TORCH_CHECK(false, "MSNPU backend doesn't support events.");
+ TORCH_CHECK(false, "ORT backend doesn't support events.");
}
void destroyEvent(
void* event,
const DeviceIndex device_index) const noexcept override { }
};
-constexpr DeviceType MSNPUGuardImpl::static_type;
-C10_REGISTER_GUARD_IMPL(MSNPU, MSNPUGuardImpl);
+constexpr DeviceType ORTGuardImpl::static_type;
+C10_REGISTER_GUARD_IMPL(ORT, ORTGuardImpl);
int get_test_int() {
return test_int;
'torch_test_cpp_extension.cpp', ['extension.cpp'],
extra_compile_args=CXX_FLAGS),
CppExtension(
- 'torch_test_cpp_extension.msnpu', ['msnpu_extension.cpp'],
+ 'torch_test_cpp_extension.ort', ['ort_extension.cpp'],
extra_compile_args=CXX_FLAGS),
CppExtension(
'torch_test_cpp_extension.rng', ['rng_extension.cpp'],
try:
if HAS_PYTEST:
cpp_extension = pytest.importorskip("torch_test_cpp_extension.cpp")
- msnpu_extension = pytest.importorskip("torch_test_cpp_extension.msnpu")
+ ort_extension = pytest.importorskip("torch_test_cpp_extension.ort")
rng_extension = pytest.importorskip("torch_test_cpp_extension.rng")
else:
import torch_test_cpp_extension.cpp as cpp_extension
- import torch_test_cpp_extension.msnpu as msnpu_extension
+ import torch_test_cpp_extension.ort as ort_extension
import torch_test_cpp_extension.rng as rng_extension
except ImportError as e:
raise RuntimeError(
self.assertFalse(has_value)
-class TestMSNPUTensor(common.TestCase):
+class TestORTTensor(common.TestCase):
def test_unregistered(self):
a = torch.arange(0, 10, device='cpu')
with self.assertRaisesRegex(RuntimeError, "Could not run"):
- b = torch.arange(0, 10, device='msnpu')
+ b = torch.arange(0, 10, device='ort')
def test_zeros(self):
a = torch.empty(5, 5, device='cpu')
self.assertEqual(a.device, torch.device('cpu'))
- b = torch.empty(5, 5, device='msnpu')
- self.assertEqual(b.device, torch.device('msnpu', 0))
- self.assertEqual(msnpu_extension.get_test_int(), 0)
+ b = torch.empty(5, 5, device='ort')
+ self.assertEqual(b.device, torch.device('ort', 0))
+ self.assertEqual(ort_extension.get_test_int(), 0)
self.assertEqual(torch.get_default_dtype(), b.dtype)
- c = torch.empty((5, 5), dtype=torch.int64, device='msnpu')
- self.assertEqual(msnpu_extension.get_test_int(), 0)
+ c = torch.empty((5, 5), dtype=torch.int64, device='ort')
+ self.assertEqual(ort_extension.get_test_int(), 0)
self.assertEqual(torch.int64, c.dtype)
def test_add(self):
- a = torch.empty(5, 5, device='msnpu', requires_grad=True)
- self.assertEqual(msnpu_extension.get_test_int(), 0)
+ a = torch.empty(5, 5, device='ort', requires_grad=True)
+ self.assertEqual(ort_extension.get_test_int(), 0)
- b = torch.empty(5, 5, device='msnpu')
- self.assertEqual(msnpu_extension.get_test_int(), 0)
+ b = torch.empty(5, 5, device='ort')
+ self.assertEqual(ort_extension.get_test_int(), 0)
c = a + b
- self.assertEqual(msnpu_extension.get_test_int(), 1)
+ self.assertEqual(ort_extension.get_test_int(), 1)
def test_conv_backend_override(self):
# To simplify tests, we use 4d input here to avoid doing view4d( which
# needs more overrides) in _convolution.
- input = torch.empty(2, 4, 10, 2, device='msnpu', requires_grad=True)
- weight = torch.empty(6, 4, 2, 2, device='msnpu', requires_grad=True)
- bias = torch.empty(6, device='msnpu')
+ input = torch.empty(2, 4, 10, 2, device='ort', requires_grad=True)
+ weight = torch.empty(6, 4, 2, 2, device='ort', requires_grad=True)
+ bias = torch.empty(6, device='ort')
# Make sure forward is overriden
out = torch.nn.functional.conv1d(input, weight, bias, 2, 0, 1, 1)
- self.assertEqual(msnpu_extension.get_test_int(), 2)
+ self.assertEqual(ort_extension.get_test_int(), 2)
self.assertEqual(out.shape[0], input.shape[0])
self.assertEqual(out.shape[1], weight.shape[0])
# Double backward is dispatched to _convolution_double_backward.
# It is not tested here as it involves more computation/overrides.
grad = torch.autograd.grad(out, input, out, create_graph=True)
- self.assertEqual(msnpu_extension.get_test_int(), 3)
+ self.assertEqual(ort_extension.get_test_int(), 3)
self.assertEqual(grad[0].shape, input.shape)
self.assertExpectedInline(output_error, '''Found an invalid operator name: abs_BAD''')
# The backend is valid, but doesn't have a valid autograd key. They can't override autograd kernels in that case.
- # Only using MSNPU here because it has a valid backend key but not an autograd key- if this changes we can update the test.
+ # Only using Vulkan here because it has a valid backend key but not an autograd key- if this changes we can update the test.
def test_backend_has_no_autograd_key_but_provides_entries(self):
yaml_str = '''\
-backend: MSNPU
-cpp_namespace: torch_msnpu
+backend: Vulkan
+cpp_namespace: torch_vulkan
supported:
- add
autograd:
def test_backend_autograd_kernel_mismatch_out_functional(self):
yaml_str = '''\
backend: XLA
-cpp_namespace: torch_msnpu
+cpp_namespace: torch_xla
supported:
- add.Tensor
autograd:
def test_backend_autograd_kernel_mismatch_functional_inplace(self):
yaml_str = '''\
backend: XLA
-cpp_namespace: torch_msnpu
+cpp_namespace: torch_xla
supported:
- add.Tensor
autograd:
def test_op_appears_in_supported_and_autograd_lists(self):
yaml_str = '''\
backend: XLA
-cpp_namespace: torch_msnpu
+cpp_namespace: torch_xla
supported:
- add.Tensor
autograd:
# TODO: add torch.* tests when we have proper namespacing on ATen functions
# test_namespace(torch)
- def test_msnpu_error(self):
+ def test_ort_error(self):
with self.assertRaisesRegex(RuntimeError,
- "Could not run 'aten::empty.memory_format' with arguments from the 'MSNPU' backend"):
- torch.zeros(1, device=torch.device('msnpu'))
+ "Could not run 'aten::empty.memory_format' with arguments from the 'ORT' backend"):
+ torch.zeros(1, device=torch.device('ort'))
def test_has_storage(self):
self.assertIsNotNone(torch.tensor([]).storage())
"aten/src/ATen/detail/CPUGuardImpl.cpp",
"aten/src/ATen/detail/CUDAHooksInterface.cpp",
"aten/src/ATen/detail/HIPHooksInterface.cpp",
+ "aten/src/ATen/detail/ORTHooksInterface.cpp",
"aten/src/ATen/metal/Context.cpp",
"aten/src/ATen/native/AutogradComposite.cpp",
"aten/src/ATen/native/BatchLinearAlgebraKernel.cpp",
CUDA = auto()
HIP = auto()
FPGA = auto()
- MSNPU = auto()
+ ORT = auto()
XLA = auto()
Lazy = auto()
Vulkan = auto()
'is_sparse_csr' : ['is_sparse_csr: _bool'],
'is_quantized': ['is_quantized: _bool'],
'is_meta': ['is_meta: _bool'],
+ 'is_ort': ['is_ort: _bool'],
'is_mkldnn': ['is_mkldnn: _bool'],
'is_vulkan': ['is_vulkan: _bool'],
'storage_offset': ['def storage_offset(self) -> _int: ...'],
IDEEP = ...
HIP = ...
FPGA = ...
- MSNPU = ...
+ ORT = ...
XLA = ...
MLC = ...
HPU = ...
# does accurate alias tracking; however, the code below
# doesn't work because of
# https://github.com/pytorch/pytorch/issues/47442
- if self.is_sparse or self.device.type in ['xla', 'mlc', 'meta']:
+ if self.is_sparse or self.device.type in ['xla', 'mlc', 'ort', 'meta']:
new_tensor = self.clone()
else:
new_storage = self.storage().__deepcopy__(memo)
# See Note [Don't serialize hooks]
torch.utils.hooks.warn_if_has_hooks(self)
backward_hooks: Dict[Any, Any] = OrderedDict()
- # Note: Numpy array is chosen to be the rebuild component for XLA Tensor.
+ # Note: Numpy array is chosen to be the rebuild component for XLA, ORT, MLC Tensors.
# We considered a few options:
# 1. CPU tensor can't be used here.
# Otherwise in torch.load CPU storage is reconstructed with randomly
- # initialized data, moved onto XLA device, and then storage is updated
- # to the serialized content. This works perfectly for CPU/CUDA but not XLA.
- # XLA tensor is disconnected with storage so it doesn't get the update.
+ # initialized data, moved onto backend device, and then storage is updated
+ # to the serialized content. This works perfectly for CPU/CUDA but not these backends;
+ # their tensors are disconnected with storage so they don't get the update.
# 2. Python list is not a good fit due to performance reason.
# `tolist()` converts every single element in the tensor into python objects
# and serialize them one by one.
- if self.device.type == 'xla':
- arg_xla = (self.cpu().numpy(),
- self.dtype,
- str(self.device),
- self.requires_grad)
- return (torch._utils._rebuild_xla_tensor, arg_xla)
- if self.device.type == 'mlc':
- arg_mlc = (self.cpu().numpy(),
- self.dtype,
- str(self.device),
- self.requires_grad)
- return (torch._utils._rebuild_mlc_tensor, arg_mlc)
+ if self.device.type in ['xla', 'ort', 'mlc']:
+ return (torch._utils._rebuild_device_tensor_from_numpy, (self.cpu().numpy(),
+ self.dtype,
+ str(self.device),
+ self.requires_grad))
if self.device.type == 'meta':
# NB: This implementation BREAKS storage sharing. Current
# hypothesis is that no one cares for meta tensors.
raise NotImplementedError("rebuilding sparse tensor for layout %s" % (layout))
-def _rebuild_xla_tensor(data, dtype, device, requires_grad):
+def _rebuild_device_tensor_from_numpy(data, dtype, device, requires_grad):
tensor = torch.from_numpy(data).to(dtype=dtype, device=device)
tensor.requires_grad = requires_grad
return tensor
-def _rebuild_mlc_tensor(data, dtype, device, requires_grad):
- tensor = torch.from_numpy(data).to(dtype=dtype, device=device)
- tensor.requires_grad = requires_grad
- return tensor
+# Should not be used, only here to be able to load Tensors serialized with older versions of pytorch
+_rebuild_xla_tensor = _rebuild_device_tensor_from_numpy
+_rebuild_mlc_tensor = _rebuild_device_tensor_from_numpy
def _rebuild_meta_tensor_no_storage(dtype, size, stride, requires_grad):
return Py_TYPE(obj) == &THPDeviceType;
}
-PyObject * THPDevice_New(const at::Device& device);
+TORCH_API PyObject * THPDevice_New(const at::Device& device);
-void THPDevice_init(PyObject *module);
+TORCH_API void THPDevice_init(PyObject *module);
.value("IDEEP", c10::DeviceType::IDEEP)
.value("HIP", c10::DeviceType::HIP)
.value("FPGA", c10::DeviceType::FPGA)
- .value("MSNPU", c10::DeviceType::MSNPU)
+ .value("ORT", c10::DeviceType::ORT)
.value("XLA", c10::DeviceType::XLA)
.value("Lazy", c10::DeviceType::Lazy)
.value("MLC", c10::DeviceType::MLC)
END_HANDLE_TH_ERRORS
}
+PyObject *THPVariable_is_ort(THPVariable *self, void *unused)
+{
+ HANDLE_TH_ERRORS
+ if (check_has_torch_function((PyObject *)self)) {
+ return handle_torch_function_getter(self, "is_ort");
+ }
+ auto& self_ = THPVariable_Unpack(self);
+ return torch::autograd::utils::wrap(self_.is_ort());
+ END_HANDLE_TH_ERRORS
+}
+
PyObject *THPVariable_is_vulkan(THPVariable *self, void *unused)
{
HANDLE_TH_ERRORS
{"is_sparse_csr", (getter)THPVariable_is_sparse_csr, nullptr, nullptr, nullptr},
{"is_mkldnn", (getter)THPVariable_is_mkldnn, nullptr, nullptr, nullptr},
{"is_mlc", (getter)THPVariable_is_mlc, nullptr, nullptr, nullptr},
+ {"is_ort", (getter)THPVariable_is_ort, nullptr, nullptr, nullptr},
{"is_vulkan", (getter)THPVariable_is_vulkan, nullptr, nullptr, nullptr},
{"is_complex", (getter)THPVariable_is_complex, nullptr, nullptr, nullptr},
{"is_quantized", (getter)THPVariable_is_quantized, nullptr, nullptr, nullptr},
{"layout", "prim"}, {"T", "prim"},
{"ndim", "prim"}, {"name", "prim"},
{"real", "aten"}, {"imag", "aten"},
- {"retains_grad", "aten"},
+ {"retains_grad", "aten"}, {"is_ort", "prim"},
}},
{TypeKind::DeviceObjType, {{"type", "prim"}, {"index", "prim"}}}};
auto kind = value_->type()->kind();
},
aliasAnalysisFromSchema()),
OperatorGenerator(
+ TORCH_SELECTIVE_SCHEMA("prim::is_ort(Tensor a) -> bool"),
+ [](Stack* stack) {
+ at::Tensor a;
+ pop(stack, a);
+ push(stack, a.is_ort());
+ },
+ aliasAnalysisFromSchema()),
+ OperatorGenerator(
TORCH_SELECTIVE_SCHEMA("prim::name(Tensor a) -> str?"),
[](Stack* stack) {
at::Tensor a;
return c10::DispatchKey::Meta;
case c10::DeviceType::HIP:
return c10::DispatchKey::HIP;
- case c10::DeviceType::MSNPU:
- return c10::DispatchKey::MSNPU;
+ case c10::DeviceType::ORT:
+ return c10::DispatchKey::ORT;
case c10::DeviceType::HPU:
return c10::DispatchKey::HPU;
default:
Tensor.retains_grad.__get__: lambda self: -1,
Tensor.is_meta.__get__: lambda self: -1,
Tensor.is_mlc.__get__: lambda self: -1,
+ Tensor.is_ort.__get__: lambda self: -1,
Tensor.is_mkldnn.__get__: lambda self: -1,
Tensor.is_quantized.__get__: lambda self: -1,
Tensor.is_sparse.__get__: lambda self: -1,