[mlir] Add RuntimeVerifiableOpInterface and transform
authorMatthias Springer <springerm@google.com>
Wed, 21 Dec 2022 09:51:10 +0000 (10:51 +0100)
committerMatthias Springer <springerm@google.com>
Wed, 21 Dec 2022 09:57:14 +0000 (10:57 +0100)
Static op verification cannot detect cases where an op is valid at compile time but may be invalid at runtime.

An example of such an op is `memref::ExpandShapeOp`.

Invalid at compile time: `memref.expand_shape %m [[0, 1]] : memref<11xf32> into memref<2x5xf32>`

Valid at compile time (because we do not know any better): `memref.expand_shape %m [[0, 1]] : memref<?xf32> into memref<?x5xf32>`. This op may or may not be valid at runtime depending on the runtime shape of `%m`.

Invalid runtime ops such as the one above are hard to debug because they can crash the program execution at a seemingly unrelated position or (even worse) compute an invalid result without crashing.

This revision adds a new op interface `RuntimeVerifiableOpInterface` that can be implemented by ops that provide additional runtime verification. Such runtime verification can be computationally expensive, so it is only generated on an opt-in basis by running `-generate-runtime-verification`. A simple runtime verifier for `memref::ExpandShapeOp` is provided as an example.

Differential Revision: https://reviews.llvm.org/D138576

15 files changed:
mlir/include/mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h [new file with mode: 0644]
mlir/include/mlir/InitAllDialects.h
mlir/include/mlir/Interfaces/CMakeLists.txt
mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.h [new file with mode: 0644]
mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td [new file with mode: 0644]
mlir/include/mlir/Transforms/Passes.h
mlir/include/mlir/Transforms/Passes.td
mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp [new file with mode: 0644]
mlir/lib/Interfaces/CMakeLists.txt
mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp [new file with mode: 0644]
mlir/lib/Transforms/CMakeLists.txt
mlir/lib/Transforms/GenerateRuntimeVerification.cpp [new file with mode: 0644]
mlir/test/Dialect/MemRef/runtime-verification.mlir [new file with mode: 0644]
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h b/mlir/include/mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h
new file mode 100644 (file)
index 0000000..df6ce9f
--- /dev/null
@@ -0,0 +1,21 @@
+//===- RuntimeOpVerification.h - Op Verification ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MEMREF_RUNTIMEOPVERIFICATION_H
+#define MLIR_DIALECT_MEMREF_RUNTIMEOPVERIFICATION_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace memref {
+void registerRuntimeVerifiableOpInterfaceExternalModels(
+    DialectRegistry &registry);
+} // namespace memref
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MEMREF_RUNTIMEOPVERIFICATION_H
index a51f15f..adbbb84 100644 (file)
@@ -45,6 +45,7 @@
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
+#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
 #include "mlir/Dialect/OpenACC/OpenACC.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
@@ -130,6 +131,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
       registry);
   linalg::registerBufferizableOpInterfaceExternalModels(registry);
   linalg::registerTilingInterfaceExternalModels(registry);
+  memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
   scf::registerBufferizableOpInterfaceExternalModels(registry);
   shape::registerBufferizableOpInterfaceExternalModels(registry);
   sparse_tensor::registerBufferizableOpInterfaceExternalModels(registry);
index 721a9de..f24d340 100644 (file)
@@ -8,6 +8,7 @@ add_mlir_interface(InferIntRangeInterface)
 add_mlir_interface(InferTypeOpInterface)
 add_mlir_interface(LoopLikeInterface)
 add_mlir_interface(ParallelCombiningOpInterface)
+add_mlir_interface(RuntimeVerifiableOpInterface)
 add_mlir_interface(ShapedOpInterfaces)
 add_mlir_interface(SideEffectInterfaces)
 add_mlir_interface(TilingInterface)
