[mlir] Handle linalg.index correctly in TilingInterface
authorAlex Zinenko <zinenko@google.com>
Fri, 8 Jul 2022 11:49:47 +0000 (13:49 +0200)
committerAlex Zinenko <zinenko@google.com>
Tue, 12 Jul 2022 12:36:33 +0000 (12:36 +0000)
The existing implementation of the TilingInterface for Linalg ops was not
modifying the `linalg.index` ops contained within other Linalg ops (they need
to be summed up with the values of respective tile loop induction variables),
which led to the interface-based tiling being incorrect for any Linalg op with
index semantics.

In the process, fix the function performing the index offsetting to use the
pattern rewriter API instead of RAUW as it is being called from patterns and
may mess up the internal state of the rewriter. Also rename the function to
clearly catch all uses.

Depends On D129365

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D129366

mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
mlir/lib/Dialect/Linalg/Transforms/Split.cpp
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/Utils/Utils.cpp
mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp

index 4bb0ead..6905b49 100644 (file)
@@ -243,10 +243,11 @@ SmallVector<Value, 4> makeTiledShapes(OpBuilder &builder, Location loc,
                                       ArrayRef<Value> sizeBounds,
                                       bool omitPartialTileCheck);
 
-/// Add the tile loop induction variables `ivs` to the IndexOp results found in
-/// the body of the `tiledOp` to account for the tile offset.
-void addTileLoopIvsToIndexOpResults(OpBuilder &b, LinalgOp tiledOp,
-                                    ArrayRef<Value> ivs);
+/// Add the specified offsets to any `linalg.index` ops contained in the given
+/// `linalgOp`. The offsets are provided in the same order as iteration space
+/// dimensions. Null offests are assumed to be zero.
+void offsetIndices(OpBuilder &b, LinalgOp linalgOp, ArrayRef<Value> offests);
+void offsetIndices(RewriterBase &b, LinalgOp linalgOp, ArrayRef<Value> offests);
 
 using FusableOpDependencesTy = llvm::MapVector<
     Operation *,
index fa0f3ef..3d8b89b 100644 (file)
@@ -170,7 +170,7 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
   SmallVector<Value> allIvs;
   llvm::transform(loopRanges, std::back_inserter(allIvs),
                   [](Range range) { return range.offset; });
-  addTileLoopIvsToIndexOpResults(b, clonedOp, allIvs);
+  offsetIndices(b, clonedOp, allIvs);
 
   return clonedOp;
 }
index 0a15c31..d968b37 100644 (file)
@@ -186,7 +186,7 @@ static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult,
   LinalgOp clonedOp = producerOp.clone(b, loc, resultTypes, tiledOperands);
 
   // Shift all IndexOp results by the tile offset.
-  addTileLoopIvsToIndexOpResults(b, clonedOp, allIvs);
+  offsetIndices(b, clonedOp, allIvs);
 
   return clonedOp;
 }
index 2325771..875c844 100644 (file)
@@ -139,8 +139,7 @@ std::pair<LinalgOp, LinalgOp> linalg::splitOp(RewriterBase &rewriter,
   SmallVector<Value> ivAdditions;
   ivAdditions.resize(splitIterationSpace.size());
   ivAdditions[dimension] = splitPointValue;
-  linalg::addTileLoopIvsToIndexOpResults(builder, cast<LinalgOp>(second),
-                                         ivAdditions);
+  linalg::offsetIndices(rewriter, cast<LinalgOp>(second), ivAdditions);
 
   // Replace the original op with the results of the two newly created ops.
   rewriter.replaceOp(op, secondResults);
index d55876f..25eab5b 100644 (file)
@@ -80,7 +80,7 @@ void mlir::linalg::transformIndexOps(
       continue;
     en.value() = ivs[rangeIndex->second];
   }
-  addTileLoopIvsToIndexOpResults(b, op, allIvs);
+  offsetIndices(b, op, allIvs);
 }
 
 /// Asserts that the given index-typed value is strictly positive. If the value
index 88b21f1..06d5f75 100644 (file)
@@ -71,9 +71,10 @@ struct LinalgOpTilingInterface
     Location loc = op->getLoc();
     LinalgOp linalgOp = cast<LinalgOp>(op);
     SmallVector<Value> valuesToTile = linalgOp.getInputAndOutputOperands();
