From: Lei Zhang Date: Thu, 9 May 2019 20:35:43 +0000 (-0700) Subject: Only forbid mixing tensor and vector when considering broadcasting behavior X-Git-Tag: llvmorg-11-init~1466^2~1784 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=b0be00c74682eeac83e3de92411319a3734339a7;p=platform%2Fupstream%2Fllvm.git Only forbid mixing tensor and vector when considering broadcasting behavior The previous approach is too restrictive; we end up forbidding all dialect-specific types as element types. Changed to not consider element types entirely. -- PiperOrigin-RevId: 247486537 --- diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp index 8988dcd..6762a60 100644 --- a/mlir/lib/Dialect/Traits.cpp +++ b/mlir/lib/Dialect/Traits.cpp @@ -21,27 +21,6 @@ using namespace mlir; -/// Returns true if the given `type` supports NumPy broadcast semantics. -/// Specifically, the given `type` must be integer type, floating point type, -/// vector type, or ranked tensor type from integer or floating point types. -static bool isBroadcastableType(Type type) { - switch (type.getKind()) { - case StandardTypes::BF16: - case StandardTypes::F16: - case StandardTypes::F32: - case StandardTypes::F64: - case StandardTypes::Integer: - case StandardTypes::Vector: - return true; - case StandardTypes::RankedTensor: - case StandardTypes::UnrankedTensor: - return type.cast().getElementType().isIntOrFloat(); - default: - break; - } - return false; -} - bool OpTrait::util::getBroadcastedShape(ArrayRef shape1, ArrayRef shape2, SmallVectorImpl &resultShape) { @@ -98,15 +77,19 @@ bool OpTrait::util::getBroadcastedShape(ArrayRef shape1, return true; } +/// Returns the shape of the given type. Scalars will be considered as having a +/// shape with zero dimensions. +static ArrayRef getShape(Type type) { + if (auto vtType = type.dyn_cast()) + return vtType.getShape(); + return {}; +} + /// Returns the result broadcast composition type from the two given types by /// following NumPy broadcast semantics. Returned type may have dynamic shape if /// either of the input types has dynamic shape. Returns null type if the two /// given types are not broadcast-compatible. Type OpTrait::util::getBroadcastedType(Type type1, Type type2) { - // Make sure both types are able to participate in broadcasting. - if (!isBroadcastableType(type1) || !isBroadcastableType(type2)) - return {}; - // Returns the scalar type out of the given type. auto getScalarType = [](Type type) -> Type { if (auto vtType = type.dyn_cast()) @@ -152,13 +135,6 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2) { resultCompositeKind = compositeKind2; } - // Returns the shape of the given type. - auto getShape = [](Type type) -> ArrayRef { - if (auto vtType = type.dyn_cast()) - return vtType.getShape(); - return {}; - }; - // Get the shape of each type. SmallVector resultShape; if (!getBroadcastedShape(getShape(type1), getShape(type2), resultShape)) @@ -172,16 +148,10 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2) { return scalarType; } -/// Returns true if the two given types are both vectors or ranked tensors and -/// they have the same shape, regardless of element types. -static bool isSameShapedVectorOrTensor(Type type1, Type type2) { - if (auto vType1 = type1.dyn_cast()) - if (auto vType2 = type2.dyn_cast()) - return vType1.getShape() == vType2.getShape(); - if (auto vType1 = type1.dyn_cast()) - if (auto vType2 = type2.dyn_cast()) - return vType1.getShape() == vType2.getShape(); - return false; +/// Returns true if the given types has both vector types and tensor types. +static bool hasBothVectorAndTensorType(ArrayRef types) { + return llvm::any_of(types, [](Type t) { return t.isa(); }) && + llvm::any_of(types, [](Type t) { return t.isa(); }); } LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) { @@ -194,19 +164,28 @@ LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) { auto type2 = op->getOperand(1)->getType(); auto retType = op->getResult(0)->getType(); - auto broadcastedType = util::getBroadcastedType(type1, type2); + // We forbid broadcasting vector and tensor. + if (hasBothVectorAndTensorType({type1, type2, retType})) + return op->emitError("cannot broadcast vector with tensor"); - if (!broadcastedType) - return op->emitOpError("operands don't have broadcast-compatible types"); + // Broadcasting unranked tensor with ranked/unranked tensor is allowed but + // the result should be unranked tensor. + if (type1.isa() || type2.isa()) { + if (!retType.isa()) + return op->emitError( + "broadcast unranked tensor should result in unranked tensor"); + return success(); + } - bool hasCompatRetType = (retType == broadcastedType) || - retType.isa() || - isSameShapedVectorOrTensor(retType, broadcastedType); - if (!hasCompatRetType) - return op->emitOpError() - << "result type '" << retType - << "' does not have the same shape as the broadcasted type '" - << broadcastedType << "' computed from the operand types"; + SmallVector resultShape; + if (!util::getBroadcastedShape(getShape(type1), getShape(type2), resultShape)) + return op->emitOpError("operands don't have broadcast-compatible shapes"); + + if (!retType.isa() && + llvm::makeArrayRef(resultShape) != getShape(retType)) + return op->emitOpError() << "result type '" << retType + << "' does not have the same shape as the one " + "computed from the operand types"; return success(); }