Expose backend extensions to python
authorRoy Li <royboy@fb.com>
Fri, 1 Feb 2019 18:55:00 +0000 (10:55 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 1 Feb 2019 19:00:18 +0000 (11:00 -0800)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/16582

Reviewed By: gchanan

Differential Revision: D13887539

fbshipit-source-id: 8755babf2e3e849af974655f2f3a91740efe977e

c10/core/Device.cpp
test/cpp_extensions/msnpu_extension.cpp [new file with mode: 0644]
test/cpp_extensions/setup.py
test/test_cpp_extensions.py
tools/autograd/gen_python_functions.py
torch/csrc/utils/tensor_layouts.cpp

index ad46ce7..1d2d1ec 100644 (file)
@@ -13,7 +13,7 @@
 namespace c10 {
 namespace {
 DeviceType parse_type(const std::string& device_string) {
-  static const std::array<std::pair<std::string, DeviceType>, 7> types = {{
+  static const std::array<std::pair<std::string, DeviceType>, 8> types = {{
       {"cpu", DeviceType::CPU},
       {"cuda", DeviceType::CUDA},
       {"mkldnn", DeviceType::MKLDNN},
@@ -21,6 +21,7 @@ DeviceType parse_type(const std::string& device_string) {
       {"opencl", DeviceType::OPENCL},
       {"ideep", DeviceType::IDEEP},
       {"hip", DeviceType::HIP},
+      {"msnpu", DeviceType::MSNPU},
   }};
   auto device = std::find_if(
       types.begin(),
@@ -32,7 +33,7 @@ DeviceType parse_type(const std::string& device_string) {
     return device->second;
   }
   AT_ERROR(
-      "Expected one of cpu, cuda, mkldnn, opengl, opencl, ideep, or hip device type at start of device string: ", device_string);
+      "Expected one of cpu, cuda, mkldnn, opengl, opencl, ideep, hip, msnpu device type at start of device string: ", device_string);
 }
 } // namespace
 
diff --git a/test/cpp_extensions/msnpu_extension.cpp b/test/cpp_extensions/msnpu_extension.cpp
new file mode 100644 (file)
index 0000000..7b6d430
--- /dev/null
@@ -0,0 +1,101 @@
+#include <torch/extension.h>
+
+#include <ATen/ExtensionBackendRegistration.h>
+
+using namespace at;
+
+static int test_int;
+
+Tensor get_dummy_tensor() {
+  auto tensor_impl = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(
+      Storage(
+          caffe2::TypeMeta::Make<float>(), 0, at::DataPtr(nullptr, Device(DeviceType::MSNPU, 1)), nullptr, false),
+      MSNPUTensorId(),
+      false);
+  return Tensor(std::move(tensor_impl));
+}
+
+Tensor zeros_override(IntList size, const TensorOptions & options) {
+  test_int = 0;
+  return get_dummy_tensor();
+}
+
+Tensor add_override(const Tensor & a, const Tensor & b , Scalar c) {
+  test_int = 1;
+  return get_dummy_tensor();
+}
+
+Tensor sum_override(const Tensor & self) {
+  test_int = 2;
+  return get_dummy_tensor();
+}
+
+// needed for sum backwards
+Tensor expand_override(const Tensor & self, IntList size, bool implicit) {
+  return get_dummy_tensor();
+}
+
+
+Tensor kl_div_override(
+    const Tensor & self, const Tensor & target, int64_t reduction) {
+  test_int = 3;
+  return get_dummy_tensor();
+}
+
+Tensor kl_div_backward_override(
+    const Tensor & grad_output,
+    const Tensor & self,
+    const Tensor & target,
+    int64_t reduction) {
+  test_int = 4;
+  return get_dummy_tensor();
+}
+
+// numel and ones_like are needed for autograd backwards
+int64_t numel_override(const Tensor & self) {
+  return 1;
+}
+
+Tensor ones_like_override(const Tensor & self, const TensorOptions & options) {
+  return get_dummy_tensor();
+}
+
+void init_msnpu_extension() {
+  register_extension_backend_op(
+    Backend::MSNPU,
+    "zeros(IntList size, TensorOptions options) -> Tensor", &zeros_override);
+  register_extension_backend_op(
+    Backend::MSNPU,
+    "add(Tensor self, Tensor other, Scalar alpha) -> Tensor", &add_override);
+  register_extension_backend_op(
+    Backend::MSNPU,
+    "sum(Tensor self) -> Tensor", &sum_override);
+  register_extension_backend_op(
+    Backend::MSNPU,
+    "expand(Tensor self, IntList size, bool implicit) -> Tensor",
+    &expand_override);
+  register_extension_backend_op(
+    Backend::MSNPU,
+    "kl_div(Tensor self, Tensor target, int64_t reduction) -> Tensor",
+    &kl_div_override);
+  register_extension_backend_op(
+    Backend::MSNPU,
+    "kl_div_backward(Tensor grad_output, Tensor self, Tensor target, int64_t reduction) -> Tensor",
+    &kl_div_backward_override);
+  register_extension_backend_op(
+    Backend::MSNPU,
+    "numel(Tensor self) -> int64_t", &numel_override);
+  register_extension_backend_op(
+    Backend::MSNPU,
+    "ones_like(Tensor self, TensorOptions options) -> Tensor",
+    &ones_like_override);
+}
+
+int get_test_int() {
+  return test_int;
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+  m.def("init_msnpu_extension", &init_msnpu_extension);
+  m.def("get_test_int", &get_test_int);
+}
index 82f990a..b1ea352 100644 (file)
@@ -10,6 +10,9 @@ ext_modules = [
     CppExtension(
         'torch_test_cpp_extension.cpp', ['extension.cpp'],
         extra_compile_args=CXX_FLAGS),
+    CppExtension(
+        'torch_test_cpp_extension.msnpu', ['msnpu_extension.cpp'],
+        extra_compile_args=CXX_FLAGS),
 ]
 
 if torch.cuda.is_available() and CUDA_HOME is not None:
index 73e60a0..0f12d65 100755 (executable)
@@ -13,6 +13,7 @@ from torch.utils.cpp_extension import CUDA_HOME
 
 try:
     import torch_test_cpp_extension.cpp as cpp_extension
+    import torch_test_cpp_extension.msnpu as msnpu_extension
 except ImportError:
     warnings.warn(
         "test_cpp_extensions.py cannot be invoked directly. Run "
@@ -622,5 +623,50 @@ class TestCppExtension(common.TestCase):
             torch.set_default_dtype(initial_default)
 
 
+class TestMSNPUTensor(common.TestCase):
+    @classmethod
+    def setUpClass(cls):
+        msnpu_extension.init_msnpu_extension()
+
+    def test_unregistered(self):
+        a = torch.empty(5, 5, device='cpu')
+        with self.assertRaisesRegex(RuntimeError, "No function registered"):
+            b = torch.empty(5, 5, device='msnpu')
+
+    def test_zeros(self):
+        a = torch.zeros(5, 5, device='cpu')
+        self.assertEqual(a.device, torch.device('cpu'))
+        self.assertEqual(a.sum(), 0)
+
+        b = torch.zeros(5, 5, device='msnpu')
+        self.assertEqual(msnpu_extension.get_test_int(), 0)
+
+    def test_add(self):
+        a = torch.zeros(5, 5, device='msnpu')
+        self.assertEqual(msnpu_extension.get_test_int(), 0)
+
+        b = torch.zeros(5, 5, device='msnpu')
+        self.assertEqual(msnpu_extension.get_test_int(), 0)
+
+        c = torch.add(a, b)
+        self.assertEqual(msnpu_extension.get_test_int(), 1)
+
+    def test_backwards(self):
+        a = torch.zeros(5, 5, device='msnpu', requires_grad=True)
+        self.assertEqual(msnpu_extension.get_test_int(), 0)
+
+        b = torch.zeros(5, 5, device='msnpu')
+        self.assertEqual(msnpu_extension.get_test_int(), 0)
+
+        c = torch.kl_div(a, b)
+        self.assertEqual(msnpu_extension.get_test_int(), 3)
+
+        d = c.sum()
+        self.assertEqual(msnpu_extension.get_test_int(), 2)
+
+        d.backward()
+        self.assertEqual(msnpu_extension.get_test_int(), 4)
+
+
 if __name__ == "__main__":
     common.run_tests()
index aaa786e..cb0ce28 100644 (file)
@@ -361,9 +361,6 @@ def create_python_bindings(python_functions, has_self, is_module=False):
                     '`{}` type is not supported in python_default_init'.format(typename)
                 unpack_with_default = unpack_with_default_methods.get(typename)
                 default_expr = arg.get('python_default_init')
-                # TODO: Type currently maps to ScalarType, figure out a cleaner solution
-                if typename == 'const Type &':
-                    default_expr += '.scalarType()'
                 expr = 'r.{}({}, {})'.format(unpack_with_default, arg_index, default_expr)
             else:
                 unpack = unpack_methods.get(typename, typename.lower())
@@ -584,7 +581,7 @@ def create_python_bindings(python_functions, has_self, is_module=False):
 
         if (is_factory_function and not has_type_input_arg) or has_options_arg:
             default_type = get_type_default(declaration)
-            py_default_dtype = 'self.type()' if is_like_function_with_options else None
+            py_default_dtype = 'self.scalar_type()' if is_like_function_with_options else None
             dtype_arg = {
                 'default': default_type,
                 'dynamic_type': 'Type',
index 0bea3b1..7d984f5 100644 (file)
@@ -23,6 +23,7 @@ void initializeLayouts() {
   // for now, let's look these up by Backend; we could create our own enum in the future.
   registerLayoutObject((THPLayout*)strided_layout, at::Backend::CPU);
   registerLayoutObject((THPLayout*)strided_layout, at::Backend::CUDA);
+  registerLayoutObject((THPLayout*)strided_layout, at::Backend::MSNPU);
 
   PyObject *sparse_coo_layout = THPLayout_New(at::Layout::Sparse, "torch.sparse_coo");
   Py_INCREF(sparse_coo_layout);