[mlir] Handle strided 1D vector transfer ops in ProgressiveVectorToSCF
authorMatthias Springer <springerm@google.com>
Fri, 23 Apr 2021 08:18:26 +0000 (17:18 +0900)
committerMatthias Springer <springerm@google.com>
Fri, 23 Apr 2021 08:19:22 +0000 (17:19 +0900)
Strided 1D vector transfer ops are 1D transfers operating on a memref dimension different from the last one. Such transfer ops do not accesses contiguous memory blocks (vectors), but access memory in a strided fashion. In the absence of a mask, strided 1D vector transfer ops can also be lowered using matrix.column.major.* LLVM instructions (in a later commit).

Subsequent commits will extend the pass to handle the remaining missing permutation maps (broadcasts, transposes, etc.).

Differential Revision: https://reviews.llvm.org/D100946

mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp
mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir [new file with mode: 0644]

index d8d4690..785da3b 100644 (file)
@@ -93,36 +93,90 @@ static void getXferIndices(OpTy xferOp, Value iv,
   indices[dim] = adaptor.indices()[dim] + iv;
 }
 
-/// Generate an in-bounds check if the transfer op on the to-be-unpacked
-/// dimension may go out-of-bounds.
-template <typename OpTy>
-static void generateInBoundsCheck(
-    OpTy xferOp, Value iv, PatternRewriter &rewriter,
-    function_ref<void(OpBuilder &, Location)> inBoundsCase,
-    function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
-  // Corresponding memref dim of the vector dim that is unpacked.
-  auto dim = unpackedDim(xferOp);
+static void maybeYieldValue(bool hasRetVal, OpBuilder builder, Location loc,
+                            Value value) {
+  if (hasRetVal) {
+    builder.create<scf::YieldOp>(loc, value);
+  } else {
+    builder.create<scf::YieldOp>(loc);
+  }
+}
 
+/// Helper function TransferOpConversion and Strided1dTransferOpConversion.
+/// Generate an in-bounds check if the transfer op may go out-of-bounds on the
+/// specified dimension `dim` with the loop iteration variable `iv`.
+/// E.g., when unpacking dimension 0 from:
+/// ```
+/// %vec = vector.transfer_read %A[%a, %b] %cst
+///     : vector<5x4xf32>, memref<?x?xf32>
+/// ```
+/// An if check similar to this will be generated inside the loop:
+/// ```
+/// %d = memref.dim %A, %c0 : memref<?x?xf32>
+/// if (%a + iv < %d) {
+///   (in-bounds case)
+/// } else {
+///   (out-of-bounds case)
+/// }
+/// ```
+/// This function variant returns the value returned by `inBoundsCase` or
+/// `outOfBoundsCase`. The MLIR type of the return value must be specified in
+/// `resultTypes`.
+template <typename OpTy>
+static Value generateInBoundsCheck(
+    OpTy xferOp, Value iv, OpBuilder &builder, int64_t dim,
+    TypeRange resultTypes,
+    function_ref<Value(OpBuilder &, Location)> inBoundsCase,
+    function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
+  bool hasRetVal = !resultTypes.empty();
   if (!xferOp.isDimInBounds(0)) {
     auto memrefDim = memref_dim(xferOp.source(), std_constant_index(dim));
     using edsc::op::operator+;
     auto memrefIdx = xferOp.indices()[dim] + iv;
     auto cond = std_cmpi_sgt(memrefDim.value, memrefIdx);
-    rewriter.create<scf::IfOp>(
-        xferOp.getLoc(), cond,
+    auto check = builder.create<scf::IfOp>(
+        xferOp.getLoc(), resultTypes, cond,
+        /*thenBuilder=*/
         [&](OpBuilder &builder, Location loc) {
-          inBoundsCase(builder, loc);
-          builder.create<scf::YieldOp>(xferOp.getLoc());
+          maybeYieldValue(hasRetVal, builder, loc, inBoundsCase(builder, loc));
         },
+        /*elseBuilder=*/
         [&](OpBuilder &builder, Location loc) {
-          if (outOfBoundsCase)
-            outOfBoundsCase(builder, loc);
-          builder.create<scf::YieldOp>(xferOp.getLoc());
+          if (outOfBoundsCase) {
+            maybeYieldValue(hasRetVal, builder, loc,
+                            outOfBoundsCase(builder, loc));
+          } else {
+            builder.create<scf::YieldOp>(loc);
+          }
         });
-  } else {
-    // No runtime check needed if dim is guaranteed to be in-bounds.
-    inBoundsCase(rewriter, xferOp.getLoc());
+
+    return hasRetVal ? check.getResult(0) : Value();
   }
+
+  // No runtime check needed if dim is guaranteed to be in-bounds.
+  return inBoundsCase(builder, xferOp.getLoc());
+}
+
+/// In this function variant, `inBoundsCase` and `outOfBoundsCase` do not have
+/// a return value. Consequently, this function does not have a return value.
+template <typename OpTy>
+static void generateInBoundsCheck(
+    OpTy xferOp, Value iv, OpBuilder &builder, int64_t dim,
+    function_ref<void(OpBuilder &, Location)> inBoundsCase,
+    function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
+  generateInBoundsCheck(
+      xferOp, iv, builder, dim, /*resultTypes=*/TypeRange(),
+      /*inBoundsCase=*/
+      [&](OpBuilder &builder, Location loc) {
+        inBoundsCase(builder, loc);
+        return Value();
+      },
+      /*outOfBoundsCase=*/
+      [&](OpBuilder &builder, Location loc) {
+        if (outOfBoundsCase)
+          outOfBoundsCase(builder, loc);
+        return Value();
+      });
 }
 
 /// Given an ArrayAttr, return a copy where the first element is dropped.
@@ -442,7 +496,7 @@ struct TransferOpConversion : public OpRewritePattern<OpTy> {
             .value;
     affineLoopBuilder(lb, ub, 1, [&](Value iv) {
       generateInBoundsCheck(
-          xferOp, iv, rewriter,
+          xferOp, iv, rewriter, unpackedDim(xferOp),
           /*inBoundsCase=*/
           [&](OpBuilder & /*b*/, Location loc) {
             Strategy<OpTy>::rewriteOp(rewriter, xferOp, casted, iv);
@@ -458,6 +512,143 @@ struct TransferOpConversion : public OpRewritePattern<OpTy> {
   }
 };
 
+/// Compute the indices into the memref for the LoadOp/StoreOp generated as
+/// part of Strided1dTransferOpConversion. Return the memref dimension on which
+/// the transfer is operating.
+template <typename OpTy>
+static unsigned get1dMemrefIndices(OpTy xferOp, Value iv,
+                                   SmallVector<Value, 8> &memrefIndices) {
+  auto indices = xferOp.indices();
+  auto map = xferOp.permutation_map();
+
+  memrefIndices.append(indices.begin(), indices.end());
+  assert(map.getNumResults() == 1 &&
+         "Expected 1 permutation map result for 1D transfer");
+  // TODO: Handle broadcast
+  auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>();
+  assert(expr && "Expected AffineDimExpr in permutation map result");
+  auto dim = expr.getPosition();
+  using edsc::op::operator+;
+  memrefIndices[dim] = memrefIndices[dim] + iv;
+  return dim;
+}
+
+/// Codegen strategy for Strided1dTransferOpConversion, depending on the
+/// operation.
+template <typename OpTy>
+struct Strategy1d;
+
+/// Codegen strategy for TransferReadOp.
+template <>
+struct Strategy1d<TransferReadOp> {
+  static void generateForLoopBody(OpBuilder &builder, Location loc,
+                                  TransferReadOp xferOp, Value iv,
+                                  ValueRange loopState) {
+    SmallVector<Value, 8> indices;
+    auto dim = get1dMemrefIndices(xferOp, iv, indices);
+    auto ivI32 = std_index_cast(IntegerType::get(builder.getContext(), 32), iv);
+    auto vec = loopState[0];
+
+    // In case of out-of-bounds access, leave `vec` as is (was initialized with
+    // padding value).
+    auto nextVec = generateInBoundsCheck(
+        xferOp, iv, builder, dim, TypeRange(xferOp.getVectorType()),
+        /*inBoundsCase=*/
+        [&](OpBuilder & /*b*/, Location loc) {
+          auto val = memref_load(xferOp.source(), indices);
+          return vector_insert_element(val, vec, ivI32.value).value;
+        },
+        /*outOfBoundsCase=*/
+        [&](OpBuilder & /*b*/, Location loc) { return vec; });
+    builder.create<scf::YieldOp>(loc, nextVec);
+  }
+
+  static Value initialLoopState(TransferReadOp xferOp) {
+    // Inititalize vector with padding value.
+    return std_splat(xferOp.getVectorType(), xferOp.padding()).value;
+  }
+};
+
+/// Codegen strategy for TransferWriteOp.
+template <>
+struct Strategy1d<TransferWriteOp> {
+  static void generateForLoopBody(OpBuilder &builder, Location loc,
+                                  TransferWriteOp xferOp, Value iv,
+                                  ValueRange /*loopState*/) {
+    SmallVector<Value, 8> indices;
+    auto dim = get1dMemrefIndices(xferOp, iv, indices);
+    auto ivI32 = std_index_cast(IntegerType::get(builder.getContext(), 32), iv);
+
+    // Nothing to do in case of out-of-bounds access.
+    generateInBoundsCheck(
+        xferOp, iv, builder, dim,
+        /*inBoundsCase=*/[&](OpBuilder & /*b*/, Location loc) {
+          auto val = vector_extract_element(xferOp.vector(), ivI32.value);
+          memref_store(val, xferOp.source(), indices);
+        });
+    builder.create<scf::YieldOp>(loc);
+  }
+
+  static Value initialLoopState(TransferWriteOp xferOp) { return Value(); }
+};
+
+/// Lower a 1D vector transfer op that operates on a dimension different from
+/// the last one. Instead of accessing contiguous chunks (vectors) of memory,
+/// such ops access memory in a strided fashion.
+///
+/// 1. Generate a for loop iterating over each vector element.
+/// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp,
+///    depending on OpTy.
+///
+/// E.g.:
+/// ```
+/// vector.transfer_write %vec, %A[%a, %b]
+///    {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]}
+///    : vector<9xf32>, memref<?x?xf32>
+/// ```
+/// Is rewritten to approximately the following pseudo-IR:
+/// ```
+/// for i = 0 to 9 {
+///   %t = vector.extractelement %vec[i] : vector<9xf32>
+///   memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
+/// }
+/// ```
+template <typename OpTy>
+struct Strided1dTransferOpConversion : public OpRewritePattern<OpTy> {
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(OpTy xferOp,
+                                PatternRewriter &rewriter) const override {
+    ScopedContext scope(rewriter, xferOp.getLoc());
+    auto map = xferOp.permutation_map();
+
+    if (xferOp.getVectorType().getRank() != 1)
+      return failure();
+    if (map.isMinorIdentity()) // Handled by ConvertVectorToLLVM
+      return failure();
+    if (xferOp.mask())
+      return failure();
+
+    // Loop bounds, step, state...
+    auto vecType = xferOp.getVectorType();
+    auto lb = std_constant_index(0);
+    auto ub = std_constant_index(vecType.getDimSize(0));
+    auto step = std_constant_index(1);
+    auto loopState = Strategy1d<OpTy>::initialLoopState(xferOp);
+
+    // Generate for loop.
+    rewriter.replaceOpWithNewOp<scf::ForOp>(
+        xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
+        [&](OpBuilder &builder, Location loc, Value iv, ValueRange loopState) {
+          ScopedContext nestedScope(builder, loc);
+          Strategy1d<OpTy>::generateForLoopBody(builder, loc, xferOp, iv,
+                                                loopState);
+        });
+
+    return success();
+  }
+};
+
 } // namespace
 
 namespace mlir {
@@ -466,7 +657,10 @@ void populateProgressiveVectorToSCFConversionPatterns(
     RewritePatternSet &patterns) {
   patterns.add<PrepareTransferReadConversion, PrepareTransferWriteConversion,
                TransferOpConversion<TransferReadOp>,
-               TransferOpConversion<TransferWriteOp>>(patterns.getContext());
+               TransferOpConversion<TransferWriteOp>,
+               Strided1dTransferOpConversion<TransferReadOp>,
+               Strided1dTransferOpConversion<TransferWriteOp>>(
+      patterns.getContext());
 }
 
 struct ConvertProgressiveVectorToSCFPass
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-1d.mlir
new file mode 100644 (file)
index 0000000..17f635f
--- /dev/null
@@ -0,0 +1,60 @@
+// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+// RUN: mlir-opt %s -test-progressive-convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \
+// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
+// RUN:   -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+// Test for special cases of 1D vector transfer ops.
+
+func @transfer_read_1d(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
+  %fm42 = constant -42.0: f32
+  %f = vector.transfer_read %A[%base1, %base2], %fm42
+      {permutation_map = affine_map<(d0, d1) -> (d0)>}
+      : memref<?x?xf32>, vector<9xf32>
+  vector.print %f: vector<9xf32>
+  return
+}
+
+func @transfer_write_1d(%A : memref<?x?xf32>, %base1 : index, %base2 : index) {
+  %fn1 = constant -1.0 : f32
+  %vf0 = splat %fn1 : vector<7xf32>
+  vector.transfer_write %vf0, %A[%base1, %base2]
+    {permutation_map = affine_map<(d0, d1) -> (d0)>}
+    : vector<7xf32>, memref<?x?xf32>
+  return
+}
+
+func @entry() {
+  %c0 = constant 0: index
+  %c1 = constant 1: index
+  %c2 = constant 2: index
+  %c3 = constant 3: index
+  %f10 = constant 10.0: f32
+  // work with dims of 4, not of 3
+  %first = constant 5: index
+  %second = constant 6: index
+  %A = memref.alloc(%first, %second) : memref<?x?xf32>
+  scf.for %i = %c0 to %first step %c1 {
+    %i32 = index_cast %i : index to i32
+    %fi = sitofp %i32 : i32 to f32
+    %fi10 = mulf %fi, %f10 : f32
+    scf.for %j = %c0 to %second step %c1 {
+        %j32 = index_cast %j : index to i32
+        %fj = sitofp %j32 : i32 to f32
+        %fres = addf %fi10, %fj : f32
+        memref.store %fres, %A[%i, %j] : memref<?x?xf32>
+    }
+  }
+
+  call @transfer_read_1d(%A, %c1, %c2) : (memref<?x?xf32>, index, index) -> ()
+  call @transfer_write_1d(%A, %c3, %c2) : (memref<?x?xf32>, index, index) -> ()
+  call @transfer_read_1d(%A, %c0, %c2) : (memref<?x?xf32>, index, index) -> ()
+  return
+}
+
+// CHECK: ( 12, 22, 32, 42, -42, -42, -42, -42, -42 )
+// CHECK: ( 2, 12, 22, -1, -1, -42, -42, -42, -42 )