void populateControlFlowToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns);
+/// Populate the cf.assert to LLVM conversion pattern. If `abortOnFailure` is
+/// set to false, the program execution continues when a condition is
+/// unsatisfied.
+void populateAssertToLLVMConversionPattern(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns,
+ bool abortOnFailure = true);
+
/// Creates a pass to convert the ControlFlow dialect into the LLVMIR dialect.
std::unique_ptr<Pass> createConvertControlFlowToLLVMPass();
} // namespace cf
LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(ModuleOp moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintStrFn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(ModuleOp moduleOp);
#define PASS_NAME "convert-cf-to-llvm"
+static std::string generateGlobalMsgSymbolName(ModuleOp moduleOp) {
+ std::string prefix = "assert_msg_";
+ int counter = 0;
+ while (moduleOp.lookupSymbol(prefix + std::to_string(counter)))
+ ++counter;
+ return prefix + std::to_string(counter);
+}
+
+/// Generate IR that prints the given string to stderr.
+static void createPrintMsg(OpBuilder &builder, Location loc, ModuleOp moduleOp,
+ StringRef msg) {
+ auto ip = builder.saveInsertionPoint();
+ builder.setInsertionPointToStart(moduleOp.getBody());
+ MLIRContext *ctx = builder.getContext();
+
+ // Create a zero-terminated byte representation and allocate global symbol.
+ SmallVector<uint8_t> elementVals;
+ elementVals.append(msg.begin(), msg.end());
+ elementVals.push_back(0);
+ auto dataAttrType = RankedTensorType::get(
+ {static_cast<int64_t>(elementVals.size())}, builder.getI8Type());
+ auto dataAttr =
+ DenseElementsAttr::get(dataAttrType, llvm::makeArrayRef(elementVals));
+ auto arrayTy =
+ LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size());
+ std::string symbolName = generateGlobalMsgSymbolName(moduleOp);
+ auto globalOp = builder.create<LLVM::GlobalOp>(
+ loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private, symbolName,
+ dataAttr);
+
+ // Emit call to `printStr` in runtime library.
+ builder.restoreInsertionPoint(ip);
+ auto msgAddr = builder.create<LLVM::AddressOfOp>(
+ loc, LLVM::LLVMPointerType::get(arrayTy), globalOp.getName());
+ SmallVector<LLVM::GEPArg> indices(1, 0);
+ Value gep = builder.create<LLVM::GEPOp>(
+ loc, LLVM::LLVMPointerType::get(builder.getI8Type()), msgAddr, indices);
+ Operation *printer = LLVM::lookupOrCreatePrintStrFn(moduleOp);
+ builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer),
+ gep);
+}
+
namespace {
/// Lower `cf.assert`. The default lowering calls the `abort` function if the
/// assertion is violated and has no effect otherwise. The failure message is
/// ignored by the default lowering but should be propagated by any custom
/// lowering.
struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
- using ConvertOpToLLVMPattern<cf::AssertOp>::ConvertOpToLLVMPattern;
+ explicit AssertOpLowering(LLVMTypeConverter &typeConverter,
+ bool abortOnFailedAssert = true)
+ : ConvertOpToLLVMPattern<cf::AssertOp>(typeConverter, /*benefit=*/1),
+ abortOnFailedAssert(abortOnFailedAssert) {}
LogicalResult
matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
-
- // Insert the `abort` declaration if necessary.
auto module = op->getParentOfType<ModuleOp>();
- auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
- if (!abortFunc) {
- OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPointToStart(module.getBody());
- auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
- abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
- "abort", abortFuncTy);
- }
// Split block at `assert` operation.
Block *opBlock = rewriter.getInsertionBlock();
auto opPosition = rewriter.getInsertionPoint();
Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition);
- // Generate IR to call `abort`.
+ // Failed block: Generate IR to print the message and call `abort`.
Block *failureBlock = rewriter.createBlock(opBlock->getParent());
- rewriter.create<LLVM::CallOp>(loc, abortFunc, std::nullopt);
- rewriter.create<LLVM::UnreachableOp>(loc);
+ createPrintMsg(rewriter, loc, module, op.getMsg());
+ if (abortOnFailedAssert) {
+ // Insert the `abort` declaration if necessary.
+ auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
+ if (!abortFunc) {
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointToStart(module.getBody());
+ auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {});
+ abortFunc = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(),
+ "abort", abortFuncTy);
+ }
+ rewriter.create<LLVM::CallOp>(loc, abortFunc, std::nullopt);
+ rewriter.create<LLVM::UnreachableOp>(loc);
+ } else {
+ rewriter.create<LLVM::BrOp>(loc, ValueRange(), continuationBlock);
+ }
// Generate assertion test.
rewriter.setInsertionPointToEnd(opBlock);
return success();
}
+
+private:
+ /// If set to `false`, messages are printed but program execution continues.
+ /// This is useful for testing asserts.
+ bool abortOnFailedAssert = true;
};
/// The cf->LLVM lowerings for branching ops require that the blocks they jump
// clang-format on
}
+void mlir::cf::populateAssertToLLVMConversionPattern(
+ LLVMTypeConverter &converter, RewritePatternSet &patterns,
+ bool abortOnFailure) {
+ patterns.add<AssertOpLowering>(converter, abortOnFailure);
+}
+
//===----------------------------------------------------------------------===//
// Pass Definition
//===----------------------------------------------------------------------===//
static constexpr llvm::StringRef kPrintU64 = "printU64";
static constexpr llvm::StringRef kPrintF32 = "printF32";
static constexpr llvm::StringRef kPrintF64 = "printF64";
+static constexpr llvm::StringRef kPrintStr = "puts";
static constexpr llvm::StringRef kPrintOpen = "printOpen";
static constexpr llvm::StringRef kPrintClose = "printClose";
static constexpr llvm::StringRef kPrintComma = "printComma";
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStrFn(ModuleOp moduleOp) {
+ return lookupOrCreateFn(
+ moduleOp, kPrintStr,
+ LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8)),
+ LLVM::LLVMVoidType::get(moduleOp->getContext()));
+}
+
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(ModuleOp moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintOpen, {},
LLVM::LLVMVoidType::get(moduleOp->getContext()));
--- /dev/null
+// RUN: mlir-opt %s -test-cf-assert \
+// RUN: -convert-func-to-llvm | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void | \
+// RUN: FileCheck %s
+
+func.func @main() {
+ %a = arith.constant 0 : i1
+ %b = arith.constant 1 : i1
+ // CHECK: assertion foo
+ cf.assert %a, "assertion foo"
+ // CHECK-NOT: assertion bar
+ cf.assert %b, "assertion bar"
+ return
+}
add_subdirectory(Affine)
add_subdirectory(Arith)
add_subdirectory(Bufferization)
+add_subdirectory(ControlFlow)
add_subdirectory(DLTI)
add_subdirectory(Func)
add_subdirectory(GPU)
--- /dev/null
+# Exclude tests from libMLIR.so
+add_mlir_library(MLIRControlFlowTestPasses
+ TestAssert.cpp
+
+ EXCLUDE_FROM_LIBMLIR
+
+ LINK_LIBS PUBLIC
+ MLIRControlFlowToLLVM
+ MLIRFuncDialect
+ MLIRLLVMCommonConversion
+ MLIRLLVMDialect
+ MLIRPass
+ MLIRTransforms
+)
--- /dev/null
+//===- TestAssert.cpp - Test cf.assert Lowering ----------------*- c++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass for integration testing of wide integer
+// emulation patterns. Applies conversion patterns only to functions whose
+// names start with a specified prefix.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
+#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+struct TestAssertPass
+ : public PassWrapper<TestAssertPass, OperationPass<ModuleOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAssertPass)
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<cf::ControlFlowDialect, LLVM::LLVMDialect>();
+ }
+ StringRef getArgument() const final { return "test-cf-assert"; }
+ StringRef getDescription() const final {
+ return "Function pass to test cf.assert lowering to LLVM without abort";
+ }
+
+ void runOnOperation() override {
+ LLVMConversionTarget target(getContext());
+ RewritePatternSet patterns(&getContext());
+
+ LLVMTypeConverter converter(&getContext());
+ mlir::cf::populateAssertToLLVMConversionPattern(converter, patterns,
+ /*abortOnFailure=*/false);
+
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
+ signalPassFailure();
+ }
+};
+} // namespace
+
+namespace mlir::test {
+void registerTestCfAssertPass() { PassRegistration<TestAssertPass>(); }
+} // namespace mlir::test
MLIRAffineTransformsTestPasses
MLIRArithTestPasses
MLIRBufferizationTestPasses
+ MLIRControlFlowTestPasses
MLIRDLTITestPasses
MLIRFuncTestPasses
MLIRGPUTestPasses
void registerTestAliasAnalysisPass();
void registerTestBuiltinAttributeInterfaces();
void registerTestCallGraphPass();
+void registerTestCfAssertPass();
void registerTestConstantFold();
void registerTestControlFlowSink();
void registerTestGpuSerializeToCubinPass();
mlir::test::registerTestArithEmulateWideIntPass();
mlir::test::registerTestBuiltinAttributeInterfaces();
mlir::test::registerTestCallGraphPass();
+ mlir::test::registerTestCfAssertPass();
mlir::test::registerTestConstantFold();
mlir::test::registerTestControlFlowSink();
mlir::test::registerTestDiagnosticsPass();
"//mlir/test:TestAnalysis",
"//mlir/test:TestArith",
"//mlir/test:TestBufferization",
+ "//mlir/test:TestControlFlow",
"//mlir/test:TestDLTI",
"//mlir/test:TestDialect",
"//mlir/test:TestFunc",
],
)
+cc_library(
+ name = "TestControlFlow",
+ srcs = glob(["lib/Dialect/ControlFlow/*.cpp"]),
+ includes = ["lib/Dialect/Test"],
+ deps = [
+ "//mlir:ControlFlowDialect",
+ "//mlir:ControlFlowToLLVM",
+ "//mlir:FuncDialect",
+ "//mlir:LLVMCommonConversion",
+ "//mlir:LLVMDialect",
+ "//mlir:Pass",
+ "//mlir:Transforms",
+ ],
+)
+
cc_library(
name = "TestShapeDialect",
srcs = [