[MLIR][Shape] Refactor verification
authorJacques Pienaar <jpienaar@google.com>
Sat, 25 Jul 2020 21:55:19 +0000 (14:55 -0700)
committerJacques Pienaar <jpienaar@google.com>
Sat, 25 Jul 2020 21:55:19 +0000 (14:55 -0700)
Based on https://reviews.llvm.org/D84439 but less restrictive, else we
don't allow shape_of to be able to produce a ranked output and doesn't
allow for iterative refinement here. We can consider making it more
restrictive later.

mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/invalid.mlir

index 797dc0b..8c32fae 100644 (file)
@@ -207,7 +207,7 @@ def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> {
 
   let hasFolder = 1;
   let hasCanonicalizer = 1;
-  let verifier = [{ return ::verify(*this); }];
+  let verifier = [{ return ::verifySizeOrIndexOp(*this); }];
 }
 
 def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> {
@@ -252,7 +252,7 @@ def Shape_GetExtentOp : Shape_Op<"get_extent", [NoSideEffect]> {
   }];
 
   let hasFolder = 1;
-  let verifier = [{ return ::verify(*this); }];
+  let verifier = [{ return ::verifySizeOrIndexOp(*this); }];
 }
 
 def Shape_IndexToSizeOp : Shape_Op<"index_to_size", [NoSideEffect]> {
@@ -325,7 +325,7 @@ def Shape_MulOp : Shape_Op<"mul", [Commutative, NoSideEffect]> {
     $lhs `,` $rhs `:` type($lhs) `,` type($rhs) `->` type($result) attr-dict
   }];
 
-  let verifier = [{ return ::verify(*this); }];
+  let verifier = [{ return ::verifySizeOrIndexOp(*this); }];
   let hasFolder = 1;
 }
 
@@ -412,7 +412,7 @@ def Shape_ShapeOfOp : Shape_Op<"shape_of", [NoSideEffect]> {
 
   let assemblyFormat = "$arg `:` type($arg) `->` type($result) attr-dict";
 
-  let verifier = [{ return ::verify(*this); }];
+  let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }];
   let hasFolder = 1;
 }
 
index d2b0dbd..104ab46 100644 (file)
@@ -28,13 +28,37 @@ static RankedTensorType getExtentTensorType(MLIRContext *ctx) {
   return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx));
 }
 
