[mlir] Add verifier for `shape.yield`.
authorAlexander Belyaev <pifon@google.com>
Fri, 5 Jun 2020 15:01:43 +0000 (17:01 +0200)
committerAlexander Belyaev <pifon@google.com>
Sun, 7 Jun 2020 13:40:11 +0000 (15:40 +0200)
Differential Revision: https://reviews.llvm.org/D81262

mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/invalid.mlir
mlir/test/Dialect/Shape/ops.mlir

index 63101b9..40fde48 100644 (file)
@@ -367,7 +367,11 @@ def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [NoSideEffect]> {
   let hasFolder = 1;
 }
 
-def Shape_YieldOp : Shape_Op<"yield", [NoSideEffect, Terminator]> {
+def Shape_YieldOp : Shape_Op<"yield", 
+    [HasParent<"ReduceOp">,
+     NoSideEffect,
+     ReturnLike,
+     Terminator]> {
   let summary = "Returns the value to parent op";
 
   let arguments = (ins Variadic<AnyType>:$operands);
@@ -376,6 +380,7 @@ def Shape_YieldOp : Shape_Op<"yield", [NoSideEffect, Terminator]> {
     "OpBuilder &b, OperationState &result", [{ build(b, result, llvm::None); }]
   >];
 
+  let verifier = [{ return ::verify(*this); }];
   let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
 }
 
index 343c427..d29f48e 100644 (file)
@@ -391,6 +391,26 @@ OpFoldResult SizeToIndexOp::fold(ArrayRef<Attribute> operands) {
 }
 
 //===----------------------------------------------------------------------===//
+// YieldOp
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verify(YieldOp op) {
+  auto *parentOp = op.getParentOp();
+  auto results = parentOp->getResults();
+  auto operands = op.getOperands();
+
+  if (parentOp->getNumResults() != op.getNumOperands())
+    return op.emitOpError() << "number of operands does not match number of "
+                               "results of its parent";
+  for (auto e : llvm::zip(results, operands))
+    if (std::get<0>(e).getType() != std::get<1>(e).getType())
+      return op.emitOpError()
+             << "types mismatch between yield op and its parent";
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
 // SplitAtOp
 //===----------------------------------------------------------------------===//
 
index 63589c8..41105dc 100644 (file)
@@ -4,7 +4,7 @@ func @reduce_op_args_num_mismatch(%shape : !shape.shape, %init : !shape.size) {
   // expected-error@+1 {{ReduceOp body is expected to have 3 arguments}}
   %num_elements = shape.reduce(%shape, %init) -> !shape.size {
     ^bb0(%index: index, %dim: !shape.size):
-      "shape.yield"(%dim) : (!shape.size) -> ()
+      shape.yield %dim : !shape.size
   }
 }
 
@@ -13,9 +13,10 @@ func @reduce_op_args_num_mismatch(%shape : !shape.shape, %init : !shape.size) {
 func @reduce_op_arg0_wrong_type(%shape : !shape.shape, %init : !shape.size) {
   // expected-error@+1 {{argument 0 of ReduceOp body is expected to be of IndexType}}
   %num_elements = shape.reduce(%shape, %init) -> !shape.size {
-    ^bb0(%index: f32, %dim: !shape.size, %lci: !shape.size):
-      %acc = "shape.add"(%lci, %dim) : (!shape.size, !shape.size) -> !shape.size
-      "shape.yield"(%acc) : (!shape.size) -> ()
+    ^bb0(%index: f32, %dim: !shape.size, %acc: !shape.size):
+      %new_acc = "shape.add"(%acc, %dim)
+          : (!shape.size, !shape.size) -> !shape.size
+      shape.yield %new_acc : !shape.size
   }
 }
 
@@ -25,7 +26,7 @@ func @reduce_op_arg1_wrong_type(%shape : !shape.shape, %init : !shape.size) {
   // expected-error@+1 {{argument 1 of ReduceOp body is expected to be of SizeType}}
   %num_elements = shape.reduce(%shape, %init) -> !shape.size {
     ^bb0(%index: index, %dim: f32, %lci: !shape.size):
-      "shape.yield"() : () -> ()
+      shape.yield
   }
 }
 
@@ -35,6 +36,27 @@ func @reduce_op_init_type_mismatch(%shape : !shape.shape, %init : f32) {
   // expected-error@+1 {{type mismatch between argument 2 of ReduceOp body and initial value 0}}
   %num_elements = shape.reduce(%shape, %init) -> f32 {
     ^bb0(%index: index, %dim: !shape.size, %lci: !shape.size):
-      "shape.yield"() : () -> ()
+      shape.yield
+  }
+}
+
+// -----
+
+func @yield_op_args_num_mismatch(%shape : !shape.shape, %init : !shape.size) {
+  // expected-error@+3 {{number of operands does not match number of results of its parent}}
+  %num_elements = shape.reduce(%shape, %init) -> !shape.size {
+    ^bb0(%index: index, %dim: !shape.size, %lci: !shape.size):
+      shape.yield %dim, %dim : !shape.size, !shape.size
+  }
+}
+
+// -----
+
+func @yield_op_type_mismatch(%shape : !shape.shape, %init : !shape.size) {
+  // expected-error@+4 {{types mismatch between yield op and its parent}}
+  %num_elements = shape.reduce(%shape, %init) -> !shape.size {
+    ^bb0(%index: index, %dim: !shape.size, %lci: !shape.size):
+      %c0 = constant 1 : index
+      shape.yield %c0 : index
   }
 }
index 0df58ed..51919ac 100644 (file)
@@ -10,7 +10,7 @@ func @shape_num_elements(%shape : !shape.shape) -> !shape.size {
   %num_elements = shape.reduce(%shape, %init) -> !shape.size {
     ^bb0(%index: index, %dim: !shape.size, %lci: !shape.size):
       %acc = "shape.add"(%lci, %dim) : (!shape.size, !shape.size) -> !shape.size
-      "shape.yield"(%acc) : (!shape.size) -> ()
+      shape.yield %acc : !shape.size
   }
   return %num_elements : !shape.size
 }