Adds DLPack support (#57110)
authorEmilio Castillo <ecastill@preferred.jp>
Mon, 13 Sep 2021 02:45:57 +0000 (19:45 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 13 Sep 2021 02:47:15 +0000 (19:47 -0700)
Summary:
Partially Fixes https://github.com/pytorch/pytorch/issues/55090
Depends on https://github.com/pytorch/pytorch/issues/55365

Inspired by https://github.com/dmlc/dlpack/issues/57#issuecomment-774482973

Questions, in PyTorch we can't create streams or easily synchronize them from just an integer. Should we add an [`ExternalStream`](https://docs.cupy.dev/en/stable/reference/generated/cupy.cuda.ExternalStream.html) object like the one we have in CuPy?

TODO: Add tests

Would like some feedback as this design needs quite a few iterations
rgommers leofang

Pull Request resolved: https://github.com/pytorch/pytorch/pull/57110

Reviewed By: saketh-are

Differential Revision: D30761481

Pulled By: mruberry

fbshipit-source-id: e85d78df3c1f8defc2a698878da89cd843cb1209

aten/src/ATen/DLConvertor.cpp
test/test_torch.py
torch/__init__.py
torch/_tensor.py
torch/overrides.py
torch/utils/dlpack.py

index 73d0ec0..ed0fa7e 100644 (file)
@@ -39,7 +39,7 @@ DLDataType getDLDataType(const Tensor& t) {
       dtype.code = DLDataTypeCode::kDLFloat;
       break;
     case ScalarType::Bool:
-      dtype.code = DLDataTypeCode::kDLUInt;
+      TORCH_CHECK(false, "Bool type is not supported by dlpack");
       break;
     case ScalarType::ComplexHalf:
       dtype.code = DLDataTypeCode::kDLComplex;
index ef76fc4..79a28ee 100644 (file)
@@ -7104,16 +7104,103 @@ else:
                 _test_helper(x, op, unary=True)
 
     @skipMeta
-    @dtypes(*get_all_dtypes())
-    def test_dlpack_conversion(self, device, dtype):
-        # DLpack does not explicitly support bool
-        # It does it through uint8 type
-        if dtype is torch.bool:
-            return
-        x = make_tensor((5,), device, dtype, low=-9, high=9)
+    @onlyOnCPUAndCUDA
+    @dtypes(*get_all_dtypes(include_bool=False))
+    def test_dlpack_capsule_conversion(self, device, dtype):
+        # DLpack does not explicitly support bool (xref dmlc/dlpack#75)
+        x = make_tensor((5,), device, dtype)
         z = from_dlpack(to_dlpack(x))
         self.assertEqual(z, x)
 
+    @skipMeta
+    @onlyOnCPUAndCUDA
+    @dtypes(*get_all_dtypes(include_bool=False))
+    def test_dlpack_protocol_conversion(self, device, dtype):
+        x = make_tensor((5,), device, dtype)
+        z = from_dlpack(x)
+        self.assertEqual(z, x)
+
+    @skipMeta
+    @onlyOnCPUAndCUDA
+    def test_dlpack_shared_storage(self, device):
+        x = make_tensor((5,), device, torch.float64)
+        z = from_dlpack(to_dlpack(x))
+        z[0] = z[0] + 20.0
+        self.assertEqual(z, x)
+
+    @skipMeta
+    @onlyCUDA
+    @dtypes(*get_all_dtypes(include_bool=False))
+    def test_dlpack_conversion_with_streams(self, device, dtype):
+        # Create a stream where the tensor will reside
+        stream = torch.cuda.Stream()
+        with torch.cuda.stream(stream):
+            # Do an operation in the actual stream
+            x = make_tensor((5,), device, dtype) + 1
+        # DLPack protocol helps establish a correct stream order
+        # (hence data dependency) at the exchange boundary.
+        # DLPack manages this synchronization for us, so we don't need to
+        # explicitly wait until x is populated
+        stream = torch.cuda.Stream()
+        with torch.cuda.stream(stream):
+            z = from_dlpack(x)
+        stream.synchronize()
+        self.assertEqual(z, x)
+
+    @skipMeta
+    @onlyCUDA
+    @dtypes(*get_all_dtypes(include_bool=False))
+    def test_dlpack_conversion_with_diff_streams(self, device, dtype):
+        from torch._C import _from_dlpack
+        stream_a = torch.cuda.Stream()
+        stream_b = torch.cuda.Stream()
+        # DLPack protocol helps establish a correct stream order
+        # (hence data dependency) at the exchange boundary.
+        # the `tensor.__dlpack__` method will insert a synchronization event
+        # in the current stream to make sure that it was correctly populated.
+        with torch.cuda.stream(stream_a):
+            x = make_tensor((5,), device, dtype) + 1
+            z = _from_dlpack(x.__dlpack__(stream_b.cuda_stream))
+            stream_a.synchronize()
+        stream_b.synchronize()
+        self.assertEqual(z, x)
+
+    @skipMeta
+    @onlyOnCPUAndCUDA
+    @dtypes(*get_all_dtypes(include_bool=False))
+    def test_dlpack_tensor_invalid_stream(self, device, dtype):
+        with self.assertRaises(TypeError):
+            x = make_tensor((5,), device, dtype)
+            x.__dlpack__(stream=object())
+
+    @skipMeta
+    def test_dlpack_error_on_bool_tensor(self):
+        x = torch.tensor([True], dtype=torch.bool)
+        with self.assertRaises(RuntimeError):
+            to_dlpack(x)
+
+    # TODO: increase tests once NumPy supports the `__dlpack__` protocol
+
+    @skipMeta
+    def test_dlpack_export_requires_grad(self):
+        x = torch.zeros(10, dtype=torch.float32, requires_grad=True)
+        with self.assertRaisesRegex(RuntimeError, r"require gradient"):
+            x.__dlpack__()
+
+    @skipMeta
+    def test_dlpack_export_is_conj(self):
+        x = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j])
+        y = torch.conj(x)
+        with self.assertRaisesRegex(RuntimeError, r"conjugate bit"):
+            y.__dlpack__()
+
+    @skipMeta
+    def test_dlpack_export_non_strided(self):
+        x = torch.sparse_coo_tensor([[0]], [1], size=(1,))
+        y = torch.conj(x)
+        with self.assertRaisesRegex(RuntimeError, r"strided"):
+            y.__dlpack__()
+
     @onlyCUDA
     @unittest.skipIf(PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property")
     def test_pin_memory_from_constructor(self, device):
index c0bd2db..a910882 100644 (file)
@@ -759,6 +759,8 @@ from ._vmap_internals import vmap as vmap
 quantized_lstm = torch.ops.aten.quantized_lstm
 quantized_gru = torch.ops.aten.quantized_gru
 
+from torch.utils.dlpack import from_dlpack, to_dlpack
+
 
 def _register_device_module(device_type, module):
     r"""Register an external runtime module of the specific :attr:`device_type`
index e7bc4ed..2e4b7a3 100644 (file)
@@ -1,4 +1,5 @@
 from collections import OrderedDict
+import enum
 import functools
 from numbers import Number
 from typing import Any, Dict, Optional, Tuple, Union
@@ -1053,6 +1054,67 @@ class Tensor(torch._C._TensorBase):
             else:
                 return _convert(ret, cls)
 
+    def __dlpack__(self, stream=None):
+        """
+        Creates a DLpack `capsule https://data-apis.org/array-api/latest/design_topics/data_interchange.html#data-interchange`_
+        of the current tensor to be exported to other libraries.
+
+        This function will be called from the `from_dlpack` method
+        of the library that will consume the capsule. `from_dlpack` passes the current
+        stream to this method as part of the specification.
+
+        Args:
+            stream (integer or None): An optional Python integer representing a
+            pointer to a CUDA stream. The current stream is synchronized with
+            this stream before the capsule is created, and since the capsule
+            shares its storage with the tensor this make it safe to access from
+            both streams.  If None or -1 is passed then no synchronization is performed.
+        """
+        if has_torch_function_unary(self):
+            return handle_torch_function(Tensor.__dlpack__, (self,), self, stream)
+
+        # DLPack capsules can't capture all of PyTorch's semantics,
+        # so we prohibit exporting tensors that would lose their properties like
+        # requires_grad and having the conjugate bit set.
+        if self.requires_grad:
+            raise RuntimeError('Can\'t export tensors that require gradient, use tensor.detach()')
+        if self.is_conj():
+            raise RuntimeError('Can\'t export tensors with the conjugate bit set')
+        if self.layout != torch.strided:
+            raise RuntimeError('Can\'t export tensors with layout other than torch.strided')
+
+        if stream is not None and type(stream) is not int:
+            # Stream pointers in CUDA/ROCm are uniquely numbered and can
+            # be retrieved from their integer value.
+            raise TypeError('stream must be ``int`` or ``none``')
+        elif stream is not None and stream != -1:
+            if self.device.type == 'cuda':
+                stream = torch.cuda.streams.ExternalStream(stream)
+                # Only synchronize on different streams
+                if stream != torch.cuda.current_stream:
+                    event = torch.cuda.Event()
+                    event.record(torch.cuda.current_stream())
+                    stream.wait_event(event)
+        return torch.to_dlpack(self)
+
+    def __dlpack_device__(self) -> Tuple[enum.IntEnum, int]:
+        # Avoid circular import
+        from torch.utils.dlpack import DLDeviceType
+        if has_torch_function_unary(self):
+            return handle_torch_function(Tensor.__dlpack_device__, (self,), self)
+        idx = self.device.index if self.device.index is not None else 0
+        if self.device.type == 'cuda' and torch.version.hip is not None:
+            device_type = DLDeviceType.kDLROCM
+        elif self.device.type == 'cpu' and self.is_pinned():
+            device_type = DLDeviceType.kDLCPUPinned
+        elif self.device.type == 'cuda':
+            device_type = DLDeviceType.kDLGPU
+        elif self.device.type == 'cpu':
+            device_type = DLDeviceType.kDLCPU
+        else:
+            raise ValueError('Unknown device type {} for Dlpack'.format(self.device.type))
+        return (device_type, idx)
+
     __module__ = 'torch'
 
 def _convert(ret, cls):
index c574109..6c545a0 100644 (file)
@@ -1161,6 +1161,8 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
         Tensor.view: lambda self, shape: -1,
         Tensor.view_as: lambda self, other: -1,
         Tensor.zero_: lambda self: -1,
+        Tensor.__dlpack__: lambda self, stream=None: -1,
+        Tensor.__dlpack_device__: lambda self: -1,
         torch.linalg.lstsq: lambda self, b, cond=None, driver=None: -1,
     }
 
index af516ec..06424be 100644 (file)
@@ -1,19 +1,24 @@
+from typing import Any
+
 import torch
+import enum
 
-from torch._C import _from_dlpack as from_dlpack
+from torch._C import _from_dlpack
 from torch._C import _to_dlpack as to_dlpack
 
-torch._C._add_docstr(from_dlpack, r"""from_dlpack(dlpack) -> Tensor
 
-Decodes a DLPack to a tensor.
+class DLDeviceType(enum.IntEnum):
+    # Enums as in DLPack specification (aten/src/ATen/dlpack.h)
+    kDLCPU = 1,
+    kDLGPU = 2,
+    kDLCPUPinned = 3,
+    kDLOpenCL = 4,
+    kDLVulkan = 7,
+    kDLMetal = 8,
+    kDLVPI = 9,
+    kDLROCM = 10,
+    kDLExtDev = 12,
 
-Args:
-    dlpack: a PyCapsule object with the dltensor
-
-The tensor will share the memory with the object represented
-in the dlpack.
-Note that each dlpack can only be consumed once.
-""")
 
 torch._C._add_docstr(to_dlpack, r"""to_dlpack(tensor) -> PyCapsule
 
@@ -22,6 +27,41 @@ Returns a DLPack representing the tensor.
 Args:
     tensor: a tensor to be exported
 
-The dlpack shares the tensors memory.
-Note that each dlpack can only be consumed once.
+The DLPack shares the tensors memory.
+Note that each DLPack can only be consumed once.
 """)
+
+# TODO: add a typing.Protocol to be able to tell Mypy that only objects with
+# __dlpack__ and __dlpack_device__ methods are accepted.
+def from_dlpack(ext_tensor: Any) -> torch.Tensor:
+    """from_dlpack(ext_tensor) -> Tensor
+
+    Convers a tensor from a external library into a ``torch.Tensor``
+    by means of the ``__dlpack__`` protocol.
+
+    The tensor will share the memory with the object represented
+    in the DLPack.
+
+    .. warning::
+      Only call from_dlpack once per capsule. Its behavior when used
+      on the same capsule multiple times is undefined.
+
+    Args:
+        ext_tensor (object with __dlpack__ attribute or DLPack capsule):
+            The tensor or DLPack capsule to convert.
+    """
+    if hasattr(ext_tensor, '__dlpack__'):
+        device = ext_tensor.__dlpack_device__()
+        # device is either CUDA or ROCm, we need to pass the current
+        # stream
+        if device[0] in (DLDeviceType.kDLGPU, DLDeviceType.kDLROCM):
+            stream = torch.cuda.current_stream('cuda:{}'.format(device[1]))
+            # cuda_stream is the pointer to the stream and it is a public
+            # attribute, but it is not documented
+            dlpack = ext_tensor.__dlpack__(stream=stream.cuda_stream)
+        else:
+            dlpack = ext_tensor.__dlpack__()
+    else:
+        # Old versions just call the converter
+        dlpack = ext_tensor
+    return _from_dlpack(dlpack)