Static op verification cannot detect cases where an op is valid at compile time but may be invalid at runtime.
An example of such an op is `memref::ExpandShapeOp`.
Invalid at compile time: `memref.expand_shape %m [[0, 1]] : memref<11xf32> into memref<2x5xf32>`
Valid at compile time (because we do not know any better): `memref.expand_shape %m [[0, 1]] : memref<?xf32> into memref<?x5xf32>`. This op may or may not be valid at runtime depending on the runtime shape of `%m`.
Invalid runtime ops such as the one above are hard to debug because they can crash the program execution at a seemingly unrelated position or (even worse) compute an invalid result without crashing.
This revision adds a new op interface `RuntimeVerifiableOpInterface` that can be implemented by ops that provide additional runtime verification. Such runtime verification can be computationally expensive, so it is only generated on an opt-in basis by running `-generate-runtime-verification`. A simple runtime verifier for `memref::ExpandShapeOp` is provided as an example.
Differential Revision: https://reviews.llvm.org/D138576
--- /dev/null
+//===- RuntimeOpVerification.h - Op Verification ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MEMREF_RUNTIMEOPVERIFICATION_H
+#define MLIR_DIALECT_MEMREF_RUNTIMEOPVERIFICATION_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace memref {
+void registerRuntimeVerifiableOpInterfaceExternalModels(
+ DialectRegistry ®istry);
+} // namespace memref
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MEMREF_RUNTIMEOPVERIFICATION_H
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
+#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
registry);
linalg::registerBufferizableOpInterfaceExternalModels(registry);
linalg::registerTilingInterfaceExternalModels(registry);
+ memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
scf::registerBufferizableOpInterfaceExternalModels(registry);
shape::registerBufferizableOpInterfaceExternalModels(registry);
sparse_tensor::registerBufferizableOpInterfaceExternalModels(registry);
add_mlir_interface(InferTypeOpInterface)
add_mlir_interface(LoopLikeInterface)
add_mlir_interface(ParallelCombiningOpInterface)
+add_mlir_interface(RuntimeVerifiableOpInterface)
add_mlir_interface(ShapedOpInterfaces)
add_mlir_interface(SideEffectInterfaces)
add_mlir_interface(TilingInterface)
--- /dev/null
+//===- RuntimeVerifiableOpInterface.h - Op Verification ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE_H_
+#define MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE_H_
+
+#include "mlir/IR/OpDefinition.h"
+
+/// Include the generated interface declarations.
+#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h.inc"
+
+#endif // MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE_H_
--- /dev/null
+//===- RuntimeVerifiableOpInterface.td - Op Verification ---*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE
+#define MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def RuntimeVerifiableOpInterface : OpInterface<"RuntimeVerifiableOpInterface"> {
+ let description = [{
+ Implementations of this interface generate IR for runtime op verification.
+
+ Incorrect op usage can often be caught by op verifiers based on static
+ program information. However, in the absence of static program information,
+ it can remain undetected at compile time (e.g., in case of dynamic memref
+ strides instead of static memref strides). Such cases can be checked at
+ runtime. The op-specific checks are generated by this interface.
+ }];
+ let cppNamespace = "::mlir";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Generate IR to verify this op at runtime, aborting runtime execution if
+ verification fails.
+ }],
+ /*retTy=*/"void",
+ /*methodName=*/"generateRuntimeVerification",
+ /*args=*/(ins "::mlir::OpBuilder &":$builder,
+ "::mlir::Location":$loc)
+ >,
+ ];
+}
+
+#endif // MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE
/// Creates a pass to perform common sub expression elimination.
std::unique_ptr<Pass> createCSEPass();
+/// Creates a pass that generates IR to verify ops at runtime.
+std::unique_ptr<Pass> createGenerateRuntimeVerificationPass();
+
/// Creates a loop invariant code motion pass that hoists loop invariant
/// instructions out of the loop.
std::unique_ptr<Pass> createLoopInvariantCodeMotionPass();
];
}
+def GenerateRuntimeVerification : Pass<"generate-runtime-verification"> {
+ let summary = "Generate additional runtime op verification checks";
+ let description = [{
+ This pass generates op-specific runtime checks using the
+ `RuntimeVerifiableOpInterface`. It can be run for debugging purposes after
+ passes that are suspected to introduce faulty IR.
+ }];
+ let constructor = "mlir::createGenerateRuntimeVerificationPass()";
+}
+
def Inliner : Pass<"inline"> {
let summary = "Inline function calls";
let constructor = "mlir::createInlinerPass()";
MultiBuffer.cpp
NormalizeMemRefs.cpp
ResolveShapedTypeResultDims.cpp
+ RuntimeOpVerification.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/MemRef
--- /dev/null
+//===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
+
+namespace mlir {
+namespace memref {
+namespace {
+struct ExpandShapeOpInterface
+ : public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
+ ExpandShapeOp> {
+ void generateRuntimeVerification(Operation *op, OpBuilder &builder,
+ Location loc) const {
+ auto expandShapeOp = cast<ExpandShapeOp>(op);
+
+ // Verify that the expanded dim sizes are a product of the collapsed dim
+ // size.
+ for (auto it : llvm::enumerate(expandShapeOp.getReassociationIndices())) {
+ Value srcDimSz =
+ builder.create<DimOp>(loc, expandShapeOp.getSrc(), it.index());
+ int64_t groupSz = 1;
+ bool foundDynamicDim = false;
+ for (int64_t resultDim : it.value()) {
+ if (expandShapeOp.getResultType().isDynamicDim(resultDim)) {
+ // Keep this assert here in case the op is extended in the future.
+ assert(!foundDynamicDim &&
+ "more than one dynamic dim found in reassoc group");
+ foundDynamicDim = true;
+ continue;
+ }
+ groupSz *= expandShapeOp.getResultType().getDimSize(resultDim);
+ }
+ Value staticResultDimSz =
+ builder.create<arith::ConstantIndexOp>(loc, groupSz);
+ // staticResultDimSz must divide srcDimSz evenly.
+ Value mod =
+ builder.create<arith::RemSIOp>(loc, srcDimSz, staticResultDimSz);
+ Value isModZero = builder.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::eq, mod,
+ builder.create<arith::ConstantIndexOp>(loc, 0));
+ builder.create<cf::AssertOp>(
+ loc, isModZero,
+ "static result dims in reassoc group do not divide src dim evenly");
+ }
+ }
+};
+} // namespace
+} // namespace memref
+} // namespace mlir
+
+void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
+ DialectRegistry ®istry) {
+ registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
+ ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
+
+ // Load additional dialects of which ops may get created.
+ ctx->loadDialect<arith::ArithDialect, cf::ControlFlowDialect>();
+ });
+}
InferTypeOpInterface.cpp
LoopLikeInterface.cpp
ParallelCombiningOpInterface.cpp
+ RuntimeVerifiableOpInterface.cpp
ShapedOpInterfaces.cpp
SideEffectInterfaces.cpp
TilingInterface.cpp
add_mlir_interface_library(InferTypeOpInterface)
add_mlir_interface_library(LoopLikeInterface)
add_mlir_interface_library(ParallelCombiningOpInterface)
+add_mlir_interface_library(RuntimeVerifiableOpInterface)
add_mlir_interface_library(ShapedOpInterfaces)
add_mlir_interface_library(SideEffectInterfaces)
add_mlir_interface_library(TilingInterface)
--- /dev/null
+//===- RuntimeVerifiableOpInterface.cpp - Op Verification -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
+
+namespace mlir {
+class Location;
+class OpBuilder;
+} // namespace mlir
+
+/// Include the definitions of the interface.
+#include "mlir/Interfaces/RuntimeVerifiableOpInterface.cpp.inc"
Canonicalizer.cpp
ControlFlowSink.cpp
CSE.cpp
+ GenerateRuntimeVerification.cpp
Inliner.cpp
LocationSnapshot.cpp
LoopInvariantCodeMotion.cpp
MLIRCopyOpInterface
MLIRLoopLikeInterface
MLIRPass
+ MLIRRuntimeVerifiableOpInterface
MLIRSideEffectInterfaces
MLIRSupport
MLIRTransformUtils
--- /dev/null
+//===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/Passes.h"
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_GENERATERUNTIMEVERIFICATION
+#include "mlir/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+struct GenerateRuntimeVerificationPass
+ : public impl::GenerateRuntimeVerificationBase<
+ GenerateRuntimeVerificationPass> {
+ void runOnOperation() override;
+};
+} // namespace
+
+void GenerateRuntimeVerificationPass::runOnOperation() {
+ getOperation()->walk([&](RuntimeVerifiableOpInterface verifiableOp) {
+ OpBuilder builder(getOperation()->getContext());
+ builder.setInsertionPoint(verifiableOp);
+ verifiableOp.generateRuntimeVerification(builder, verifiableOp.getLoc());
+ });
+}
+
+std::unique_ptr<Pass> mlir::createGenerateRuntimeVerificationPass() {
+ return std::make_unique<GenerateRuntimeVerificationPass>();
+}
--- /dev/null
+// RUN: mlir-opt %s -generate-runtime-verification -cse | FileCheck %s
+
+// CHECK-LABEL: func @expand_shape(
+// CHECK-SAME: %[[m:.*]]: memref<?xf32>
+// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[c5:.*]] = arith.constant 5 : index
+// CHECK-DAG: %[[dim:.*]] = memref.dim %[[m]], %[[c0]]
+// CHECK: %[[mod:.*]] = arith.remsi %[[dim]], %[[c5]]
+// CHECK: %[[cmpi:.*]] = arith.cmpi eq, %[[mod]], %[[c0]]
+// CHECK: cf.assert %[[cmpi]], "static result dims in reassoc group do not divide src dim evenly"
+func.func @expand_shape(%m: memref<?xf32>) -> memref<?x5xf32> {
+ %0 = memref.expand_shape %m [[0, 1]] : memref<?xf32> into memref<?x5xf32>
+ return %0 : memref<?x5xf32>
+}
)
td_library(
+ name = "RuntimeVerifiableOpInterfaceTdFiles",
+ srcs = ["include/mlir/Interfaces/RuntimeVerifiableOpInterface.td"],
+ includes = ["include"],
+ deps = [":OpBaseTdFiles"],
+)
+
+td_library(
name = "SideEffectInterfacesTdFiles",
srcs = [
"include/mlir/Interfaces/SideEffectInterfaceBase.td",
)
cc_library(
+ name = "RuntimeVerifiableOpInterface",
+ srcs = ["lib/Interfaces/RuntimeVerifiableOpInterface.cpp"],
+ hdrs = ["include/mlir/Interfaces/RuntimeVerifiableOpInterface.h"],
+ includes = ["include"],
+ deps = [
+ ":IR",
+ ":RuntimeVerifiableOpInterfaceIncGen",
+ "//llvm:Support",
+ ],
+)
+
+cc_library(
name = "VectorInterfaces",
srcs = ["lib/Interfaces/VectorInterfaces.cpp"],
hdrs = ["include/mlir/Interfaces/VectorInterfaces.h"],
)
gentbl_cc_library(
+ name = "RuntimeVerifiableOpInterfaceIncGen",
+ strip_include_prefix = "include",
+ tbl_outs = [
+ (
+ ["-gen-op-interface-decls"],
+ "include/mlir/Interfaces/RuntimeVerifiableOpInterface.h.inc",
+ ),
+ (
+ ["-gen-op-interface-defs"],
+ "include/mlir/Interfaces/RuntimeVerifiableOpInterface.cpp.inc",
+ ),
+ ],
+ tblgen = ":mlir-tblgen",
+ td_file = "include/mlir/Interfaces/RuntimeVerifiableOpInterface.td",
+ deps = [":RuntimeVerifiableOpInterfaceTdFiles"],
+)
+
+gentbl_cc_library(
name = "VectorInterfacesIncGen",
strip_include_prefix = "include",
tbl_outs = [
":LoopLikeInterface",
":Pass",
":Rewrite",
+ ":RuntimeVerifiableOpInterface",
":SideEffectInterfaces",
":Support",
":TransformUtils",
":ArithDialect",
":ArithTransforms",
":ArithUtils",
+ ":ControlFlowDialect",
":DialectUtils",
":FuncDialect",
":IR",
":MemRefDialect",
":MemRefPassIncGen",
":Pass",
+ ":RuntimeVerifiableOpInterface",
":TensorDialect",
":Transforms",
":VectorDialect",