bool getBroadcastedShape(ArrayRef<int64_t> shape1, ArrayRef<int64_t> shape2,
SmallVectorImpl<int64_t> &resultShape);
+/// Returns true if a broadcast between the 2 shapes is guaranteed to be
+/// successful and not result in an error. False does not guarantee that the
+/// shapes are not broadcastable; it might guarantee that they are not
+/// broadcastable or it might mean that this function does not have enough
+/// information to know.
+///
+/// Conceptually, this returns true if getBroadcastedShape would have returned
+/// true and vice versa, with one exception. If a dimension is unknown in both
+/// shapes, getBroadcastedShape would return true and have a result with unknown
+/// dimension, while this function will return false because it's possible for
+/// both shapes to have a dimension greater than 1 and different which would
+/// fail to broadcast.
+bool staticallyKnownBroadcastable(ArrayRef<int64_t> shape1,
+ ArrayRef<int64_t> shape2);
+
/// Returns the result broadcast composition type from the two given types by
/// following NumPy broadcast semantics. Returned type may have dynamic shape if
/// either of the input types has dynamic shape. Returns null type if the two
// CstrBroadcastableOp
//===----------------------------------------------------------------------===//
+namespace {
+// Given an input shape Value, try to obtain the shape's values.
+LogicalResult getShapeVec(Value input, SmallVectorImpl<int64_t> &shapeValues) {
+ if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
+ auto type = inputOp.arg().getType().dyn_cast<ShapedType>();
+ if (!type.hasRank())
+ return failure();
+ shapeValues = llvm::to_vector<6>(type.getShape());
+ return success();
+ } else if (auto inputOp = input.getDefiningOp<ConstShapeOp>()) {
+ shapeValues = llvm::to_vector<6>(inputOp.shape().getValues<int64_t>());
+ return success();
+ } else {
+ return failure();
+ }
+}
+
+// For shapes that were created by some operations, we can obtain partial
+// information on the shapes and sometimes determine if they will be
+// broadcastable with that.
+struct CstrBroadcastablePartialInfo
+ : public OpRewritePattern<CstrBroadcastableOp> {
+ using OpRewritePattern<CstrBroadcastableOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(CstrBroadcastableOp op,
+ PatternRewriter &rewriter) const override {
+ SmallVector<int64_t, 6> lhsShape, rhsShape;
+ if (failed(getShapeVec(op.lhs(), lhsShape)))
+ return failure();
+ if (failed(getShapeVec(op.rhs(), rhsShape)))
+ return failure();
+ if (!OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
+ return failure();
+
+ rewriter.replaceOpWithNewOp<ConstWitnessOp>(op.getOperation(), true);
+ return success();
+ }
+};
+
+// Scalars are always broadcastable.
+struct CstrBroadcastableScalar : public OpRewritePattern<CstrBroadcastableOp> {
+ using OpRewritePattern<CstrBroadcastableOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(CstrBroadcastableOp op,
+ PatternRewriter &rewriter) const override {
+ SmallVector<int64_t, 6> shape;
+ if (failed(getShapeVec(op.lhs(), shape)) || shape.size() > 0)
+ return failure();
+ if (failed(getShapeVec(op.rhs(), shape)) || shape.size() > 0)
+ return failure();
+
+ rewriter.replaceOpWithNewOp<ConstWitnessOp>(op.getOperation(), true);
+ return success();
+ }
+};
+
+} // namespace
+
void CstrBroadcastableOp::getCanonicalizationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
- // If inputs are equal, return passing witness
- patterns.insert<CstrBroadcastableEqOps>(context);
+ // Canonicalization patterns have overlap with the considerations during
+ // folding in case additional shape information is inferred at some point that
+ // does not result in folding.
+ patterns.insert<CstrBroadcastableEqOps, CstrBroadcastablePartialInfo,
+ CstrBroadcastableScalar>(context);
}
OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {
- if (!operands[0] || !operands[1])
+ // Both operands are not needed if one is a scalar.
+ if (operands[0] &&
+ operands[0].cast<DenseIntElementsAttr>().getNumElements() == 0)
+ return BoolAttr::get(true, getContext());
+ if (operands[1] &&
+ operands[1].cast<DenseIntElementsAttr>().getNumElements() == 0)
+ return BoolAttr::get(true, getContext());
+
+ if (operands[0] && operands[1]) {
+ auto lhsShape = llvm::to_vector<6>(
+ operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
+ auto rhsShape = llvm::to_vector<6>(
+ operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
+ SmallVector<int64_t, 6> resultShape;
+ if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
+ return BoolAttr::get(true, getContext());
+ }
+
+ // Lastly, see if folding can be completed based on what constraints are known
+ // on the input shapes.
+ SmallVector<int64_t, 6> lhsShape, rhsShape;
+ if (failed(getShapeVec(lhs(), lhsShape)))
return nullptr;
- auto lhsShape = llvm::to_vector<6>(
- operands[0].cast<DenseIntElementsAttr>().getValues<int64_t>());
- auto rhsShape = llvm::to_vector<6>(
- operands[1].cast<DenseIntElementsAttr>().getValues<int64_t>());
- SmallVector<int64_t, 6> resultShape;
- if (OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
+ if (failed(getShapeVec(rhs(), rhsShape)))
+ return nullptr;
+
+ if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
return BoolAttr::get(true, getContext());
// Because a failing witness result here represents an eventual assertion
// -----
// Broadcastable with non-broadcastable constant shapes is always false
-// CHECK-LABEL: func @f
-func @f() {
+// CHECK-LABEL: func @static_non_broadcastable
+func @static_non_broadcastable() {
// CHECK-NEXT: shape.const_shape
// CHECK-NEXT: shape.const_shape
// CHECK-NEXT: shape.cstr_broadcastable
return %result : !shape.size
}
+// -----
+
+// Canonicalize scalar cstr_broadcastable checks
+// CHECK-LABEL: @cstr_broadcastable_scalar
+func @cstr_broadcastable_scalar(%arg0 : tensor<?xf32>) {
+ // CHECK-NEXT: shape.const_witness true
+ // CHECK-NEXT: consume.witness
+ // CHECK-NEXT: return
+ %0 = shape.const_shape []
+ %1 = shape.shape_of %arg0 : tensor<?xf32>
+ %2 = shape.cstr_broadcastable %0, %1
+ "consume.witness"(%2) : (!shape.witness) -> ()
+ return
+}
+
+// -----
+
+// Do not canonicalize cstr_broadcastable checks with 2 unknowns
+// CHECK-LABEL: @cstr_broadcastable_unknown
+func @cstr_broadcastable_unknown(%arg0 : tensor<?xf32>, %arg1 : tensor<?xf32>) {
+ // CHECK-NEXT: shape.shape_of %arg0
+ // CHECK-NEXT: shape.shape_of %arg1
+ // CHECK-NEXT: shape.cstr_broadcastable
+ // CHECK-NEXT: consume.witness
+ // CHECK-NEXT: return
+ %0 = shape.shape_of %arg0 : tensor<?xf32>
+ %1 = shape.shape_of %arg1 : tensor<?xf32>
+ %2 = shape.cstr_broadcastable %0, %1
+ "consume.witness"(%2) : (!shape.witness) -> ()
+ return
+}
+
+// -----
+
+// Scalars are safe to broadcast to unranked sizes.
+// CHECK-LABEL: @cstr_broadcastable_scalar_unranked
+func @cstr_broadcastable_scalar_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<index>) {
+ // CHECK-NEXT: shape.const_witness true
+ // CHECK-NEXT: consume.witness
+ // CHECK-NEXT: return
+ %0 = shape.shape_of %arg1 : tensor<index>
+ %1 = shape.shape_of %arg0 : tensor<*xf32>
+ %2 = shape.cstr_broadcastable %0, %1
+ "consume.witness"(%2) : (!shape.witness) -> ()
+ return
+}