From 108b08f2a91272f82d524616a337a8ce52edeed5 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Wed, 21 Dec 2022 10:51:10 +0100 Subject: [PATCH] [mlir] Add RuntimeVerifiableOpInterface and transform Static op verification cannot detect cases where an op is valid at compile time but may be invalid at runtime. An example of such an op is `memref::ExpandShapeOp`. Invalid at compile time: `memref.expand_shape %m [[0, 1]] : memref<11xf32> into memref<2x5xf32>` Valid at compile time (because we do not know any better): `memref.expand_shape %m [[0, 1]] : memref into memref`. This op may or may not be valid at runtime depending on the runtime shape of `%m`. Invalid runtime ops such as the one above are hard to debug because they can crash the program execution at a seemingly unrelated position or (even worse) compute an invalid result without crashing. This revision adds a new op interface `RuntimeVerifiableOpInterface` that can be implemented by ops that provide additional runtime verification. Such runtime verification can be computationally expensive, so it is only generated on an opt-in basis by running `-generate-runtime-verification`. A simple runtime verifier for `memref::ExpandShapeOp` is provided as an example. Differential Revision: https://reviews.llvm.org/D138576 --- .../MemRef/Transforms/RuntimeOpVerification.h | 21 +++++++ mlir/include/mlir/InitAllDialects.h | 2 + mlir/include/mlir/Interfaces/CMakeLists.txt | 1 + .../mlir/Interfaces/RuntimeVerifiableOpInterface.h | 17 ++++++ .../Interfaces/RuntimeVerifiableOpInterface.td | 40 +++++++++++++ mlir/include/mlir/Transforms/Passes.h | 3 + mlir/include/mlir/Transforms/Passes.td | 10 ++++ mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt | 1 + .../MemRef/Transforms/RuntimeOpVerification.cpp | 70 ++++++++++++++++++++++ mlir/lib/Interfaces/CMakeLists.txt | 2 + .../Interfaces/RuntimeVerifiableOpInterface.cpp | 17 ++++++ mlir/lib/Transforms/CMakeLists.txt | 2 + .../lib/Transforms/GenerateRuntimeVerification.cpp | 40 +++++++++++++ mlir/test/Dialect/MemRef/runtime-verification.mlir | 14 +++++ utils/bazel/llvm-project-overlay/mlir/BUILD.bazel | 40 +++++++++++++ 15 files changed, 280 insertions(+) create mode 100644 mlir/include/mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h create mode 100644 mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.h create mode 100644 mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td create mode 100644 mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp create mode 100644 mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp create mode 100644 mlir/lib/Transforms/GenerateRuntimeVerification.cpp create mode 100644 mlir/test/Dialect/MemRef/runtime-verification.mlir diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h b/mlir/include/mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h new file mode 100644 index 0000000..df6ce9f --- /dev/null +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h @@ -0,0 +1,21 @@ +//===- RuntimeOpVerification.h - Op Verification ----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_MEMREF_RUNTIMEOPVERIFICATION_H +#define MLIR_DIALECT_MEMREF_RUNTIMEOPVERIFICATION_H + +namespace mlir { +class DialectRegistry; + +namespace memref { +void registerRuntimeVerifiableOpInterfaceExternalModels( + DialectRegistry ®istry); +} // namespace memref +} // namespace mlir + +#endif // MLIR_DIALECT_MEMREF_RUNTIMEOPVERIFICATION_H diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h index a51f15f..adbbb84 100644 --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -45,6 +45,7 @@ #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h" +#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/OpenACC/OpenACC.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" @@ -130,6 +131,7 @@ inline void registerAllDialects(DialectRegistry ®istry) { registry); linalg::registerBufferizableOpInterfaceExternalModels(registry); linalg::registerTilingInterfaceExternalModels(registry); + memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry); scf::registerBufferizableOpInterfaceExternalModels(registry); shape::registerBufferizableOpInterfaceExternalModels(registry); sparse_tensor::registerBufferizableOpInterfaceExternalModels(registry); diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt index 721a9de..f24d340 100644 --- a/mlir/include/mlir/Interfaces/CMakeLists.txt +++ b/mlir/include/mlir/Interfaces/CMakeLists.txt @@ -8,6 +8,7 @@ add_mlir_interface(InferIntRangeInterface) add_mlir_interface(InferTypeOpInterface) add_mlir_interface(LoopLikeInterface) add_mlir_interface(ParallelCombiningOpInterface) +add_mlir_interface(RuntimeVerifiableOpInterface) add_mlir_interface(ShapedOpInterfaces) add_mlir_interface(SideEffectInterfaces) add_mlir_interface(TilingInterface) diff --git a/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.h b/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.h new file mode 100644 index 0000000..2995e4a --- /dev/null +++ b/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.h @@ -0,0 +1,17 @@ +//===- RuntimeVerifiableOpInterface.h - Op Verification ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE_H_ +#define MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE_H_ + +#include "mlir/IR/OpDefinition.h" + +/// Include the generated interface declarations. +#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h.inc" + +#endif // MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE_H_ diff --git a/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td b/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td new file mode 100644 index 0000000..d5f11d0 --- /dev/null +++ b/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td @@ -0,0 +1,40 @@ +//===- RuntimeVerifiableOpInterface.td - Op Verification ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE +#define MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE + +include "mlir/IR/OpBase.td" + +def RuntimeVerifiableOpInterface : OpInterface<"RuntimeVerifiableOpInterface"> { + let description = [{ + Implementations of this interface generate IR for runtime op verification. + + Incorrect op usage can often be caught by op verifiers based on static + program information. However, in the absence of static program information, + it can remain undetected at compile time (e.g., in case of dynamic memref + strides instead of static memref strides). Such cases can be checked at + runtime. The op-specific checks are generated by this interface. + }]; + let cppNamespace = "::mlir"; + + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Generate IR to verify this op at runtime, aborting runtime execution if + verification fails. + }], + /*retTy=*/"void", + /*methodName=*/"generateRuntimeVerification", + /*args=*/(ins "::mlir::OpBuilder &":$builder, + "::mlir::Location":$loc) + >, + ]; +} + +#endif // MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h index 9d800c3..2ff8a4c 100644 --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -64,6 +64,9 @@ std::unique_ptr createControlFlowSinkPass(); /// Creates a pass to perform common sub expression elimination. std::unique_ptr createCSEPass(); +/// Creates a pass that generates IR to verify ops at runtime. +std::unique_ptr createGenerateRuntimeVerificationPass(); + /// Creates a loop invariant code motion pass that hoists loop invariant /// instructions out of the loop. std::unique_ptr createLoopInvariantCodeMotionPass(); diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td index 8b8e6a1..d45f5f0 100644 --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -77,6 +77,16 @@ def CSE : Pass<"cse"> { ]; } +def GenerateRuntimeVerification : Pass<"generate-runtime-verification"> { + let summary = "Generate additional runtime op verification checks"; + let description = [{ + This pass generates op-specific runtime checks using the + `RuntimeVerifiableOpInterface`. It can be run for debugging purposes after + passes that are suspected to introduce faulty IR. + }]; + let constructor = "mlir::createGenerateRuntimeVerificationPass()"; +} + def Inliner : Pass<"inline"> { let summary = "Inline function calls"; let constructor = "mlir::createInlinerPass()"; diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt index 7c65605..ceccc51 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt @@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms MultiBuffer.cpp NormalizeMemRefs.cpp ResolveShapedTypeResultDims.cpp + RuntimeOpVerification.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/MemRef diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp new file mode 100644 index 0000000..af36855 --- /dev/null +++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp @@ -0,0 +1,70 @@ +//===- RuntimeOpVerification.cpp - Op Verification ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h" + +namespace mlir { +namespace memref { +namespace { +struct ExpandShapeOpInterface + : public RuntimeVerifiableOpInterface::ExternalModel { + void generateRuntimeVerification(Operation *op, OpBuilder &builder, + Location loc) const { + auto expandShapeOp = cast(op); + + // Verify that the expanded dim sizes are a product of the collapsed dim + // size. + for (auto it : llvm::enumerate(expandShapeOp.getReassociationIndices())) { + Value srcDimSz = + builder.create(loc, expandShapeOp.getSrc(), it.index()); + int64_t groupSz = 1; + bool foundDynamicDim = false; + for (int64_t resultDim : it.value()) { + if (expandShapeOp.getResultType().isDynamicDim(resultDim)) { + // Keep this assert here in case the op is extended in the future. + assert(!foundDynamicDim && + "more than one dynamic dim found in reassoc group"); + foundDynamicDim = true; + continue; + } + groupSz *= expandShapeOp.getResultType().getDimSize(resultDim); + } + Value staticResultDimSz = + builder.create(loc, groupSz); + // staticResultDimSz must divide srcDimSz evenly. + Value mod = + builder.create(loc, srcDimSz, staticResultDimSz); + Value isModZero = builder.create( + loc, arith::CmpIPredicate::eq, mod, + builder.create(loc, 0)); + builder.create( + loc, isModZero, + "static result dims in reassoc group do not divide src dim evenly"); + } + } +}; +} // namespace +} // namespace memref +} // namespace mlir + +void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) { + ExpandShapeOp::attachInterface(*ctx); + + // Load additional dialects of which ops may get created. + ctx->loadDialect(); + }); +} diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt index fb4958a..a7cdbb5 100644 --- a/mlir/lib/Interfaces/CMakeLists.txt +++ b/mlir/lib/Interfaces/CMakeLists.txt @@ -10,6 +10,7 @@ set(LLVM_OPTIONAL_SOURCES InferTypeOpInterface.cpp LoopLikeInterface.cpp ParallelCombiningOpInterface.cpp + RuntimeVerifiableOpInterface.cpp ShapedOpInterfaces.cpp SideEffectInterfaces.cpp TilingInterface.cpp @@ -44,6 +45,7 @@ add_mlir_interface_library(InferIntRangeInterface) add_mlir_interface_library(InferTypeOpInterface) add_mlir_interface_library(LoopLikeInterface) add_mlir_interface_library(ParallelCombiningOpInterface) +add_mlir_interface_library(RuntimeVerifiableOpInterface) add_mlir_interface_library(ShapedOpInterfaces) add_mlir_interface_library(SideEffectInterfaces) add_mlir_interface_library(TilingInterface) diff --git a/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp b/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp new file mode 100644 index 0000000..9205d8d --- /dev/null +++ b/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp @@ -0,0 +1,17 @@ +//===- RuntimeVerifiableOpInterface.cpp - Op Verification -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h" + +namespace mlir { +class Location; +class OpBuilder; +} // namespace mlir + +/// Include the definitions of the interface. +#include "mlir/Interfaces/RuntimeVerifiableOpInterface.cpp.inc" diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt index 71ca9b0..3a6b064 100644 --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ add_mlir_library(MLIRTransforms Canonicalizer.cpp ControlFlowSink.cpp CSE.cpp + GenerateRuntimeVerification.cpp Inliner.cpp LocationSnapshot.cpp LoopInvariantCodeMotion.cpp @@ -26,6 +27,7 @@ add_mlir_library(MLIRTransforms MLIRCopyOpInterface MLIRLoopLikeInterface MLIRPass + MLIRRuntimeVerifiableOpInterface MLIRSideEffectInterfaces MLIRSupport MLIRTransformUtils diff --git a/mlir/lib/Transforms/GenerateRuntimeVerification.cpp b/mlir/lib/Transforms/GenerateRuntimeVerification.cpp new file mode 100644 index 0000000..62db9ce --- /dev/null +++ b/mlir/lib/Transforms/GenerateRuntimeVerification.cpp @@ -0,0 +1,40 @@ +//===- RuntimeOpVerification.cpp - Op Verification ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/Passes.h" + +#include "mlir/IR/Builders.h" +#include "mlir/IR/Operation.h" +#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h" + +namespace mlir { +#define GEN_PASS_DEF_GENERATERUNTIMEVERIFICATION +#include "mlir/Transforms/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +namespace { +struct GenerateRuntimeVerificationPass + : public impl::GenerateRuntimeVerificationBase< + GenerateRuntimeVerificationPass> { + void runOnOperation() override; +}; +} // namespace + +void GenerateRuntimeVerificationPass::runOnOperation() { + getOperation()->walk([&](RuntimeVerifiableOpInterface verifiableOp) { + OpBuilder builder(getOperation()->getContext()); + builder.setInsertionPoint(verifiableOp); + verifiableOp.generateRuntimeVerification(builder, verifiableOp.getLoc()); + }); +} + +std::unique_ptr mlir::createGenerateRuntimeVerificationPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/MemRef/runtime-verification.mlir b/mlir/test/Dialect/MemRef/runtime-verification.mlir new file mode 100644 index 0000000..f77717c --- /dev/null +++ b/mlir/test/Dialect/MemRef/runtime-verification.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-opt %s -generate-runtime-verification -cse | FileCheck %s + +// CHECK-LABEL: func @expand_shape( +// CHECK-SAME: %[[m:.*]]: memref +// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[c5:.*]] = arith.constant 5 : index +// CHECK-DAG: %[[dim:.*]] = memref.dim %[[m]], %[[c0]] +// CHECK: %[[mod:.*]] = arith.remsi %[[dim]], %[[c5]] +// CHECK: %[[cmpi:.*]] = arith.cmpi eq, %[[mod]], %[[c0]] +// CHECK: cf.assert %[[cmpi]], "static result dims in reassoc group do not divide src dim evenly" +func.func @expand_shape(%m: memref) -> memref { + %0 = memref.expand_shape %m [[0, 1]] : memref into memref + return %0 : memref +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index ecd2400..3947862 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -1038,6 +1038,13 @@ td_library( ) td_library( + name = "RuntimeVerifiableOpInterfaceTdFiles", + srcs = ["include/mlir/Interfaces/RuntimeVerifiableOpInterface.td"], + includes = ["include"], + deps = [":OpBaseTdFiles"], +) + +td_library( name = "SideEffectInterfacesTdFiles", srcs = [ "include/mlir/Interfaces/SideEffectInterfaceBase.td", @@ -2993,6 +3000,18 @@ cc_library( ) cc_library( + name = "RuntimeVerifiableOpInterface", + srcs = ["lib/Interfaces/RuntimeVerifiableOpInterface.cpp"], + hdrs = ["include/mlir/Interfaces/RuntimeVerifiableOpInterface.h"], + includes = ["include"], + deps = [ + ":IR", + ":RuntimeVerifiableOpInterfaceIncGen", + "//llvm:Support", + ], +) + +cc_library( name = "VectorInterfaces", srcs = ["lib/Interfaces/VectorInterfaces.cpp"], hdrs = ["include/mlir/Interfaces/VectorInterfaces.h"], @@ -5716,6 +5735,24 @@ gentbl_cc_library( ) gentbl_cc_library( + name = "RuntimeVerifiableOpInterfaceIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + ["-gen-op-interface-decls"], + "include/mlir/Interfaces/RuntimeVerifiableOpInterface.h.inc", + ), + ( + ["-gen-op-interface-defs"], + "include/mlir/Interfaces/RuntimeVerifiableOpInterface.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Interfaces/RuntimeVerifiableOpInterface.td", + deps = [":RuntimeVerifiableOpInterfaceTdFiles"], +) + +gentbl_cc_library( name = "VectorInterfacesIncGen", strip_include_prefix = "include", tbl_outs = [ @@ -5818,6 +5855,7 @@ cc_library( ":LoopLikeInterface", ":Pass", ":Rewrite", + ":RuntimeVerifiableOpInterface", ":SideEffectInterfaces", ":Support", ":TransformUtils", @@ -9783,6 +9821,7 @@ cc_library( ":ArithDialect", ":ArithTransforms", ":ArithUtils", + ":ControlFlowDialect", ":DialectUtils", ":FuncDialect", ":IR", @@ -9791,6 +9830,7 @@ cc_library( ":MemRefDialect", ":MemRefPassIncGen", ":Pass", + ":RuntimeVerifiableOpInterface", ":TensorDialect", ":Transforms", ":VectorDialect", -- 2.7.4