From a4cbe4ebe1a6735bda4630919900b8eeb03b7626 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Tue, 16 Jul 2019 12:45:05 -0700 Subject: [PATCH] Verify that ReturnOp only appears within the region of a FuncOp. The invariants of ReturnOp are directly tied to FuncOp, making ReturnOp invalid in any other context. PiperOrigin-RevId: 258421200 --- mlir/include/mlir/IR/OpDefinition.h | 6 +++++- mlir/lib/StandardOps/Ops.cpp | 6 +++--- mlir/test/IR/invalid-ops.mlir | 10 ++++++++++ 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index 32f7efa..c275d03 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -71,7 +71,11 @@ public: /// Return the operation that this refers to. Operation *getOperation() { return state; } - /// Return the closes surrounding parent operation that is of type 'OpTy'. + /// Returns the closest surrounding operation that contains this operation + /// or nullptr if this is a top-level operation. + Operation *getParentOp() { return getOperation()->getParentOp(); } + + /// Return the closest surrounding parent operation that is of type 'OpTy'. template OpTy getParentOfType() { return getOperation()->getParentOfType(); } diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index 2fa37ff..738446a 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -1868,9 +1868,9 @@ static void print(OpAsmPrinter *p, ReturnOp op) { } static LogicalResult verify(ReturnOp op) { - // TODO(b/137008268): Return op should verify that it is nested directly - // within a function operation. - auto function = op.getParentOfType(); + auto function = dyn_cast_or_null(op.getParentOp()); + if (!function) + return op.emitOpError() << "must be nested within a 'func' region"; // The operand number and types must match the function signature. const auto &results = function.getType().getResults(); diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir index c13514f..2991d12 100644 --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -811,3 +811,13 @@ func @std_if_illegal_block_argument(%arg0: i1) { }, {}): (i1) -> () return } + +// ----- + +func @return_not_in_function() { + "foo.region"() ({ + // expected-error@+1 {{must be nested within a 'func' region}} + return + }): () -> () + return +} -- 2.7.4