[mlir][memref] Add runtime verification for memref::CastOp
authorMatthias Springer <springerm@google.com>
Fri, 6 Jan 2023 13:24:30 +0000 (14:24 +0100)
committerMatthias Springer <springerm@google.com>
Fri, 6 Jan 2023 13:38:56 +0000 (14:38 +0100)
Verify unranked -> ranked casts and casts of dynamic sizes/offset/strides to static ones.

Differential Revision: https://reviews.llvm.org/D138671

mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
mlir/test/Dialect/MemRef/runtime-verification.mlir
mlir/test/Integration/Dialect/Memref/cast-runtime-verification.mlir [new file with mode: 0644]

index 002e5f1..9ffb315 100644 (file)
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
 
+using namespace mlir;
+
+/// Generate an error message string for the given op and the specified error.
+static std::string generateErrorMessage(Operation *op, const std::string &msg) {
+  std::string buffer;
+  llvm::raw_string_ostream stream(buffer);
+  OpPrintingFlags flags;
+  stream << "ERROR: Runtime op verification failed\n";
+  op->print(stream, flags);
+  stream << "\n^ " << msg;
+  stream << "\nLocation: ";
+  op->getLoc().print(stream);
+  return stream.str();
+}
+
 namespace mlir {
 namespace memref {
 namespace {
+struct CastOpInterface
+    : public RuntimeVerifiableOpInterface::ExternalModel<CastOpInterface,
+                                                         CastOp> {
+  void generateRuntimeVerification(Operation *op, OpBuilder &builder,
+                                   Location loc) const {
+    auto castOp = cast<CastOp>(op);
+    auto srcType = castOp.getSource().getType().cast<BaseMemRefType>();
+
+    // Nothing to check if the result is an unranked memref.
+    auto resultType = castOp.getType().dyn_cast<MemRefType>();
+    if (!resultType)
+      return;
+
+    if (srcType.isa<UnrankedMemRefType>()) {
+      // Check rank.
+      Value srcRank = builder.create<RankOp>(loc, castOp.getSource());
+      Value resultRank =
+          builder.create<arith::ConstantIndexOp>(loc, resultType.getRank());
+      Value isSameRank = builder.create<arith::CmpIOp>(
+          loc, arith::CmpIPredicate::eq, srcRank, resultRank);
+      builder.create<cf::AssertOp>(loc, isSameRank,
+                                   generateErrorMessage(op, "rank mismatch"));
+    }
+
+    // Get source offset and strides. We do not have an op to get offsets and
+    // strides from unranked memrefs, so cast the source to a type with fully
+    // dynamic layout, from which we can then extract the offset and strides.
+    // (Rank was already verified.)
+    int64_t dynamicOffset = ShapedType::kDynamic;
+    SmallVector<int64_t> dynamicShape(resultType.getRank(),
+                                      ShapedType::kDynamic);
+    auto stridedLayout = StridedLayoutAttr::get(builder.getContext(),
+                                                dynamicOffset, dynamicShape);
+    auto dynStridesType =
+        MemRefType::get(dynamicShape, resultType.getElementType(),
+                        stridedLayout, resultType.getMemorySpace());
+    Value helperCast =
+        builder.create<CastOp>(loc, dynStridesType, castOp.getSource());
+    auto metadataOp = builder.create<ExtractStridedMetadataOp>(loc, helperCast);
+
+    // Check dimension sizes.
+    for (const auto &it : llvm::enumerate(resultType.getShape())) {
+      // Static dim size -> static/dynamic dim size does not need verification.
+      if (auto rankedSrcType = srcType.dyn_cast<MemRefType>())
+        if (!rankedSrcType.isDynamicDim(it.index()))
+          continue;
+
+      // Static/dynamic dim size -> dynamic dim size does not need verification.
+      if (resultType.isDynamicDim(it.index()))
+        continue;
+
+      Value srcDimSz =
+          builder.create<DimOp>(loc, castOp.getSource(), it.index());
+      Value resultDimSz =
+          builder.create<arith::ConstantIndexOp>(loc, it.value());
+      Value isSameSz = builder.create<arith::CmpIOp>(
+          loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz);
+      builder.create<cf::AssertOp>(
+          loc, isSameSz,
+          generateErrorMessage(op, "size mismatch of dim " +
+                                       std::to_string(it.index())));
+    }
+
+    // Get result offset and strides.
+    int64_t resultOffset;
+    SmallVector<int64_t> resultStrides;
+    if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
+      return;
+
+    // Check offset.
+    if (resultOffset != ShapedType::kDynamic) {
+      // Static/dynamic offset -> dynamic offset does not need verification.
+      Value srcOffset = metadataOp.getResult(1);
+      Value resultOffsetVal =
+          builder.create<arith::ConstantIndexOp>(loc, resultOffset);
+      Value isSameOffset = builder.create<arith::CmpIOp>(
+          loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal);
+      builder.create<cf::AssertOp>(loc, isSameOffset,
+                                   generateErrorMessage(op, "offset mismatch"));
+    }
+
+    // Check strides.
+    for (const auto &it : llvm::enumerate(resultStrides)) {
+      // Static/dynamic stride -> dynamic stride does not need verification.
+      if (it.value() == ShapedType::kDynamic)
+        continue;
+
+      Value srcStride =
+          metadataOp.getResult(2 + resultType.getRank() + it.index());
+      Value resultStrideVal =
+          builder.create<arith::ConstantIndexOp>(loc, it.value());
+      Value isSameStride = builder.create<arith::CmpIOp>(
+          loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal);
+      builder.create<cf::AssertOp>(
+          loc, isSameStride,
+          generateErrorMessage(op, "stride mismatch of dim " +
+                                       std::to_string(it.index())));
+    }
+  }
+};
+
 struct ExpandShapeOpInterface
     : public RuntimeVerifiableOpInterface::ExternalModel<ExpandShapeOpInterface,
                                                          ExpandShapeOp> {
@@ -53,7 +169,8 @@ struct ExpandShapeOpInterface
           builder.create<arith::ConstantIndexOp>(loc, 0));
       builder.create<cf::AssertOp>(
           loc, isModZero,
-          "static result dims in reassoc group do not divide src dim evenly");
+          generateErrorMessage(op, "static result dims in reassoc group do not "
+                                   "divide src dim evenly"));
     }
   }
 };