+    SmallVector<Value> offsetValues =
+        getValueOrCreateConstantIndexOp(b, loc, offsets);
     SmallVector<Value, 4> tiledOperands = makeTiledShapes(
-        b, loc, linalgOp, valuesToTile,
-        getValueOrCreateConstantIndexOp(b, loc, offsets),
+        b, loc, linalgOp, valuesToTile, offsetValues,
         getValueOrCreateConstantIndexOp(b, loc, sizes), {}, true);
 
     SmallVector<Type> resultTensorTypes = llvm::to_vector(llvm::map_range(
@@ -83,6 +84,7 @@ struct LinalgOpTilingInterface
 
     Operation *tiledOp =
         linalgOp.clone(b, loc, resultTensorTypes, tiledOperands);
+    offsetIndices(b, cast<LinalgOp>(tiledOp), offsetValues);
 
     return {tiledOp};
   }
index 34b6714..5f54ee4 100644 (file)
@@ -1048,21 +1048,29 @@ SmallVector<Value, 4> makeTiledShapes(OpBuilder &b, Location loc,
   return tiledShapes;
 }
 
-void addTileLoopIvsToIndexOpResults(OpBuilder &b, LinalgOp tiledOp,
-                                    ArrayRef<Value> ivs) {
-  if (tiledOp.hasIndexSemantics()) {
-    for (IndexOp indexOp : tiledOp.getBlock()->getOps<IndexOp>()) {
-      if (ivs[indexOp.dim()] == nullptr)
-        continue;
-      OpBuilder::InsertionGuard guard(b);
-      b.setInsertionPointAfter(indexOp);
-      AffineExpr index, offset;
-      bindDims(b.getContext(), index, offset);
-      AffineApplyOp applyOp = makeComposedAffineApply(
-          b, indexOp.getLoc(), index + offset,
-          ValueRange{indexOp.getResult(), ivs[indexOp.dim()]});
-      indexOp.getResult().replaceAllUsesExcept(applyOp, applyOp);
-    }
+void offsetIndices(OpBuilder &b, LinalgOp linalgOp, ArrayRef<Value> offsets) {
+  IRRewriter rewriter(b);
+  offsetIndices(rewriter, linalgOp, offsets);
+}
+
+void offsetIndices(RewriterBase &b, LinalgOp linalgOp,
+                   ArrayRef<Value> offsets) {
+  if (!linalgOp.hasIndexSemantics())
+    return;
+
+  for (IndexOp indexOp : linalgOp.getBlock()->getOps<IndexOp>()) {
+    if (indexOp.dim() >= offsets.size() || offsets[indexOp.dim()] == nullptr)
+      continue;
+    OpBuilder::InsertionGuard guard(b);
+    b.setInsertionPointAfter(indexOp);
+    AffineExpr index, offset;
+    bindDims(b.getContext(), index, offset);
+    AffineApplyOp applyOp = makeComposedAffineApply(
+        b, indexOp.getLoc(), index + offset,
+        ValueRange{indexOp.getResult(), offsets[indexOp.dim()]});
+    b.replaceOpWithIf(indexOp, applyOp.getResult(), [&](OpOperand &use) {
+      return use.getOwner() != applyOp;
+    });
   }
 }
 
index a7367a7..d8ec2c5 100644 (file)
@@ -192,3 +192,37 @@ func.func @conv2D(%arg0 : tensor<?x?x?x?xf32>, %arg1 : tensor<?x?x?x?xf32>,
 // CHECK-SAME:             outs(%[[INIT_TILE]] :
 //      CHECK:         tensor.insert_slice %[[CONV_TILE]] into %[[INIT2]]
 // CHECK-SAME:             [0, 0, 0, 0] [%[[N]], %[[R]], %[[S]], %[[F]]]
+
+// -----
+
+// CHECK: #[[$MAP_ADD:.+]] = affine_map<(d0, d1) -> (d0 + d1)>
+
+// CHECK-LABEL: @indexed_semantics
+func.func @indexed_semantics(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  // Check that we correctly amend "linalg.index" results.
+
+  // CHECK: scf.for %[[I0:.+]] = %{{.*}} to %{{.*}} step %{{.*}}
+  // CHECK: scf.for %[[I1:.+]] = %{{.*}} to %{{.*}} step %{{.*}}
+  %0 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                     affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]}
+    {__internal_linalg_transform__ = "indexed_semantics"}
+    ins(%arg0: tensor<?x?xf32>)
+    outs(%arg1: tensor<?x?xf32>) {
+  ^bb0(%arg2: f32, %arg3: f32):
+    // CHECK: %[[INDEX0:.+]] = linalg.index 0
+    // CHECK: %[[INDEX0_AMENDED:.+]] = affine.apply #[[$MAP_ADD]](%[[INDEX0]], %[[I0]])
+    %1 = linalg.index 0 : index
+    // CHECK: %[[INDEX1:.+]] = linalg.index 1
+    // CHECK: %[[INDEX1_AMENDED:.+]] = affine.apply #[[$MAP_ADD]](%[[INDEX1]], %[[I1]])
+    %2 = linalg.index 1 : index
+    // CHECK: arith.addi %[[INDEX0_AMENDED]], %[[INDEX1_AMENDED]]
+    %3 = arith.addi %1, %2 : index
+    %4 = arith.index_cast %3 : index to i64
+    %5 = arith.uitofp %4 : i64 to f32
+    %6 = arith.addf %5, %arg2 : f32
+    linalg.yield %6 : f32
+  } -> (tensor<?x?xf32>)
+  return %0 : tensor<?x?xf32>
+}
index 6241603..cebe7b1 100644 (file)
@@ -171,6 +171,9 @@ void TestTilingInterfacePass::addTestPatterns(MLIRContext *context,
     // 4. Tiling 2D conv op.
     addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
         context, {0, 0, 0, 0, 10, 20, 30}, "simple_conv", patterns);
+    // 5. Tiling a simple op with `linalg.index` inside.
+    addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
+        context, {10, 20}, "indexed_semantics", patterns);
     return;
   }
   if (testTileConsumerAndFuseProducer) {