} // end anonymous namespace
LogicalResult linalg::convertToLLVM(mlir::ModuleOp module) {
- for (auto func : module.getOps<FuncOp>()) {
- if (failed(mlir::lowerAffineConstructs(func)))
- return failure();
- if (failed(mlir::lowerControlFlow(func)))
- return failure();
- }
-
// Convert Linalg ops to the LLVM IR dialect using the converter defined
// above.
LinalgTypeConverter converter(module.getContext());
OwningRewritePatternList patterns;
+ populateAffineToStdConversionPatterns(patterns, module.getContext());
+ populateLoopToStdConversionPatterns(patterns, module.getContext());
populateStdToLLVMConversionPatterns(converter, patterns);
populateLinalg1ToLLVMConversionPatterns(patterns, module.getContext());
ConversionTarget target(*module.getContext());
target.addLegalDialect<LLVM::LLVMDialect>();
- if (failed(applyConversionPatterns(module, target, converter,
- std::move(patterns))))
- return failure();
-
- return success();
+ return applyConversionPatterns(module, target, converter,
+ std::move(patterns));
}
namespace {
}
LogicalResult linalg::convertLinalg3ToLLVM(ModuleOp module) {
- // Remove affine constructs.
- for (auto func : module.getOps<FuncOp>()) {
- if (failed(lowerAffineConstructs(func)))
- return failure();
- if (failed(mlir::lowerControlFlow(func)))
- return failure();
- }
-
// Convert Linalg ops to the LLVM IR dialect using the converter defined
// above.
LinalgTypeConverter converter(module.getContext());
OwningRewritePatternList patterns;
+ populateAffineToStdConversionPatterns(patterns, module.getContext());
+ populateLoopToStdConversionPatterns(patterns, module.getContext());
populateStdToLLVMConversionPatterns(converter, patterns);
populateLinalg1ToLLVMConversionPatterns(patterns, module.getContext());
populateLinalg3ToLLVMConversionPatterns(patterns, module.getContext());
#ifndef MLIR_CONVERSION_CONTROLFLOWTOCFG_CONVERTCONTROLFLOWTOCFG_H_
#define MLIR_CONVERSION_CONTROLFLOWTOCFG_CONVERTCONTROLFLOWTOCFG_H_
+#include <memory>
+#include <vector>
+
namespace mlir {
class FuncOp;
+class FunctionPassBase;
struct LogicalResult;
-class ModulePassBase;
+class MLIRContext;
+class RewritePattern;
+
+// Owning list of rewriting patterns.
+using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
-/// Lowers loop.for, loop.if and loop.terminator ops to CFG.
-LogicalResult lowerControlFlow(FuncOp func);
+/// Collect a set of patterns to lower from loop.for, loop.if, and
+/// loop.terminator to CFG operations within the Standard dialect, in particular
+/// convert structured control flow into CFG branch-based control flow.
+void populateLoopToStdConversionPatterns(OwningRewritePatternList &patterns,
+ MLIRContext *ctx);
/// Creates a pass to convert loop.for, loop.if and loop.terminator ops to CFG.
-ModulePassBase *createConvertToCFGPass();
+FunctionPassBase *createConvertToCFGPass();
} // namespace mlir
#define MLIR_TRANSFORMS_LOWERAFFINE_H
#include "mlir/Support/LLVM.h"
+#include <vector>
namespace mlir {
class AffineExpr;
class AffineForOp;
-class FuncOp;
class Location;
struct LogicalResult;
+class MLIRContext;
class OpBuilder;
+class RewritePattern;
class Value;
+// Owning list of rewriting patterns.
+using OwningRewritePatternList = std::vector<std::unique_ptr<RewritePattern>>;
+
/// Emit code that computes the given affine expression using standard
/// arithmetic operations applied to the provided dimension and symbol values.
Value *expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr,
ArrayRef<Value *> dimValues,
ArrayRef<Value *> symbolValues);
-/// Convert from the Affine dialect to the Standard dialect, in particular
-/// convert structured affine control flow into CFG branch-based control flow.
-LogicalResult lowerAffineConstructs(FuncOp function);
+/// Collect a set of patterns to convert from the Affine dialect to the Standard
+/// dialect, in particular convert structured affine control flow into CFG
+/// branch-based control flow.
+void populateAffineToStdConversionPatterns(OwningRewritePatternList &patterns,
+ MLIRContext *ctx);
/// Emit code that computes the lower bound of the given affine loop using
/// standard arithmetic operations.
namespace {
-struct ControlFlowToCFGPass : public ModulePass<ControlFlowToCFGPass> {
- void runOnModule() override;
+struct ControlFlowToCFGPass : public FunctionPass<ControlFlowToCFGPass> {
+ void runOnFunction() override;
};
// Create a CFG subgraph for the loop around its body blocks (if the body
return matchSuccess();
}
-LogicalResult mlir::lowerControlFlow(FuncOp func) {
- OwningRewritePatternList patterns;
+void mlir::populateLoopToStdConversionPatterns(
+ OwningRewritePatternList &patterns, MLIRContext *ctx) {
RewriteListBuilder<ForLowering, IfLowering, TerminatorLowering>::build(
- patterns, func.getContext());
- ConversionTarget target(*func.getContext());
- target.addLegalDialect<StandardOpsDialect>();
- return applyConversionPatterns(func, target, std::move(patterns));
+ patterns, ctx);
}
-void ControlFlowToCFGPass::runOnModule() {
- for (auto func : getModule().getOps<FuncOp>())
- if (failed(mlir::lowerControlFlow(func)))
- return signalPassFailure();
+void ControlFlowToCFGPass::runOnFunction() {
+ OwningRewritePatternList patterns;
+ populateLoopToStdConversionPatterns(patterns, &getContext());
+ ConversionTarget target(getContext());
+ target.addLegalDialect<StandardOpsDialect>();
+ if (failed(
+ applyConversionPatterns(getFunction(), target, std::move(patterns))))
+ signalPassFailure();
}
-ModulePassBase *mlir::createConvertToCFGPass() {
+FunctionPassBase *mlir::createConvertToCFGPass() {
return new ControlFlowToCFGPass();
}
return signalPassFailure();
ModuleOp m = getModule();
- for (auto func : m.getOps<FuncOp>())
- if (failed(mlir::lowerControlFlow(func)))
- signalPassFailure();
-
LLVM::ensureDistinctSuccessors(m);
std::unique_ptr<LLVMTypeConverter> typeConverter =
typeConverterMaker(&getContext());
return signalPassFailure();
OwningRewritePatternList patterns;
+ populateLoopToStdConversionPatterns(patterns, m.getContext());
patternListFiller(*typeConverter, patterns);
ConversionTarget target(getContext());
for (auto f : module.getOps<FuncOp>()) {
lowerLinalgSubViewOps(f);
lowerLinalgForToCFG(f);
- if (failed(lowerAffineConstructs(f)))
- signalPassFailure();
}
// Convert to the LLVM IR dialect using the converter defined above.
OwningRewritePatternList patterns;
LinalgTypeConverter converter(&getContext());
+ populateAffineToStdConversionPatterns(patterns, &getContext());
populateStdToLLVMConversionPatterns(converter, patterns);
populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext());
} // end namespace
-LogicalResult mlir::lowerAffineConstructs(FuncOp function) {
- OwningRewritePatternList patterns;
+void mlir::populateAffineToStdConversionPatterns(
+ OwningRewritePatternList &patterns, MLIRContext *ctx) {
RewriteListBuilder<AffineApplyLowering, AffineDmaStartLowering,
AffineDmaWaitLowering, AffineLoadLowering,
AffineStoreLowering, AffineForLowering, AffineIfLowering,
- AffineTerminatorLowering>::build(patterns,
- function.getContext());
- ConversionTarget target(*function.getContext());
- target.addLegalDialect<loop::LoopOpsDialect, StandardOpsDialect>();
- return applyConversionPatterns(function, target, std::move(patterns));
+ AffineTerminatorLowering>::build(patterns, ctx);
}
namespace {
class LowerAffinePass : public FunctionPass<LowerAffinePass> {
- void runOnFunction() override { lowerAffineConstructs(getFunction()); }
+ void runOnFunction() override {
+ OwningRewritePatternList patterns;
+ populateAffineToStdConversionPatterns(patterns, &getContext());
+ ConversionTarget target(getContext());
+ target.addLegalDialect<loop::LoopOpsDialect, StandardOpsDialect>();
+ if (failed(applyConversionPatterns(getFunction(), target,
+ std::move(patterns))))
+ signalPassFailure();
+ }
};
} // namespace