test basic tensor interop
authorRoy Li <royboy@fb.com>
Fri, 28 Dec 2018 01:01:19 +0000 (17:01 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 28 Dec 2018 01:04:00 +0000 (17:04 -0800)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/12249

Differential Revision: D13469356

Pulled By: li-roy

fbshipit-source-id: b49748462aa44ac34b8ce79783f2c895a537a232

aten/src/ATen/test/CMakeLists.txt
aten/src/ATen/test/tensor_interop_test.cpp [new file with mode: 0644]
aten/tools/run_tests.sh
caffe2/core/tensor.h

index 37f8cf3..d3f0d01 100644 (file)
@@ -15,6 +15,7 @@ list(APPEND ATen_CPU_TEST_SRCS
   ${CMAKE_CURRENT_SOURCE_DIR}/dlconvertor_test.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/native_test.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/scalar_tensor_test.cpp
+  ${CMAKE_CURRENT_SOURCE_DIR}/tensor_interop_test.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/test_parallel.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/undefined_tensor_test.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/verify_api_visibility.cpp
diff --git a/aten/src/ATen/test/tensor_interop_test.cpp b/aten/src/ATen/test/tensor_interop_test.cpp
new file mode 100644 (file)
index 0000000..ec3886b
--- /dev/null
@@ -0,0 +1,141 @@
+#include "gtest/gtest.h"
+
+#include "ATen/ATen.h"
+#include <caffe2/core/init.h>
+#include <caffe2/core/operator.h>
+
+TEST(TestTensorInterop, Caffe2ToPytorchSimpleLegacy) {
+  caffe2::Tensor c2_tensor(caffe2::CPU);
+  c2_tensor.Resize(4, 4);
+  auto data = c2_tensor.mutable_data<int64_t>();
+  for (int64_t i = 0; i < 16; i++) {
+    data[i] = i;
+  }
+
+  // TODO: find out why calling data on tensor doesn't work
+  at::Tensor at_tensor(c2_tensor.getIntrusivePtr());
+  at::TensorImpl* impl = at_tensor.unsafeGetTensorImpl();
+
+  auto it = impl->data<int64_t>();
+  for (int64_t i = 0; i < 16; i++) {
+    ASSERT_EQ(it[i], i);
+  }
+}
+
+TEST(TestTensorInterop, Caffe2ToPytorchSimple) {
+  caffe2::Tensor c2_tensor = caffe2::empty({4, 4}, at::kLong);
+  auto data = c2_tensor.mutable_data<int64_t>();
+  for (int64_t i = 0; i < 16; i++) {
+    data[i] = i;
+  }
+  at::Tensor at_tensor(c2_tensor.getIntrusivePtr());
+  at::TensorImpl* impl = at_tensor.unsafeGetTensorImpl();
+
+  auto it = impl->data<int64_t>();
+  for (int64_t i = 0; i < 16; i++) {
+    ASSERT_EQ(it[i], i);
+  }
+}
+
+TEST(TestTensorInterop, Caffe2ToPytorchOp) {
+  caffe2::Tensor c2_tensor(caffe2::CPU);
+  c2_tensor.Resize(3, 3);
+  auto data = c2_tensor.mutable_data<int64_t>();
+  for (int64_t i = 0; i < 9; i++) {
+    data[i] = i;
+  }
+  at::Tensor at_tensor(c2_tensor.getIntrusivePtr());
+
+  ASSERT_EQ(at::sum(at_tensor).item<int64_t>(), 36);
+}
+
+TEST(TestTensorInterop, Caffe2ToPytorchUnsupportedDevice) {
+  caffe2::Tensor c2_tensor(caffe2::IDEEP);
+  at::Tensor at_tensor(c2_tensor.getIntrusivePtr());
+  ASSERT_ANY_THROW(at::sum(at_tensor));
+}
+
+TEST(TestTensorInterop, PytorchToCaffe2Op) {
+  caffe2::Workspace workspace;
+  caffe2::NetDef net;
+
+  auto at_tensor_a = at::ones({5, 5}, at::dtype(at::kFloat));
+  auto at_tensor_b = at::ones({5, 5}, at::dtype(at::kFloat));
+  auto at_tensor_c = at::ones({5, 5}, at::dtype(at::kFloat));
+
+  auto* c2_tensor_a = BlobSetTensor(workspace.CreateBlob("a"), at_tensor_a.getIntrusivePtr());
+  auto* c2_tensor_b = BlobSetTensor(workspace.CreateBlob("b"), at_tensor_b.getIntrusivePtr());
+
+  // Test ShareData as well
+  {
+    auto c2_tensor_c = XBlobGetMutableTensor(workspace.CreateBlob("c"), {0}, at::kCPU);
+    c2_tensor_c.ResizeLike(at_tensor_c.getIntrusivePtr());
+    c2_tensor_c.ShareData(at_tensor_c.getIntrusivePtr());
+  }
+
+  {
+    auto op = net.add_op();
+    op->set_type("Sum");
+    op->add_input("a");
+    op->add_input("b");
+    op->add_input("c");
+    op->add_output("d");
+  }
+
+  workspace.RunNetOnce(net);
+
+  auto result = XBlobGetMutableTensor(workspace.CreateBlob("d"), {5, 5}, at::kCPU);
+
+  auto it = result.data<float>();
+  for (int64_t i = 0; i < 25; i++) {
+    ASSERT_EQ(it[i], 3.0);
+  }
+  at::Tensor at_result(result.getIntrusivePtr());
+  ASSERT_EQ(at::sum(at_result).item<float>(), 75);
+}
+
+TEST(TestTensorInterop, PytorchToCaffe2SharedStorage) {
+  caffe2::Workspace workspace;
+  caffe2::NetDef net;
+
+  auto at_tensor_a = at::ones({5, 5}, at::dtype(at::kFloat));
+  auto at_tensor_b = at_tensor_a.view({5, 5});
+
+  auto* c2_tensor_a = BlobSetTensor(workspace.CreateBlob("a"), at_tensor_a.getIntrusivePtr());
+  auto* c2_tensor_b = BlobSetTensor(workspace.CreateBlob("b"), at_tensor_b.getIntrusivePtr());
+
+  {
+    auto op = net.add_op();
+    op->set_type("Add");
+    op->add_input("a");
+    op->add_input("b");
+    op->add_output("c");
+  }
+
+  workspace.RunNetOnce(net);
+
+  auto result = XBlobGetMutableTensor(workspace.CreateBlob("c"), {5, 5}, at::kCPU);
+  auto it = result.data<float>();
+  for (int64_t i = 0; i < 25; i++) {
+    ASSERT_EQ(it[i], 2.0);
+  }
+  at::Tensor at_result(result.getIntrusivePtr());
+  ASSERT_EQ(at::sum(at_result).item<float>(), 50);
+}
+
+TEST(TestTensorInterop, PytorchToCaffe2Strided) {
+  caffe2::Workspace workspace;
+  caffe2::NetDef net;
+
+  auto at_tensor = at::ones({5, 5}, at::dtype(at::kFloat)).t();
+  auto* c2_tensor = BlobSetTensor(workspace.CreateBlob("blob"), at_tensor.getIntrusivePtr());
+
+  {
+    auto op = net.add_op();
+    op->set_type("Sum");
+    op->add_input("blob");
+    op->add_output("out");
+  }
+
+  ASSERT_ANY_THROW(workspace.RunNetOnce(net));
+}
index 720afe4..c2a0d2f 100755 (executable)
@@ -15,6 +15,7 @@ VALGRIND=${VALGRIND:=ON}
 ./dlconvertor_test
 ./native_test
 ./scalar_tensor_test
+./tensor_interop_test
 ./undefined_tensor_test
 if [[ -x ./cudnn_test ]]; then
   ./cudnn_test
@@ -37,6 +38,7 @@ fi
 if [ "$VALGRIND" == "ON" ]
 then
   valgrind --suppressions="$VALGRIND_SUP" --error-exitcode=1 ./basic "[cpu]"
+  valgrind --suppressions="$VALGRIND_SUP" --error-exitcode=1 ./tensor_interop_test
 fi
 
 popd
index a91e854..feacc64 100644 (file)
@@ -28,6 +28,12 @@ class CAFFE2_API Tensor final {
 
  public:
   Tensor() : impl_() {}
+  Tensor(c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl)
+      : impl_(std::move(tensor_impl)) {
+    if (impl_.get() == nullptr) {
+      throw std::runtime_error("TensorBaseImpl with nullptr not supported");
+    }
+  }
 
   operator bool() const {
     return impl_.defined();