From: Emilio Castillo Date: Mon, 13 Sep 2021 02:45:57 +0000 (-0700) Subject: Adds DLPack support (#57110) X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~274 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=1cb3507ed38330af43bf23e255dc64e9215384c4;p=platform%2Fupstream%2Fpytorch.git Adds DLPack support (#57110) Summary: Partially Fixes https://github.com/pytorch/pytorch/issues/55090 Depends on https://github.com/pytorch/pytorch/issues/55365 Inspired by https://github.com/dmlc/dlpack/issues/57#issuecomment-774482973 Questions, in PyTorch we can't create streams or easily synchronize them from just an integer. Should we add an [`ExternalStream`](https://docs.cupy.dev/en/stable/reference/generated/cupy.cuda.ExternalStream.html) object like the one we have in CuPy? TODO: Add tests Would like some feedback as this design needs quite a few iterations rgommers leofang Pull Request resolved: https://github.com/pytorch/pytorch/pull/57110 Reviewed By: saketh-are Differential Revision: D30761481 Pulled By: mruberry fbshipit-source-id: e85d78df3c1f8defc2a698878da89cd843cb1209 --- diff --git a/aten/src/ATen/DLConvertor.cpp b/aten/src/ATen/DLConvertor.cpp index 73d0ec0..ed0fa7e 100644 --- a/aten/src/ATen/DLConvertor.cpp +++ b/aten/src/ATen/DLConvertor.cpp @@ -39,7 +39,7 @@ DLDataType getDLDataType(const Tensor& t) { dtype.code = DLDataTypeCode::kDLFloat; break; case ScalarType::Bool: - dtype.code = DLDataTypeCode::kDLUInt; + TORCH_CHECK(false, "Bool type is not supported by dlpack"); break; case ScalarType::ComplexHalf: dtype.code = DLDataTypeCode::kDLComplex; diff --git a/test/test_torch.py b/test/test_torch.py index ef76fc4..79a28ee 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -7104,16 +7104,103 @@ else: _test_helper(x, op, unary=True) @skipMeta - @dtypes(*get_all_dtypes()) - def test_dlpack_conversion(self, device, dtype): - # DLpack does not explicitly support bool - # It does it through uint8 type - if dtype is torch.bool: - return - x = make_tensor((5,), device, dtype, low=-9, high=9) + @onlyOnCPUAndCUDA + @dtypes(*get_all_dtypes(include_bool=False)) + def test_dlpack_capsule_conversion(self, device, dtype): + # DLpack does not explicitly support bool (xref dmlc/dlpack#75) + x = make_tensor((5,), device, dtype) z = from_dlpack(to_dlpack(x)) self.assertEqual(z, x) + @skipMeta + @onlyOnCPUAndCUDA + @dtypes(*get_all_dtypes(include_bool=False)) + def test_dlpack_protocol_conversion(self, device, dtype): + x = make_tensor((5,), device, dtype) + z = from_dlpack(x) + self.assertEqual(z, x) + + @skipMeta + @onlyOnCPUAndCUDA + def test_dlpack_shared_storage(self, device): + x = make_tensor((5,), device, torch.float64) + z = from_dlpack(to_dlpack(x)) + z[0] = z[0] + 20.0 + self.assertEqual(z, x) + + @skipMeta + @onlyCUDA + @dtypes(*get_all_dtypes(include_bool=False)) + def test_dlpack_conversion_with_streams(self, device, dtype): + # Create a stream where the tensor will reside + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + # Do an operation in the actual stream + x = make_tensor((5,), device, dtype) + 1 + # DLPack protocol helps establish a correct stream order + # (hence data dependency) at the exchange boundary. + # DLPack manages this synchronization for us, so we don't need to + # explicitly wait until x is populated + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + z = from_dlpack(x) + stream.synchronize() + self.assertEqual(z, x) + + @skipMeta + @onlyCUDA + @dtypes(*get_all_dtypes(include_bool=False)) + def test_dlpack_conversion_with_diff_streams(self, device, dtype): + from torch._C import _from_dlpack + stream_a = torch.cuda.Stream() + stream_b = torch.cuda.Stream() + # DLPack protocol helps establish a correct stream order + # (hence data dependency) at the exchange boundary. + # the `tensor.__dlpack__` method will insert a synchronization event + # in the current stream to make sure that it was correctly populated. + with torch.cuda.stream(stream_a): + x = make_tensor((5,), device, dtype) + 1 + z = _from_dlpack(x.__dlpack__(stream_b.cuda_stream)) + stream_a.synchronize() + stream_b.synchronize() + self.assertEqual(z, x) + + @skipMeta + @onlyOnCPUAndCUDA + @dtypes(*get_all_dtypes(include_bool=False)) + def test_dlpack_tensor_invalid_stream(self, device, dtype): + with self.assertRaises(TypeError): + x = make_tensor((5,), device, dtype) + x.__dlpack__(stream=object()) + + @skipMeta + def test_dlpack_error_on_bool_tensor(self): + x = torch.tensor([True], dtype=torch.bool) + with self.assertRaises(RuntimeError): + to_dlpack(x) + + # TODO: increase tests once NumPy supports the `__dlpack__` protocol + + @skipMeta + def test_dlpack_export_requires_grad(self): + x = torch.zeros(10, dtype=torch.float32, requires_grad=True) + with self.assertRaisesRegex(RuntimeError, r"require gradient"): + x.__dlpack__() + + @skipMeta + def test_dlpack_export_is_conj(self): + x = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]) + y = torch.conj(x) + with self.assertRaisesRegex(RuntimeError, r"conjugate bit"): + y.__dlpack__() + + @skipMeta + def test_dlpack_export_non_strided(self): + x = torch.sparse_coo_tensor([[0]], [1], size=(1,)) + y = torch.conj(x) + with self.assertRaisesRegex(RuntimeError, r"strided"): + y.__dlpack__() + @onlyCUDA @unittest.skipIf(PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property") def test_pin_memory_from_constructor(self, device): diff --git a/torch/__init__.py b/torch/__init__.py index c0bd2db..a910882 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -759,6 +759,8 @@ from ._vmap_internals import vmap as vmap quantized_lstm = torch.ops.aten.quantized_lstm quantized_gru = torch.ops.aten.quantized_gru +from torch.utils.dlpack import from_dlpack, to_dlpack + def _register_device_module(device_type, module): r"""Register an external runtime module of the specific :attr:`device_type` diff --git a/torch/_tensor.py b/torch/_tensor.py index e7bc4ed..2e4b7a3 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -1,4 +1,5 @@ from collections import OrderedDict +import enum import functools from numbers import Number from typing import Any, Dict, Optional, Tuple, Union @@ -1053,6 +1054,67 @@ class Tensor(torch._C._TensorBase): else: return _convert(ret, cls) + def __dlpack__(self, stream=None): + """ + Creates a DLpack `capsule https://data-apis.org/array-api/latest/design_topics/data_interchange.html#data-interchange`_ + of the current tensor to be exported to other libraries. + + This function will be called from the `from_dlpack` method + of the library that will consume the capsule. `from_dlpack` passes the current + stream to this method as part of the specification. + + Args: + stream (integer or None): An optional Python integer representing a + pointer to a CUDA stream. The current stream is synchronized with + this stream before the capsule is created, and since the capsule + shares its storage with the tensor this make it safe to access from + both streams. If None or -1 is passed then no synchronization is performed. + """ + if has_torch_function_unary(self): + return handle_torch_function(Tensor.__dlpack__, (self,), self, stream) + + # DLPack capsules can't capture all of PyTorch's semantics, + # so we prohibit exporting tensors that would lose their properties like + # requires_grad and having the conjugate bit set. + if self.requires_grad: + raise RuntimeError('Can\'t export tensors that require gradient, use tensor.detach()') + if self.is_conj(): + raise RuntimeError('Can\'t export tensors with the conjugate bit set') + if self.layout != torch.strided: + raise RuntimeError('Can\'t export tensors with layout other than torch.strided') + + if stream is not None and type(stream) is not int: + # Stream pointers in CUDA/ROCm are uniquely numbered and can + # be retrieved from their integer value. + raise TypeError('stream must be ``int`` or ``none``') + elif stream is not None and stream != -1: + if self.device.type == 'cuda': + stream = torch.cuda.streams.ExternalStream(stream) + # Only synchronize on different streams + if stream != torch.cuda.current_stream: + event = torch.cuda.Event() + event.record(torch.cuda.current_stream()) + stream.wait_event(event) + return torch.to_dlpack(self) + + def __dlpack_device__(self) -> Tuple[enum.IntEnum, int]: + # Avoid circular import + from torch.utils.dlpack import DLDeviceType + if has_torch_function_unary(self): + return handle_torch_function(Tensor.__dlpack_device__, (self,), self) + idx = self.device.index if self.device.index is not None else 0 + if self.device.type == 'cuda' and torch.version.hip is not None: + device_type = DLDeviceType.kDLROCM + elif self.device.type == 'cpu' and self.is_pinned(): + device_type = DLDeviceType.kDLCPUPinned + elif self.device.type == 'cuda': + device_type = DLDeviceType.kDLGPU + elif self.device.type == 'cpu': + device_type = DLDeviceType.kDLCPU + else: + raise ValueError('Unknown device type {} for Dlpack'.format(self.device.type)) + return (device_type, idx) + __module__ = 'torch' def _convert(ret, cls): diff --git a/torch/overrides.py b/torch/overrides.py index c574109..6c545a0 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -1161,6 +1161,8 @@ def get_testing_overrides() -> Dict[Callable, Callable]: Tensor.view: lambda self, shape: -1, Tensor.view_as: lambda self, other: -1, Tensor.zero_: lambda self: -1, + Tensor.__dlpack__: lambda self, stream=None: -1, + Tensor.__dlpack_device__: lambda self: -1, torch.linalg.lstsq: lambda self, b, cond=None, driver=None: -1, } diff --git a/torch/utils/dlpack.py b/torch/utils/dlpack.py index af516ec..06424be 100644 --- a/torch/utils/dlpack.py +++ b/torch/utils/dlpack.py @@ -1,19 +1,24 @@ +from typing import Any + import torch +import enum -from torch._C import _from_dlpack as from_dlpack +from torch._C import _from_dlpack from torch._C import _to_dlpack as to_dlpack -torch._C._add_docstr(from_dlpack, r"""from_dlpack(dlpack) -> Tensor -Decodes a DLPack to a tensor. +class DLDeviceType(enum.IntEnum): + # Enums as in DLPack specification (aten/src/ATen/dlpack.h) + kDLCPU = 1, + kDLGPU = 2, + kDLCPUPinned = 3, + kDLOpenCL = 4, + kDLVulkan = 7, + kDLMetal = 8, + kDLVPI = 9, + kDLROCM = 10, + kDLExtDev = 12, -Args: - dlpack: a PyCapsule object with the dltensor - -The tensor will share the memory with the object represented -in the dlpack. -Note that each dlpack can only be consumed once. -""") torch._C._add_docstr(to_dlpack, r"""to_dlpack(tensor) -> PyCapsule @@ -22,6 +27,41 @@ Returns a DLPack representing the tensor. Args: tensor: a tensor to be exported -The dlpack shares the tensors memory. -Note that each dlpack can only be consumed once. +The DLPack shares the tensors memory. +Note that each DLPack can only be consumed once. """) + +# TODO: add a typing.Protocol to be able to tell Mypy that only objects with +# __dlpack__ and __dlpack_device__ methods are accepted. +def from_dlpack(ext_tensor: Any) -> torch.Tensor: + """from_dlpack(ext_tensor) -> Tensor + + Convers a tensor from a external library into a ``torch.Tensor`` + by means of the ``__dlpack__`` protocol. + + The tensor will share the memory with the object represented + in the DLPack. + + .. warning:: + Only call from_dlpack once per capsule. Its behavior when used + on the same capsule multiple times is undefined. + + Args: + ext_tensor (object with __dlpack__ attribute or DLPack capsule): + The tensor or DLPack capsule to convert. + """ + if hasattr(ext_tensor, '__dlpack__'): + device = ext_tensor.__dlpack_device__() + # device is either CUDA or ROCm, we need to pass the current + # stream + if device[0] in (DLDeviceType.kDLGPU, DLDeviceType.kDLROCM): + stream = torch.cuda.current_stream('cuda:{}'.format(device[1])) + # cuda_stream is the pointer to the stream and it is a public + # attribute, but it is not documented + dlpack = ext_tensor.__dlpack__(stream=stream.cuda_stream) + else: + dlpack = ext_tensor.__dlpack__() + else: + # Old versions just call the converter + dlpack = ext_tensor + return _from_dlpack(dlpack)