From: Hanhan Wang Date: Wed, 23 Nov 2022 18:46:46 +0000 (-0800) Subject: [mlir][linalg] Add a new pattern to handle folding unit reduction dims. X-Git-Tag: upstream/17.0.6~26105 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=9b16d9d27138433ab180241352b3b93200ac7d6c;p=platform%2Fupstream%2Fllvm.git [mlir][linalg] Add a new pattern to handle folding unit reduction dims. The output operands will be added to input operands if the generic op (on tensors) becomes an elementwise operation. The outputs of the generic op is still the same. They will be cleaned up by ReplaceWithEmptyTensorIfUnused pattern. This is https://reviews.llvm.org/D138251, plus a cmake dep fix. Reviewed By: mravishankar Differential Revision: https://reviews.llvm.org/D138843 --- diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt index 8809b25..ca13f44 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -52,6 +52,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms MLIRInferTypeOpInterface MLIRIR MLIRMemRefDialect + MLIRMemRefTransforms MLIRLinalgDialect MLIRLinalgAnalysis MLIRLinalgUtils diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index d6ddfff..b0e45d1 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -19,12 +19,15 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SetVector.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -225,6 +228,125 @@ struct FoldUnitDimLoops : public OpRewritePattern { } }; +/// Pattern to add init operands to ins when all the loops are parallel and +/// blockArgument corresponding to init is used in the region. This is a fix-up +/// when unit reduction dimensions are all folded away. In this context, it +/// becomes a elementwise generic op. E.g., it converts +/// +/// %0 = tensor.empty() : tensor<1x1xf32> +/// %1 = linalg.fill +/// ins(%cst : f32) +/// outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32> +/// %2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>, +/// affine_map<(d0) -> (0, d0)>], +/// iterator_types = ["parallel"]} +/// ins(%arg0 : tensor<1x?x1x1xf32>) +/// outs(%1 : tensor<1x1xf32>) { +/// ^bb0(%in: f32, %out: f32): +/// %3 = arith.addf %in, %out : f32 +/// linalg.yield %3 : f32 +/// } -> tensor<1x1xf32> +/// +/// into +/// +/// %0 = tensor.empty() : tensor<1x1xf32> +/// %1 = linalg.fill +/// ins(%cst : f32) +/// outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32> +/// %2 = tensor.empty() : tensor<1x1xf32> +/// %3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>, +/// affine_map<(d0) -> (0, d0)>, +/// affine_map<(d0) -> (0, d0)>], +/// iterator_types = ["parallel"]} +/// ins(%arg0, %1 : tensor<1x?x1x1xf32>, tensor<1x1xf32>) +/// outs(%2 : tensor<1x1xf32>) { +/// ^bb0(%in: f32, %in_0: f32, %out: f32): +/// %4 = arith.addf %in, %in_0 : f32 +/// linalg.yield %4 : f32 +/// } -> tensor<1x1xf32> +struct AddInitOperandsToInput : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + if (!genericOp.hasTensorSemantics()) + return failure(); + if (genericOp.getNumParallelLoops() != genericOp.getNumLoops()) + return failure(); + + auto outputOperands = genericOp.getDpsInitOperands(); + SetVector candidates; + for (OpOperand *op : outputOperands) { + if (genericOp.getMatchingBlockArgument(op).use_empty()) + continue; + candidates.insert(op); + } + + if (candidates.empty()) + return failure(); + + // Compute the modified indexing maps. + int64_t origNumInput = genericOp.getNumDpsInputs(); + SmallVector newInputOperands = genericOp.getDpsInputOperands(); + SmallVector indexingMaps = genericOp.getIndexingMapsArray(); + SmallVector newIndexingMaps; + newIndexingMaps.append(indexingMaps.begin(), + std::next(indexingMaps.begin(), origNumInput)); + for (OpOperand *op : candidates) { + newInputOperands.push_back(op->get()); + newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(op)); + } + newIndexingMaps.append(std::next(indexingMaps.begin(), origNumInput), + indexingMaps.end()); + + Location loc = genericOp.getLoc(); + SmallVector newOutputOperands = outputOperands; + for (OpOperand *op : candidates) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointAfterValue(op->get()); + auto elemType = op->get().getType().cast().getElementType(); + auto empty = rewriter.create( + loc, tensor::createDimValues(rewriter, loc, op->get()), elemType); + + auto [start, end] = genericOp.getDpsInitsPositionRange(); + newOutputOperands[op->getOperandNumber() - start] = empty.getResult(); + } + + auto newOp = rewriter.create( + loc, genericOp.getResultTypes(), newInputOperands, newOutputOperands, + newIndexingMaps, genericOp.getIteratorTypesArray(), + /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp)); + + Region ®ion = newOp.getRegion(); + Block *block = new Block(); + region.push_back(block); + BlockAndValueMapping mapper; + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(block); + for (auto bbarg : genericOp.getRegionInputArgs()) + mapper.map(bbarg, block->addArgument(bbarg.getType(), loc)); + + for (OpOperand *op : candidates) { + BlockArgument bbarg = genericOp.getMatchingBlockArgument(op); + mapper.map(bbarg, block->addArgument(bbarg.getType(), loc)); + } + + for (OpOperand *op : outputOperands) { + BlockArgument bbarg = genericOp.getMatchingBlockArgument(op); + if (candidates.count(op)) + block->addArgument(bbarg.getType(), loc); + else + mapper.map(bbarg, block->addArgument(bbarg.getType(), loc)); + } + + for (auto &op : genericOp.getBody()->getOperations()) { + rewriter.clone(op, mapper); + } + rewriter.replaceOp(genericOp, newOp.getResults()); + + return success(); + } +}; + struct UnitExtentReplacementInfo { Type type; AffineMap indexMap; @@ -536,7 +658,8 @@ struct RankReducedInsertSliceOp : public OpRewritePattern { void mlir::linalg::populateFoldUnitExtentDimsPatterns( RewritePatternSet &patterns) { auto *context = patterns.getContext(); - patterns.add, RankReducedInsertSliceOp>( context); @@ -544,6 +667,8 @@ void mlir::linalg::populateFoldUnitExtentDimsPatterns( tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context); tensor::EmptyOp::getCanonicalizationPatterns(patterns, context); tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context); + memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns); + memref::populateResolveShapedTypeResultDimsPatterns(patterns); } namespace { @@ -555,7 +680,7 @@ struct LinalgFoldUnitExtentDimsPass MLIRContext *context = op->getContext(); RewritePatternSet patterns(context); if (foldOneTripLoopsOnly) - patterns.add(context); + patterns.add(context); else populateFoldUnitExtentDimsPatterns(patterns); (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir index 4ff1f19..ffa9563 100644 --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -384,11 +384,12 @@ func.func @unit_dim_for_both_reduction(%arg0: tensor<1x?x1x1xf32>) -> tensor<1x1 // CHECK-DAG: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2, 3] // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1xf32> // CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[INIT]] +// CHECK: %[[INIT2:.+]] = tensor.empty() : tensor<1xf32> // CHECK: %[[RESULT:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]]] +// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]] // CHECK-SAME: iterator_types = ["parallel"] -// CHECK-SAME: ins(%[[RESHAPE]] : tensor) -// CHECK-SAME: outs(%[[FILL]] : tensor<1xf32>) +// CHECK-SAME: ins(%[[RESHAPE]], %[[FILL]] : tensor, tensor<1xf32>) +// CHECK-SAME: outs(%[[INIT2]] : tensor<1xf32>) // CHECK: %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]] // CHECK: return %[[RESULT_RESHAPE]] diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 9f85689..87412e6 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -8331,6 +8331,7 @@ cc_library( ":LinalgUtils", ":MathDialect", ":MemRefDialect", + ":MemRefTransforms", ":Pass", ":SCFDialect", ":SCFTransforms",