PyTorch/Caffe2 tensor interop in Python (#17190)
authorDmytro Dzhulgakov <dzhulgakov@fb.com>
Mon, 4 Mar 2019 19:30:43 +0000 (11:30 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 4 Mar 2019 19:34:01 +0000 (11:34 -0800)
Summary:
Because of two separate python extensions with different pybind
instances I have to go through void* conversion. Since it's hidden from
user, it's fine.

New APIs added on C2 side:
- workspace.FetchTorch('blob')
- workspace.Workspace.current.blobs['blob'].to_torch()
- workspace.FeedBlob('blob', pytorch_tensor)

Works on CPU an GPU.

The only glitches are with resizing because of variable/tensor split.
But data sharing works properly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17190

Reviewed By: ezyang

Differential Revision: D14163882

Pulled By: dzhulgakov

fbshipit-source-id: d18e5b8fcae026f393c842a1149e972515732de2

aten/src/ATen/core/Tensor.cpp
c10/test/util/intrusive_ptr_test.cpp
c10/util/intrusive_ptr.h
caffe2/python/pybind_state.cc
caffe2/python/workspace.py
caffe2/python/workspace_test.py
torch/csrc/autograd/python_variable.cpp

index a6489c0..776b4bb 100644 (file)
@@ -14,12 +14,15 @@ void Tensor::enforce_invariants() {
   // supported by ATen
   scalar_type();
   if (defined()) {
-    AT_ASSERTM(
-        impl_->dtype_initialized(),
-        "Partially-initialized tensor not supported by at::Tensor");
-    AT_ASSERTM(
-        impl_->storage_initialized(),
-        "Partially-initialized tensor not supported by at::Tensor");
+    // If it's a variable - we definitely not in C2 land
+    if (!is_variable()) {
+      AT_ASSERTM(
+          impl_->dtype_initialized(),
+          "Partially-initialized tensor not supported by at::Tensor");
+      AT_ASSERTM(
+          impl_->storage_initialized(),
+          "Partially-initialized tensor not supported by at::Tensor");
+    }
     // Ensure LegacyTypeDispatch is initialized. In ATen it's done in tensor
     // factory functions, but when we get a tensor from Caffe2 we might bypass
     // those factory functions.
index 3532cf8..c548cfd 100644 (file)
@@ -1539,6 +1539,7 @@ TEST(IntrusivePtrTest, givenCopyAssignedPtr_whenReassigningCopy_thenIsUnique) {
 TEST(IntrusivePtrTest, givenPtr_whenReleasedAndReclaimed_thenDoesntCrash) {
   intrusive_ptr<SomeClass> obj = make_intrusive<SomeClass>();
   SomeClass* ptr = obj.release();
+  EXPECT_FALSE(obj.defined());
   intrusive_ptr<SomeClass> reclaimed = intrusive_ptr<SomeClass>::reclaim(ptr);
 }
 
@@ -1574,6 +1575,39 @@ TEST(IntrusivePtrTest, givenStackObject_whenReclaimed_thenCrashes) {
   EXPECT_ANY_THROW(ptr = intrusive_ptr<SomeClass>::reclaim(&obj));
 }
 
+TEST(IntrusivePtrTest, givenPtr_whenNonOwningReclaimed_thenDoesntCrash) {
+  intrusive_ptr<SomeClass> obj = make_intrusive<SomeClass>();
+  SomeClass* raw_ptr = obj.get();
+  EXPECT_TRUE(obj.defined());
+  intrusive_ptr<SomeClass> reclaimed =
+      intrusive_ptr<SomeClass>::unsafe_reclaim_from_nonowning(raw_ptr);
+  EXPECT_TRUE(reclaimed.defined());
+  EXPECT_EQ(reclaimed.get(), obj.get());
+}
+
+TEST(
+    IntrusivePtrTest,
+    givenPtr_whenNonOwningReclaimed_thenIsDestructedAtEnd) {
+  bool resourcesReleased = false;
+  bool wasDestructed = false;
+  {
+    intrusive_ptr<DestructableMock> outer;
+    {
+      intrusive_ptr<DestructableMock> inner =
+          make_intrusive<DestructableMock>(&resourcesReleased, &wasDestructed);
+      DestructableMock* raw_ptr = inner.get();
+      outer = intrusive_ptr<DestructableMock>::unsafe_reclaim_from_nonowning(
+          raw_ptr);
+    }
+    // inner is destructed
+    EXPECT_FALSE(resourcesReleased);
+    EXPECT_FALSE(wasDestructed);
+  }
+  // outer is destructed
+  EXPECT_TRUE(resourcesReleased);
+  EXPECT_TRUE(wasDestructed);
+}
+
 namespace {
 template <class T>
 struct IntrusiveAndWeak final {
index 9f5e686..b71cd72 100644 (file)
@@ -73,14 +73,14 @@ class C10_API intrusive_ptr_target {
 // some other compilers don't know about -Wterminate or -Wexceptions and
 // will show a warning about unknown warning options otherwise.
 #ifdef _MSC_VER
-#  pragma warning(push)  
-#  pragma warning(disable: 4297) // function assumed not to throw an exception but does  
-#else  
-#  pragma GCC diagnostic push  
-#  pragma GCC diagnostic ignored "-Wpragmas"  
-#  pragma GCC diagnostic ignored "-Wunknown-warning-option"  
-#  pragma GCC diagnostic ignored "-Wterminate"  
-#  pragma GCC diagnostic ignored "-Wexceptions"  
+#  pragma warning(push)
+#  pragma warning(disable: 4297) // function assumed not to throw an exception but does
+#else
+#  pragma GCC diagnostic push
+#  pragma GCC diagnostic ignored "-Wpragmas"
+#  pragma GCC diagnostic ignored "-Wunknown-warning-option"
+#  pragma GCC diagnostic ignored "-Wterminate"
+#  pragma GCC diagnostic ignored "-Wexceptions"
 #endif
     AT_ASSERTM(
         refcount_.load() == 0,
@@ -89,9 +89,9 @@ class C10_API intrusive_ptr_target {
         weakcount_.load() == 0,
         "Tried to destruct an intrusive_ptr_target that still has weak_intrusive_ptr to it");
 #ifdef _MSC_VER
-#  pragma warning(pop)  
-#else  
-#  pragma GCC diagnostic pop  
+#  pragma warning(pop)
+#else
+#  pragma GCC diagnostic pop
 #endif
   }
 
@@ -362,6 +362,21 @@ class intrusive_ptr final {
 
     return result;
   }
+
+  /**
+   * Turn a **non-owning raw pointer** to an intrusive_ptr.
+   *
+   * This method is potentially dangerous (as it can mess up refcount).
+   */
+  static intrusive_ptr unsafe_reclaim_from_nonowning(TTarget* raw_ptr) {
+    // See Note [Stack allocated intrusive_ptr_target safety]
+    AT_ASSERTM(
+        raw_ptr == NullType::singleton() || raw_ptr->refcount_.load() > 0,
+        "intrusive_ptr: Can only reclaim pointers that are owned by someone");
+    auto ptr = reclaim(raw_ptr); // doesn't increase refcount
+    ptr.retain_();
+    return ptr;
+  }
 };
 
 template <
index a4756a4..95ea452 100644 (file)
@@ -31,6 +31,7 @@
 #include "caffe2/utils/cpuid.h"
 #include "caffe2/utils/proto_convert.h"
 #include "caffe2/utils/string_utils.h"
+#include "torch/csrc/autograd/variable.h"
 
 namespace caffe2 {
 namespace python {
@@ -333,6 +334,19 @@ void addObjectMethods(py::module& m) {
                 blob.meta().name());
             return fetcher->Fetch(blob);
           })
+      .def("is_tensor", [](Blob* blob) { return blob->IsType<Tensor>(); })
+      // return any device Tensor
+      .def(
+          "as_tensor",
+          [](Blob* blob) {
+            CAFFE_ENFORCE(
+                blob->IsType<Tensor>(),
+                "Passed in blob doesn't contain Tensor and instead has ",
+                blob->meta());
+            return py::cast(&blob->Get<Tensor>());
+          },
+          py::return_value_policy::reference_internal)
+      // legacy API that resets tensor to CPUTensor if it's not already
       .def(
           "tensor",
           [](Blob* blob) { return py::cast(BlobGetMutableTensor(blob, CPU)); },
@@ -371,7 +385,14 @@ void addObjectMethods(py::module& m) {
           },
           "Feed an input array or string, with the (optional) DeviceOption",
           py::arg("arg"),
-          py::arg("device_option") = py::none());
+          py::arg("device_option") = py::none())
+      .def("_wrap_tensor_impl", [](Blob* blob, void* ptr) {
+        auto p = c10::intrusive_ptr<c10::TensorImpl, at::UndefinedTensorImpl>::
+            unsafe_reclaim_from_nonowning(static_cast<c10::TensorImpl*>(ptr));
+        AT_CHECK(p.defined(), "Can't wrap undefined tensor");
+        auto at_tensor = at::Tensor::wrap_tensor_impl(std::move(p));
+        BlobSetTensor(blob, Tensor(std::move(at_tensor)));
+      });
 
   py::class_<DLPackWrapper<CPUContext>>(m, "DLPackTensorCPU")
       .def_property_readonly(
@@ -459,6 +480,15 @@ void addObjectMethods(py::module& m) {
           },
           "Initialize this tensor to given shape and data type. "
           "Fail if the given data type cannot be accessed from python.")
+      .def(
+        "_tensor_impl_raw_handle",
+        [](TensorCPU* t) -> void* {
+          auto p = t->getIntrusivePtr();
+          // We return a raw non-owning pointer here, we rely on surrounding
+          // code to keep the original tensor alive
+          return p.get();
+        }
+      )
       .def_property_readonly(
           "_shape", [](const TensorCPU& t) { return t.sizes().vec(); })
       .def("_reshape", [](TensorCPU* t, std::vector<int64_t> dims) {
index cba1ae6..342bdfc 100644 (file)
@@ -364,6 +364,11 @@ def FetchBlob(name):
     return result
 
 
+def FetchTorch(name):
+    ws = C.Workspace.current
+    return ws.blobs[name].to_torch()
+
+
 Int8Tensor = collections.namedtuple(
     'Int8Tensor', ['data', 'scale', 'zero_point']
 )
@@ -701,13 +706,45 @@ Workspace.run = _Workspace_run
 Workspace.feed_blob = _Workspace_feed_blob
 Workspace.remove_blob = _Workspace_remove_blob
 
-
 # C.Blob methods.
 
+
 def _Blob_feed(blob, arg, device_option=None):
+    # conservative type check to avoid unnecessary import
+    if type(arg).__name__ == 'Tensor' and type(arg).__module__ == 'torch':
+        import torch
+        if isinstance(arg, torch.Tensor):
+            assert device_option is None, \
+                "device_option doesn't make sense with PyTorch tensors"
+            handle = torch._C._tensor_impl_raw_handle(arg)
+            blob._wrap_tensor_impl(handle)
+            return True  # _feed() returns True for some reason
     if device_option is not None:
         device_option = StringifyProto(device_option)
     return blob._feed(arg, device_option)
 
 
 C.Blob.feed = _Blob_feed
+
+
+def _Tensor_to_torch(tensor):
+    """
+    PyTorch tensor interop (TensorCPU methods)
+
+    Can be accessed as:
+      workspace.Workspace.current.blobs['foo'].tensor().to_torch()
+    """
+    # avoiding circular dependency
+    import torch
+    handle = tensor._tensor_impl_raw_handle()
+    return torch._C._wrap_tensor_impl(handle)
+
+C.TensorCPU.to_torch = _Tensor_to_torch
+
+
+def _Blob_to_torch(blob):
+    if not blob.is_tensor():
+        raise RuntimeError("Blob has to be a tensor")
+    return blob.as_tensor().to_torch()
+
+C.Blob.to_torch = _Blob_to_torch
index 12b3b35..80fc70b 100644 (file)
@@ -7,6 +7,7 @@ import numpy as np
 import os
 import unittest
 
+import torch
 from caffe2.proto import caffe2_pb2
 from caffe2.python import core, test_util, workspace, model_helper, brew
 
@@ -289,6 +290,27 @@ class TestWorkspace(unittest.TestCase):
         for key in workspace.blobs:
             self.assertEqual(key, "testblob")
 
+    def testTorchInterop(self):
+        workspace.RunOperatorOnce(core.CreateOperator(
+            "ConstantFill", [], "foo", shape=(4,), value=2, dtype=10))
+        t = workspace.FetchTorch("foo")
+        t.resize_(5)
+        t[4] = t[2] = 777
+        np.testing.assert_array_equal(t.numpy(), np.array([2,2,777,2,777]))
+        # this doesn't work because of variable / tensor confusion
+        # the underlying data tensor is not properly reshaped :(
+        np.testing.assert_array_equal(
+            workspace.FetchBlob("foo"), np.array([2,2,777,2]))
+
+        z = torch.ones((4,), dtype=torch.int64)
+        workspace.FeedBlob('bar', z)
+        workspace.RunOperatorOnce(
+            core.CreateOperator("Reshape", ['bar'], ['bar', '_'], shape=(2,2)))
+        z[0,1] = 123
+        np.testing.assert_array_equal(
+            workspace.FetchBlob("bar"), np.array([[1,123],[1,1]]))
+        np.testing.assert_array_equal(z, np.array([[1,123],[1,1]]))
+
 
 class TestMultiWorkspaces(unittest.TestCase):
     def setUp(self):
@@ -349,6 +371,43 @@ class TestWorkspaceGPU(test_util.TestCase):
         self.assertEqual(pattern.shape[0], pattern.shape[1])
         self.assertEqual(pattern.shape[0], workspace.NumGpuDevices())
 
+    @unittest.skipIf(not workspace.has_cuda_support,
+                     "Tensor interop doesn't yet work on ROCm")
+    def testTorchInterop(self):
+        # CUDA has convenient mem stats, let's use them to make sure we didn't
+        # leak memory
+        initial_mem = torch.cuda.memory_allocated()
+        workspace.RunOperatorOnce(core.CreateOperator(
+            "ConstantFill", [], "foo", shape=(4,), value=2, dtype=10,
+            device_option=core.DeviceOption(workspace.GpuDeviceType)))
+        t = workspace.FetchTorch("foo")
+        t.resize_(5)
+        self.assertTrue(t.is_cuda)
+        t[4] = t[2] = 777
+        np.testing.assert_array_equal(
+            t.cpu().numpy(), np.array([2,2,777,2,777]))
+        # this doesn't work because of variable / tensor confusion
+        # the underlying data tensor is not properly reshaped :(
+        np.testing.assert_array_equal(
+            workspace.FetchBlob("foo"), np.array([2,2,777,2]))
+
+        z = torch.ones((4,), dtype=torch.int64, device="cuda")
+        workspace.FeedBlob('bar', z)
+        workspace.RunOperatorOnce(
+            core.CreateOperator("Reshape", ['bar'], ['bar', '_'], shape=(2,2),
+            device_option=core.DeviceOption(workspace.GpuDeviceType)))
+        z[0,1] = 123
+        np.testing.assert_array_equal(
+            workspace.FetchBlob("bar"), np.array([[1,123],[1,1]]))
+        np.testing.assert_array_equal(z.cpu(), np.array([[1,123],[1,1]]))
+
+        self.assertGreater(torch.cuda.memory_allocated(), initial_mem)
+        # clean up everything
+        del t
+        del z
+        workspace.ResetWorkspace()
+        self.assertEqual(torch.cuda.memory_allocated(), initial_mem)
+
 
 @unittest.skipIf(not workspace.C.use_mkldnn, "No MKLDNN support.")
 class TestWorkspaceIDEEP(test_util.TestCase):
index 1d47a0e..fd33033 100644 (file)
 #include <torch/csrc/tensor/python_tensor.h>
 #include <torch/csrc/utils/auto_gil.h>
 #include <torch/csrc/utils/cuda_lazy_init.h>
+#include <torch/csrc/utils/pybind.h>
 #include <torch/csrc/utils/python_strings.h>
 #include <torch/csrc/utils/python_arg_parser.h>
 #include <torch/csrc/utils/tensor_new.h>
 #include <torch/csrc/jit/tracer.h>
 
 #include <ATen/ATen.h>
+#include <pybind11/pybind11.h>
 
 #include <structmember.h>
 #include <memory>
@@ -35,6 +37,8 @@ using namespace at;
 using namespace torch;
 using namespace torch::autograd;
 
+namespace py = pybind11;
+
 PyObject *THPVariableClass = nullptr;
 
 static const char* VOLATILE_WARNING =
@@ -489,6 +493,25 @@ namespace torch { namespace autograd {
 extern PyMethodDef variable_methods[];
 extern void initTorchFunctions(PyObject *module);
 
+void initTensorImplConversion(PyObject* module) {
+  auto m = py::handle(module).cast<py::module>();
+  m.def("_wrap_tensor_impl", [](void* ptr) {
+    auto p = c10::intrusive_ptr<c10::TensorImpl, at::UndefinedTensorImpl>::
+        unsafe_reclaim_from_nonowning(static_cast<c10::TensorImpl*>(ptr));
+    AT_CHECK(p.defined(), "Can't wrap undefined tensor");
+    AT_CHECK(!p->is_variable(), "Can wrap only non-variable tensor");
+    auto tensor = at::Tensor::wrap_tensor_impl(std::move(p));
+    return py::cast(torch::autograd::Variable(
+        torch::autograd::make_variable(std::move(tensor), false)));
+  });
+  // set on the module level to avoid mixing pybind and plain CPython extensions
+  m.def("_tensor_impl_raw_handle", [](torch::autograd::Variable* t) -> void* {
+    auto p = t->data().getIntrusivePtr();
+    // We return a raw non-owning pointer here, we rely on surrounding
+    // code to keep the original tensor alive
+    return p.get();
+  });
+}
 }}
 
 bool THPVariable_initModule(PyObject *module)
@@ -502,5 +525,6 @@ bool THPVariable_initModule(PyObject *module)
   Py_INCREF(&THPVariableType);
   PyModule_AddObject(module, "_TensorBase",   (PyObject *)&THPVariableType);
   torch::autograd::initTorchFunctions(module);
+  torch::autograd::initTensorImplConversion(module);
   return true;
 }