From 4d27f06f9454a6733c3f801c8b992193702607b3 Mon Sep 17 00:00:00 2001 From: Christopher Bate Date: Fri, 16 Sep 2022 16:11:46 -0600 Subject: [PATCH] [mlir][Tensor] Fix ExtractSliceFromReshape transform edge case The transformation would fail if none of the sliced dimensions were linearized by the producing `tensor.collapse_shape`. This is a trivial edge case but it wasn't correctly tested. Fixes the issue and adds a test. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D134088 --- mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h | 6 ++++-- .../Tensor/Transforms/ExtractSliceFromReshape.cpp | 5 +++-- mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp | 9 +++------ .../Tensor/extract-slice-from-collapse-shape.mlir | 15 +++++++++++++++ mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp | 18 +++++++++++++++--- 5 files changed, 40 insertions(+), 13 deletions(-) diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h index e6b6048..f693b35 100644 --- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h @@ -441,14 +441,16 @@ public: /// only one tiled dimension (D_0) and `arith.delinearize_index` produces the /// multi-index (%3) that would be passed to this function to generate the /// parameters for the `tensor.extract_slice` op (%4). - SmallVector getExtractSliceParams(ArrayRef multiIndices); + SmallVector getExtractSliceParams(MLIRContext *ctx, + ArrayRef multiIndices); /// This function takes indices in the index space of the "tiled dimensions" /// described above and returns a set of Range variables that describe how the /// slice should be inserted into the destination. In the example above, `%iv` /// would be passed to this function to generate the parameters for the /// `tensor.insert_slice` op producing %6. - SmallVector getInsertSliceParams(ValueRange tileIndices); + SmallVector getInsertSliceParams(MLIRContext *ctx, + ValueRange tileIndices); private: SmallVector reassociationIndices; diff --git a/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshape.cpp b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshape.cpp index 4acd548..dcee9de 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshape.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshape.cpp @@ -164,13 +164,14 @@ tensor::ExtractSliceFromCollapseHelper::emitLoopNestBody( } } - auto extractParams = helper.getExtractSliceParams(multiIndices); + SmallVector extractParams = + helper.getExtractSliceParams(builder.getContext(), multiIndices); Value subTileResult = builder.create( loc, collapseShapeOp.getSrc(), extractParams); SmallVector insertParams = - helper.getInsertSliceParams(tileInductionVars); + helper.getInsertSliceParams(builder.getContext(), tileInductionVars); // Collapse the dimensions of the source slice back down. Value collapsedResult = builder.create( diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp index 7f5b638..9bca50f 100644 --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -298,11 +298,8 @@ llvm::SmallBitVector mlir::getLinearizedDimensions( } SmallVector SliceFromCollapseHelper::getExtractSliceParams( - ArrayRef multiIndices) { - assert(!multiIndices.empty() && !multiIndices[0].empty() && - "multiIndices should not be empty"); + MLIRContext *ctx, ArrayRef multiIndices) { unsigned loopIdx = 0; - MLIRContext *ctx = multiIndices[0][0].getContext(); auto oneAttr = IntegerAttr::get(IndexType::get(ctx), 1); auto zeroAttr = IntegerAttr::get(IndexType::get(ctx), 0); SmallVector offsetsSizesAndStrides; @@ -339,8 +336,8 @@ SmallVector SliceFromCollapseHelper::getExtractSliceParams( } SmallVector -SliceFromCollapseHelper::getInsertSliceParams(ValueRange tileIndices) { - MLIRContext *ctx = tileIndices[0].getContext(); +SliceFromCollapseHelper::getInsertSliceParams(MLIRContext *ctx, + ValueRange tileIndices) { auto one = IntegerAttr::get(IndexType::get(ctx), 1); auto zero = IntegerAttr::get(IndexType::get(ctx), 0); SmallVector insertParams; diff --git a/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir b/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir index d8ca129..02e2502 100644 --- a/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir +++ b/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir @@ -162,3 +162,18 @@ func.func @extract_slice_non_sliced_linearized_dim(%input: tensor<3x?x?x11x2xf32 // CHECK: tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0, 0] [1, 1, 1, 11, 2] [1, 1, 1, 1, 1] return %slice : tensor } + +// ----- + +// CHECK: @no_sliced_linearized_dims(%[[arg0:.+]]: tensor<{{.*}}>, %[[arg1:.+]]: index, %[[arg2:.+]]: index +func.func @no_sliced_linearized_dims(%input: tensor<30x11x100xf32>, %offt: index, %size: index) -> tensor<330x?xf32> { + %collapsed = tensor.collapse_shape %input [[0, 1], [2]] : tensor<30x11x100xf32> into tensor<330x100xf32> + %slice = tensor.extract_slice %collapsed [0, %offt] [330, %size] [1, 1] : tensor<330x100xf32> to tensor<330x?xf32> + // CHECK-NOT: scf.for + // CHECK: %[[init:.+]] = linalg.init_tensor [330, %[[arg2]]] + // CHECK: %[[e:.+]] = tensor.extract_slice %[[arg0]][0, 0, %[[arg1]]] [30, 11, %[[arg2]]] [1, 1, 1] + // CHECK: %[[c:.+]] = tensor.collapse_shape %[[e]] {{\[}}[0, 1], [2]] + // CHECK: %[[res:.+]] = tensor.insert_slice %[[c]] into %[[init]] + // CHECK: return %[[res]] + return %slice : tensor<330x?xf32> +} diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp index f5a7f98..5dd5d76 100644 --- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -151,6 +151,13 @@ struct RewriteExtractSliceFromCollapseShapeUsingScfFor auto one = rewriter.create(loc, 1); SmallVector lbs(numTiledDims, zero); SmallVector steps(numTiledDims, one); + + // Below, we pass out the result of the loop body builder lambda via the + // `insertResult` variable. In certain cases, no loops will be created, but + // the body builder will still execute. In this case, the results will not + // be passed to the LoopNest object. + // TODO: remove this workaround if `scf::buildLoopNest` behavior is updated. + Value insertResult = nullptr; scf::LoopNest nest = scf::buildLoopNest( rewriter, loc, lbs, helper.getIterationSpaceSizes(), steps, dest, [&](OpBuilder &nestedBuilder, Location loc, ValueRange outputIvs, @@ -159,11 +166,16 @@ struct RewriteExtractSliceFromCollapseShapeUsingScfFor helper.emitLoopNestBody(nestedBuilder, loc, outputIvs); // Insert the slice into the destination. - Value result = nestedBuilder.create( + insertResult = nestedBuilder.create( loc, tile, iterArgs[0], insertParams); - return {result}; + return {insertResult}; }); - rewriter.replaceOp(op, nest.getResults()[0]); + + if (!nest.loops.empty()) + rewriter.replaceOp(op, nest.getResults()); + else + rewriter.replaceOp(op, insertResult); + return success(); } }; -- 2.7.4