From d5a02fcd9642007b98f7c80cc4a0bb204a78805d Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Mon, 29 Jul 2019 10:45:17 -0700 Subject: [PATCH] Add a `HasParent` operation trait to enforce a specific parent on an operation (NFC) PiperOrigin-RevId: 260532592 --- mlir/include/mlir/IR/OpBase.td | 4 ++++ mlir/include/mlir/IR/OpDefinition.h | 14 ++++++++++++++ mlir/include/mlir/StandardOps/Ops.td | 2 +- mlir/lib/StandardOps/Ops.cpp | 4 +--- mlir/test/IR/invalid-ops.mlir | 2 +- mlir/test/IR/traits.mlir | 10 ++++++++++ mlir/test/lib/TestDialect/TestOps.td | 4 ++++ 7 files changed, 35 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index c23331c..7eb1d7e 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1018,6 +1018,10 @@ def Terminator : NativeOpTrait<"IsTerminator">; class SingleBlockImplicitTerminator : ParamNativeOpTrait<"SingleBlockImplicitTerminator", op>; +// Op's parent operation is the provided one. +class HasParent + : ParamNativeOpTrait<"HasParent", op>; + // Op result type is derived from the first attribute. If the attribute is an // subclass of `TypeAttrBase`, its value is used, otherwise, the type of the // attribute content is used. diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index 4ab4996..e59093b 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -845,6 +845,20 @@ template struct SingleBlockImplicitTerminator { }; }; +/// This class provides a verifier for ops that are expecting a specific parent. +template struct HasParent { + template + class Impl : public TraitBase { + public: + static LogicalResult verifyTrait(Operation *op) { + if (isa(op->getParentOp())) + return success(); + return op->emitOpError() << "expects parent op '" + << ParentOpType::getOperationName() << "'"; + } + }; +}; + } // end namespace OpTrait //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/StandardOps/Ops.td b/mlir/include/mlir/StandardOps/Ops.td index 15dd843..b6bf2cf 100644 --- a/mlir/include/mlir/StandardOps/Ops.td +++ b/mlir/include/mlir/StandardOps/Ops.td @@ -765,7 +765,7 @@ def RemIUOp : IntArithmeticOp<"remiu"> { let hasFolder = 1; } -def ReturnOp : Std_Op<"return", [Terminator]> { +def ReturnOp : Std_Op<"return", [Terminator, HasParent<"FuncOp">]> { let summary = "return operation"; let description = [{ The "return" operation represents a return operation within a function. diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index 9d3d5e9..df99f00 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -1868,9 +1868,7 @@ static void print(OpAsmPrinter *p, ReturnOp op) { } static LogicalResult verify(ReturnOp op) { - auto function = dyn_cast_or_null(op.getParentOp()); - if (!function) - return op.emitOpError() << "must be nested within a 'func' region"; + auto function = cast(op.getParentOp()); // The operand number and types must match the function signature. const auto &results = function.getType().getResults(); diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir index 33944cd..c04b53b 100644 --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -720,7 +720,7 @@ func @sitofp_f32_to_i32(%arg0 : f32) { func @return_not_in_function() { "foo.region"() ({ - // expected-error@+1 {{must be nested within a 'func' region}} + // expected-error@+1 {{'std.return' op expects parent op 'func'}} return }): () -> () return diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir index 281e3c0..c8aedee 100644 --- a/mlir/test/IR/traits.mlir +++ b/mlir/test/IR/traits.mlir @@ -40,3 +40,13 @@ func @succeededSameOperandAndResultShape(%t10x10 : tensor<10x10xf32>, %t1: tenso // expected-error@+1 {{requires the same shape for all operands and results}} %0 = "test.same_operand_and_result_shape"(%t1, %t10x10) : (tensor<1xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> } + +// ----- + +func @hasParent() { + "some.op"() ({ + // expected-error@+1 {{'test.child' op expects parent op 'test.parent'}} + "test.child"() : () -> () + }) : () -> () +} + diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td index fc54460..ae87037 100644 --- a/mlir/test/lib/TestDialect/TestOps.td +++ b/mlir/test/lib/TestDialect/TestOps.td @@ -201,6 +201,10 @@ def BroadcastableOp : TEST_Op<"broadcastable", [Broadcastable]> { let results = (outs AnyTensor:$res); } +// There the "HasParent" trait. +def ParentOp : TEST_Op<"parent">; +def ChildOp : TEST_Op<"child", [HasParent<"ParentOp">]>; + //===----------------------------------------------------------------------===// // Test Patterns //===----------------------------------------------------------------------===// -- 2.7.4