Fix to handle null strides in DLPack tensor (#18510)
authorBram Wasti <bwasti@fb.com>
Thu, 4 Apr 2019 07:24:16 +0000 (00:24 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 4 Apr 2019 07:28:13 +0000 (00:28 -0700)
Summary:
DLPack can have non-strided tensors, which is represented by a nullptr in the place of dl_tensor.strides.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18510

Differential Revision: D14647328

Pulled By: bwasti

fbshipit-source-id: 5364282810a5772cfc2319fc8133fe86fdd84dd1

aten/src/ATen/DLConvertor.cpp
aten/src/ATen/test/dlconvertor_test.cpp

index e2983eb..a40e872 100644 (file)
@@ -178,6 +178,13 @@ Tensor fromDLPack(const DLManagedTensor* src) {
   auto deleter = [src](void* self) {
     src->deleter(const_cast<DLManagedTensor*>(src));
   };
+  if (!src->dl_tensor.strides) {
+    return at::from_blob(src->dl_tensor.data,
+        IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim),
+        deleter,
+        at::device(device_type).dtype(stype));
+  }
+
   return at::from_blob(
       src->dl_tensor.data,
       IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim),
index e2c7a62..4e5eb93 100644 (file)
@@ -18,3 +18,15 @@ TEST(TestDlconvertor, TestDlconvertor) {
 
   ASSERT_TRUE(a.equal(b));
 }
+
+TEST(TestDlconvertor, TestDlconvertorNoStrides) {
+  manual_seed(123);
+
+  Tensor a = rand({3, 4});
+  DLManagedTensor* dlMTensor = toDLPack(a);
+  dlMTensor->dl_tensor.strides = nullptr;
+
+  Tensor b = fromDLPack(dlMTensor);
+
+  ASSERT_TRUE(a.equal(b));
+}