From 0676ba0c5cdc51d882337a46dcac3d60c885d4c4 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Cl=C3=A9ment=20Pinard?= Date: Fri, 19 Apr 2019 07:17:09 -0700 Subject: [PATCH] Mention packed accessors in tensor basics doc (#19464) Summary: This is a continuation of efforts into packed accessor awareness. A very simple example is added, along with the mention that the template can hold more arguments. Pull Request resolved: https://github.com/pytorch/pytorch/pull/19464 Differential Revision: D15012564 Pulled By: soumith fbshipit-source-id: a19ed536e016fae519b062d847cc58aef01b1b92 --- docs/cpp/source/notes/tensor_basics.rst | 44 ++++++++++++++++++++++++++++++--- 1 file changed, 40 insertions(+), 4 deletions(-) diff --git a/docs/cpp/source/notes/tensor_basics.rst b/docs/cpp/source/notes/tensor_basics.rst index 5872852..8c9ff46 100644 --- a/docs/cpp/source/notes/tensor_basics.rst +++ b/docs/cpp/source/notes/tensor_basics.rst @@ -40,7 +40,21 @@ small. However, there are cases, especially in your own kernels, where efficient element-wise access is needed, and the cost of dynamic dispatch inside the element-wise loop is very high. ATen provides *accessors* that are created with a single dynamic check that a Tensor is the type and number of dimensions. -Accessors then expose an API for accessing the Tensor elements efficiently: +Accessors then expose an API for accessing the Tensor elements efficiently. + +Accessors are temporary views of a Tensor. They are only valid for the lifetime +of the tensor that they view and hence should only be used locally in a +function, like iterators. + +Note that accessors are not compatible with CUDA tensors inside kernel functions. +Instead, you will have to use a *packed accessor* which behaves the same way but +copies tensor metadata instead of pointing to it. + +It is thus recommended to use *accessors* for CPU tensors and *packed accessors* +for CUDA tensors. + +CPU accessors +************* .. code-block:: cpp @@ -55,9 +69,31 @@ Accessors then expose an API for accessing the Tensor elements efficiently: trace += foo_a[i][i]; } -Accessors are temporary views of a Tensor. They are only valid for the lifetime -of the tensor that they view and hence should only be used locally in a -function, like iterators. +CUDA accessors +************** + + +.. code-block:: cpp + + __global__ void packed_accessor_kernel( + PackedTensorAccessor foo, + float* trace) { + int i=threadIdx.x + atomicAdd(trace, foo[i](i]) + } + + torch::Tensor foo = torch::rand({12, 12}); + + // assert foo is 2-dimensional and holds floats. + auto foo_a = foo.packed_accessor(); + float trace = 0; + + packed_accessor_kernel<<<1, 12>>>(foo_a, &trace); + +Note that the template can hold other parameters such as the pointer restriction +and the integer type for indexing. See documentation for a thorough template +description of *accessors* and *packed accessors*. + Using Externally Created Data ----------------------------- -- 2.7.4