Add verifyCompatibleShape function overload with shapes
authorSmit Hinsu <hinsu@google.com>
Sat, 14 Dec 2019 19:18:01 +0000 (11:18 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Sat, 14 Dec 2019 19:18:38 +0000 (11:18 -0800)
PiperOrigin-RevId: 285574334

mlir/include/mlir/IR/TypeUtilities.h
mlir/lib/IR/TypeUtilities.cpp

index 6512f8f..c1d1095 100644 (file)
@@ -52,6 +52,12 @@ SmallVector<Type, 10> getFlattenedTypes(TupleType t);
 /// dialect and typeData.
 bool isOpaqueTypeWithName(Type type, StringRef dialect, StringRef typeData);
 
+/// Returns success if the given two shapes are compatible. That is, they have
+/// the same size and each pair of the elements are equal or one of them is
+/// dynamic.
+LogicalResult verifyCompatibleShape(ArrayRef<int64_t> shape1,
+                                    ArrayRef<int64_t> shape2);
+
 /// Returns success if the given two types have compatible shape. That is,
 /// they are both scalars (not shaped), or they are both shaped types and at
 /// least one is unranked or they have compatible dimensions. Dimensions are
index 0172141..54b1bf6 100644 (file)
@@ -61,6 +61,23 @@ bool mlir::isOpaqueTypeWithName(Type type, StringRef dialect,
   return false;
 }
 
+/// Returns success if the given two shapes are compatible. That is, they have
+/// the same size and each pair of the elements are equal or one of them is
+/// dynamic.
+LogicalResult mlir::verifyCompatibleShape(ArrayRef<int64_t> shape1,
+                                          ArrayRef<int64_t> shape2) {
+  if (shape1.size() != shape2.size())
+    return failure();
+  for (const auto &dims : llvm::zip(shape1, shape2)) {
+    int64_t dim1 = std::get<0>(dims);
+    int64_t dim2 = std::get<1>(dims);
+    if (!ShapedType::isDynamic(dim1) && !ShapedType::isDynamic(dim2) &&
+        dim1 != dim2)
+      return failure();
+  }
+  return success();
+}
+
 /// Returns success if the given two types have compatible shape. That is,
 /// they are both scalars (not shaped), or they are both shaped types and at
 /// least one is unranked or they have compatible dimensions. Dimensions are
@@ -79,17 +96,7 @@ LogicalResult mlir::verifyCompatibleShape(Type type1, Type type2) {
   if (!sType1.hasRank() || !sType2.hasRank())
     return success();
 
-  if (sType1.getRank() != sType2.getRank())
-    return failure();
-
-  for (const auto &dims : llvm::zip(sType1.getShape(), sType2.getShape())) {
-    int64_t dim1 = std::get<0>(dims);
-    int64_t dim2 = std::get<1>(dims);
-    if (!ShapedType::isDynamic(dim1) && !ShapedType::isDynamic(dim2) &&
-        dim1 != dim2)
-      return failure();
-  }
-  return success();
+  return verifyCompatibleShape(sType1.getShape(), sType2.getShape());
 }
 
 OperandElementTypeIterator::OperandElementTypeIterator(