Summary:
DLPack can have non-strided tensors, which is represented by a nullptr in the place of dl_tensor.strides.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18510
Differential Revision:
D14647328
Pulled By: bwasti
fbshipit-source-id:
5364282810a5772cfc2319fc8133fe86fdd84dd1
auto deleter = [src](void* self) {
src->deleter(const_cast<DLManagedTensor*>(src));
};
+ if (!src->dl_tensor.strides) {
+ return at::from_blob(src->dl_tensor.data,
+ IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim),
+ deleter,
+ at::device(device_type).dtype(stype));
+ }
+
return at::from_blob(
src->dl_tensor.data,
IntArrayRef(src->dl_tensor.shape, src->dl_tensor.ndim),
ASSERT_TRUE(a.equal(b));
}
+
+TEST(TestDlconvertor, TestDlconvertorNoStrides) {
+ manual_seed(123);
+
+ Tensor a = rand({3, 4});
+ DLManagedTensor* dlMTensor = toDLPack(a);
+ dlMTensor->dl_tensor.strides = nullptr;
+
+ Tensor b = fromDLPack(dlMTensor);
+
+ ASSERT_TRUE(a.equal(b));
+}