Allow dynamic but ranked types in ops with SameOperandsAndResultShape and SameOperand...
authorSmit Hinsu <hinsu@google.com>
Wed, 9 Oct 2019 02:36:37 +0000 (19:36 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 9 Oct 2019 02:37:11 +0000 (19:37 -0700)
Currently SameOperandsAndResultShape trait allows operands to have tensor<*xf32> and tensor<2xf32> but doesn't allow tensor<?xf32> and tensor<10xf32>.

Also, use the updated shape compatibility helper function in TensorCastOp::areCastCompatible method.

PiperOrigin-RevId: 273658336

mlir/include/mlir/IR/TypeUtilities.h
mlir/lib/Dialect/StandardOps/Ops.cpp
mlir/lib/IR/Operation.cpp
mlir/lib/IR/TypeUtilities.cpp
mlir/test/IR/invalid-ops.mlir
mlir/test/IR/traits.mlir
mlir/test/lib/TestDialect/TestOps.td

index ce0169f..49d57e8 100644 (file)
@@ -52,6 +52,13 @@ SmallVector<Type, 10> getFlattenedTypes(TupleType t);
 /// dialect and typeData.
 bool isOpaqueTypeWithName(Type type, StringRef dialect, StringRef typeData);
 
+/// Returns success if the given two types have compatible shape. That is,
+/// they are both scalars (not shaped), or they are both shaped types and at
+/// least one is unranked or they have compatible dimensions. Dimensions are
+/// compatible if at least one is dynamic or both are equal. The element type
+/// does not matter.
+LogicalResult verifyCompatibleShape(Type type1, Type type2);
+
 //===----------------------------------------------------------------------===//
 // Utility Iterators
 //===----------------------------------------------------------------------===//
index f71fde7..5cbdb67 100644 (file)
@@ -2215,24 +2215,7 @@ bool TensorCastOp::areCastCompatible(Type a, Type b) {
   if (aT.getElementType() != bT.getElementType())
     return false;
 
-  // If the either are unranked, then the cast is valid.
-  auto aRType = aT.dyn_cast<RankedTensorType>();
-  auto bRType = bT.dyn_cast<RankedTensorType>();
-  if (!aRType || !bRType)
-    return true;
-
-  // If they are both ranked, they have to have the same rank, and any specified
-  // dimensions must match.
-  if (aRType.getRank() != bRType.getRank())
-    return false;
-
-  for (unsigned i = 0, e = aRType.getRank(); i != e; ++i) {
-    int64_t aDim = aRType.getDimSize(i), bDim = bRType.getDimSize(i);
-    if (aDim != -1 && bDim != -1 && aDim != bDim)
-      return false;
-  }
-
-  return true;
+  return succeeded(verifyCompatibleShape(aT, bT));
 }
 
 OpFoldResult TensorCastOp::fold(ArrayRef<Attribute> operands) {
index 23983bc..adf38ca 100644 (file)
@@ -748,33 +748,13 @@ LogicalResult OpTrait::impl::verifyAtLeastNResults(Operation *op,
   return success();
 }
 
-/// Returns success if the given two types have the same shape. That is,
-/// they are both scalars (not shaped), or they are both shaped types and at
-/// least one is unranked or they have the same shape. The element type does not
-/// matter.
-static LogicalResult verifyShapeMatch(Type type1, Type type2) {
-  auto sType1 = type1.dyn_cast<ShapedType>();
-  auto sType2 = type2.dyn_cast<ShapedType>();
-
-  // Either both or neither type should be shaped.
-  if (!sType1)
-    return success(!sType2);
-  if (!sType2)
-    return failure();
-
-  if (!sType1.hasRank() || !sType2.hasRank())
-    return success();
-
-  return success(sType1.getShape() == sType2.getShape());
-}
-
 LogicalResult OpTrait::impl::verifySameOperandsShape(Operation *op) {
   if (failed(verifyAtLeastNOperands(op, 1)))
     return failure();
 
   auto type = op->getOperand(0)->getType();
   for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) {
-    if (failed(verifyShapeMatch(opType, type)))
+    if (failed(verifyCompatibleShape(opType, type)))
       return op->emitOpError() << "requires the same shape for all operands";
   }
   return success();
@@ -787,12 +767,12 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) {
 
   auto type = op->getOperand(0)->getType();
   for (auto resultType : op->getResultTypes()) {
-    if (failed(verifyShapeMatch(resultType, type)))
+    if (failed(verifyCompatibleShape(resultType, type)))
       return op->emitOpError()
              << "requires the same shape for all operands and results";
   }
   for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) {
-    if (failed(verifyShapeMatch(opType, type)))
+    if (failed(verifyCompatibleShape(opType, type)))
       return op->emitOpError()
              << "requires the same shape for all operands and results";
   }
@@ -843,13 +823,16 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) {
     return failure();
 
   auto type = op->getResult(0)->getType();
+  auto elementType = getElementTypeOrSelf(type);
   for (auto resultType : llvm::drop_begin(op->getResultTypes(), 1)) {
-    if (resultType != type)
+    if (getElementTypeOrSelf(resultType) != elementType ||
+        failed(verifyCompatibleShape(resultType, type)))
       return op->emitOpError()
              << "requires the same type for all operands and results";
   }
   for (auto opType : op->getOperandTypes()) {
-    if (opType != type)
+    if (getElementTypeOrSelf(opType) != elementType ||
+        failed(verifyCompatibleShape(opType, type)))
       return op->emitOpError()
              << "requires the same type for all operands and results";
   }
index 95895af..a963a8d 100644 (file)
@@ -61,6 +61,37 @@ bool mlir::isOpaqueTypeWithName(Type type, StringRef dialect,
   return false;
 }
 
+/// Returns success if the given two types have compatible shape. That is,
+/// they are both scalars (not shaped), or they are both shaped types and at
+/// least one is unranked or they have compatible dimensions. Dimensions are
+/// compatible if at least one is dynamic or both are equal. The element type
+/// does not matter.
+LogicalResult mlir::verifyCompatibleShape(Type type1, Type type2) {
+  auto sType1 = type1.dyn_cast<ShapedType>();
+  auto sType2 = type2.dyn_cast<ShapedType>();
+
+  // Either both or neither type should be shaped.
+  if (!sType1)
+    return success(!sType2);
+  if (!sType2)
+    return failure();
+
+  if (!sType1.hasRank() || !sType2.hasRank())
+    return success();
+
+  if (sType1.getRank() != sType2.getRank())
+    return failure();
+
+  for (const auto &dims : llvm::zip(sType1.getShape(), sType2.getShape())) {
+    int64_t dim1 = std::get<0>(dims);
+    int64_t dim2 = std::get<1>(dims);
+    if (!ShapedType::isDynamic(dim1) && !ShapedType::isDynamic(dim2) &&
+        dim1 != dim2)
+      return failure();
+  }
+  return success();
+}
+
 OperandElementTypeIterator::OperandElementTypeIterator(OperandIterator it)
     : llvm::mapped_iterator<OperandIterator, Type (*)(Value *)>(it, &unwrap) {}
 
index d28200b..be44a6b 100644 (file)
@@ -297,14 +297,6 @@ func @func_with_ops(i1, tensor<42xi32>, tensor<?xi32>) {
 
 // -----
 
-func @func_with_ops(tensor<?xi1>, tensor<42xi32>, tensor<42xi32>) {
-^bb0(%cond : tensor<?xi1>, %t : tensor<42xi32>, %f : tensor<42xi32>):
-  // expected-error@+1 {{requires the same shape for all operands and results}}
-  %r = "std.select"(%cond, %t, %f) : (tensor<?xi1>, tensor<42xi32>, tensor<42xi32>) -> tensor<42xi32>
-}
-
-// -----
-
 func @test_vector.transfer_read(memref<?x?xf32>) {
 ^bb0(%arg0: memref<?x?xf32>):
   %c3 = constant 3 : index
index 926547c..6c7fddb 100644 (file)
@@ -113,10 +113,11 @@ func @failedSameOperandShape_no_operands() {
 // -----
 
 // CHECK: succeededSameOperandAndResultShape
-func @succeededSameOperandAndResultShape(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>, %tr: tensor<*xf32>) {
+func @succeededSameOperandAndResultShape(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>, %tr: tensor<*xf32>, %t1d: tensor<?xf32>) {
   %0 = "test.same_operand_and_result_shape"(%t1, %t1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
   %1 = "test.same_operand_and_result_shape"(%t10x10, %t10x10) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
   %2 = "test.same_operand_and_result_shape"(%t1, %tr) : (tensor<1xf32>, tensor<*xf32>) -> tensor<1xf32>
+  %3 = "test.same_operand_and_result_shape"(%t1, %t1d) : (tensor<1xf32>, tensor<?xf32>) -> tensor<1xf32>
   return
 }
 
@@ -143,6 +144,24 @@ func @failedSameOperandAndResultShape_no_operands(%t1: tensor<1xf32>) {
 
 // -----
 
+// CHECK: succeededSameOperandAndResultType
+func @succeededSameOperandAndResultType(%t10x10 : tensor<10x10xf32>, %t1: tensor<1xf32>, %tr: tensor<*xf32>, %t1d: tensor<?xf32>) {
+  %0 = "test.same_operand_and_result_type"(%t1, %t1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+  %1 = "test.same_operand_and_result_type"(%t10x10, %t10x10) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
+  %2 = "test.same_operand_and_result_type"(%t1, %tr) : (tensor<1xf32>, tensor<*xf32>) -> tensor<1xf32>
+  %3 = "test.same_operand_and_result_type"(%t1, %t1d) : (tensor<1xf32>, tensor<?xf32>) -> tensor<1xf32>
+  return
+}
+
+// -----
+
+func @failedSameOperandAndResultType_operand_result_mismatch(%t10 : tensor<10xf32>, %t20 : tensor<20xf32>) {
+  // expected-error@+1 {{requires the same type for all operands and results}}
+  %0 = "test.same_operand_and_result_type"(%t10, %t20) : (tensor<10xf32>, tensor<20xf32>) -> tensor<10xf32>
+}
+
+// -----
+
 func @failedHasParent_wrong_parent() {
   "some.op"() ({
    // expected-error@+1 {{'test.child' op expects parent op 'test.parent'}}
index dd620de..68bae70 100644 (file)
@@ -257,6 +257,12 @@ def SameOperandAndResultShapeOp : TEST_Op<"same_operand_and_result_shape",
   let results = (outs Variadic<AnyVectorOrTensor>);
 }
 
+def SameOperandAndResultTypeOp : TEST_Op<"same_operand_and_result_type",
+    [SameOperandsAndResultType]> {
+  let arguments = (ins Variadic<AnyVectorOrTensor>);
+  let results = (outs Variadic<AnyVectorOrTensor>);
+}
+
 def ArgAndResHaveFixedElementTypesOp :
     TEST_Op<"arg_and_res_have_fixed_element_types",
       [PredOpTrait<"fixed type combination",