From dec116e96f4af069c8c6568e8f01202a3cafc412 Mon Sep 17 00:00:00 2001 From: Dmytro Dzhulgakov Date: Mon, 4 Mar 2019 11:30:43 -0800 Subject: [PATCH] PyTorch/Caffe2 tensor interop in Python (#17190) Summary: Because of two separate python extensions with different pybind instances I have to go through void* conversion. Since it's hidden from user, it's fine. New APIs added on C2 side: - workspace.FetchTorch('blob') - workspace.Workspace.current.blobs['blob'].to_torch() - workspace.FeedBlob('blob', pytorch_tensor) Works on CPU an GPU. The only glitches are with resizing because of variable/tensor split. But data sharing works properly. Pull Request resolved: https://github.com/pytorch/pytorch/pull/17190 Reviewed By: ezyang Differential Revision: D14163882 Pulled By: dzhulgakov fbshipit-source-id: d18e5b8fcae026f393c842a1149e972515732de2 --- aten/src/ATen/core/Tensor.cpp | 15 +++++---- c10/test/util/intrusive_ptr_test.cpp | 34 +++++++++++++++++++ c10/util/intrusive_ptr.h | 37 +++++++++++++++------ caffe2/python/pybind_state.cc | 32 +++++++++++++++++- caffe2/python/workspace.py | 39 +++++++++++++++++++++- caffe2/python/workspace_test.py | 59 +++++++++++++++++++++++++++++++++ torch/csrc/autograd/python_variable.cpp | 24 ++++++++++++++ 7 files changed, 221 insertions(+), 19 deletions(-) diff --git a/aten/src/ATen/core/Tensor.cpp b/aten/src/ATen/core/Tensor.cpp index a6489c0..776b4bb 100644 --- a/aten/src/ATen/core/Tensor.cpp +++ b/aten/src/ATen/core/Tensor.cpp @@ -14,12 +14,15 @@ void Tensor::enforce_invariants() { // supported by ATen scalar_type(); if (defined()) { - AT_ASSERTM( - impl_->dtype_initialized(), - "Partially-initialized tensor not supported by at::Tensor"); - AT_ASSERTM( - impl_->storage_initialized(), - "Partially-initialized tensor not supported by at::Tensor"); + // If it's a variable - we definitely not in C2 land + if (!is_variable()) { + AT_ASSERTM( + impl_->dtype_initialized(), + "Partially-initialized tensor not supported by at::Tensor"); + AT_ASSERTM( + impl_->storage_initialized(), + "Partially-initialized tensor not supported by at::Tensor"); + } // Ensure LegacyTypeDispatch is initialized. In ATen it's done in tensor // factory functions, but when we get a tensor from Caffe2 we might bypass // those factory functions. diff --git a/c10/test/util/intrusive_ptr_test.cpp b/c10/test/util/intrusive_ptr_test.cpp index 3532cf8..c548cfd 100644 --- a/c10/test/util/intrusive_ptr_test.cpp +++ b/c10/test/util/intrusive_ptr_test.cpp @@ -1539,6 +1539,7 @@ TEST(IntrusivePtrTest, givenCopyAssignedPtr_whenReassigningCopy_thenIsUnique) { TEST(IntrusivePtrTest, givenPtr_whenReleasedAndReclaimed_thenDoesntCrash) { intrusive_ptr obj = make_intrusive(); SomeClass* ptr = obj.release(); + EXPECT_FALSE(obj.defined()); intrusive_ptr reclaimed = intrusive_ptr::reclaim(ptr); } @@ -1574,6 +1575,39 @@ TEST(IntrusivePtrTest, givenStackObject_whenReclaimed_thenCrashes) { EXPECT_ANY_THROW(ptr = intrusive_ptr::reclaim(&obj)); } +TEST(IntrusivePtrTest, givenPtr_whenNonOwningReclaimed_thenDoesntCrash) { + intrusive_ptr obj = make_intrusive(); + SomeClass* raw_ptr = obj.get(); + EXPECT_TRUE(obj.defined()); + intrusive_ptr reclaimed = + intrusive_ptr::unsafe_reclaim_from_nonowning(raw_ptr); + EXPECT_TRUE(reclaimed.defined()); + EXPECT_EQ(reclaimed.get(), obj.get()); +} + +TEST( + IntrusivePtrTest, + givenPtr_whenNonOwningReclaimed_thenIsDestructedAtEnd) { + bool resourcesReleased = false; + bool wasDestructed = false; + { + intrusive_ptr outer; + { + intrusive_ptr inner = + make_intrusive(&resourcesReleased, &wasDestructed); + DestructableMock* raw_ptr = inner.get(); + outer = intrusive_ptr::unsafe_reclaim_from_nonowning( + raw_ptr); + } + // inner is destructed + EXPECT_FALSE(resourcesReleased); + EXPECT_FALSE(wasDestructed); + } + // outer is destructed + EXPECT_TRUE(resourcesReleased); + EXPECT_TRUE(wasDestructed); +} + namespace { template struct IntrusiveAndWeak final { diff --git a/c10/util/intrusive_ptr.h b/c10/util/intrusive_ptr.h index 9f5e686..b71cd72 100644 --- a/c10/util/intrusive_ptr.h +++ b/c10/util/intrusive_ptr.h @@ -73,14 +73,14 @@ class C10_API intrusive_ptr_target { // some other compilers don't know about -Wterminate or -Wexceptions and // will show a warning about unknown warning options otherwise. #ifdef _MSC_VER -# pragma warning(push) -# pragma warning(disable: 4297) // function assumed not to throw an exception but does -#else -# pragma GCC diagnostic push -# pragma GCC diagnostic ignored "-Wpragmas" -# pragma GCC diagnostic ignored "-Wunknown-warning-option" -# pragma GCC diagnostic ignored "-Wterminate" -# pragma GCC diagnostic ignored "-Wexceptions" +# pragma warning(push) +# pragma warning(disable: 4297) // function assumed not to throw an exception but does +#else +# pragma GCC diagnostic push +# pragma GCC diagnostic ignored "-Wpragmas" +# pragma GCC diagnostic ignored "-Wunknown-warning-option" +# pragma GCC diagnostic ignored "-Wterminate" +# pragma GCC diagnostic ignored "-Wexceptions" #endif AT_ASSERTM( refcount_.load() == 0, @@ -89,9 +89,9 @@ class C10_API intrusive_ptr_target { weakcount_.load() == 0, "Tried to destruct an intrusive_ptr_target that still has weak_intrusive_ptr to it"); #ifdef _MSC_VER -# pragma warning(pop) -#else -# pragma GCC diagnostic pop +# pragma warning(pop) +#else +# pragma GCC diagnostic pop #endif } @@ -362,6 +362,21 @@ class intrusive_ptr final { return result; } + + /** + * Turn a **non-owning raw pointer** to an intrusive_ptr. + * + * This method is potentially dangerous (as it can mess up refcount). + */ + static intrusive_ptr unsafe_reclaim_from_nonowning(TTarget* raw_ptr) { + // See Note [Stack allocated intrusive_ptr_target safety] + AT_ASSERTM( + raw_ptr == NullType::singleton() || raw_ptr->refcount_.load() > 0, + "intrusive_ptr: Can only reclaim pointers that are owned by someone"); + auto ptr = reclaim(raw_ptr); // doesn't increase refcount + ptr.retain_(); + return ptr; + } }; template < diff --git a/caffe2/python/pybind_state.cc b/caffe2/python/pybind_state.cc index a4756a4..95ea452 100644 --- a/caffe2/python/pybind_state.cc +++ b/caffe2/python/pybind_state.cc @@ -31,6 +31,7 @@ #include "caffe2/utils/cpuid.h" #include "caffe2/utils/proto_convert.h" #include "caffe2/utils/string_utils.h" +#include "torch/csrc/autograd/variable.h" namespace caffe2 { namespace python { @@ -333,6 +334,19 @@ void addObjectMethods(py::module& m) { blob.meta().name()); return fetcher->Fetch(blob); }) + .def("is_tensor", [](Blob* blob) { return blob->IsType(); }) + // return any device Tensor + .def( + "as_tensor", + [](Blob* blob) { + CAFFE_ENFORCE( + blob->IsType(), + "Passed in blob doesn't contain Tensor and instead has ", + blob->meta()); + return py::cast(&blob->Get()); + }, + py::return_value_policy::reference_internal) + // legacy API that resets tensor to CPUTensor if it's not already .def( "tensor", [](Blob* blob) { return py::cast(BlobGetMutableTensor(blob, CPU)); }, @@ -371,7 +385,14 @@ void addObjectMethods(py::module& m) { }, "Feed an input array or string, with the (optional) DeviceOption", py::arg("arg"), - py::arg("device_option") = py::none()); + py::arg("device_option") = py::none()) + .def("_wrap_tensor_impl", [](Blob* blob, void* ptr) { + auto p = c10::intrusive_ptr:: + unsafe_reclaim_from_nonowning(static_cast(ptr)); + AT_CHECK(p.defined(), "Can't wrap undefined tensor"); + auto at_tensor = at::Tensor::wrap_tensor_impl(std::move(p)); + BlobSetTensor(blob, Tensor(std::move(at_tensor))); + }); py::class_>(m, "DLPackTensorCPU") .def_property_readonly( @@ -459,6 +480,15 @@ void addObjectMethods(py::module& m) { }, "Initialize this tensor to given shape and data type. " "Fail if the given data type cannot be accessed from python.") + .def( + "_tensor_impl_raw_handle", + [](TensorCPU* t) -> void* { + auto p = t->getIntrusivePtr(); + // We return a raw non-owning pointer here, we rely on surrounding + // code to keep the original tensor alive + return p.get(); + } + ) .def_property_readonly( "_shape", [](const TensorCPU& t) { return t.sizes().vec(); }) .def("_reshape", [](TensorCPU* t, std::vector dims) { diff --git a/caffe2/python/workspace.py b/caffe2/python/workspace.py index cba1ae6..342bdfc 100644 --- a/caffe2/python/workspace.py +++ b/caffe2/python/workspace.py @@ -364,6 +364,11 @@ def FetchBlob(name): return result +def FetchTorch(name): + ws = C.Workspace.current + return ws.blobs[name].to_torch() + + Int8Tensor = collections.namedtuple( 'Int8Tensor', ['data', 'scale', 'zero_point'] ) @@ -701,13 +706,45 @@ Workspace.run = _Workspace_run Workspace.feed_blob = _Workspace_feed_blob Workspace.remove_blob = _Workspace_remove_blob - # C.Blob methods. + def _Blob_feed(blob, arg, device_option=None): + # conservative type check to avoid unnecessary import + if type(arg).__name__ == 'Tensor' and type(arg).__module__ == 'torch': + import torch + if isinstance(arg, torch.Tensor): + assert device_option is None, \ + "device_option doesn't make sense with PyTorch tensors" + handle = torch._C._tensor_impl_raw_handle(arg) + blob._wrap_tensor_impl(handle) + return True # _feed() returns True for some reason if device_option is not None: device_option = StringifyProto(device_option) return blob._feed(arg, device_option) C.Blob.feed = _Blob_feed + + +def _Tensor_to_torch(tensor): + """ + PyTorch tensor interop (TensorCPU methods) + + Can be accessed as: + workspace.Workspace.current.blobs['foo'].tensor().to_torch() + """ + # avoiding circular dependency + import torch + handle = tensor._tensor_impl_raw_handle() + return torch._C._wrap_tensor_impl(handle) + +C.TensorCPU.to_torch = _Tensor_to_torch + + +def _Blob_to_torch(blob): + if not blob.is_tensor(): + raise RuntimeError("Blob has to be a tensor") + return blob.as_tensor().to_torch() + +C.Blob.to_torch = _Blob_to_torch diff --git a/caffe2/python/workspace_test.py b/caffe2/python/workspace_test.py index 12b3b35..80fc70b 100644 --- a/caffe2/python/workspace_test.py +++ b/caffe2/python/workspace_test.py @@ -7,6 +7,7 @@ import numpy as np import os import unittest +import torch from caffe2.proto import caffe2_pb2 from caffe2.python import core, test_util, workspace, model_helper, brew @@ -289,6 +290,27 @@ class TestWorkspace(unittest.TestCase): for key in workspace.blobs: self.assertEqual(key, "testblob") + def testTorchInterop(self): + workspace.RunOperatorOnce(core.CreateOperator( + "ConstantFill", [], "foo", shape=(4,), value=2, dtype=10)) + t = workspace.FetchTorch("foo") + t.resize_(5) + t[4] = t[2] = 777 + np.testing.assert_array_equal(t.numpy(), np.array([2,2,777,2,777])) + # this doesn't work because of variable / tensor confusion + # the underlying data tensor is not properly reshaped :( + np.testing.assert_array_equal( + workspace.FetchBlob("foo"), np.array([2,2,777,2])) + + z = torch.ones((4,), dtype=torch.int64) + workspace.FeedBlob('bar', z) + workspace.RunOperatorOnce( + core.CreateOperator("Reshape", ['bar'], ['bar', '_'], shape=(2,2))) + z[0,1] = 123 + np.testing.assert_array_equal( + workspace.FetchBlob("bar"), np.array([[1,123],[1,1]])) + np.testing.assert_array_equal(z, np.array([[1,123],[1,1]])) + class TestMultiWorkspaces(unittest.TestCase): def setUp(self): @@ -349,6 +371,43 @@ class TestWorkspaceGPU(test_util.TestCase): self.assertEqual(pattern.shape[0], pattern.shape[1]) self.assertEqual(pattern.shape[0], workspace.NumGpuDevices()) + @unittest.skipIf(not workspace.has_cuda_support, + "Tensor interop doesn't yet work on ROCm") + def testTorchInterop(self): + # CUDA has convenient mem stats, let's use them to make sure we didn't + # leak memory + initial_mem = torch.cuda.memory_allocated() + workspace.RunOperatorOnce(core.CreateOperator( + "ConstantFill", [], "foo", shape=(4,), value=2, dtype=10, + device_option=core.DeviceOption(workspace.GpuDeviceType))) + t = workspace.FetchTorch("foo") + t.resize_(5) + self.assertTrue(t.is_cuda) + t[4] = t[2] = 777 + np.testing.assert_array_equal( + t.cpu().numpy(), np.array([2,2,777,2,777])) + # this doesn't work because of variable / tensor confusion + # the underlying data tensor is not properly reshaped :( + np.testing.assert_array_equal( + workspace.FetchBlob("foo"), np.array([2,2,777,2])) + + z = torch.ones((4,), dtype=torch.int64, device="cuda") + workspace.FeedBlob('bar', z) + workspace.RunOperatorOnce( + core.CreateOperator("Reshape", ['bar'], ['bar', '_'], shape=(2,2), + device_option=core.DeviceOption(workspace.GpuDeviceType))) + z[0,1] = 123 + np.testing.assert_array_equal( + workspace.FetchBlob("bar"), np.array([[1,123],[1,1]])) + np.testing.assert_array_equal(z.cpu(), np.array([[1,123],[1,1]])) + + self.assertGreater(torch.cuda.memory_allocated(), initial_mem) + # clean up everything + del t + del z + workspace.ResetWorkspace() + self.assertEqual(torch.cuda.memory_allocated(), initial_mem) + @unittest.skipIf(not workspace.C.use_mkldnn, "No MKLDNN support.") class TestWorkspaceIDEEP(test_util.TestCase): diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index 1d47a0e..fd33033 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -19,12 +19,14 @@ #include #include #include +#include #include #include #include #include #include +#include #include #include @@ -35,6 +37,8 @@ using namespace at; using namespace torch; using namespace torch::autograd; +namespace py = pybind11; + PyObject *THPVariableClass = nullptr; static const char* VOLATILE_WARNING = @@ -489,6 +493,25 @@ namespace torch { namespace autograd { extern PyMethodDef variable_methods[]; extern void initTorchFunctions(PyObject *module); +void initTensorImplConversion(PyObject* module) { + auto m = py::handle(module).cast(); + m.def("_wrap_tensor_impl", [](void* ptr) { + auto p = c10::intrusive_ptr:: + unsafe_reclaim_from_nonowning(static_cast(ptr)); + AT_CHECK(p.defined(), "Can't wrap undefined tensor"); + AT_CHECK(!p->is_variable(), "Can wrap only non-variable tensor"); + auto tensor = at::Tensor::wrap_tensor_impl(std::move(p)); + return py::cast(torch::autograd::Variable( + torch::autograd::make_variable(std::move(tensor), false))); + }); + // set on the module level to avoid mixing pybind and plain CPython extensions + m.def("_tensor_impl_raw_handle", [](torch::autograd::Variable* t) -> void* { + auto p = t->data().getIntrusivePtr(); + // We return a raw non-owning pointer here, we rely on surrounding + // code to keep the original tensor alive + return p.get(); + }); +} }} bool THPVariable_initModule(PyObject *module) @@ -502,5 +525,6 @@ bool THPVariable_initModule(PyObject *module) Py_INCREF(&THPVariableType); PyModule_AddObject(module, "_TensorBase", (PyObject *)&THPVariableType); torch::autograd::initTorchFunctions(module); + torch::autograd::initTensorImplConversion(module); return true; } -- 2.7.4