@@ -64,6 +181,7 @@ struct ExpandShapeOpInterface
 void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels(
     DialectRegistry &registry) {
   registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
+    CastOp::attachInterface<CastOpInterface>(*ctx);
     ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx);
 
     // Load additional dialects of which ops may get created.
index f77717c..4d7fcf6 100644 (file)
@@ -7,7 +7,7 @@
 //   CHECK-DAG:   %[[dim:.*]] = memref.dim %[[m]], %[[c0]]
 //       CHECK:   %[[mod:.*]] = arith.remsi %[[dim]], %[[c5]]
 //       CHECK:   %[[cmpi:.*]] = arith.cmpi eq, %[[mod]], %[[c0]]
-//       CHECK:   cf.assert %[[cmpi]], "static result dims in reassoc group do not divide src dim evenly"
+//       CHECK:   cf.assert %[[cmpi]], "ERROR: Runtime op verification failed
 func.func @expand_shape(%m: memref<?xf32>) -> memref<?x5xf32> {
   %0 = memref.expand_shape %m [[0, 1]] : memref<?xf32> into memref<?x5xf32>
   return %0 : memref<?x5xf32>
diff --git a/mlir/test/Integration/Dialect/Memref/cast-runtime-verification.mlir b/mlir/test/Integration/Dialect/Memref/cast-runtime-verification.mlir
new file mode 100644 (file)
index 0000000..5f551f8
--- /dev/null
@@ -0,0 +1,69 @@
+// RUN: mlir-opt %s -generate-runtime-verification -convert-memref-to-llvm \
+// RUN:     -test-cf-assert \
+// RUN:     -convert-func-to-llvm -reconcile-unrealized-casts | \
+// RUN: mlir-cpu-runner -e main -entry-point-result=void \
+// RUN:     -shared-libs=%mlir_lib_dir/libmlir_runner_utils%shlibext 2>&1 | \
+// RUN: FileCheck %s
+
+func.func @cast_to_static_dim(%m: memref<?xf32>) -> memref<10xf32> {
+  %0 = memref.cast %m : memref<?xf32> to memref<10xf32>
+  return %0 : memref<10xf32>
+}
+
+func.func @cast_to_ranked(%m: memref<*xf32>) -> memref<f32> {
+  %0 = memref.cast %m : memref<*xf32> to memref<f32>
+  return %0 : memref<f32>
+}
+
+func.func @cast_to_static_strides(%m: memref<?xf32, strided<[?], offset: ?>>)
+    -> memref<?xf32, strided<[9], offset: 5>> {
+  %0 = memref.cast %m : memref<?xf32, strided<[?], offset: ?>>
+                     to memref<?xf32, strided<[9], offset: 5>>
+  return %0 : memref<?xf32, strided<[9], offset: 5>>
+}
+
+func.func @valid_cast(%m: memref<*xf32>) -> memref<?xf32> {
+  %0 = memref.cast %m : memref<*xf32> to memref<?xf32>
+  return %0 : memref<?xf32>
+}
+
+func.func @main() {
+  // All casts inside the called functions are invalid at runtime, except for
+  // the last one.
+  %alloc = memref.alloc() : memref<5xf32>
+
+  //      CHECK: ERROR: Runtime op verification failed
+  // CHECK-NEXT: memref.cast %{{.*}} : memref<?xf32> to memref<10xf32>
+  // CHECK-NEXT: ^ size mismatch of dim 0
+  // CHECK-NEXT: Location: loc({{.*}})
+  %1 = memref.cast %alloc : memref<5xf32> to memref<?xf32>
+  func.call @cast_to_static_dim(%1) : (memref<?xf32>) -> (memref<10xf32>)
+
+  // CHECK-NEXT: ERROR: Runtime op verification failed
+  // CHECK-NEXT: memref.cast %{{.*}} : memref<*xf32> to memref<f32>
+  // CHECK-NEXT: ^ rank mismatch
+  // CHECK-NEXT: Location: loc({{.*}})
+  %3 = memref.cast %alloc : memref<5xf32> to memref<*xf32>
+  func.call @cast_to_ranked(%3) : (memref<*xf32>) -> (memref<f32>)
+
+  // CHECK-NEXT: ERROR: Runtime op verification failed
+  // CHECK-NEXT: memref.cast %{{.*}} : memref<?xf32, strided<[?], offset: ?>> to memref<?xf32, strided<[9], offset: 5>>
+  // CHECK-NEXT: ^ offset mismatch
+  // CHECK-NEXT: Location: loc({{.*}})
+
+  // CHECK-NEXT: ERROR: Runtime op verification failed
+  // CHECK-NEXT: memref.cast %{{.*}} : memref<?xf32, strided<[?], offset: ?>> to memref<?xf32, strided<[9], offset: 5>>
+  // CHECK-NEXT: ^ stride mismatch of dim 0
+  // CHECK-NEXT: Location: loc({{.*}})
+  %4 = memref.cast %alloc
+      : memref<5xf32> to memref<?xf32, strided<[?], offset: ?>>
+  func.call @cast_to_static_strides(%4)
+      : (memref<?xf32, strided<[?], offset: ?>>)
+     -> (memref<?xf32, strided<[9], offset: 5>>)
+
+  // A last cast that actually succeeds.
+  // CHECK-NOT: ERROR: Runtime op verification failed
+  func.call @valid_cast(%3) : (memref<*xf32>) -> (memref<?xf32>)
+
+  return
+}