Remove lowerAffineConstructs and lowerControlFlow in favor of providing patterns.
authorRiver Riddle <riverriddle@google.com>
Mon, 15 Jul 2019 19:52:44 +0000 (12:52 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Tue, 16 Jul 2019 20:44:45 +0000 (13:44 -0700)
These methods don't compose well with the rest of conversion framework, and create artificial breaks in conversion. Replace these methods with two(populateAffineToStdConversionPatterns and populateLoopToStdConversionPatterns respectively) that populate a list of patterns to perform the same behavior.

PiperOrigin-RevId: 258219277

mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp
mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp
mlir/include/mlir/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.h
mlir/include/mlir/Transforms/LowerAffine.h
mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp
mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
mlir/lib/Transforms/LowerAffine.cpp

index 9adbcc5..392d9be 100644 (file)
@@ -410,27 +410,19 @@ struct LinalgTypeConverter : public LLVMTypeConverter {
 } // end anonymous namespace
 
 LogicalResult linalg::convertToLLVM(mlir::ModuleOp module) {
-  for (auto func : module.getOps<FuncOp>()) {
-    if (failed(mlir::lowerAffineConstructs(func)))
-      return failure();
-    if (failed(mlir::lowerControlFlow(func)))
-      return failure();
-  }
-
   // Convert Linalg ops to the LLVM IR dialect using the converter defined
   // above.
   LinalgTypeConverter converter(module.getContext());
   OwningRewritePatternList patterns;
+  populateAffineToStdConversionPatterns(patterns, module.getContext());
+  populateLoopToStdConversionPatterns(patterns, module.getContext());
   populateStdToLLVMConversionPatterns(converter, patterns);
   populateLinalg1ToLLVMConversionPatterns(patterns, module.getContext());
 
   ConversionTarget target(*module.getContext());
   target.addLegalDialect<LLVM::LLVMDialect>();
-  if (failed(applyConversionPatterns(module, target, converter,
-                                     std::move(patterns))))
-    return failure();
-
-  return success();
+  return applyConversionPatterns(module, target, converter,
+                                 std::move(patterns));
 }
 
 namespace {
index adf26c6..54d1f55 100644 (file)
@@ -148,18 +148,12 @@ static void populateLinalg3ToLLVMConversionPatterns(
 }
 
 LogicalResult linalg::convertLinalg3ToLLVM(ModuleOp module) {
-  // Remove affine constructs.
-  for (auto func : module.getOps<FuncOp>()) {
-    if (failed(lowerAffineConstructs(func)))
-      return failure();
-    if (failed(mlir::lowerControlFlow(func)))
-      return failure();
-  }
-
   // Convert Linalg ops to the LLVM IR dialect using the converter defined
   // above.
   LinalgTypeConverter converter(module.getContext());
   OwningRewritePatternList patterns;
+  populateAffineToStdConversionPatterns(patterns, module.getContext());
+  populateLoopToStdConversionPatterns(patterns, module.getContext());
   populateStdToLLVMConversionPatterns(converter, patterns);
   populateLinalg1ToLLVMConversionPatterns(patterns, module.getContext());
   populateLinalg3ToLLVMConversionPatterns(patterns, module.getContext());
index dae459b..e8ab273 100644 (file)
 #ifndef MLIR_CONVERSION_CONTROLFLOWTOCFG_CONVERTCONTROLFLOWTOCFG_H_
 #define MLIR_CONVERSION_CONTROLFLOWTOCFG_CONVERTCONTROLFLOWTOCFG_H_
 
+#include <memory>
+#include <vector>
+
 namespace mlir {
 class FuncOp;
+class FunctionPassBase;
 struct LogicalResult;
-class ModulePassBase;
+class MLIRContext;
+class RewritePattern;
+
+// Owning list of rewriting patterns.
+using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
 
-/// Lowers loop.for, loop.if and loop.terminator ops to CFG.
-LogicalResult lowerControlFlow(FuncOp func);
+/// Collect a set of patterns to lower from loop.for, loop.if, and
+/// loop.terminator to CFG operations within the Standard dialect, in particular
+/// convert structured control flow into CFG branch-based control flow.
+void populateLoopToStdConversionPatterns(OwningRewritePatternList &patterns,
+                                         MLIRContext *ctx);
 
 /// Creates a pass to convert loop.for, loop.if and loop.terminator ops to CFG.
-ModulePassBase *createConvertToCFGPass();
+FunctionPassBase *createConvertToCFGPass();
 
 } // namespace mlir
 
index 1711bdd..9ad3f66 100644 (file)
 #define MLIR_TRANSFORMS_LOWERAFFINE_H
 
 #include "mlir/Support/LLVM.h"
+#include <vector>
 
 namespace mlir {
 class AffineExpr;
 class AffineForOp;
-class FuncOp;
 class Location;
 struct LogicalResult;
+class MLIRContext;
 class OpBuilder;
+class RewritePattern;
 class Value;
 
+// Owning list of rewriting patterns.
+using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
+
 /// Emit code that computes the given affine expression using standard
 /// arithmetic operations applied to the provided dimension and symbol values.
 Value *expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr,
                         ArrayRef<Value *> dimValues,
                         ArrayRef<Value *> symbolValues);
 
-/// Convert from the Affine dialect to the Standard dialect, in particular
-/// convert structured affine control flow into CFG branch-based control flow.
-LogicalResult lowerAffineConstructs(FuncOp function);
+/// Collect a set of patterns to convert from the Affine dialect to the Standard
+/// dialect, in particular convert structured affine control flow into CFG
+/// branch-based control flow.
+void populateAffineToStdConversionPatterns(OwningRewritePatternList &patterns,
+                                           MLIRContext *ctx);
 
 /// Emit code that computes the lower bound of the given affine loop using
 /// standard arithmetic operations.
index 7fd70c2..f5d0cce 100644 (file)
@@ -43,8 +43,8 @@ using namespace mlir::loop;
 
 namespace {
 
-struct ControlFlowToCFGPass : public ModulePass<ControlFlowToCFGPass> {
-  void runOnModule() override;
+struct ControlFlowToCFGPass : public FunctionPass<ControlFlowToCFGPass> {
+  void runOnFunction() override;
 };
 
 // Create a CFG subgraph for the loop around its body blocks (if the body
@@ -270,22 +270,23 @@ IfLowering::matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
   return matchSuccess();
 }
 
-LogicalResult mlir::lowerControlFlow(FuncOp func) {
-  OwningRewritePatternList patterns;
+void mlir::populateLoopToStdConversionPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *ctx) {
   RewriteListBuilder<ForLowering, IfLowering, TerminatorLowering>::build(
-      patterns, func.getContext());
-  ConversionTarget target(*func.getContext());
-  target.addLegalDialect<StandardOpsDialect>();
-  return applyConversionPatterns(func, target, std::move(patterns));
+      patterns, ctx);
 }
 
-void ControlFlowToCFGPass::runOnModule() {
-  for (auto func : getModule().getOps<FuncOp>())
-    if (failed(mlir::lowerControlFlow(func)))
-      return signalPassFailure();
+void ControlFlowToCFGPass::runOnFunction() {
+  OwningRewritePatternList patterns;
+  populateLoopToStdConversionPatterns(patterns, &getContext());
+  ConversionTarget target(getContext());
+  target.addLegalDialect<StandardOpsDialect>();
+  if (failed(
+          applyConversionPatterns(getFunction(), target, std::move(patterns))))
+    signalPassFailure();
 }
 
-ModulePassBase *mlir::createConvertToCFGPass() {
+FunctionPassBase *mlir::createConvertToCFGPass() {
   return new ControlFlowToCFGPass();
 }
 
index 48efa00..2e8e313 100644 (file)
@@ -1052,10 +1052,6 @@ struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
       return signalPassFailure();
 
     ModuleOp m = getModule();
-    for (auto func : m.getOps<FuncOp>())
-      if (failed(mlir::lowerControlFlow(func)))
-        signalPassFailure();
-
     LLVM::ensureDistinctSuccessors(m);
     std::unique_ptr<LLVMTypeConverter> typeConverter =
         typeConverterMaker(&getContext());
@@ -1063,6 +1059,7 @@ struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
       return signalPassFailure();
 
     OwningRewritePatternList patterns;
+    populateLoopToStdConversionPatterns(patterns, m.getContext());
     patternListFiller(*typeConverter, patterns);
 
     ConversionTarget target(getContext());
index 0ff7a6c..b2e964e 100644 (file)
@@ -806,13 +806,12 @@ void LowerLinalgToLLVMPass::runOnModule() {
   for (auto f : module.getOps<FuncOp>()) {
     lowerLinalgSubViewOps(f);
     lowerLinalgForToCFG(f);
-    if (failed(lowerAffineConstructs(f)))
-      signalPassFailure();
   }
 
   // Convert to the LLVM IR dialect using the converter defined above.
   OwningRewritePatternList patterns;
   LinalgTypeConverter converter(&getContext());
+  populateAffineToStdConversionPatterns(patterns, &getContext());
   populateStdToLLVMConversionPatterns(converter, patterns);
   populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext());
 
index 0ff80ca..82b7074 100644 (file)
@@ -506,21 +506,25 @@ public:
 
 } // end namespace
 
-LogicalResult mlir::lowerAffineConstructs(FuncOp function) {
-  OwningRewritePatternList patterns;
+void mlir::populateAffineToStdConversionPatterns(
+    OwningRewritePatternList &patterns, MLIRContext *ctx) {
   RewriteListBuilder<AffineApplyLowering, AffineDmaStartLowering,
                      AffineDmaWaitLowering, AffineLoadLowering,
                      AffineStoreLowering, AffineForLowering, AffineIfLowering,
-                     AffineTerminatorLowering>::build(patterns,
-                                                      function.getContext());
-  ConversionTarget target(*function.getContext());
-  target.addLegalDialect<loop::LoopOpsDialect, StandardOpsDialect>();
-  return applyConversionPatterns(function, target, std::move(patterns));
+                     AffineTerminatorLowering>::build(patterns, ctx);
 }
 
 namespace {
 class LowerAffinePass : public FunctionPass<LowerAffinePass> {
-  void runOnFunction() override { lowerAffineConstructs(getFunction()); }
+  void runOnFunction() override {
+    OwningRewritePatternList patterns;
+    populateAffineToStdConversionPatterns(patterns, &getContext());
+    ConversionTarget target(getContext());
+    target.addLegalDialect<loop::LoopOpsDialect, StandardOpsDialect>();
+    if (failed(applyConversionPatterns(getFunction(), target,
+                                       std::move(patterns))))
+      signalPassFailure();
+  }
 };
 } // namespace