Fix crash in scf.parallel verifier
authorMehdi Amini <joker.eph@gmail.com>
Tue, 17 Jan 2023 16:20:20 +0000 (16:20 +0000)
committerMehdi Amini <joker.eph@gmail.com>
Tue, 17 Jan 2023 16:21:28 +0000 (16:21 +0000)
Fixes #59989

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D141911

mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/test/Dialect/SCF/invalid.mlir

index 4e6cd2e..fc7ce76 100644 (file)
@@ -78,6 +78,23 @@ void mlir::scf::buildTerminatedBody(OpBuilder &builder, Location loc) {
   builder.create<scf::YieldOp>(loc);
 }
 
+/// Verifies that the first block of the given `region` is terminated by a
+/// TerminatorTy. Reports errors on the given operation if it is not the case.
+template <typename TerminatorTy>
+static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region,
+                                           StringRef errorMessage) {
+  Operation *terminatorOperation = nullptr;
+  if (!region.empty() && !region.front().empty()) {
+    terminatorOperation = &region.front().back();
+    if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
+      return yield;
+  }
+  auto diag = op->emitOpError(errorMessage);
+  if (terminatorOperation)
+    diag.attachNote(terminatorOperation->getLoc()) << "terminator here";
+  return nullptr;
+}
+
 //===----------------------------------------------------------------------===//
 // ExecuteRegionOp
 //===----------------------------------------------------------------------===//
@@ -2323,10 +2340,13 @@ LogicalResult ParallelOp::verify() {
           "expects arguments for the induction variable to be of index type");
 
   // Check that the yield has no results
-  Operation *yield = body->getTerminator();
+  auto yield = verifyAndGetTerminator<scf::YieldOp>(
+      *this, getRegion(), "expects body to terminate with 'scf.yield'");
+  if (!yield)
+    return failure();
   if (yield->getNumOperands() != 0)
-    return yield->emitOpError() << "not allowed to have operands inside '"
-                                << ParallelOp::getOperationName() << "'";
+    return yield.emitOpError() << "not allowed to have operands inside '"
+                               << ParallelOp::getOperationName() << "'";
 
   // Check that the number of results is the same as the number of ReduceOps.
   SmallVector<ReduceOp, 4> reductions(body->getOps<ReduceOp>());
@@ -2854,23 +2874,6 @@ static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left,
   return success();
 }
 
-/// Verifies that the first block of the given `region` is terminated by a
-/// YieldOp. Reports errors on the given operation if it is not the case.
-template <typename TerminatorTy>
-static TerminatorTy verifyAndGetTerminator(scf::WhileOp op, Region &region,
-                                           StringRef errorMessage) {
-  Operation *terminatorOperation = nullptr;
-  if (!region.empty() && !region.front().empty()) {
-    terminatorOperation = &region.front().back();
-    if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
-      return yield;
-  }
-  auto diag = op.emitOpError(errorMessage);
-  if (terminatorOperation)
-    diag.attachNote(terminatorOperation->getLoc()) << "terminator here";
-  return nullptr;
-}
-
 LogicalResult scf::WhileOp::verify() {
   auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
       *this, getBefore(),
index 498a3bc..c1c6639 100644 (file)
@@ -672,3 +672,16 @@ func.func @switch_missing_terminator(%arg0: index, %arg1: i32) {
     return
   }) {cases = array<i64: 1>} : (index) -> ()
 }
+
+// -----
+
+func.func @parallel_missing_terminator(%0 : index) {
+  // expected-error @below {{'scf.parallel' op expects body to terminate with 'scf.yield'}}
+  "scf.parallel"(%0, %0, %0) ({
+  ^bb0(%arg1: index):
+    // expected-note @below {{terminator here}}
+    %2 = "arith.constant"() {value = 1.000000e+00 : f32} : () -> f32
+  }) {operand_segment_sizes = array<i32: 1, 1, 1, 0>} : (index, index, index) -> ()
+  return
+}
+