From 9d7da415d244124af93e42a7a378eb79c2fb391f Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Tue, 17 Jan 2023 16:20:20 +0000 Subject: [PATCH] Fix crash in scf.parallel verifier Fixes #59989 Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D141911 --- mlir/lib/Dialect/SCF/IR/SCF.cpp | 43 ++++++++++++++++++++------------------ mlir/test/Dialect/SCF/invalid.mlir | 13 ++++++++++++ 2 files changed, 36 insertions(+), 20 deletions(-) diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 4e6cd2e..fc7ce76 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -78,6 +78,23 @@ void mlir::scf::buildTerminatedBody(OpBuilder &builder, Location loc) { builder.create(loc); } +/// Verifies that the first block of the given `region` is terminated by a +/// TerminatorTy. Reports errors on the given operation if it is not the case. +template +static TerminatorTy verifyAndGetTerminator(Operation *op, Region ®ion, + StringRef errorMessage) { + Operation *terminatorOperation = nullptr; + if (!region.empty() && !region.front().empty()) { + terminatorOperation = ®ion.front().back(); + if (auto yield = dyn_cast_or_null(terminatorOperation)) + return yield; + } + auto diag = op->emitOpError(errorMessage); + if (terminatorOperation) + diag.attachNote(terminatorOperation->getLoc()) << "terminator here"; + return nullptr; +} + //===----------------------------------------------------------------------===// // ExecuteRegionOp //===----------------------------------------------------------------------===// @@ -2323,10 +2340,13 @@ LogicalResult ParallelOp::verify() { "expects arguments for the induction variable to be of index type"); // Check that the yield has no results - Operation *yield = body->getTerminator(); + auto yield = verifyAndGetTerminator( + *this, getRegion(), "expects body to terminate with 'scf.yield'"); + if (!yield) + return failure(); if (yield->getNumOperands() != 0) - return yield->emitOpError() << "not allowed to have operands inside '" - << ParallelOp::getOperationName() << "'"; + return yield.emitOpError() << "not allowed to have operands inside '" + << ParallelOp::getOperationName() << "'"; // Check that the number of results is the same as the number of ReduceOps. SmallVector reductions(body->getOps()); @@ -2854,23 +2874,6 @@ static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left, return success(); } -/// Verifies that the first block of the given `region` is terminated by a -/// YieldOp. Reports errors on the given operation if it is not the case. -template -static TerminatorTy verifyAndGetTerminator(scf::WhileOp op, Region ®ion, - StringRef errorMessage) { - Operation *terminatorOperation = nullptr; - if (!region.empty() && !region.front().empty()) { - terminatorOperation = ®ion.front().back(); - if (auto yield = dyn_cast_or_null(terminatorOperation)) - return yield; - } - auto diag = op.emitOpError(errorMessage); - if (terminatorOperation) - diag.attachNote(terminatorOperation->getLoc()) << "terminator here"; - return nullptr; -} - LogicalResult scf::WhileOp::verify() { auto beforeTerminator = verifyAndGetTerminator( *this, getBefore(), diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir index 498a3bc..c1c6639 100644 --- a/mlir/test/Dialect/SCF/invalid.mlir +++ b/mlir/test/Dialect/SCF/invalid.mlir @@ -672,3 +672,16 @@ func.func @switch_missing_terminator(%arg0: index, %arg1: i32) { return }) {cases = array} : (index) -> () } + +// ----- + +func.func @parallel_missing_terminator(%0 : index) { + // expected-error @below {{'scf.parallel' op expects body to terminate with 'scf.yield'}} + "scf.parallel"(%0, %0, %0) ({ + ^bb0(%arg1: index): + // expected-note @below {{terminator here}} + %2 = "arith.constant"() {value = 1.000000e+00 : f32} : () -> f32 + }) {operand_segment_sizes = array} : (index, index, index) -> () + return +} + -- 2.7.4