From 2d22b1e04e7f75ecc38247fc2a5cd18058374cc0 Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Sat, 14 Dec 2019 11:18:01 -0800 Subject: [PATCH] Add verifyCompatibleShape function overload with shapes PiperOrigin-RevId: 285574334 --- mlir/include/mlir/IR/TypeUtilities.h | 6 ++++++ mlir/lib/IR/TypeUtilities.cpp | 29 ++++++++++++++++++----------- 2 files changed, 24 insertions(+), 11 deletions(-) diff --git a/mlir/include/mlir/IR/TypeUtilities.h b/mlir/include/mlir/IR/TypeUtilities.h index 6512f8f..c1d1095 100644 --- a/mlir/include/mlir/IR/TypeUtilities.h +++ b/mlir/include/mlir/IR/TypeUtilities.h @@ -52,6 +52,12 @@ SmallVector getFlattenedTypes(TupleType t); /// dialect and typeData. bool isOpaqueTypeWithName(Type type, StringRef dialect, StringRef typeData); +/// Returns success if the given two shapes are compatible. That is, they have +/// the same size and each pair of the elements are equal or one of them is +/// dynamic. +LogicalResult verifyCompatibleShape(ArrayRef shape1, + ArrayRef shape2); + /// Returns success if the given two types have compatible shape. That is, /// they are both scalars (not shaped), or they are both shaped types and at /// least one is unranked or they have compatible dimensions. Dimensions are diff --git a/mlir/lib/IR/TypeUtilities.cpp b/mlir/lib/IR/TypeUtilities.cpp index 0172141..54b1bf6 100644 --- a/mlir/lib/IR/TypeUtilities.cpp +++ b/mlir/lib/IR/TypeUtilities.cpp @@ -61,6 +61,23 @@ bool mlir::isOpaqueTypeWithName(Type type, StringRef dialect, return false; } +/// Returns success if the given two shapes are compatible. That is, they have +/// the same size and each pair of the elements are equal or one of them is +/// dynamic. +LogicalResult mlir::verifyCompatibleShape(ArrayRef shape1, + ArrayRef shape2) { + if (shape1.size() != shape2.size()) + return failure(); + for (const auto &dims : llvm::zip(shape1, shape2)) { + int64_t dim1 = std::get<0>(dims); + int64_t dim2 = std::get<1>(dims); + if (!ShapedType::isDynamic(dim1) && !ShapedType::isDynamic(dim2) && + dim1 != dim2) + return failure(); + } + return success(); +} + /// Returns success if the given two types have compatible shape. That is, /// they are both scalars (not shaped), or they are both shaped types and at /// least one is unranked or they have compatible dimensions. Dimensions are @@ -79,17 +96,7 @@ LogicalResult mlir::verifyCompatibleShape(Type type1, Type type2) { if (!sType1.hasRank() || !sType2.hasRank()) return success(); - if (sType1.getRank() != sType2.getRank()) - return failure(); - - for (const auto &dims : llvm::zip(sType1.getShape(), sType2.getShape())) { - int64_t dim1 = std::get<0>(dims); - int64_t dim2 = std::get<1>(dims); - if (!ShapedType::isDynamic(dim1) && !ShapedType::isDynamic(dim2) && - dim1 != dim2) - return failure(); - } - return success(); + return verifyCompatibleShape(sType1.getShape(), sType2.getShape()); } OperandElementTypeIterator::OperandElementTypeIterator( -- 2.7.4