using namespace mlir;
-/// Returns true if the given `type` supports NumPy broadcast semantics.
-/// Specifically, the given `type` must be integer type, floating point type,
-/// vector type, or ranked tensor type from integer or floating point types.
-static bool isBroadcastableType(Type type) {
- switch (type.getKind()) {
- case StandardTypes::BF16:
- case StandardTypes::F16:
- case StandardTypes::F32:
- case StandardTypes::F64:
- case StandardTypes::Integer:
- case StandardTypes::Vector:
- return true;
- case StandardTypes::RankedTensor:
- case StandardTypes::UnrankedTensor:
- return type.cast<TensorType>().getElementType().isIntOrFloat();
- default:
- break;
- }
- return false;
-}
-
bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
ArrayRef<int64_t> shape2,
SmallVectorImpl<int64_t> &resultShape) {
return true;
}
+/// Returns the shape of the given type. Scalars will be considered as having a
+/// shape with zero dimensions.
+static ArrayRef<int64_t> getShape(Type type) {
+ if (auto vtType = type.dyn_cast<VectorOrTensorType>())
+ return vtType.getShape();
+ return {};
+}
+
/// Returns the result broadcast composition type from the two given types by
/// following NumPy broadcast semantics. Returned type may have dynamic shape if
/// either of the input types has dynamic shape. Returns null type if the two
/// given types are not broadcast-compatible.
Type OpTrait::util::getBroadcastedType(Type type1, Type type2) {
- // Make sure both types are able to participate in broadcasting.
- if (!isBroadcastableType(type1) || !isBroadcastableType(type2))
- return {};
-
// Returns the scalar type out of the given type.
auto getScalarType = [](Type type) -> Type {
if (auto vtType = type.dyn_cast<VectorOrTensorType>())
resultCompositeKind = compositeKind2;
}
- // Returns the shape of the given type.
- auto getShape = [](Type type) -> ArrayRef<int64_t> {
- if (auto vtType = type.dyn_cast<VectorOrTensorType>())
- return vtType.getShape();
- return {};
- };
-
// Get the shape of each type.
SmallVector<int64_t, 4> resultShape;
if (!getBroadcastedShape(getShape(type1), getShape(type2), resultShape))
return scalarType;
}
-/// Returns true if the two given types are both vectors or ranked tensors and
-/// they have the same shape, regardless of element types.
-static bool isSameShapedVectorOrTensor(Type type1, Type type2) {
- if (auto vType1 = type1.dyn_cast<RankedTensorType>())
- if (auto vType2 = type2.dyn_cast<RankedTensorType>())
- return vType1.getShape() == vType2.getShape();
- if (auto vType1 = type1.dyn_cast<VectorType>())
- if (auto vType2 = type2.dyn_cast<VectorType>())
- return vType1.getShape() == vType2.getShape();
- return false;
+/// Returns true if the given types has both vector types and tensor types.
+static bool hasBothVectorAndTensorType(ArrayRef<Type> types) {
+ return llvm::any_of(types, [](Type t) { return t.isa<VectorType>(); }) &&
+ llvm::any_of(types, [](Type t) { return t.isa<TensorType>(); });
}
LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
auto type2 = op->getOperand(1)->getType();
auto retType = op->getResult(0)->getType();
- auto broadcastedType = util::getBroadcastedType(type1, type2);
+ // We forbid broadcasting vector and tensor.
+ if (hasBothVectorAndTensorType({type1, type2, retType}))
+ return op->emitError("cannot broadcast vector with tensor");
- if (!broadcastedType)
- return op->emitOpError("operands don't have broadcast-compatible types");
+ // Broadcasting unranked tensor with ranked/unranked tensor is allowed but
+ // the result should be unranked tensor.
+ if (type1.isa<UnrankedTensorType>() || type2.isa<UnrankedTensorType>()) {
+ if (!retType.isa<UnrankedTensorType>())
+ return op->emitError(
+ "broadcast unranked tensor should result in unranked tensor");
+ return success();
+ }
- bool hasCompatRetType = (retType == broadcastedType) ||
- retType.isa<UnrankedTensorType>() ||
- isSameShapedVectorOrTensor(retType, broadcastedType);
- if (!hasCompatRetType)
- return op->emitOpError()
- << "result type '" << retType
- << "' does not have the same shape as the broadcasted type '"
- << broadcastedType << "' computed from the operand types";
+ SmallVector<int64_t, 4> resultShape;
+ if (!util::getBroadcastedShape(getShape(type1), getShape(type2), resultShape))
+ return op->emitOpError("operands don't have broadcast-compatible shapes");
+
+ if (!retType.isa<UnrankedTensorType>() &&
+ llvm::makeArrayRef(resultShape) != getShape(retType))
+ return op->emitOpError() << "result type '" << retType
+ << "' does not have the same shape as the one "
+ "computed from the operand types";
return success();
}