Only forbid mixing tensor and vector when considering broadcasting behavior
authorLei Zhang <antiagainst@google.com>
Thu, 9 May 2019 20:35:43 +0000 (13:35 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sat, 11 May 2019 02:26:43 +0000 (19:26 -0700)
    The previous approach is too restrictive; we end up forbidding all dialect-specific
    types as element types. Changed to not consider element types entirely.

--

PiperOrigin-RevId: 247486537

mlir/lib/Dialect/Traits.cpp

index 8988dcd..6762a60 100644 (file)
 
 using namespace mlir;
 
-/// Returns true if the given `type` supports NumPy broadcast semantics.
-/// Specifically, the given `type` must be integer type, floating point type,
-/// vector type, or ranked tensor type from integer or floating point types.
-static bool isBroadcastableType(Type type) {
-  switch (type.getKind()) {
-  case StandardTypes::BF16:
-  case StandardTypes::F16:
-  case StandardTypes::F32:
-  case StandardTypes::F64:
-  case StandardTypes::Integer:
-  case StandardTypes::Vector:
-    return true;
-  case StandardTypes::RankedTensor:
-  case StandardTypes::UnrankedTensor:
-    return type.cast<TensorType>().getElementType().isIntOrFloat();
-  default:
-    break;
-  }
-  return false;
-}
-
 bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
                                         ArrayRef<int64_t> shape2,
                                         SmallVectorImpl<int64_t> &resultShape) {
@@ -98,15 +77,19 @@ bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
   return true;
 }
 
+/// Returns the shape of the given type. Scalars will be considered as having a
+/// shape with zero dimensions.
+static ArrayRef<int64_t> getShape(Type type) {
+  if (auto vtType = type.dyn_cast<VectorOrTensorType>())
+    return vtType.getShape();
+  return {};
+}
+
 /// Returns the result broadcast composition type from the two given types by
 /// following NumPy broadcast semantics. Returned type may have dynamic shape if
 /// either of the input types has dynamic shape. Returns null type if the two
 /// given types are not broadcast-compatible.
 Type OpTrait::util::getBroadcastedType(Type type1, Type type2) {
-  // Make sure both types are able to participate in broadcasting.
-  if (!isBroadcastableType(type1) || !isBroadcastableType(type2))
-    return {};
-
   // Returns the scalar type out of the given type.
   auto getScalarType = [](Type type) -> Type {
     if (auto vtType = type.dyn_cast<VectorOrTensorType>())
@@ -152,13 +135,6 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2) {
     resultCompositeKind = compositeKind2;
   }
 
-  // Returns the shape of the given type.
-  auto getShape = [](Type type) -> ArrayRef<int64_t> {
-    if (auto vtType = type.dyn_cast<VectorOrTensorType>())
-      return vtType.getShape();
-    return {};
-  };
-
   // Get the shape of each type.
   SmallVector<int64_t, 4> resultShape;
   if (!getBroadcastedShape(getShape(type1), getShape(type2), resultShape))
@@ -172,16 +148,10 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2) {
   return scalarType;
 }
 
-/// Returns true if the two given types are both vectors or ranked tensors and
-/// they have the same shape, regardless of element types.
-static bool isSameShapedVectorOrTensor(Type type1, Type type2) {
-  if (auto vType1 = type1.dyn_cast<RankedTensorType>())
-    if (auto vType2 = type2.dyn_cast<RankedTensorType>())
-      return vType1.getShape() == vType2.getShape();
-  if (auto vType1 = type1.dyn_cast<VectorType>())
-    if (auto vType2 = type2.dyn_cast<VectorType>())
-      return vType1.getShape() == vType2.getShape();
-  return false;
+/// Returns true if the given types has both vector types and tensor types.
+static bool hasBothVectorAndTensorType(ArrayRef<Type> types) {
+  return llvm::any_of(types, [](Type t) { return t.isa<VectorType>(); }) &&
+         llvm::any_of(types, [](Type t) { return t.isa<TensorType>(); });
 }
 
 LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
@@ -194,19 +164,28 @@ LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
   auto type2 = op->getOperand(1)->getType();
   auto retType = op->getResult(0)->getType();
 
-  auto broadcastedType = util::getBroadcastedType(type1, type2);
+  // We forbid broadcasting vector and tensor.
+  if (hasBothVectorAndTensorType({type1, type2, retType}))
+    return op->emitError("cannot broadcast vector with tensor");
 
-  if (!broadcastedType)
-    return op->emitOpError("operands don't have broadcast-compatible types");
+  // Broadcasting unranked tensor with ranked/unranked tensor is allowed but
+  // the result should be unranked tensor.
+  if (type1.isa<UnrankedTensorType>() || type2.isa<UnrankedTensorType>()) {
+    if (!retType.isa<UnrankedTensorType>())
+      return op->emitError(
+          "broadcast unranked tensor should result in unranked tensor");
+    return success();
+  }
 
-  bool hasCompatRetType = (retType == broadcastedType) ||
-                          retType.isa<UnrankedTensorType>() ||
-                          isSameShapedVectorOrTensor(retType, broadcastedType);
-  if (!hasCompatRetType)
-    return op->emitOpError()
-           << "result type '" << retType
-           << "' does not have the same shape as the broadcasted type '"
-           << broadcastedType << "' computed from the operand types";
+  SmallVector<int64_t, 4> resultShape;
+  if (!util::getBroadcastedShape(getShape(type1), getShape(type2), resultShape))
+    return op->emitOpError("operands don't have broadcast-compatible shapes");
+
+  if (!retType.isa<UnrankedTensorType>() &&
+      llvm::makeArrayRef(resultShape) != getShape(retType))
+    return op->emitOpError() << "result type '" << retType
+                             << "' does not have the same shape as the one "
+                                "computed from the operand types";
 
   return success();
 }