/// following NumPy broadcast semantics. Returned type may have dynamic shape if
/// either of the input types has dynamic shape. Returns null type if the two
/// given types are not broadcast-compatible.
-Type getBroadcastedType(Type type1, Type type2);
+///
+/// elementType, if specified, will be used as the element type of the
+/// broadcasted result type. Otherwise it is required that the element type of
+/// type1 and type2 is the same and this element type will be used as the
+/// resultant element type.
+Type getBroadcastedType(Type type1, Type type2, Type elementType = nullptr);
+
} // namespace util
-/// This class provides the API for ops that are known to have broadcast-
-/// compatible operand and result types. Specifically, starting from the
-/// most varying dimension, each dimension pair of the two operands' types
-/// should either be the same or one of them is one. Also, the result type
-/// should have the corresponding dimension equal to the larger one, if known.
-/// Shapes are checked partially if ranks or dimensions are not known. For
-/// example, an op with tensor<? x 2 x f32> and tensor <2 x f32> as operand
-/// types and tensor<3 x 2 x f32> as the result type is broadcast-compatible.
-///
-/// Ths trait assumes the op has two operands and one result, and it asserts
-/// if the pre-condition is not satisfied.
+/// Trait for ops that are known to have broadcast compatible operands and
+/// result types. Specifically, starting from the most varying dimension, each
+/// dimension pair of the operands' shapes should either be the same or one
+/// of them is one. Also, the results's shapes should have the corresponding
+/// dimension equal to the larger one, if known. Shapes are checked partially if
+/// ranks or dimensions are not known. For example, an op with tensor<?x2xf32>
+/// and tensor<2xf32> as operand types and tensor<5x3x2xi16> as the result
+/// type has broadcast compatible operands ns result types.
template <typename ConcreteType>
-class BroadcastableTwoOperandsOneResult
- : public TraitBase<ConcreteType, BroadcastableTwoOperandsOneResult> {
+class ResultsBroadcastableShape
+ : public TraitBase<ConcreteType, ResultsBroadcastableShape> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyCompatibleOperandBroadcast(op);
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/TypeUtilities.h"
#include "llvm/Support/FormatVariadic.h"
using namespace mlir;
/// following NumPy broadcast semantics. Returned type may have dynamic shape if
/// either of the input types has dynamic shape. Returns null type if the two
/// given types are not broadcast-compatible.
-Type OpTrait::util::getBroadcastedType(Type type1, Type type2) {
- // Returns the scalar type out of the given type.
- auto getScalarType = [](Type type) -> Type {
- if (auto shapedType = type.dyn_cast<ShapedType>())
- return shapedType.getElementType();
- return type;
- };
-
- // Make sure underlying scalar type is the same.
- auto scalarType = getScalarType(type1);
- if (scalarType != getScalarType(type2))
- return {};
+///
+/// elementType, if specified, will be used as the element type of the
+/// broadcasted result type. Otherwise it is required that the element type of
+/// type1 and type2 is the same and this element type will be used as the
+/// resultant element type.
+Type OpTrait::util::getBroadcastedType(Type type1, Type type2,
+ Type elementType) {
+ // If the elementType is not specified, then the use the common element type
+ // of the inputs or fail if there is no common element type.
+ if (!elementType) {
+ elementType = getElementTypeOrSelf(type1);
+ if (elementType != getElementTypeOrSelf(type2))
+ return {};
+ }
// If one of the types is unranked tensor, then the other type shouldn't be
// vector and the result should have unranked tensor type.
if (type1.isa<UnrankedTensorType>() || type2.isa<UnrankedTensorType>()) {
if (type1.isa<VectorType>() || type2.isa<VectorType>())
return {};
- return UnrankedTensorType::get(scalarType);
+ return UnrankedTensorType::get(elementType);
}
// Returns the type kind if the given type is a vector or ranked tensor type.
// Compose the final broadcasted type
if (resultCompositeKind == StandardTypes::Vector)
- return VectorType::get(resultShape, scalarType);
+ return VectorType::get(resultShape, elementType);
if (resultCompositeKind == StandardTypes::RankedTensor)
- return RankedTensorType::get(resultShape, scalarType);
- return scalarType;
+ return RankedTensorType::get(resultShape, elementType);
+ return elementType;
}
-/// Returns true if the given types has both vector types and tensor types.
-static bool hasBothVectorAndTensorType(ArrayRef<Type> types) {
- return llvm::any_of(types, [](Type t) { return t.isa<VectorType>(); }) &&
- llvm::any_of(types, [](Type t) { return t.isa<TensorType>(); });
+/// Returns a tuple corresponding to whether range has tensor or vector type.
+template <typename iterator_range>
+static std::tuple<bool, bool> hasTensorOrVectorType(iterator_range types) {
+ return std::make_tuple(
+ llvm::any_of(types, [](Type t) { return t.isa<TensorType>(); }),
+ llvm::any_of(types, [](Type t) { return t.isa<VectorType>(); }));
}
static bool areCompatibleShapes(ArrayRef<int64_t> shape1,
return true;
}
-LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
- assert(op->getNumOperands() == 2 &&
- "only support broadcast check on two operands");
- assert(op->getNumResults() == 1 &&
- "only support broadcast check on one result");
-
- auto type1 = op->getOperand(0).getType();
- auto type2 = op->getOperand(1).getType();
- auto retType = op->getResult(0).getType();
+static std::string getShapeString(ArrayRef<int64_t> shape) {
+ // TODO: should replace with printing shape more uniformly across here and
+ // when in type.
+ return formatv("'{0:$[x]}'", llvm::make_range(shape.begin(), shape.end()));
+}
- // We forbid broadcasting vector and tensor.
- if (hasBothVectorAndTensorType({type1, type2, retType}))
+LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
+ // Ensure broadcasting only tensor or only vector types.
+ auto operandsHasTensorVectorType =
+ hasTensorOrVectorType(op->getOperandTypes());
+ auto resultsHasTensorVectorType = hasTensorOrVectorType(op->getResultTypes());
+ if ((std::get<0>(operandsHasTensorVectorType) ||
+ std::get<0>(resultsHasTensorVectorType)) &&
+ (std::get<1>(operandsHasTensorVectorType) ||
+ std::get<1>(resultsHasTensorVectorType)))
return op->emitError("cannot broadcast vector with tensor");
- if (retType.isa<UnrankedTensorType>())
- return success();
-
- bool isUnranked1 = type1.isa<UnrankedTensorType>();
- bool isUnranked2 = type2.isa<UnrankedTensorType>();
+ auto rankedOperands = make_filter_range(
+ op->getOperandTypes(), [](Type t) { return t.isa<RankedTensorType>(); });
- // If both operands are unranked, then all result shapes are possible.
- if (isUnranked1 && isUnranked2)
+ // If all operands are unranked, then all result shapes are possible.
+ if (rankedOperands.empty())
return success();
- // If one of the operands is unranked, then the known dimensions in the result
- // should be compatible with the other shaped operand.
- if (isUnranked1 || isUnranked2) {
- // Result should have higher rank than the shaped operand's rank and then
- // the result's trailing dimensions should be compatible with the operand
- // shape.
- ArrayRef<int64_t> shape = getShape(!isUnranked1 ? type1 : type2);
- ArrayRef<int64_t> actualSuffix = getShape(retType).take_back(shape.size());
- if (!areCompatibleShapes(actualSuffix, shape))
- return op->emitOpError()
- << "result type " << retType
- << " has shape incompatible with a ranked operand type";
- return success();
+ // Compute broadcasted shape of operands (which requires that operands are
+ // broadcast compatible). The results need to be broadcast compatible with
+ // this result shape.
+ SmallVector<int64_t, 4> resultShape;
+ (void)util::getBroadcastedShape(getShape(*rankedOperands.begin()), {},
+ resultShape);
+ for (auto other : make_early_inc_range(rankedOperands)) {
+ SmallVector<int64_t, 4> temp = resultShape;
+ if (!util::getBroadcastedShape(temp, getShape(other), resultShape))
+ return op->emitOpError("operands don't have broadcast-compatible shapes");
}
- // If both operands are shaped, then the computed broadcasted shape should be
- // compatible with the result shape.
- SmallVector<int64_t, 4> resultShape;
- if (!util::getBroadcastedShape(getShape(type1), getShape(type2), resultShape))
- return op->emitOpError("operands don't have broadcast-compatible shapes");
+ auto rankedResults = make_filter_range(
+ op->getResultTypes(), [](Type t) { return t.isa<RankedTensorType>(); });
- if (!areCompatibleShapes(resultShape, getShape(retType)))
- return op->emitOpError() << "result type " << retType
- << " does not have shape compatible with the one "
- "computed from the operand types";
+ // If all of the results are unranked then no further verfication.
+ if (rankedResults.empty())
+ return success();
+ for (auto type : rankedResults) {
+ ArrayRef<int64_t> actualSuffix =
+ getShape(type).take_back(resultShape.size());
+ if (!areCompatibleShapes(actualSuffix, resultShape))
+ return op->emitOpError()
+ << "result type " << getShapeString(getShape(type))
+ << " not broadcast compatible with broadcasted operands's shapes "
+ << getShapeString(resultShape);
+ }
return success();
}
// Check incompatible result type with known dimension
func @broadcast_tensor_tensor_tensor(tensor<4x3x2xi32>, tensor<3x1xi32>) -> tensor<4x3x3xi32> {
^bb0(%arg0: tensor<4x3x2xi32>, %arg1: tensor<3x1xi32>):
- // expected-error @+1 {{does not have shape compatible with the one computed}}
+ // expected-error @+1 {{op result type '4x3x3' not broadcast compatible with broadcasted operands's shapes '4x3x2'}}
%0 = "test.broadcastable"(%arg0, %arg1) : (tensor<4x3x2xi32>, tensor<3x1xi32>) -> tensor<4x3x3xi32>
return %0 : tensor<4x3x3xi32>
}
// Check incompatible result type with known dimension
func @broadcast_tensor_tensor_tensor(tensor<8x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x6x1xi32> {
^bb0(%arg0: tensor<8x1x6x1xi32>, %arg1: tensor<7x1x5xi32>):
- // expected-error @+1 {{does not have shape compatible with the one computed}}
+ // expected-error @+1 {{op result type '8x7x6x1' not broadcast compatible with broadcasted operands's shapes '8x7x6x5'}}
%0 = "test.broadcastable"(%arg0, %arg1) : (tensor<8x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x6x1xi32>
return %0 : tensor<8x7x6x1xi32>
}
// Unranked operand and compatible ranked result
func @broadcast_tensor_tensor_tensor(tensor<3x2xi32>, tensor<*xi32>) -> tensor<4x3x2xi32> {
^bb0(%arg0: tensor<3x2xi32>, %arg1: tensor<*xi32>):
- %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<3x2xi32>, tensor<*xi32>) -> tensor<4x3x2xi32>
+ %0 = "test.broadcastable"(%arg0, %arg0, %arg1) : (tensor<3x2xi32>, tensor<3x2xi32>, tensor<*xi32>) -> tensor<4x3x2xi32>
return %0 : tensor<4x3x2xi32>
}
func @broadcast_tensor_tensor_tensor(tensor<3x2xi32>, tensor<*xi32>) -> tensor<2xi32> {
^bb0(%arg0: tensor<3x2xi32>, %arg1: tensor<*xi32>):
- // expected-error @+1 {{shape incompatible with a ranked operand type}}
+ // expected-error @+1 {{op result type '2' not broadcast compatible with broadcasted operands's shapes '3x2'}}
%0 = "test.broadcastable"(%arg0, %arg1) : (tensor<3x2xi32>, tensor<*xi32>) -> tensor<2xi32>
return %0 : tensor<2xi32>
}