From 81b62f7feb5de2fb37261974ffa0b2a43a2d83ee Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Fri, 8 Jul 2022 13:49:47 +0200 Subject: [PATCH] [mlir] Handle linalg.index correctly in TilingInterface The existing implementation of the TilingInterface for Linalg ops was not modifying the `linalg.index` ops contained within other Linalg ops (they need to be summed up with the values of respective tile loop induction variables), which led to the interface-based tiling being incorrect for any Linalg op with index semantics. In the process, fix the function performing the index offsetting to use the pattern rewriter API instead of RAUW as it is being called from patterns and may mess up the internal state of the rewriter. Also rename the function to clearly catch all uses. Depends On D129365 Reviewed By: mravishankar Differential Revision: https://reviews.llvm.org/D129366 --- mlir/include/mlir/Dialect/Linalg/Utils/Utils.h | 9 ++--- mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp | 2 +- .../Dialect/Linalg/Transforms/FusionOnTensors.cpp | 2 +- mlir/lib/Dialect/Linalg/Transforms/Split.cpp | 3 +- mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 2 +- .../Linalg/Transforms/TilingInterfaceImpl.cpp | 6 ++-- mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 38 +++++++++++++--------- .../TilingInterface/tile-using-interface.mlir | 34 +++++++++++++++++++ .../TilingInterface/TestTilingInterface.cpp | 3 ++ 9 files changed, 73 insertions(+), 26 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index 4bb0ead..6905b49 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -243,10 +243,11 @@ SmallVector makeTiledShapes(OpBuilder &builder, Location loc, ArrayRef sizeBounds, bool omitPartialTileCheck); -/// Add the tile loop induction variables `ivs` to the IndexOp results found in -/// the body of the `tiledOp` to account for the tile offset. -void addTileLoopIvsToIndexOpResults(OpBuilder &b, LinalgOp tiledOp, - ArrayRef ivs); +/// Add the specified offsets to any `linalg.index` ops contained in the given +/// `linalgOp`. The offsets are provided in the same order as iteration space +/// dimensions. Null offests are assumed to be zero. +void offsetIndices(OpBuilder &b, LinalgOp linalgOp, ArrayRef offests); +void offsetIndices(RewriterBase &b, LinalgOp linalgOp, ArrayRef offests); using FusableOpDependencesTy = llvm::MapVector< Operation *, diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp index fa0f3ef..3d8b89b 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -170,7 +170,7 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer, SmallVector allIvs; llvm::transform(loopRanges, std::back_inserter(allIvs), [](Range range) { return range.offset; }); - addTileLoopIvsToIndexOpResults(b, clonedOp, allIvs); + offsetIndices(b, clonedOp, allIvs); return clonedOp; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp index 0a15c31..d968b37 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -186,7 +186,7 @@ static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult, LinalgOp clonedOp = producerOp.clone(b, loc, resultTypes, tiledOperands); // Shift all IndexOp results by the tile offset. - addTileLoopIvsToIndexOpResults(b, clonedOp, allIvs); + offsetIndices(b, clonedOp, allIvs); return clonedOp; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp index 2325771..875c844 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp @@ -139,8 +139,7 @@ std::pair linalg::splitOp(RewriterBase &rewriter, SmallVector ivAdditions; ivAdditions.resize(splitIterationSpace.size()); ivAdditions[dimension] = splitPointValue; - linalg::addTileLoopIvsToIndexOpResults(builder, cast(second), - ivAdditions); + linalg::offsetIndices(rewriter, cast(second), ivAdditions); // Replace the original op with the results of the two newly created ops. rewriter.replaceOp(op, secondResults); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index d55876f..25eab5b 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -80,7 +80,7 @@ void mlir::linalg::transformIndexOps( continue; en.value() = ivs[rangeIndex->second]; } - addTileLoopIvsToIndexOpResults(b, op, allIvs); + offsetIndices(b, op, allIvs); } /// Asserts that the given index-typed value is strictly positive. If the value diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index 88b21f1..06d5f75 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -71,9 +71,10 @@ struct LinalgOpTilingInterface Location loc = op->getLoc(); LinalgOp linalgOp = cast(op); SmallVector valuesToTile = linalgOp.getInputAndOutputOperands(); + SmallVector offsetValues = + getValueOrCreateConstantIndexOp(b, loc, offsets); SmallVector tiledOperands = makeTiledShapes( - b, loc, linalgOp, valuesToTile, - getValueOrCreateConstantIndexOp(b, loc, offsets), + b, loc, linalgOp, valuesToTile, offsetValues, getValueOrCreateConstantIndexOp(b, loc, sizes), {}, true); SmallVector resultTensorTypes = llvm::to_vector(llvm::map_range( @@ -83,6 +84,7 @@ struct LinalgOpTilingInterface Operation *tiledOp = linalgOp.clone(b, loc, resultTensorTypes, tiledOperands); + offsetIndices(b, cast(tiledOp), offsetValues); return {tiledOp}; } diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 34b6714..5f54ee4 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -1048,21 +1048,29 @@ SmallVector makeTiledShapes(OpBuilder &b, Location loc, return tiledShapes; } -void addTileLoopIvsToIndexOpResults(OpBuilder &b, LinalgOp tiledOp, - ArrayRef ivs) { - if (tiledOp.hasIndexSemantics()) { - for (IndexOp indexOp : tiledOp.getBlock()->getOps()) { - if (ivs[indexOp.dim()] == nullptr) - continue; - OpBuilder::InsertionGuard guard(b); - b.setInsertionPointAfter(indexOp); - AffineExpr index, offset; - bindDims(b.getContext(), index, offset); - AffineApplyOp applyOp = makeComposedAffineApply( - b, indexOp.getLoc(), index + offset, - ValueRange{indexOp.getResult(), ivs[indexOp.dim()]}); - indexOp.getResult().replaceAllUsesExcept(applyOp, applyOp); - } +void offsetIndices(OpBuilder &b, LinalgOp linalgOp, ArrayRef offsets) { + IRRewriter rewriter(b); + offsetIndices(rewriter, linalgOp, offsets); +} + +void offsetIndices(RewriterBase &b, LinalgOp linalgOp, + ArrayRef offsets) { + if (!linalgOp.hasIndexSemantics()) + return; + + for (IndexOp indexOp : linalgOp.getBlock()->getOps()) { + if (indexOp.dim() >= offsets.size() || offsets[indexOp.dim()] == nullptr) + continue; + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointAfter(indexOp); + AffineExpr index, offset; + bindDims(b.getContext(), index, offset); + AffineApplyOp applyOp = makeComposedAffineApply( + b, indexOp.getLoc(), index + offset, + ValueRange{indexOp.getResult(), offsets[indexOp.dim()]}); + b.replaceOpWithIf(indexOp, applyOp.getResult(), [&](OpOperand &use) { + return use.getOwner() != applyOp; + }); } } diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir index a7367a7..d8ec2c5 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir @@ -192,3 +192,37 @@ func.func @conv2D(%arg0 : tensor, %arg1 : tensor, // CHECK-SAME: outs(%[[INIT_TILE]] : // CHECK: tensor.insert_slice %[[CONV_TILE]] into %[[INIT2]] // CHECK-SAME: [0, 0, 0, 0] [%[[N]], %[[R]], %[[S]], %[[F]]] + +// ----- + +// CHECK: #[[$MAP_ADD:.+]] = affine_map<(d0, d1) -> (d0 + d1)> + +// CHECK-LABEL: @indexed_semantics +func.func @indexed_semantics(%arg0: tensor, %arg1: tensor) -> tensor { + // Check that we correctly amend "linalg.index" results. + + // CHECK: scf.for %[[I0:.+]] = %{{.*}} to %{{.*}} step %{{.*}} + // CHECK: scf.for %[[I1:.+]] = %{{.*}} to %{{.*}} step %{{.*}} + %0 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + {__internal_linalg_transform__ = "indexed_semantics"} + ins(%arg0: tensor) + outs(%arg1: tensor) { + ^bb0(%arg2: f32, %arg3: f32): + // CHECK: %[[INDEX0:.+]] = linalg.index 0 + // CHECK: %[[INDEX0_AMENDED:.+]] = affine.apply #[[$MAP_ADD]](%[[INDEX0]], %[[I0]]) + %1 = linalg.index 0 : index + // CHECK: %[[INDEX1:.+]] = linalg.index 1 + // CHECK: %[[INDEX1_AMENDED:.+]] = affine.apply #[[$MAP_ADD]](%[[INDEX1]], %[[I1]]) + %2 = linalg.index 1 : index + // CHECK: arith.addi %[[INDEX0_AMENDED]], %[[INDEX1_AMENDED]] + %3 = arith.addi %1, %2 : index + %4 = arith.index_cast %3 : index to i64 + %5 = arith.uitofp %4 : i64 to f32 + %6 = arith.addf %5, %arg2 : f32 + linalg.yield %6 : f32 + } -> (tensor) + return %0 : tensor +} diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp index 6241603..cebe7b1 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp @@ -171,6 +171,9 @@ void TestTilingInterfacePass::addTestPatterns(MLIRContext *context, // 4. Tiling 2D conv op. addPatternForTiling( context, {0, 0, 0, 0, 10, 20, 30}, "simple_conv", patterns); + // 5. Tiling a simple op with `linalg.index` inside. + addPatternForTiling( + context, {10, 20}, "indexed_semantics", patterns); return; } if (testTileConsumerAndFuseProducer) { -- 2.7.4