diff --git a/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.h b/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.h
new file mode 100644 (file)
index 0000000..2995e4a
--- /dev/null
@@ -0,0 +1,17 @@
+//===- RuntimeVerifiableOpInterface.h - Op Verification ---------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE_H_
+#define MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE_H_
+
+#include "mlir/IR/OpDefinition.h"
+
+/// Include the generated interface declarations.
+#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h.inc"
+
+#endif // MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE_H_
diff --git a/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td b/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td
new file mode 100644 (file)
index 0000000..d5f11d0
--- /dev/null
@@ -0,0 +1,40 @@
+//===- RuntimeVerifiableOpInterface.td - Op Verification ---*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE
+#define MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def RuntimeVerifiableOpInterface : OpInterface<"RuntimeVerifiableOpInterface"> {
+  let description = [{
+    Implementations of this interface generate IR for runtime op verification.
+
+    Incorrect op usage can often be caught by op verifiers based on static
+    program information. However, in the absence of static program information,
+    it can remain undetected at compile time (e.g., in case of dynamic memref
+    strides instead of static memref strides). Such cases can be checked at
+    runtime. The op-specific checks are generated by this interface.
+  }];
+  let cppNamespace = "::mlir";
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/[{
+        Generate IR to verify this op at runtime, aborting runtime execution if
+        verification fails.
+      }],
+      /*retTy=*/"void",
+      /*methodName=*/"generateRuntimeVerification",
+      /*args=*/(ins "::mlir::OpBuilder &":$builder,
+                    "::mlir::Location":$loc)
+    >,
+  ];
+}
+
+#endif // MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE
index 9d800c3..2ff8a4c 100644 (file)
@@ -64,6 +64,9 @@ std::unique_ptr<Pass> createControlFlowSinkPass();
 /// Creates a pass to perform common sub expression elimination.
 std::unique_ptr<Pass> createCSEPass();
 
+/// Creates a pass that generates IR to verify ops at runtime.
+std::unique_ptr<Pass> createGenerateRuntimeVerificationPass();
+
 /// Creates a loop invariant code motion pass that hoists loop invariant
 /// instructions out of the loop.
 std::unique_ptr<Pass> createLoopInvariantCodeMotionPass();
index 8b8e6a1..d45f5f0 100644 (file)
@@ -77,6 +77,16 @@ def CSE : Pass<"cse"> {
   ];
 }
 
