Add support for the ONNX Runtime Eager Mode backend (#58248)
authorAaron Bockover <abock@microsoft.com>
Fri, 20 Aug 2021 18:11:47 +0000 (11:11 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 20 Aug 2021 18:17:13 +0000 (11:17 -0700)
Summary:
This PR implements the necessary hooks/stubs/enums/etc for complete ONNX Runtime (ORT) Eager Mode integration. The actual extension will live out of tree at https://github.com/pytorch/ort.

We have been [working on this at Microsoft](https://github.com/microsoft/onnxruntime-pytorch/tree/eager-ort/torch_onnxruntime) for the last few months, and are finally ready to contribute the PyTorch core changes upstream (nothing major or exciting, just the usual boilerplate for adding new backends).

The ORT backend will allow us to ferry [almost] all torch ops into granular ONNX kernels that ORT will eagerly execute against any devices it supports (therefore, we only need a single ORT backend from a PyTorch perspective).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/58248

Reviewed By: astaff

Differential Revision: D30344992

Pulled By: albanD

fbshipit-source-id: 69082b32121246340d686e16653626114b7714b2

38 files changed:
aten/src/ATen/Context.h
aten/src/ATen/Version.cpp
aten/src/ATen/core/aten_interned_strings.h
aten/src/ATen/core/op_registration/README.md
aten/src/ATen/detail/ORTHooksInterface.cpp [new file with mode: 0644]
aten/src/ATen/detail/ORTHooksInterface.h [new file with mode: 0644]
aten/src/ATen/templates/TensorBody.h
aten/src/ATen/test/extension_backend_test.cpp
c10/core/Backend.h
c10/core/Device.cpp
c10/core/DeviceType.cpp
c10/core/DeviceType.h
c10/core/DispatchKey.cpp
c10/core/DispatchKey.h
c10/core/DispatchKeySet.cpp
c10/core/DispatchKeySet.h
c10/core/TensorImpl.h
c10/core/TensorOptions.h
caffe2/proto/caffe2.proto
caffe2/proto/caffe2_pb2.pyi
test/cpp_extensions/ort_extension.cpp [moved from test/cpp_extensions/msnpu_extension.cpp with 78% similarity]
test/cpp_extensions/setup.py
test/test_cpp_extensions_aot.py
test/test_gen_backend_stubs.py
test/test_torch.py
tools/build_variables.bzl
tools/codegen/model.py
tools/pyi/gen_pyi.py
torch/_C/_autograd.pyi
torch/_tensor.py
torch/_utils.py
torch/csrc/Device.h
torch/csrc/autograd/init.cpp
torch/csrc/autograd/python_variable.cpp
torch/csrc/jit/frontend/sugared_value.cpp
torch/csrc/jit/runtime/register_prim_ops.cpp
torch/library.h
torch/overrides.py

index 26f1d11..4a45ac6 100644 (file)
@@ -9,6 +9,7 @@
 #include <ATen/core/LegacyTypeDispatch.h>
 #include <ATen/detail/CUDAHooksInterface.h>
 #include <ATen/detail/HIPHooksInterface.h>
+#include <ATen/detail/ORTHooksInterface.h>
 #include <c10/util/Exception.h>
 #include <c10/core/impl/DeviceGuardImplInterface.h>
 #include <c10/core/QEngine.h>
@@ -79,6 +80,9 @@ class TORCH_API Context {
   static bool hasMLC() {
     return c10::impl::hasDeviceGuardImpl(at::DeviceType::MLC);
   }
+  static bool hasORT() {
+    return c10::impl::hasDeviceGuardImpl(at::DeviceType::ORT);
+  }
   // defined in header so that getNonVariableType has ability to inline
   // call_once check. getNonVariableType is called fairly frequently
   THCState* lazyInitCUDA() {
@@ -292,6 +296,10 @@ static inline bool hasMLC() {
   return globalContext().hasMLC();
 }
 
+static inline bool hasORT() {
+  return globalContext().hasORT();
+}
+
 // Despite its name, this function returns the number of *CUDA* GPUs.
 static inline size_t getNumGPUs() {
   // WARNING: DO NOT ADD LOGIC TO HANDLE OTHER DEVICE TYPES TO THIS
index 750c90b..0c0ea61 100644 (file)
@@ -184,6 +184,10 @@ std::string show_config() {
     ss << detail::getCUDAHooks().showConfig();
   }
 
+  if (hasORT()) {
+    ss << detail::getORTHooks().showConfig();
+  }
+
   ss << "  - Build settings: ";
   for (const auto& pair : caffe2::GetBuildOptions()) {
     if (!pair.second.empty()) {
index 584e3db..abdf397 100644 (file)
@@ -405,6 +405,7 @@ _(aten, is_complex) \
 _(aten, is_contiguous) \
 _(aten, is_cuda) \
 _(aten, is_mlc) \
+_(aten, is_ort) \
 _(aten, is_distributed) \
 _(aten, is_floating_point) \
 _(aten, is_inference) \
index edd9f91..5605e96 100644 (file)
@@ -13,13 +13,13 @@ There’s four main use cases
 * You’re writing a new operator that isn’t supposed to be part of the public PyTorch API.
 * You’re writing a new operator but don’t want to change the core pytorch code base, say you’re developing a shared library with operators.
 * You’re writing a C++ extension for PyTorch or you’re using inline c++ in your .py model files.
-* You’re writing a backend library like XLA or MSNPU that adds new kernels to all operators defined in `native_functions.yaml`.
+* You’re writing a backend library like XLA or ORT that adds new kernels to all operators defined in `native_functions.yaml`.
 
 For these use cases, the custom operator API is the better solution.
 
 ### What is the price for using the custom operator API instead of `native_functions.yaml`?
 
-If you’re just using the custom operator API to add new kernels for existing operators (e.g. the XLA/MSNPU example above), then you’re fine and don’t pay any price. If, however, you define a new operator purely using the custom op API, i.e. your operator never shows up in `native_functions.yaml`, then you need to be aware of a few caveats.
+If you’re just using the custom operator API to add new kernels for existing operators (e.g. the XLA/ORT example above), then you’re fine and don’t pay any price. If, however, you define a new operator purely using the custom op API, i.e. your operator never shows up in `native_functions.yaml`, then you need to be aware of a few caveats.
 
 * It will not get a C++ API generated. There will not be `Tensor::your_op()` methods or `at::your_op()` functions to call your operator.
 * The API for calling the operator from Python looks a little bit different. It needs to be called through `torch.ops.your_op()` instead of `torch._C`.
diff --git a/aten/src/ATen/detail/ORTHooksInterface.cpp b/aten/src/ATen/detail/ORTHooksInterface.cpp
new file mode 100644 (file)
index 0000000..33f7093
--- /dev/null
@@ -0,0 +1,31 @@
+#include <ATen/detail/ORTHooksInterface.h>
+
+#include <c10/util/Exception.h>
+
+#include <cstddef>
+#include <memory>
+#include <mutex>
+
+namespace at {
+namespace detail {
+
+// See getCUDAHooks for some more commentary
+const ORTHooksInterface& getORTHooks() {
+  static std::unique_ptr<ORTHooksInterface> ort_hooks;
+  static std::once_flag once;
+  std::call_once(once, [] {
+    ort_hooks = ORTHooksRegistry()->Create("ORTHooks", {});
+    if (!ort_hooks) {
+      ort_hooks =
+          // NOLINTNEXTLINE(modernize-make-unique)
+          std::unique_ptr<ORTHooksInterface>(new ORTHooksInterface());
+    }
+  });
+  return *ort_hooks;
+}
+} // namespace detail
+
+// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
+C10_DEFINE_REGISTRY(ORTHooksRegistry, ORTHooksInterface, ORTHooksArgs)
+
+} // namespace at
diff --git a/aten/src/ATen/detail/ORTHooksInterface.h b/aten/src/ATen/detail/ORTHooksInterface.h
new file mode 100644 (file)
index 0000000..caee55c
--- /dev/null
@@ -0,0 +1,36 @@
+#pragma once
+
+#include <c10/util/Exception.h>
+#include <c10/util/Registry.h>
+
+constexpr const char* ORT_HELP =
+  " You need to 'import torch_ort' to use the 'ort' device in PyTorch. "
+  "The 'torch_ort' module is provided by the ONNX Runtime itself "
+  "(https://onnxruntime.ai).";
+
+// NB: Class must live in `at` due to limitations of Registry.h.
+namespace at {
+
+struct TORCH_API ORTHooksInterface {
+  // This should never actually be implemented, but it is used to
+  // squelch -Werror=non-virtual-dtor
+  virtual ~ORTHooksInterface() {}
+
+  virtual std::string showConfig() const {
+    TORCH_CHECK(false, "Cannot query detailed ORT version information.", ORT_HELP);
+  }
+};
+
+// NB: dummy argument to suppress "ISO C++11 requires at least one argument
+// for the "..." in a variadic macro"
+struct TORCH_API ORTHooksArgs {};
+
+C10_DECLARE_REGISTRY(ORTHooksRegistry, ORTHooksInterface, ORTHooksArgs);
+#define REGISTER_ORT_HOOKS(clsname) \
+  C10_REGISTER_CLASS(ORTHooksRegistry, clsname, clsname)
+
+namespace detail {
+TORCH_API const ORTHooksInterface& getORTHooks();
+} // namespace detail
+
+} // namespace at
index be14980..a6e6583 100644 (file)
@@ -492,6 +492,12 @@ class TORCH_API Tensor {
     return impl_->is_mlc();
   }
 
+  /// Returns if a `Tensor` is ort tensor.
+  bool is_ort() const {
+    // NB: this is not a native function to avoid dispatching overhead.
+    return impl_->is_ort();
+  }
+
   /// Returns if a `Tensor` is vulkan tensor.
   bool is_vulkan() const {
     // NB: this is not a native function to avoid dispatching overhead.
index 531507e..9b215a9 100644 (file)
@@ -6,6 +6,11 @@
 
 #include <torch/csrc/jit/runtime/operator.h>
 
+// NB. These tests use the ORT dispatch key to test backend dispatching
+// machinery, but these tests are not specific to ORT at all. The ORT
+// backend is fully out-of-tree, so it's safe to use this key for
+// in-tree tests.
+
 using namespace at;
 
 static int test_int;
@@ -17,16 +22,16 @@ Tensor empty_override(IntArrayRef size, c10::optional<ScalarType> dtype, c10::op
       Storage(
           Storage::use_byte_size_t(),
           0,
-          at::DataPtr(nullptr, Device(DeviceType::MSNPU, 1)),
+          at::DataPtr(nullptr, Device(DeviceType::ORT, 1)),
           nullptr,
           false),
-      DispatchKey::MSNPU,
+      DispatchKey::ORT,
       caffe2::TypeMeta::Make<float>());
   return Tensor(std::move(tensor_impl));
 }
 
 Tensor add_override(const Tensor & a, const Tensor & b , const Scalar& c) {
-  auto out = empty({5, 5}, at::kMSNPU);  // Don't return self as-is
+  auto out = empty({5, 5}, at::kORT);  // Don't return self as-is
   test_int = 2;
   return out;
 }
@@ -42,28 +47,28 @@ Tensor empty_strided_override(
   return empty_override(size, dtype, layout, device, pin_memory, c10::nullopt);
 }
 
-TORCH_LIBRARY_IMPL(aten, MSNPU, m) {
+TORCH_LIBRARY_IMPL(aten, ORT, m) {
   m.impl("aten::empty.memory_format",  empty_override);
   m.impl("aten::empty_strided",        empty_strided_override);
   m.impl("aten::add.Tensor",           add_override);
 }
 
 TEST(BackendExtensionTest, TestRegisterOp) {
-  Tensor a = empty({5, 5}, at::kMSNPU);
-  ASSERT_EQ(a.device().type(), at::kMSNPU);
+  Tensor a = empty({5, 5}, at::kORT);
+  ASSERT_EQ(a.device().type(), at::kORT);
   ASSERT_EQ(a.device().index(), 1);
   ASSERT_EQ(a.dtype(), caffe2::TypeMeta::Make<float>());
   ASSERT_EQ(test_int, 1);
 
-  Tensor b = empty_like(a, at::kMSNPU);
-  ASSERT_EQ(b.device().type(), at::kMSNPU);
+  Tensor b = empty_like(a, at::kORT);
+  ASSERT_EQ(b.device().type(), at::kORT);
   ASSERT_EQ(b.device().index(), 1);
   ASSERT_EQ(b.dtype(), caffe2::TypeMeta::Make<float>());
 
   add(a, b);
   ASSERT_EQ(test_int, 2);
 
-  // Ensure that non-MSNPU operator still works
+  // Ensure that non-ORT operator still works
   Tensor d = empty({5, 5}, at::kCPU);
   ASSERT_EQ(d.device().type(), at::kCPU);
 }
index 2f07134..e17a1bc 100644 (file)
@@ -40,7 +40,7 @@ enum class Backend {
   SparseHIP,
   SparseVE,
   SparseXPU,
-  MSNPU,
+  ORT,
   XLA,
   Vulkan,
   Metal,
@@ -66,8 +66,8 @@ static inline Backend dispatchKeyToBackend(DispatchKey t) {
     return Backend::VE;
   } else if (t == DispatchKey::FPGA) {
     return Backend::FPGA;
-  } else if (t == DispatchKey::MSNPU) {
-    return Backend::MSNPU;
+  } else if (t == DispatchKey::ORT) {
+    return Backend::ORT;
   } else if (t == DispatchKey::XLA || t == DispatchKey::AutogradXLA) {
     return Backend::XLA;
   } else if (t == DispatchKey::Lazy || t == DispatchKey::AutogradLazy) {
@@ -123,8 +123,8 @@ static inline DispatchKey backendToDispatchKey(Backend b) {
       return DispatchKey::VE;
     case Backend::FPGA:
       return DispatchKey::FPGA;
-    case Backend::MSNPU:
-      return DispatchKey::MSNPU;
+    case Backend::ORT:
+      return DispatchKey::ORT;
     case Backend::XLA:
       return DispatchKey::XLA;
     case Backend::Lazy:
@@ -178,8 +178,8 @@ static inline DeviceType backendToDeviceType(Backend b) {
       return DeviceType::VE;
     case Backend::FPGA:
       return DeviceType::FPGA;
-    case Backend::MSNPU:
-      return DeviceType::MSNPU;
+    case Backend::ORT:
+      return DeviceType::ORT;
     case Backend::XLA:
       return DeviceType::XLA;
     case Backend::Lazy:
@@ -235,8 +235,8 @@ static inline const char* toString(Backend b) {
       return "FPGA";
     case Backend::XPU:
       return "XPU";
-    case Backend::MSNPU:
-      return "MSNPU";
+    case Backend::ORT:
+      return "ORT";
     case Backend::XLA:
       return "XLA";
     case Backend::Lazy:
index 2709c29..2531e39 100644 (file)
@@ -28,7 +28,7 @@ DeviceType parse_type(const std::string& device_string) {
           {"hip", DeviceType::HIP},
           {"ve", DeviceType::VE},
           {"fpga", DeviceType::FPGA},
-          {"msnpu", DeviceType::MSNPU},
+          {"ort", DeviceType::ORT},
           {"xla", DeviceType::XLA},
           {"lazy", DeviceType::Lazy},
           {"vulkan", DeviceType::Vulkan},
@@ -47,7 +47,7 @@ DeviceType parse_type(const std::string& device_string) {
   }
   TORCH_CHECK(
       false,
-      "Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, ve, msnpu, mlc, xla, lazy, vulkan, meta, hpu device type at start of device string: ",
+      "Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, ve, ort, mlc, xla, lazy, vulkan, meta, hpu device type at start of device string: ",
       device_string);
 }
 enum DeviceStringParsingState { START, INDEX_START, INDEX_REST, ERROR };
index 4ff9398..4635acd 100644 (file)
@@ -25,8 +25,8 @@ std::string DeviceTypeName(DeviceType d, bool lower_case) {
       return lower_case ? "ve" : "VE";
     case DeviceType::FPGA:
       return lower_case ? "fpga" : "FPGA";
-    case DeviceType::MSNPU:
-      return lower_case ? "msnpu" : "MSNPU";
+    case DeviceType::ORT:
+      return lower_case ? "ort" : "ORT";
     case DeviceType::XLA:
       return lower_case ? "xla" : "XLA";
     case DeviceType::Lazy:
@@ -75,7 +75,7 @@ bool isValidDeviceType(DeviceType d) {
     case DeviceType::HIP:
     case DeviceType::VE:
     case DeviceType::FPGA:
-    case DeviceType::MSNPU:
+    case DeviceType::ORT:
     case DeviceType::XLA:
     case DeviceType::Lazy:
     case DeviceType::MLC:
index 2ae028d..c6bd569 100644 (file)
@@ -21,7 +21,7 @@ enum class DeviceType : int8_t {
   IDEEP = 5, // IDEEP.
   HIP = 6, // AMD HIP
   FPGA = 7, // FPGA
-  MSNPU = 8, // MSNPU
+  ORT = 8, // ONNX Runtime / Microsoft
   XLA = 9, // XLA / TPU
   Vulkan = 10, // Vulkan
   Metal = 11, // Metal
@@ -42,7 +42,7 @@ constexpr DeviceType kCPU = DeviceType::CPU;
 constexpr DeviceType kCUDA = DeviceType::CUDA;
 constexpr DeviceType kHIP = DeviceType::HIP;
 constexpr DeviceType kFPGA = DeviceType::FPGA;
-constexpr DeviceType kMSNPU = DeviceType::MSNPU;
+constexpr DeviceType kORT = DeviceType::ORT;
 constexpr DeviceType kXLA = DeviceType::XLA;
 constexpr DeviceType kMLC = DeviceType::MLC;
 constexpr DeviceType kMeta = DeviceType::Meta;
index 5c41448..18aa4fc 100644 (file)
@@ -19,8 +19,8 @@ const char* toString(DispatchKey t) {
       return "FPGA";
     case DispatchKey::XPU:
       return "XPU";
-    case DispatchKey::MSNPU:
-      return "MSNPU";
+    case DispatchKey::ORT:
+      return "ORT";
     case DispatchKey::XLA:
       return "XLA";
     case DispatchKey::Lazy:
index 5b20a1c..07222b7 100644 (file)
@@ -59,8 +59,15 @@ enum class DispatchKey : uint8_t {
   // CUDA]
   FPGA, // Xilinx support lives out of tree at
   // https://gitlab.com/pytorch-complex/vitis_kernels
-  MSNPU, // unused externally, but tested at
-  // test/cpp_extensions/msnpu_extension.cpp
+
+  // ONNX Runtime, lives out of tree at https://github.com/pytorch/ort and
+  // https://github.com/microsoft/onnxruntime, and is also used to test general
+  // backend/extension machinery in the core. cf:
+  // - test/cpp_extensions/ort_extension.cpp
+  // - test/test_torch.py
+  // - aten/src/ATen/test/extension_backend_test.cpp
+  ORT,
+
   XLA, // lives out of tree at https://github.com/pytorch/xla
   MLC, // lives out of tree at https://github.com/pytorch/MLCompute
   Vulkan,
@@ -114,7 +121,7 @@ enum class DispatchKey : uint8_t {
 
   // Here are reserved backends for user-defined backends, see Note [Private use
   // DispatchKey]
-  // To see some example about how to use this, check out MSNPU
+  // To see some example about how to use this, check out ORT
   PrivateUse1,
   PrivateUse2,
   PrivateUse3,
index b796114..404acc7 100644 (file)
@@ -19,6 +19,7 @@ constexpr DispatchKeySet backend_dispatch_keyset = autogradother_backends |
         DispatchKey::PrivateUse3,
         DispatchKey::MLC,
         DispatchKey::HPU,
+        DispatchKey::ORT,
         DispatchKey::Meta,
     });
 
index 0d3a25e..b1f5f04 100644 (file)
@@ -248,7 +248,7 @@ constexpr DispatchKeySet autogradother_backends = DispatchKeySet(
     {DispatchKey::HIP,
      DispatchKey::VE,
      DispatchKey::FPGA,
-     DispatchKey::MSNPU,
+     DispatchKey::ORT,
      DispatchKey::Vulkan,
      DispatchKey::Metal,
      DispatchKey::QuantizedCPU,
index 65d7af3..7051e36 100644 (file)
@@ -873,6 +873,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
     return key_set_.has(DispatchKey::MLC);
   }
 
+  bool is_ort() const {
+    return key_set_.has(DispatchKey::ORT);
+  }
+
   // TODO: remove this once we don't automatically enabled Autograd dispatch
   // keys
   //       in TensorImpl constructor.
index fff9433..287b2fa 100644 (file)
@@ -663,8 +663,8 @@ inline DispatchKey computeDispatchKey(
           return DispatchKey::VE;
         case DeviceType::FPGA:
           return DispatchKey::FPGA;
-        case DeviceType::MSNPU:
-          return DispatchKey::MSNPU;
+        case DeviceType::ORT:
+          return DispatchKey::ORT;
         case DeviceType::XLA:
           return DispatchKey::XLA;
         case DeviceType::Lazy:
@@ -790,10 +790,8 @@ inline DeviceType dispatchKeyToDeviceType(DispatchKey dispatch_key) {
     case DispatchKey::HPU:
     case DispatchKey::AutogradHPU:
       return DeviceType::HPU;
-
-    // stuff that isn't real
-    case DispatchKey::MSNPU:
-      return DeviceType::MSNPU;
+    case DispatchKey::ORT:
+      return DeviceType::ORT;
     default:
       TORCH_CHECK(
           false,
index 6e05577..90a2020 100644 (file)
@@ -219,7 +219,7 @@ enum DeviceTypeProto {
   PROTO_IDEEP = 5;                  // IDEEP.
   PROTO_HIP = 6;                    // AMD HIP
   PROTO_FPGA = 7;                   // FPGA
-  PROTO_MSNPU = 8;                  // MSNPU
+  PROTO_ORT = 8;                    // ONNX Runtime
   PROTO_XLA = 9;                    // XLA / TPU
   PROTO_MLC = 10;                   // ML Compute
   // Change the following number if you add more devices in the code.
index 1258664..f7f4430 100644 (file)
@@ -23,7 +23,7 @@ class _DeviceTypeProto(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapp
     PROTO_IDEEP = DeviceTypeProto.V(5)
     PROTO_HIP = DeviceTypeProto.V(6)
     PROTO_FPGA = DeviceTypeProto.V(7)
-    PROTO_MSNPU = DeviceTypeProto.V(8)
+    PROTO_ORT = DeviceTypeProto.V(8)
     PROTO_XLA = DeviceTypeProto.V(9)
     PROTO_MLC = DeviceTypeProto.V(10)
     PROTO_COMPILE_TIME_MAX_DEVICE_TYPES = DeviceTypeProto.V(11)
@@ -37,7 +37,7 @@ PROTO_OPENCL = DeviceTypeProto.V(4)
 PROTO_IDEEP = DeviceTypeProto.V(5)
 PROTO_HIP = DeviceTypeProto.V(6)
 PROTO_FPGA = DeviceTypeProto.V(7)
-PROTO_MSNPU = DeviceTypeProto.V(8)
+PROTO_ORT = DeviceTypeProto.V(8)
 PROTO_XLA = DeviceTypeProto.V(9)
 PROTO_MLC = DeviceTypeProto.V(10)
 PROTO_COMPILE_TIME_MAX_DEVICE_TYPES = DeviceTypeProto.V(11)
similarity index 78%
rename from test/cpp_extensions/msnpu_extension.cpp
rename to test/cpp_extensions/ort_extension.cpp
index e47347c..b646f3b 100644 (file)
@@ -10,10 +10,10 @@ Tensor get_tensor(caffe2::TypeMeta dtype, IntArrayRef size) {
       Storage(
           Storage::use_byte_size_t(),
           0,
-          at::DataPtr(nullptr, Device(DeviceType::MSNPU, 0)),
+          at::DataPtr(nullptr, Device(DeviceType::ORT, 0)),
           nullptr,
           false),
-      DispatchKey::MSNPU,
+      DispatchKey::ORT,
       dtype);
   // This is a hack to workaround the shape checks in _convolution.
   tensor_impl->set_sizes_contiguous(size);
@@ -52,7 +52,7 @@ std::tuple<Tensor,Tensor,Tensor> fake_convolution_backward(
             get_tensor(input.dtype(), {}));
 }
 
-TORCH_LIBRARY_IMPL(aten, MSNPU, m) {
+TORCH_LIBRARY_IMPL(aten, ORT, m) {
   m.impl("empty.memory_format",                empty_override);
   m.impl("add.out",                            add_out_override);
   m.impl("convolution_overrideable",           fake_convolution);
@@ -61,34 +61,34 @@ TORCH_LIBRARY_IMPL(aten, MSNPU, m) {
 
 // TODO: Extend this to exercise multi-device setting.  In that case,
 // we need to add a thread local variable to track the current device.
-struct MSNPUGuardImpl final : public c10::impl::DeviceGuardImplInterface {
-  static constexpr DeviceType static_type = DeviceType::MSNPU;
-  MSNPUGuardImpl() {}
-  MSNPUGuardImpl(DeviceType t) {
-    AT_ASSERT(t == DeviceType::MSNPU);
+struct ORTGuardImpl final : public c10::impl::DeviceGuardImplInterface {
+  static constexpr DeviceType static_type = DeviceType::ORT;
+  ORTGuardImpl() {}
+  ORTGuardImpl(DeviceType t) {
+    AT_ASSERT(t == DeviceType::ORT);
   }
   DeviceType type() const override {
-    return DeviceType::MSNPU;
+    return DeviceType::ORT;
   }
   Device exchangeDevice(Device d) const override {
-    AT_ASSERT(d.type() == DeviceType::MSNPU);
+    AT_ASSERT(d.type() == DeviceType::ORT);
     AT_ASSERT(d.index() == 0);
     return d;
   }
   Device getDevice() const override {
-    return Device(DeviceType::MSNPU, 0);
+    return Device(DeviceType::ORT, 0);
   }
   void setDevice(Device d) const override {
-    AT_ASSERT(d.type() == DeviceType::MSNPU);
+    AT_ASSERT(d.type() == DeviceType::ORT);
     AT_ASSERT(d.index() == 0);
   }
   void uncheckedSetDevice(Device d) const noexcept override {
   }
   Stream getStream(Device d) const noexcept override {
-    return Stream(Stream::DEFAULT, Device(DeviceType::MSNPU, 0));
+    return Stream(Stream::DEFAULT, Device(DeviceType::ORT, 0));
   }
   Stream exchangeStream(Stream s) const noexcept override {
-    return Stream(Stream::DEFAULT, Device(DeviceType::MSNPU, 0));
+    return Stream(Stream::DEFAULT, Device(DeviceType::ORT, 0));
   }
   DeviceIndex deviceCount() const noexcept override {
     return 1;
@@ -99,23 +99,23 @@ struct MSNPUGuardImpl final : public c10::impl::DeviceGuardImplInterface {
     const Stream& stream,
     const DeviceIndex device_index,
     const EventFlag flag) const override {
-    TORCH_CHECK(false, "MSNPU backend doesn't support events.");
+    TORCH_CHECK(false, "ORT backend doesn't support events.");
   }
   void block(
     void* event,
     const Stream& stream) const override {
-    TORCH_CHECK(false, "MSNPU backend doesn't support events.");
+    TORCH_CHECK(false, "ORT backend doesn't support events.");
   }
   bool queryEvent(void* event) const override {
-    TORCH_CHECK(false, "MSNPU backend doesn't support events.");
+    TORCH_CHECK(false, "ORT backend doesn't support events.");
   }
   void destroyEvent(
     void* event,
     const DeviceIndex device_index) const noexcept override { }
 };
 
-constexpr DeviceType MSNPUGuardImpl::static_type;
-C10_REGISTER_GUARD_IMPL(MSNPU, MSNPUGuardImpl);
+constexpr DeviceType ORTGuardImpl::static_type;
+C10_REGISTER_GUARD_IMPL(ORT, ORTGuardImpl);
 
 int get_test_int() {
   return test_int;
index 8f77938..7888d0e 100644 (file)
@@ -21,7 +21,7 @@ ext_modules = [
         'torch_test_cpp_extension.cpp', ['extension.cpp'],
         extra_compile_args=CXX_FLAGS),
     CppExtension(
-        'torch_test_cpp_extension.msnpu', ['msnpu_extension.cpp'],
+        'torch_test_cpp_extension.ort', ['ort_extension.cpp'],
         extra_compile_args=CXX_FLAGS),
     CppExtension(
         'torch_test_cpp_extension.rng', ['rng_extension.cpp'],
index 307df0e..cf35e6b 100644 (file)
@@ -19,11 +19,11 @@ except ImportError as e:
 try:
     if HAS_PYTEST:
         cpp_extension = pytest.importorskip("torch_test_cpp_extension.cpp")
-        msnpu_extension = pytest.importorskip("torch_test_cpp_extension.msnpu")
+        ort_extension = pytest.importorskip("torch_test_cpp_extension.ort")
         rng_extension = pytest.importorskip("torch_test_cpp_extension.rng")
     else:
         import torch_test_cpp_extension.cpp as cpp_extension
-        import torch_test_cpp_extension.msnpu as msnpu_extension
+        import torch_test_cpp_extension.ort as ort_extension
         import torch_test_cpp_extension.rng as rng_extension
 except ImportError as e:
     raise RuntimeError(
@@ -100,45 +100,45 @@ class TestCppExtensionAOT(common.TestCase):
         self.assertFalse(has_value)
 
 
-class TestMSNPUTensor(common.TestCase):
+class TestORTTensor(common.TestCase):
     def test_unregistered(self):
         a = torch.arange(0, 10, device='cpu')
         with self.assertRaisesRegex(RuntimeError, "Could not run"):
-            b = torch.arange(0, 10, device='msnpu')
+            b = torch.arange(0, 10, device='ort')
 
     def test_zeros(self):
         a = torch.empty(5, 5, device='cpu')
         self.assertEqual(a.device, torch.device('cpu'))
 
-        b = torch.empty(5, 5, device='msnpu')
-        self.assertEqual(b.device, torch.device('msnpu', 0))
-        self.assertEqual(msnpu_extension.get_test_int(), 0)
+        b = torch.empty(5, 5, device='ort')
+        self.assertEqual(b.device, torch.device('ort', 0))
+        self.assertEqual(ort_extension.get_test_int(), 0)
         self.assertEqual(torch.get_default_dtype(), b.dtype)
 
-        c = torch.empty((5, 5), dtype=torch.int64, device='msnpu')
-        self.assertEqual(msnpu_extension.get_test_int(), 0)
+        c = torch.empty((5, 5), dtype=torch.int64, device='ort')
+        self.assertEqual(ort_extension.get_test_int(), 0)
         self.assertEqual(torch.int64, c.dtype)
 
     def test_add(self):
-        a = torch.empty(5, 5, device='msnpu', requires_grad=True)
-        self.assertEqual(msnpu_extension.get_test_int(), 0)
+        a = torch.empty(5, 5, device='ort', requires_grad=True)
+        self.assertEqual(ort_extension.get_test_int(), 0)
 
-        b = torch.empty(5, 5, device='msnpu')
-        self.assertEqual(msnpu_extension.get_test_int(), 0)
+        b = torch.empty(5, 5, device='ort')
+        self.assertEqual(ort_extension.get_test_int(), 0)
 
         c = a + b
-        self.assertEqual(msnpu_extension.get_test_int(), 1)
+        self.assertEqual(ort_extension.get_test_int(), 1)
 
     def test_conv_backend_override(self):
         # To simplify tests, we use 4d input here to avoid doing view4d( which
         # needs more overrides) in _convolution.
-        input = torch.empty(2, 4, 10, 2, device='msnpu', requires_grad=True)
-        weight = torch.empty(6, 4, 2, 2, device='msnpu', requires_grad=True)
-        bias = torch.empty(6, device='msnpu')
+        input = torch.empty(2, 4, 10, 2, device='ort', requires_grad=True)
+        weight = torch.empty(6, 4, 2, 2, device='ort', requires_grad=True)
+        bias = torch.empty(6, device='ort')
 
         # Make sure forward is overriden
         out = torch.nn.functional.conv1d(input, weight, bias, 2, 0, 1, 1)
-        self.assertEqual(msnpu_extension.get_test_int(), 2)
+        self.assertEqual(ort_extension.get_test_int(), 2)
         self.assertEqual(out.shape[0], input.shape[0])
         self.assertEqual(out.shape[1], weight.shape[0])
 
@@ -146,7 +146,7 @@ class TestMSNPUTensor(common.TestCase):
         # Double backward is dispatched to _convolution_double_backward.
         # It is not tested here as it involves more computation/overrides.
         grad = torch.autograd.grad(out, input, out, create_graph=True)
-        self.assertEqual(msnpu_extension.get_test_int(), 3)
+        self.assertEqual(ort_extension.get_test_int(), 3)
         self.assertEqual(grad[0].shape, input.shape)
 
 
index e1a66c6..f788a8f 100644 (file)
@@ -138,11 +138,11 @@ supported:
         self.assertExpectedInline(output_error, '''Found an invalid operator name: abs_BAD''')
 
     # The backend is valid, but doesn't have a valid autograd key. They can't override autograd kernels in that case.
-    # Only using MSNPU here because it has a valid backend key but not an autograd key- if this changes we can update the test.
+    # Only using Vulkan here because it has a valid backend key but not an autograd key- if this changes we can update the test.
     def test_backend_has_no_autograd_key_but_provides_entries(self):
         yaml_str = '''\
-backend: MSNPU
-cpp_namespace: torch_msnpu
+backend: Vulkan
+cpp_namespace: torch_vulkan
 supported:
 - add
 autograd:
@@ -155,7 +155,7 @@ autograd:
     def test_backend_autograd_kernel_mismatch_out_functional(self):
         yaml_str = '''\
 backend: XLA
-cpp_namespace: torch_msnpu
+cpp_namespace: torch_xla
 supported:
 - add.Tensor
 autograd:
@@ -168,7 +168,7 @@ autograd:
     def test_backend_autograd_kernel_mismatch_functional_inplace(self):
         yaml_str = '''\
 backend: XLA
-cpp_namespace: torch_msnpu
+cpp_namespace: torch_xla
 supported:
 - add.Tensor
 autograd:
@@ -182,7 +182,7 @@ autograd:
     def test_op_appears_in_supported_and_autograd_lists(self):
         yaml_str = '''\
 backend: XLA
-cpp_namespace: torch_msnpu
+cpp_namespace: torch_xla
 supported:
 - add.Tensor
 autograd:
index 515052a..d0f631a 100644 (file)
@@ -221,10 +221,10 @@ class AbstractTestCases:
             # TODO: add torch.* tests when we have proper namespacing on ATen functions
             # test_namespace(torch)
 
-        def test_msnpu_error(self):
+        def test_ort_error(self):
             with self.assertRaisesRegex(RuntimeError,
-                                        "Could not run 'aten::empty.memory_format' with arguments from the 'MSNPU' backend"):
-                torch.zeros(1, device=torch.device('msnpu'))
+                                        "Could not run 'aten::empty.memory_format' with arguments from the 'ORT' backend"):
+                torch.zeros(1, device=torch.device('ort'))
 
         def test_has_storage(self):
             self.assertIsNotNone(torch.tensor([]).storage())
index 89697b4..e20d973 100644 (file)
@@ -829,6 +829,7 @@ aten_cpu_source_non_codegen_list = [
     "aten/src/ATen/detail/CPUGuardImpl.cpp",
     "aten/src/ATen/detail/CUDAHooksInterface.cpp",
     "aten/src/ATen/detail/HIPHooksInterface.cpp",
+    "aten/src/ATen/detail/ORTHooksInterface.cpp",
     "aten/src/ATen/metal/Context.cpp",
     "aten/src/ATen/native/AutogradComposite.cpp",
     "aten/src/ATen/native/BatchLinearAlgebraKernel.cpp",
index d6f02d5..4f82b70 100644 (file)
@@ -56,7 +56,7 @@ class DispatchKey(Enum):
     CUDA = auto()
     HIP = auto()
     FPGA = auto()
-    MSNPU = auto()
+    ORT = auto()
     XLA = auto()
     Lazy = auto()
     Vulkan = auto()
index 4f39fec..882b7f1 100644 (file)
@@ -469,6 +469,7 @@ def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str, fm: FileManager) -
         'is_sparse_csr' : ['is_sparse_csr: _bool'],
         'is_quantized': ['is_quantized: _bool'],
         'is_meta': ['is_meta: _bool'],
+        'is_ort': ['is_ort: _bool'],
         'is_mkldnn': ['is_mkldnn: _bool'],
         'is_vulkan': ['is_vulkan: _bool'],
         'storage_offset': ['def storage_offset(self) -> _int: ...'],
index 6468eb5..7ffb618 100644 (file)
@@ -24,7 +24,7 @@ class DeviceType(Enum):
     IDEEP = ...
     HIP = ...
     FPGA = ...
-    MSNPU = ...
+    ORT = ...
     XLA = ...
     MLC = ...
     HPU = ...
index 2bd617d..b4cee9a 100644 (file)
@@ -90,7 +90,7 @@ class Tensor(torch._C._TensorBase):
             # does accurate alias tracking; however, the code below
             # doesn't work because of
             # https://github.com/pytorch/pytorch/issues/47442
-            if self.is_sparse or self.device.type in ['xla', 'mlc', 'meta']:
+            if self.is_sparse or self.device.type in ['xla', 'mlc', 'ort', 'meta']:
                 new_tensor = self.clone()
             else:
                 new_storage = self.storage().__deepcopy__(memo)
@@ -153,28 +153,21 @@ class Tensor(torch._C._TensorBase):
         # See Note [Don't serialize hooks]
         torch.utils.hooks.warn_if_has_hooks(self)
         backward_hooks: Dict[Any, Any] = OrderedDict()
-        # Note: Numpy array is chosen to be the rebuild component for XLA Tensor.
+        # Note: Numpy array is chosen to be the rebuild component for XLA, ORT, MLC Tensors.
         # We considered a few options:
         # 1. CPU tensor can't be used here.
         #    Otherwise in torch.load CPU storage is reconstructed with randomly
-        #    initialized data, moved onto XLA device, and then storage is updated
-        #    to the serialized content. This works perfectly for CPU/CUDA but not XLA.
-        #    XLA tensor is disconnected with storage so it doesn't get the update.
+        #    initialized data, moved onto backend device, and then storage is updated
+        #    to the serialized content. This works perfectly for CPU/CUDA but not these backends;
+        #    their tensors are disconnected with storage so they don't get the update.
         # 2. Python list is not a good fit due to performance reason.
         #    `tolist()` converts every single element in the tensor into python objects
         #    and serialize them one by one.
-        if self.device.type == 'xla':
-            arg_xla = (self.cpu().numpy(),
-                       self.dtype,
-                       str(self.device),
-                       self.requires_grad)
-            return (torch._utils._rebuild_xla_tensor, arg_xla)
-        if self.device.type == 'mlc':
-            arg_mlc = (self.cpu().numpy(),
-                       self.dtype,
-                       str(self.device),
-                       self.requires_grad)
-            return (torch._utils._rebuild_mlc_tensor, arg_mlc)
+        if self.device.type in ['xla', 'ort', 'mlc']:
+            return (torch._utils._rebuild_device_tensor_from_numpy, (self.cpu().numpy(),
+                                                                     self.dtype,
+                                                                     str(self.device),
+                                                                     self.requires_grad))
         if self.device.type == 'meta':
             # NB: This implementation BREAKS storage sharing.  Current
             # hypothesis is that no one cares for meta tensors.
index 210b0cd..75e9075 100644 (file)
@@ -173,16 +173,15 @@ def _rebuild_sparse_tensor(layout, data):
     raise NotImplementedError("rebuilding sparse tensor for layout %s" % (layout))
 
 
-def _rebuild_xla_tensor(data, dtype, device, requires_grad):
+def _rebuild_device_tensor_from_numpy(data, dtype, device, requires_grad):
     tensor = torch.from_numpy(data).to(dtype=dtype, device=device)
     tensor.requires_grad = requires_grad
     return tensor
 
 
-def _rebuild_mlc_tensor(data, dtype, device, requires_grad):
-    tensor = torch.from_numpy(data).to(dtype=dtype, device=device)
-    tensor.requires_grad = requires_grad
-    return tensor
+# Should not be used, only here to be able to load Tensors serialized with older versions of pytorch
+_rebuild_xla_tensor = _rebuild_device_tensor_from_numpy
+_rebuild_mlc_tensor = _rebuild_device_tensor_from_numpy
 
 
 def _rebuild_meta_tensor_no_storage(dtype, size, stride, requires_grad):
index b1f18dc..3286812 100644 (file)
@@ -17,6 +17,6 @@ inline bool THPDevice_Check(PyObject *obj) {
   return Py_TYPE(obj) == &THPDeviceType;
 }
 
-PyObject * THPDevice_New(const at::Device& device);
+TORCH_API PyObject * THPDevice_New(const at::Device& device);
 
-void THPDevice_init(PyObject *module);
+TORCH_API void THPDevice_init(PyObject *module);
index 2eacbf1..697ca87 100644 (file)
@@ -114,7 +114,7 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) {
       .value("IDEEP", c10::DeviceType::IDEEP)
       .value("HIP", c10::DeviceType::HIP)
       .value("FPGA", c10::DeviceType::FPGA)
-      .value("MSNPU", c10::DeviceType::MSNPU)
+      .value("ORT", c10::DeviceType::ORT)
       .value("XLA", c10::DeviceType::XLA)
       .value("Lazy", c10::DeviceType::Lazy)
       .value("MLC", c10::DeviceType::MLC)
index 3035846..50d6eb9 100644 (file)
@@ -834,6 +834,17 @@ PyObject *THPVariable_is_mlc(THPVariable *self, void *unused)
   END_HANDLE_TH_ERRORS
 }
 
+PyObject *THPVariable_is_ort(THPVariable *self, void *unused)
+{
+  HANDLE_TH_ERRORS
+  if (check_has_torch_function((PyObject *)self)) {
+    return handle_torch_function_getter(self, "is_ort");
+  }
+  auto& self_ = THPVariable_Unpack(self);
+  return torch::autograd::utils::wrap(self_.is_ort());
+  END_HANDLE_TH_ERRORS
+}
+
 PyObject *THPVariable_is_vulkan(THPVariable *self, void *unused)
 {
   HANDLE_TH_ERRORS
@@ -980,6 +991,7 @@ static struct PyGetSetDef THPVariable_properties[] = {
   {"is_sparse_csr", (getter)THPVariable_is_sparse_csr, nullptr, nullptr, nullptr},
   {"is_mkldnn", (getter)THPVariable_is_mkldnn, nullptr, nullptr, nullptr},
   {"is_mlc", (getter)THPVariable_is_mlc, nullptr, nullptr, nullptr},
+  {"is_ort", (getter)THPVariable_is_ort, nullptr, nullptr, nullptr},
   {"is_vulkan", (getter)THPVariable_is_vulkan, nullptr, nullptr, nullptr},
   {"is_complex", (getter)THPVariable_is_complex, nullptr, nullptr, nullptr},
   {"is_quantized", (getter)THPVariable_is_quantized, nullptr, nullptr, nullptr},
index ab70d6c..a5f0007 100644 (file)
@@ -119,7 +119,7 @@ std::shared_ptr<SugaredValue> SimpleValue::attr(
            {"layout", "prim"},        {"T", "prim"},
            {"ndim", "prim"},          {"name", "prim"},
            {"real", "aten"},          {"imag", "aten"},
-           {"retains_grad", "aten"},
+           {"retains_grad", "aten"},  {"is_ort", "prim"},
        }},
       {TypeKind::DeviceObjType, {{"type", "prim"}, {"index", "prim"}}}};
   auto kind = value_->type()->kind();
index a61cb48..984073f 100644 (file)
@@ -2212,6 +2212,14 @@ RegisterOperators reg1(
          },
          aliasAnalysisFromSchema()),
      OperatorGenerator(
+         TORCH_SELECTIVE_SCHEMA("prim::is_ort(Tensor a) -> bool"),
+         [](Stack* stack) {
+           at::Tensor a;
+           pop(stack, a);
+           push(stack, a.is_ort());
+         },
+         aliasAnalysisFromSchema()),
+     OperatorGenerator(
          TORCH_SELECTIVE_SCHEMA("prim::name(Tensor a) -> str?"),
          [](Stack* stack) {
            at::Tensor a;
index ce2bb92..a873b42 100644 (file)
@@ -317,8 +317,8 @@ inline CppFunction dispatch(c10::DeviceType type, Func&& raw_f) {
         return c10::DispatchKey::Meta;
       case c10::DeviceType::HIP:
         return c10::DispatchKey::HIP;
-      case c10::DeviceType::MSNPU:
-        return c10::DispatchKey::MSNPU;
+      case c10::DeviceType::ORT:
+        return c10::DispatchKey::ORT;
       case c10::DeviceType::HPU:
         return c10::DispatchKey::HPU;
       default:
index 5a0ea6c..09748b9 100644 (file)
@@ -1030,6 +1030,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
         Tensor.retains_grad.__get__: lambda self: -1,
         Tensor.is_meta.__get__: lambda self: -1,
         Tensor.is_mlc.__get__: lambda self: -1,
+        Tensor.is_ort.__get__: lambda self: -1,
         Tensor.is_mkldnn.__get__: lambda self: -1,
         Tensor.is_quantized.__get__: lambda self: -1,
         Tensor.is_sparse.__get__: lambda self: -1,