Add a `HasParent` operation trait to enforce a specific parent on an operation (NFC)
authorMehdi Amini <aminim@google.com>
Mon, 29 Jul 2019 17:45:17 +0000 (10:45 -0700)
committerjpienaar <jpienaar@google.com>
Tue, 30 Jul 2019 13:17:11 +0000 (06:17 -0700)
PiperOrigin-RevId: 260532592

mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/StandardOps/Ops.td
mlir/lib/StandardOps/Ops.cpp
mlir/test/IR/invalid-ops.mlir
mlir/test/IR/traits.mlir
mlir/test/lib/TestDialect/TestOps.td

index c23331c..7eb1d7e 100644 (file)
@@ -1018,6 +1018,10 @@ def Terminator       : NativeOpTrait<"IsTerminator">;
 class SingleBlockImplicitTerminator<string op>
     : ParamNativeOpTrait<"SingleBlockImplicitTerminator", op>;
 
+// Op's parent operation is the provided one.
+class HasParent<string op>
+    : ParamNativeOpTrait<"HasParent", op>;
+
 // Op result type is derived from the first attribute. If the attribute is an
 // subclass of `TypeAttrBase`, its value is used, otherwise, the type of the
 // attribute content is used.
index 4ab4996..e59093b 100644 (file)
@@ -845,6 +845,20 @@ template <typename TerminatorOpType> struct SingleBlockImplicitTerminator {
   };
 };
 
+/// This class provides a verifier for ops that are expecting a specific parent.
+template <typename ParentOpType> struct HasParent {
+  template <typename ConcreteType>
+  class Impl : public TraitBase<ConcreteType, Impl> {
+  public:
+    static LogicalResult verifyTrait(Operation *op) {
+      if (isa<ParentOpType>(op->getParentOp()))
+        return success();
+      return op->emitOpError() << "expects parent op '"
+                               << ParentOpType::getOperationName() << "'";
+    }
+  };
+};
+
 } // end namespace OpTrait
 
 //===----------------------------------------------------------------------===//
index 15dd843..b6bf2cf 100644 (file)
@@ -765,7 +765,7 @@ def RemIUOp : IntArithmeticOp<"remiu"> {
   let hasFolder = 1;
 }
 
-def ReturnOp : Std_Op<"return", [Terminator]> {
+def ReturnOp : Std_Op<"return", [Terminator, HasParent<"FuncOp">]> {
   let summary = "return operation";
   let description = [{
     The "return" operation represents a return operation within a function.
index 9d3d5e9..df99f00 100644 (file)
@@ -1868,9 +1868,7 @@ static void print(OpAsmPrinter *p, ReturnOp op) {
 }
 
 static LogicalResult verify(ReturnOp op) {
-  auto function = dyn_cast_or_null<FuncOp>(op.getParentOp());
-  if (!function)
-    return op.emitOpError() << "must be nested within a 'func' region";
+  auto function = cast<FuncOp>(op.getParentOp());
 
   // The operand number and types must match the function signature.
   const auto &results = function.getType().getResults();
index 33944cd..c04b53b 100644 (file)
@@ -720,7 +720,7 @@ func @sitofp_f32_to_i32(%arg0 : f32) {
 
 func @return_not_in_function() {
   "foo.region"() ({
-    // expected-error@+1 {{must be nested within a 'func' region}}
+    // expected-error@+1 {{'std.return' op expects parent op 'func'}}
     return
   }): () -> ()
   return
index 281e3c0..c8aedee 100644 (file)
@@ -40,3 +40,13 @@ func @succeededSameOperandAndResultShape(%t10x10 : tensor<10x10xf32>, %t1: tenso
   // expected-error@+1 {{requires the same shape for all operands and results}}
   %0 = "test.same_operand_and_result_shape"(%t1, %t10x10) : (tensor<1xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
 }
+
+// -----
+
+func @hasParent() {
+  "some.op"() ({
+   // expected-error@+1 {{'test.child' op expects parent op 'test.parent'}}
+    "test.child"() : () -> ()
+  }) : () -> ()
+}
+
index fc54460..ae87037 100644 (file)
@@ -201,6 +201,10 @@ def BroadcastableOp : TEST_Op<"broadcastable", [Broadcastable]> {
   let results = (outs AnyTensor:$res);
 }
 
+// There the "HasParent" trait.
+def ParentOp : TEST_Op<"parent">;
+def ChildOp : TEST_Op<"child", [HasParent<"ParentOp">]>;
+
 //===----------------------------------------------------------------------===//
 // Test Patterns
 //===----------------------------------------------------------------------===//