+def GenerateRuntimeVerification : Pass<"generate-runtime-verification"> {
+  let summary = "Generate additional runtime op verification checks";
+  let description = [{
+    This pass generates op-specific runtime checks using the
+    `RuntimeVerifiableOpInterface`. It can be run for debugging purposes after
+    passes that are suspected to introduce faulty IR.
+  }];
+  let constructor = "mlir::createGenerateRuntimeVerificationPass()";
+}
+
 def Inliner : Pass<"inline"> {
   let summary = "Inline function calls";
   let constructor = "mlir::createInlinerPass()";
index 7c65605..ceccc51 100644 (file)
@@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
   MultiBuffer.cpp
   NormalizeMemRefs.cpp
   ResolveShapedTypeResultDims.cpp
+  RuntimeOpVerification.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/MemRef
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
new file mode 100644 (file)
index 0000000..af36855
--- /dev/null
@@ -0,0 +1,70 @@
+//===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
+
+namespace mlir {
+namespace memref {
+namespace {
+struct ExpandShapeOpInterface
+    : public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
+                                                         ExpandShapeOp> {
+  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
+                                   Location loc) const {
+    auto expandShapeOp = cast<ExpandShapeOp>(op);
+
+    // Verify that the expanded dim sizes are a product of the collapsed dim
+    // size.
+    for (auto it : llvm::enumerate(expandShapeOp.getReassociationIndices())) {
+      Value srcDimSz =
+          builder.create<DimOp>(loc, expandShapeOp.getSrc(), it.index());
+      int64_t groupSz = 1;
+      bool foundDynamicDim = false;
+      for (int64_t resultDim : it.value()) {
+        if (expandShapeOp.getResultType().isDynamicDim(resultDim)) {
+          // Keep this assert here in case the op is extended in the future.
+          assert(!foundDynamicDim &&
+                 "more than one dynamic dim found in reassoc group");
+          foundDynamicDim = true;
+          continue;
+        }
+        groupSz *= expandShapeOp.getResultType().getDimSize(resultDim);
+      }
+      Value staticResultDimSz =
+          builder.create<arith::ConstantIndexOp>(loc, groupSz);
+      // staticResultDimSz must divide srcDimSz evenly.
+      Value mod =
+          builder.create<arith::RemSIOp>(loc, srcDimSz, staticResultDimSz);
+      Value isModZero = builder.create<arith::CmpIOp>(
+          loc, arith::CmpIPredicate::eq, mod,
+          builder.create<arith::ConstantIndexOp>(loc, 0));
+      builder.create<cf::AssertOp>(
+          loc, isModZero,
+          "static result dims in reassoc group do not divide src dim evenly");
+    }
+  }
+};
+} // namespace
+} // namespace memref
+} // namespace mlir
+
+void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
+    DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
+    ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
+
+    // Load additional dialects of which ops may get created.
+    ctx->loadDialect<arith::ArithDialect, cf::ControlFlowDialect>();
+  });
+}
index fb4958a..a7cdbb5 100644 (file)
@@ -10,6 +10,7 @@ set(LLVM_OPTIONAL_SOURCES
   InferTypeOpInterface.cpp
   LoopLikeInterface.cpp
   ParallelCombiningOpInterface.cpp
+  RuntimeVerifiableOpInterface.cpp
   ShapedOpInterfaces.cpp
   SideEffectInterfaces.cpp
   TilingInterface.cpp
@@ -44,6 +45,7 @@ add_mlir_interface_library(InferIntRangeInterface)
 add_mlir_interface_library(InferTypeOpInterface)
 add_mlir_interface_library(LoopLikeInterface)
 add_mlir_interface_library(ParallelCombiningOpInterface)
+add_mlir_interface_library(RuntimeVerifiableOpInterface)
 add_mlir_interface_library(ShapedOpInterfaces)
 add_mlir_interface_library(SideEffectInterfaces)
 add_mlir_interface_library(TilingInterface)
diff --git a/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp b/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp
new file mode 100644 (file)
index 0000000..9205d8d
--- /dev/null
@@ -0,0 +1,17 @@
+//===- RuntimeVerifiableOpInterface.cpp - Op Verification -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
+
+namespace mlir {
+class Location;
+class OpBuilder;
+} // namespace mlir
+
+/// Include the definitions of the interface.
+#include "mlir/Interfaces/RuntimeVerifiableOpInterface.cpp.inc"
index 71ca9b0..3a6b064 100644 (file)
@@ -4,6 +4,7 @@ add_mlir_library(MLIRTransforms
   Canonicalizer.cpp
   ControlFlowSink.cpp
   CSE.cpp
+  GenerateRuntimeVerification.cpp
   Inliner.cpp
   LocationSnapshot.cpp
   LoopInvariantCodeMotion.cpp
@@ -26,6 +27,7 @@ add_mlir_library(MLIRTransforms
   MLIRCopyOpInterface
   MLIRLoopLikeInterface
   MLIRPass
+  MLIRRuntimeVerifiableOpInterface
   MLIRSideEffectInterfaces
   MLIRSupport
   MLIRTransformUtils
diff --git a/mlir/lib/Transforms/GenerateRuntimeVerification.cpp b/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
new file mode 100644 (file)
index 0000000..62db9ce
--- /dev/null
@@ -0,0 +1,40 @@
+//===- RuntimeOpVerification.cpp - Op Verification ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/Passes.h"
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_GENERATERUNTIMEVERIFICATION
+#include "mlir/Transforms/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+struct GenerateRuntimeVerificationPass
+    : public impl::GenerateRuntimeVerificationBase<
+          GenerateRuntimeVerificationPass> {
+  void runOnOperation() override;
+};
+} // namespace
+
+void GenerateRuntimeVerificationPass::runOnOperation() {
+  getOperation()->walk([&](RuntimeVerifiableOpInterface verifiableOp) {
+    OpBuilder builder(getOperation()->getContext());
+    builder.setInsertionPoint(verifiableOp);
+    verifiableOp.generateRuntimeVerification(builder, verifiableOp.getLoc());
+  });
+}
+
+std::unique_ptr<Pass> mlir::createGenerateRuntimeVerificationPass() {
+  return std::make_unique<GenerateRuntimeVerificationPass>();
+}
diff --git a/mlir/test/Dialect/MemRef/runtime-verification.mlir b/mlir/test/Dialect/MemRef/runtime-verification.mlir
new file mode 100644 (file)
index 0000000..f77717c
--- /dev/null
@@ -0,0 +1,14 @@
+// RUN: mlir-opt %s -generate-runtime-verification -cse | FileCheck %s
+
+// CHECK-LABEL: func @expand_shape(
+//  CHECK-SAME:     %[[m:.*]]: memref<?xf32>
+//   CHECK-DAG:   %[[c0:.*]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[c5:.*]] = arith.constant 5 : index
+//   CHECK-DAG:   %[[dim:.*]] = memref.dim %[[m]], %[[c0]]
+//       CHECK:   %[[mod:.*]] = arith.remsi %[[dim]], %[[c5]]
+//       CHECK:   %[[cmpi:.*]] = arith.cmpi eq, %[[mod]], %[[c0]]
+//       CHECK:   cf.assert %[[cmpi]], "static result dims in reassoc group do not divide src dim evenly"
+func.func @expand_shape(%m: memref<?xf32>) -> memref<?x5xf32> {
+  %0 = memref.expand_shape %m [[0, 1]] : memref<?xf32> into memref<?x5xf32>
+  return %0 : memref<?x5xf32>
+}
index ecd2400..3947862 100644 (file)
@@ -1038,6 +1038,13 @@ td_library(
 )
 
 td_library(
+    name = "RuntimeVerifiableOpInterfaceTdFiles",
+    srcs = ["include/mlir/Interfaces/RuntimeVerifiableOpInterface.td"],
+    includes = ["include"],
+    deps = [":OpBaseTdFiles"],
+)
+
+td_library(
     name = "SideEffectInterfacesTdFiles",
     srcs = [
         "include/mlir/Interfaces/SideEffectInterfaceBase.td",
@@ -2993,6 +3000,18 @@ cc_library(
 )
 
 cc_library(
+    name = "RuntimeVerifiableOpInterface",
+    srcs = ["lib/Interfaces/RuntimeVerifiableOpInterface.cpp"],
+    hdrs = ["include/mlir/Interfaces/RuntimeVerifiableOpInterface.h"],
+    includes = ["include"],
+    deps = [
+        ":IR",
+        ":RuntimeVerifiableOpInterfaceIncGen",
+        "//llvm:Support",
+    ],
+)
+
+cc_library(
     name = "VectorInterfaces",
     srcs = ["lib/Interfaces/VectorInterfaces.cpp"],
     hdrs = ["include/mlir/Interfaces/VectorInterfaces.h"],
@@ -5716,6 +5735,24 @@ gentbl_cc_library(
 )
 
 gentbl_cc_library(
+    name = "RuntimeVerifiableOpInterfaceIncGen",
+    strip_include_prefix = "include",
+    tbl_outs = [
+        (
+            ["-gen-op-interface-decls"],
+            "include/mlir/Interfaces/RuntimeVerifiableOpInterface.h.inc",
+        ),
+        (
+            ["-gen-op-interface-defs"],
+            "include/mlir/Interfaces/RuntimeVerifiableOpInterface.cpp.inc",
+        ),
+    ],
+    tblgen = ":mlir-tblgen",
+    td_file = "include/mlir/Interfaces/RuntimeVerifiableOpInterface.td",
+    deps = [":RuntimeVerifiableOpInterfaceTdFiles"],
+)
+
+gentbl_cc_library(
     name = "VectorInterfacesIncGen",
     strip_include_prefix = "include",
     tbl_outs = [
@@ -5818,6 +5855,7 @@ cc_library(
         ":LoopLikeInterface",
         ":Pass",
         ":Rewrite",
+        ":RuntimeVerifiableOpInterface",
         ":SideEffectInterfaces",
         ":Support",
         ":TransformUtils",
@@ -9783,6 +9821,7 @@ cc_library(
         ":ArithDialect",
         ":ArithTransforms",
         ":ArithUtils",
+        ":ControlFlowDialect",
         ":DialectUtils",
         ":FuncDialect",
         ":IR",
@@ -9791,6 +9830,7 @@ cc_library(
         ":MemRefDialect",
         ":MemRefPassIncGen",
         ":Pass",
+        ":RuntimeVerifiableOpInterface",
         ":TensorDialect",
         ":Transforms",
         ":VectorDialect",