-static bool isErrorPropagationPossible(ArrayRef<Type> operandTypes) {
+static bool isErrorPropagationPossible(TypeRange operandTypes) {
   for (Type ty : operandTypes)
     if (ty.isa<SizeType>() || ty.isa<ShapeType>() || ty.isa<ValueShapeType>())
       return true;
   return false;
 }
 
+static LogicalResult verifySizeOrIndexOp(Operation *op) {
+  assert(op != nullptr && op->getNumResults() == 1);
+  Type resultTy = op->getResultTypes().front();
+  if (isErrorPropagationPossible(op->getOperandTypes())) {
+    if (!resultTy.isa<SizeType>())
+      return op->emitOpError()
+             << "if at least one of the operands can hold error values then "
+                "the result must be of type `size` to propagate them";
+  }
+  return success();
+}
+
+static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
+  assert(op != nullptr && op->getNumResults() == 1);
+  Type resultTy = op->getResultTypes().front();
+  if (isErrorPropagationPossible(op->getOperandTypes())) {
+    if (!resultTy.isa<ShapeType>())
+      return op->emitOpError()
+             << "if at least one of the operands can hold error values then "
+                "the result must be of type `shape` to propagate them";
+  }
+  return success();
+}
+
 ShapeDialect::ShapeDialect(MLIRContext *context)
     : Dialect(getDialectNamespace(), context) {
   addOperations<
@@ -542,23 +566,6 @@ OpFoldResult FromExtentsOp::fold(ArrayRef<Attribute> operands) {
 // GetExtentOp
 //===----------------------------------------------------------------------===//
 
-static LogicalResult verify(GetExtentOp op) {
-  Type shapeTy = op.shape().getType();
-  Type dimTy = op.dim().getType();
-  Type extentTy = op.extent().getType();
-  if (isErrorPropagationPossible({shapeTy, dimTy})) {
-    if (!extentTy.isa<SizeType>())
-      op.emitError()
-          << "if at least one of the operands can hold error values then the "
-             "result must be of type `size` to propagate them";
-  } else {
-    if (extentTy.isa<SizeType>())
-      op.emitError() << "if none of the operands can hold error values then "
-                        "the result must be of type `index`";
-  }
-  return success();
-}
-
 Optional<int64_t> GetExtentOp::getConstantDim() {
   if (auto constSizeOp = dim().getDefiningOp<ConstSizeOp>())
     return constSizeOp.value().getLimitedValue();
@@ -597,15 +604,6 @@ void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
 // RankOp
 //===----------------------------------------------------------------------===//
 
-static LogicalResult verify(shape::RankOp op) {
-  if (op.shape().getType().isa<ShapeType>() &&
-      !op.rank().getType().isa<SizeType>())
-    return op.emitOpError()
-           << "if operand is of type `shape` then the result must be of type "
-              "`size` to propagate potential errors";
-  return success();
-}
-
 OpFoldResult shape::RankOp::fold(ArrayRef<Attribute> operands) {
   auto shape = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
   if (!shape)
@@ -680,21 +678,6 @@ OpFoldResult NumElementsOp::fold(ArrayRef<Attribute> operands) {
 // MulOp
 //===----------------------------------------------------------------------===//
 
-static LogicalResult verify(MulOp op) {
-  Type resultTy = op.result().getType();
-  if (isErrorPropagationPossible({op.lhs().getType(), op.rhs().getType()})) {
-    if (!resultTy.isa<SizeType>())
-      return op.emitOpError()
-             << "if at least one of the operands can hold error values then "
-                "the result must be of type `size` to propagate them";
-  } else {
-    if (resultTy.isa<SizeType>())
-      return op.emitError() << "if none of the operands can hold error values "
-                               "then the result must be of type `index`";
-  }
-  return success();
-}
-
 OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
   auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
   if (!lhs)
@@ -719,21 +702,6 @@ OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
   return builder.getIndexTensorAttr(type.getShape());
 }
 
-static LogicalResult verify(ShapeOfOp op) {
-  Type resultTy = op.result().getType();
-  if (isErrorPropagationPossible(op.arg().getType())) {
-    if (!resultTy.isa<ShapeType>())
-      return op.emitOpError()
-             << "if operand is of type `value_shape` then the result must be "
-                "of type `shape` to propagate potential error shapes";
-  } else {
-    if (resultTy != getExtentTensorType(op.getContext()))
-      return op.emitOpError() << "if operand is a shaped type then the result "
-                                 "must be an extent tensor";
-  }
-  return success();
-}
-
 //===----------------------------------------------------------------------===//
 // SizeToIndexOp
 //===----------------------------------------------------------------------===//
index b4900e4..20f4e87 100644 (file)
@@ -90,39 +90,21 @@ func @assuming_all_op_too_few_operands() {
 
 func @shape_of(%value_arg : !shape.value_shape,
                %shaped_arg : tensor<?x3x4xf32>) {
-  // expected-error@+1 {{if operand is of type `value_shape` then the result must be of type `shape` to propagate potential error shapes}}
+  // expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `shape` to propagate them}}
   %0 = shape.shape_of %value_arg : !shape.value_shape -> tensor<?xindex>
   return
 }
 
 // -----
 
-func @shape_of(%value_arg : !shape.value_shape,
-               %shaped_arg : tensor<?x3x4xf32>) {
-  // expected-error@+1 {{if operand is a shaped type then the result must be an extent tensor}}
-  %1 = shape.shape_of %shaped_arg : tensor<?x3x4xf32> -> !shape.shape
-  return
-}
-
-// -----
-
 func @rank(%arg : !shape.shape) {
-  // expected-error@+1 {{if operand is of type `shape` then the result must be of type `size` to propagate potential errors}}
+  // expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}}
   %0 = shape.rank %arg : !shape.shape -> index
   return
 }
 
 // -----
 
-func @get_extent_error_free(%arg : tensor<?xindex>) -> !shape.size {
-  %c0 = constant 0 : index
-  // expected-error@+1 {{if none of the operands can hold error values then the result must be of type `index`}}
-  %result = shape.get_extent %arg, %c0 : tensor<?xindex>, index -> !shape.size
-  return %result : !shape.size
-}
-
-// -----
-
 func @get_extent_error_possible(%arg : tensor<?xindex>) -> index {
   %c0 = shape.const_size 0
   // expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}}
@@ -132,14 +114,6 @@ func @get_extent_error_possible(%arg : tensor<?xindex>) -> index {
 
 // -----
 
-func @mul_error_free(%arg : index) -> !shape.size {
-  // expected-error@+1 {{if none of the operands can hold error values then the result must be of type `index`}}
-  %result = shape.mul %arg, %arg : index, index -> !shape.size
-  return %result : !shape.size
-}
-
-// -----
-
 func @mul_error_possible(%lhs : !shape.size, %rhs : index) -> index {
   // expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}}
   %result = shape.mul %lhs, %rhs : !shape.size, index -> index