From d18ffd61d4f2500dc4ae267f4705102abb2cf02f Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Mon, 30 Aug 2021 01:12:14 +0000 Subject: [PATCH] [mlir][SCF] Canonicalize dim(x) where x is an iter_arg * Add `DimOfIterArgFolder`. * Move existing cross-dialect canonicalization patterns to `LoopCanonicalization.cpp`. * Rename `SCFAffineOpCanonicalization` pass to `SCFForLoopCanonicalization`. * Expand documentaton of scf.for: The type of loop-carried variables may not change with iterations. (Not even the dynamic type.) Differential Revision: https://reviews.llvm.org/D108806 --- mlir/include/mlir/Dialect/SCF/Passes.h | 2 +- mlir/include/mlir/Dialect/SCF/Passes.td | 16 +-- mlir/include/mlir/Dialect/SCF/SCFOps.td | 22 ++-- mlir/include/mlir/Dialect/SCF/Transforms.h | 2 +- .../Dialect/Linalg/Transforms/CodegenStrategy.cpp | 2 +- mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 2 +- mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt | 2 + .../SCF/Transforms/LoopCanonicalization.cpp | 127 +++++++++++++++++++++ .../Dialect/SCF/Transforms/LoopSpecialization.cpp | 59 ---------- mlir/lib/Dialect/SCF/Transforms/PassDetail.h | 4 + ...fine-op.mlir => for-loop-canonicalization.mlir} | 20 +++- .../lib/Dialect/Linalg/TestConvVectorization.cpp | 2 +- .../Dialect/Linalg/TestLinalgFusionTransforms.cpp | 2 +- utils/bazel/llvm-project-overlay/mlir/BUILD.bazel | 1 + 14 files changed, 180 insertions(+), 83 deletions(-) create mode 100644 mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp rename mlir/test/Dialect/SCF/{canonicalize-affine-op.mlir => for-loop-canonicalization.mlir} (91%) diff --git a/mlir/include/mlir/Dialect/SCF/Passes.h b/mlir/include/mlir/Dialect/SCF/Passes.h index 18a2798..34fe3c0 100644 --- a/mlir/include/mlir/Dialect/SCF/Passes.h +++ b/mlir/include/mlir/Dialect/SCF/Passes.h @@ -30,7 +30,7 @@ std::unique_ptr createForLoopPeelingPass(); /// Creates a pass that canonicalizes affine.min and affine.max operations /// inside of scf.for loops with known lower and upper bounds. -std::unique_ptr createSCFAffineOpCanonicalizationPass(); +std::unique_ptr createSCFForLoopCanonicalizationPass(); /// Creates a loop fusion pass which fuses parallel loops. std::unique_ptr createParallelLoopFusionPass(); diff --git a/mlir/include/mlir/Dialect/SCF/Passes.td b/mlir/include/mlir/Dialect/SCF/Passes.td index 2f2e7de..ef4df99 100644 --- a/mlir/include/mlir/Dialect/SCF/Passes.td +++ b/mlir/include/mlir/Dialect/SCF/Passes.td @@ -17,14 +17,14 @@ def SCFBufferize : FunctionPass<"scf-bufferize"> { let dependentDialects = ["memref::MemRefDialect"]; } -// Note: Making this a canonicalization pattern would require a dependency -// of the SCF dialect on the Affine dialect or vice versa. -def SCFAffineOpCanonicalization - : FunctionPass<"canonicalize-scf-affine-op"> { - let summary = "Canonicalize affine.min and affine.max ops in the context of " - "SCF loops with known bounds"; - let constructor = "mlir::createSCFAffineOpCanonicalizationPass()"; - let dependentDialects = ["AffineDialect"]; +// Note: Making these canonicalization patterns would require a dependency +// of the SCF dialect on the Affine/Tensor/MemRef dialects or vice versa. +def SCFForLoopCanonicalization + : FunctionPass<"for-loop-canonicalization"> { + let summary = "Canonicalize operations within scf.for loop bodies"; + let constructor = "mlir::createSCFForLoopCanonicalizationPass()"; + let dependentDialects = ["AffineDialect", "tensor::TensorDialect", + "memref::MemRefDialect"]; } def SCFForLoopPeeling diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td index 9e23c30..d04266a 100644 --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -122,7 +122,7 @@ def ForOp : SCF_Op<"for", let summary = "for operation"; let description = [{ The "scf.for" operation represents a loop taking 3 SSA value as operands - that represent the lower bound, upper bound and step respectively. The + that represent the lower bound, upper bound and step respectively. The operation defines an SSA value for its induction variable. It has one region capturing the loop body. The induction variable is represented as an argument of this region. This SSA value always has type index, which is the @@ -146,14 +146,18 @@ def ForOp : SCF_Op<"for", values after loop termination. The initial values of the variables are passed as additional SSA operands to the "scf.for" following the 3 loop control SSA values mentioned above (lower bound, upper bound and step). The - operation region has equivalent arguments for each variable representing - the value of the variable at the current iteration. - - The region must terminate with a "scf.yield" that passes all the current - iteration variables to the next iteration, or to the "scf.for" result, if - at the last iteration. Note, that when the loop-carried variables are - present, calling ForOp::build will not insert the terminator implicitly. - The caller must insert "scf.yield" in that case. + operation region has an argument for the induction variable, followed by + one argument for each loop-carried variable, representing he value of the + variable at the current iteration. + + The region must terminate with a "scf.yield" that passes the current + values of loop-carried variables to the next iteration, or to the "scf.for" + result, if at the last iteration. The type (static or dynamic) of a + loop-carried variable may not change with iterations. E.g., it is illegal + to pass a tensor of larger size to the next iteration; even if the tensor's + dimensions are dynamic (i.e., same static type). Note, that when the + loop-carried variables are present, calling ForOp::build will not insert the + terminator implicitly. The caller must insert "scf.yield" in that case. "scf.for" results hold the final values after the last iteration. For example, to sum-reduce a memref: diff --git a/mlir/include/mlir/Dialect/SCF/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms.h index a3fd6b7..a9b8937 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms.h @@ -179,7 +179,7 @@ void populateSCFLoopPipeliningPatterns(RewritePatternSet &patterns, /// Populate patterns for canonicalizing operations inside SCF loop bodies. /// At the moment, only affine.min/max computations with iteration variables, /// loop bounds and loop steps are canonicalized. -void populateSCFLoopBodyCanonicalizationPatterns(RewritePatternSet &patterns); +void populateSCFForLoopCanonicalizationPatterns(RewritePatternSet &patterns); } // namespace scf } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp index 42f5a53..8192551 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp @@ -48,7 +48,7 @@ void mlir::linalg::CodegenStrategy::transform(FuncOp func) const { RewritePatternSet stage2Patterns = linalg::getLinalgTilingCanonicalizationPatterns(context); - scf::populateSCFLoopBodyCanonicalizationPatterns(stage2Patterns); + scf::populateSCFForLoopCanonicalizationPatterns(stage2Patterns); auto stage3Transforms = [&](Operation *op) { // Some of these may be too aggressive as a stage 3 that is applied on each diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index 5b6eb6a..db8b89d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -537,7 +537,7 @@ applyTilingToLoopPatterns(LinalgTilingLoopType loopType, FuncOp funcOp, MLIRContext *ctx = funcOp.getContext(); RewritePatternSet patterns(ctx); insertTilingPatterns(patterns, options); - scf::populateSCFLoopBodyCanonicalizationPatterns(patterns); + scf::populateSCFForLoopCanonicalizationPatterns(patterns); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); (void)applyPatternsAndFoldGreedily( funcOp, getLinalgTilingCanonicalizationPatterns(ctx)); diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt index 81d3058..759f02b 100644 --- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRSCFTransforms Bufferize.cpp + LoopCanonicalization.cpp LoopPipelining.cpp LoopRangeFolding.cpp LoopSpecialization.cpp @@ -22,6 +23,7 @@ add_mlir_dialect_library(MLIRSCFTransforms MLIRSCF MLIRStandard MLIRSupport + MLIRTensor MLIRTransforms MLIRTransformUtils ) diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp new file mode 100644 index 0000000..3e6fe03 --- /dev/null +++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp @@ -0,0 +1,127 @@ +//===- LoopCanonicalization.cpp - Cross-dialect canonicalization patterns -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains cross-dialect canonicalization patterns that cannot be +// actual canonicalization patterns due to undesired additional dependencies. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/Passes.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/SCF/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::scf; + +namespace { +/// Fold dim ops of iter_args to dim ops of their respective init args. E.g.: +/// +/// ``` +/// %0 = ... : tensor +/// scf.for ... iter_args(%arg0 = %0) -> (tensor) { +/// %1 = tensor.dim %arg0, %c0 : tensor +/// ... +/// } +/// ``` +/// +/// is folded to: +/// +/// ``` +/// %0 = ... : tensor +/// scf.for ... iter_args(%arg0 = %0) -> (tensor) { +/// %1 = tensor.dim %0, %c0 : tensor +/// ... +/// } +/// ``` +template +struct DimOfIterArgFolder : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy dimOp, + PatternRewriter &rewriter) const override { + auto blockArg = dimOp.source().template dyn_cast(); + if (!blockArg) + return failure(); + auto forOp = dyn_cast(blockArg.getParentBlock()->getParentOp()); + if (!forOp) + return failure(); + + Value initArg = forOp.getOpOperandForRegionIterArg(blockArg).get(); + rewriter.updateRootInPlace( + dimOp, [&]() { dimOp.sourceMutable().assign(initArg); }); + + return success(); + }; +}; + +/// Canonicalize AffineMinOp/AffineMaxOp operations in the context of scf.for +/// and scf.parallel loops with a known range. +template +struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + auto loopMatcher = [](Value iv, Value &lb, Value &ub, Value &step) { + if (scf::ForOp forOp = scf::getForInductionVarOwner(iv)) { + lb = forOp.lowerBound(); + ub = forOp.upperBound(); + step = forOp.step(); + return success(); + } + if (scf::ParallelOp parOp = scf::getParallelForInductionVarOwner(iv)) { + for (unsigned idx = 0; idx < parOp.getNumLoops(); ++idx) { + if (parOp.getInductionVars()[idx] == iv) { + lb = parOp.lowerBound()[idx]; + ub = parOp.upperBound()[idx]; + step = parOp.step()[idx]; + return success(); + } + } + return failure(); + } + return failure(); + }; + + return scf::canonicalizeMinMaxOpInLoop(rewriter, op, op.getAffineMap(), + op.operands(), IsMin, loopMatcher); + } +}; + +struct SCFForLoopCanonicalization + : public SCFForLoopCanonicalizationBase { + void runOnFunction() override { + FuncOp funcOp = getFunction(); + MLIRContext *ctx = funcOp.getContext(); + RewritePatternSet patterns(ctx); + scf::populateSCFForLoopCanonicalizationPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +void mlir::scf::populateSCFForLoopCanonicalizationPatterns( + RewritePatternSet &patterns) { + MLIRContext *ctx = patterns.getContext(); + patterns + .insert, + AffineOpSCFCanonicalizationPattern, + DimOfIterArgFolder, + DimOfIterArgFolder>(ctx); +} + +std::unique_ptr mlir::createSCFForLoopCanonicalizationPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp index dd8ae01..b5e9543 100644 --- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp @@ -516,40 +516,6 @@ struct ForLoopPeelingPattern : public OpRewritePattern { /// the direct parent. bool skipPartial; }; - -/// Canonicalize AffineMinOp/AffineMaxOp operations in the context of scf.for -/// and scf.parallel loops with a known range. -template -struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(OpTy op, - PatternRewriter &rewriter) const override { - auto loopMatcher = [](Value iv, Value &lb, Value &ub, Value &step) { - if (scf::ForOp forOp = scf::getForInductionVarOwner(iv)) { - lb = forOp.lowerBound(); - ub = forOp.upperBound(); - step = forOp.step(); - return success(); - } - if (scf::ParallelOp parOp = scf::getParallelForInductionVarOwner(iv)) { - for (unsigned idx = 0; idx < parOp.getNumLoops(); ++idx) { - if (parOp.getInductionVars()[idx] == iv) { - lb = parOp.lowerBound()[idx]; - ub = parOp.upperBound()[idx]; - step = parOp.step()[idx]; - return success(); - } - } - return failure(); - } - return failure(); - }; - - return scf::canonicalizeMinMaxOpInLoop(rewriter, op, op.getAffineMap(), - op.operands(), IsMin, loopMatcher); - } -}; } // namespace namespace { @@ -583,24 +549,8 @@ struct ForLoopPeeling : public SCFForLoopPeelingBase { }); } }; - -struct SCFAffineOpCanonicalization - : public SCFAffineOpCanonicalizationBase { - void runOnFunction() override { - FuncOp funcOp = getFunction(); - MLIRContext *ctx = funcOp.getContext(); - RewritePatternSet patterns(ctx); - scf::populateSCFLoopBodyCanonicalizationPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) - signalPassFailure(); - } -}; } // namespace -std::unique_ptr mlir::createSCFAffineOpCanonicalizationPass() { - return std::make_unique(); -} - std::unique_ptr mlir::createParallelLoopSpecializationPass() { return std::make_unique(); } @@ -612,12 +562,3 @@ std::unique_ptr mlir::createForLoopSpecializationPass() { std::unique_ptr mlir::createForLoopPeelingPass() { return std::make_unique(); } - -void mlir::scf::populateSCFLoopBodyCanonicalizationPatterns( - RewritePatternSet &patterns) { - MLIRContext *ctx = patterns.getContext(); - patterns - .insert, - AffineOpSCFCanonicalizationPattern>( - ctx); -} diff --git a/mlir/lib/Dialect/SCF/Transforms/PassDetail.h b/mlir/lib/Dialect/SCF/Transforms/PassDetail.h index 2e27ed2..8485511 100644 --- a/mlir/lib/Dialect/SCF/Transforms/PassDetail.h +++ b/mlir/lib/Dialect/SCF/Transforms/PassDetail.h @@ -22,6 +22,10 @@ namespace memref { class MemRefDialect; } // end namespace memref +namespace tensor { +class TensorDialect; +} // end namespace tensor + #define GEN_PASS_CLASSES #include "mlir/Dialect/SCF/Passes.h.inc" diff --git a/mlir/test/Dialect/SCF/canonicalize-affine-op.mlir b/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir similarity index 91% rename from mlir/test/Dialect/SCF/canonicalize-affine-op.mlir rename to mlir/test/Dialect/SCF/for-loop-canonicalization.mlir index 05b41e7..9f80b31 100644 --- a/mlir/test/Dialect/SCF/canonicalize-affine-op.mlir +++ b/mlir/test/Dialect/SCF/for-loop-canonicalization.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -canonicalize-scf-affine-op -split-input-file | FileCheck %s +// RUN: mlir-opt %s -for-loop-canonicalization -split-input-file | FileCheck %s // CHECK-LABEL: func @scf_for_canonicalize_min // CHECK: %[[C2:.*]] = constant 2 : i64 @@ -224,3 +224,21 @@ func @scf_parallel_canonicalize_min_2(%A : memref) { } return } + +// ----- + +// CHECK-LABEL: func @tensor_dim_of_iter_arg( +// CHECK-SAME: %[[t:.*]]: tensor +// CHECK: scf.for +// CHECK: tensor.dim %[[t]] +func @tensor_dim_of_iter_arg(%t : tensor) -> index { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c10 = constant 10 : index + %0, %1 = scf.for %i = %c0 to %c10 step %c1 iter_args(%arg0 = %t, %arg1 = %c0) + -> (tensor, index) { + %dim = tensor.dim %arg0, %c0 : tensor + scf.yield %arg0, %dim : tensor, index + } + return %1 : index +} diff --git a/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp b/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp index eafb575..aed6394 100644 --- a/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp @@ -71,7 +71,7 @@ void TestConvVectorization::runOnOperation() { RewritePatternSet stage2Patterns = linalg::getLinalgTilingCanonicalizationPatterns(context); - scf::populateSCFLoopBodyCanonicalizationPatterns(stage2Patterns); + scf::populateSCFForLoopCanonicalizationPatterns(stage2Patterns); auto stage3Transforms = [](Operation *op) { PassManager pm(op->getContext()); diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp index 2192986..7b1ef87 100644 --- a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp @@ -237,7 +237,7 @@ struct TestLinalgGreedyFusion RewritePatternSet patterns = linalg::getLinalgTilingCanonicalizationPatterns(context); patterns.add(context); - scf::populateSCFLoopBodyCanonicalizationPatterns(patterns); + scf::populateSCFForLoopCanonicalizationPatterns(patterns); FrozenRewritePatternSet frozenPatterns(std::move(patterns)); do { (void)applyPatternsAndFoldGreedily(getFunction(), frozenPatterns); diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index d32dee1..33fc904 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -1491,6 +1491,7 @@ cc_library( ":SCFPassIncGen", ":StandardOps", ":Support", + ":TensorDialect", ":Transforms", "//llvm:Support", ], -- 2.7.4