[MLIR][doc] Improve/fix the doc on mlir.vector.transfer_read (NFC)
authorMehdi Amini <joker.eph@gmail.com>
Wed, 31 May 2023 20:28:44 +0000 (13:28 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Thu, 1 Jun 2023 18:45:05 +0000 (11:45 -0700)
This doc was written 4 years ago, some refresh in the example was
overdue I suspect.

Differential Revision: https://reviews.llvm.org/D151037

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

index d783a24..22b60d8 100644 (file)
@@ -1246,8 +1246,9 @@ def Vector_TransferReadOp :
     ```
 
     This operation always reads a slice starting at `%A[%expr1, %expr2, %expr3,
-    %expr4]`. The size of the slice is 3 along d2 and 5 along d0, so the slice
-    is: `%A[%expr1 : %expr1 + 5, %expr2, %expr3:%expr3 + 3, %expr4]`
+    %expr4]`. The size of the slice can be inferred from the resulting vector
+    shape and walking back through the permutation map: 3 along d2 and 5 along
+    d0, so the slice is: `%A[%expr1 : %expr1 + 5, %expr2, %expr3:%expr3 + 3, %expr4]`
 
     That slice needs to be read into a `vector<3x4x5xf32>`. Since the
     permutation map is not full rank, there must be a broadcast along vector
@@ -1257,44 +1258,52 @@ def Vector_TransferReadOp :
 
     ```mlir
     // %expr1, %expr2, %expr3, %expr4 defined before this point
-    %tmp = alloc() : vector<3x4x5xf32>
-    %view_in_tmp = "element_type_cast"(%tmp) : memref<1xvector<3x4x5xf32>>
+    // alloc a temporary buffer for performing the "gather" of the slice.
+    %tmp = memref.alloc() : memref<vector<3x4x5xf32>>
     for %i = 0 to 3 {
       affine.for %j = 0 to 4 {
         affine.for %k = 0 to 5 {
-          %a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] :
-            memref<?x?x?x?xf32>
-          store %tmp[%i, %j, %k] : vector<3x4x5xf32>
+          // Note that this load does not involve %j.
+          %a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] : memref<?x?x?x?xf32>
+          // Update the temporary gathered slice with the individual element
+          %slice = memref.load %tmp : memref<vector<3x4x5xf32>> -> vector<3x4x5xf32>
+          %updated = vector.insert %a, %slice[%i, %j, %k] : f32 into vector<3x4x5xf32>
+          memref.store %updated, %temp : memref<vector<3x4x5xf32>>
     }}}
-    %c0 = arith.constant 0 : index
-    %vec = load %view_in_tmp[%c0] : vector<3x4x5xf32>
+    // At this point we gathered the elements from the original
+    // memref into the desired vector layout, stored in the `%tmp` allocation.
+    %vec = memref.load %tmp : memref<vector<3x4x5xf32>> -> vector<3x4x5xf32>
     ```
 
     On a GPU one could then map `i`, `j`, `k` to blocks and threads. Notice that
-    the temporary storage footprint is `3 * 5` values but `3 * 4 * 5` values are
-    actually transferred between `%A` and `%tmp`.
+    the temporary storage footprint could conceptually be only `3 * 5` values but
+    `3 * 4 * 5` values are actually transferred between `%A` and `%tmp`.
 
-    Alternatively, if a notional vector broadcast operation were available, the
-    lowered code would resemble:
+    Alternatively, if a notional vector broadcast operation were available, we
+    could avoid the loop on `%j` and the lowered code would resemble:
 
     ```mlir
     // %expr1, %expr2, %expr3, %expr4 defined before this point
-    %tmp = alloc() : vector<3x4x5xf32>
-    %view_in_tmp = "element_type_cast"(%tmp) : memref<1xvector<3x4x5xf32>>
+    %tmp = memref.alloc() : memref<vector<3x4x5xf32>>
     for %i = 0 to 3 {
       affine.for %k = 0 to 5 {
-        %a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] :
-          memref<?x?x?x?xf32>
-        store %tmp[%i, 0, %k] : vector<3x4x5xf32>
+        %a = load %A[%expr1 + %k, %expr2, %expr3 + %i, %expr4] : memref<?x?x?x?xf32>
+        %slice = memref.load %tmp : memref<vector<3x4x5xf32>> -> vector<3x4x5xf32>
+        // Here we only store to the first element in dimension one
+        %updated = vector.insert %a, %slice[%i, 0, %k] : f32 into vector<3x4x5xf32>
+        memref.store %updated, %temp : memref<vector<3x4x5xf32>>
     }}
-    %c0 = arith.constant 0 : index
-    %tmpvec = load %view_in_tmp[%c0] : vector<3x4x5xf32>
+    // At this point we gathered the elements from the original
+    // memref into the desired vector layout, stored in the `%tmp` allocation.
+    // However we haven't replicated them alongside the first dimension, we need
+    // to broadcast now.
+    %partialVec = load %tmp : memref<vector<3x4x5xf32>> -> vector<3x4x5xf32>
     %vec = broadcast %tmpvec, 1 : vector<3x4x5xf32>
     ```
 
     where `broadcast` broadcasts from element 0 to all others along the
-    specified dimension. This time, the temporary storage footprint is `3 * 5`
-    values which is the same amount of data as the `3 * 5` values transferred.
+    specified dimension. This time, the number of loaded element is `3 * 5`
+    values.
     An additional `1` broadcast is required. On a GPU this broadcast could be
     implemented using a warp-shuffle if loop `j` were mapped to `threadIdx.x`.
 
@@ -1310,7 +1319,7 @@ def Vector_TransferReadOp :
     // Read the slice `%A[%i0, %i1:%i1+256, %i2:%i2+32]` into vector<32x256xf32>
     // and pad with %f0 to handle the boundary case:
     %f0 = arith.constant 0.0f : f32
-    for %i0 = 0 to %0 {
+    affine.for %i0 = 0 to %0 {
       affine.for %i1 = 0 to %1 step 256 {
         affine.for %i2 = 0 to %2 step 32 {
           %v = vector.transfer_read %A[%i0, %i1, %i2], (%f0)
@@ -1320,7 +1329,7 @@ def Vector_TransferReadOp :
 
     // or equivalently (rewrite with vector.transpose)
     %f0 = arith.constant 0.0f : f32
-    for %i0 = 0 to %0 {
+    affine.for %i0 = 0 to %0 {
       affine.for %i1 = 0 to %1 step 256 {
         affine.for %i2 = 0 to %2 step 32 {
           %v0 = vector.transfer_read %A[%i0, %i1, %i2], (%f0)
@@ -1333,7 +1342,7 @@ def Vector_TransferReadOp :
     // Read the slice `%A[%i0, %i1]` (i.e. the element `%A[%i0, %i1]`) into
     // vector<128xf32>. The underlying implementation will require a 1-D vector
     // broadcast:
-    for %i0 = 0 to %0 {
+    affine.for %i0 = 0 to %0 {
       affine.for %i1 = 0 to %1 {
         %3 = vector.transfer_read %A[%i0, %i1]
              {permutation_map: (d0, d1) -> (0)} :