return failure();
}
}
-
-// For shapes that were created by some operations, we can obtain partial
-// information on the shapes and sometimes determine if they will be
-// broadcastable with that.
-struct CstrBroadcastablePartialInfo
- : public OpRewritePattern<CstrBroadcastableOp> {
- using OpRewritePattern<CstrBroadcastableOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(CstrBroadcastableOp op,
- PatternRewriter &rewriter) const override {
- SmallVector<int64_t, 6> lhsShape, rhsShape;
- if (failed(getShapeVec(op.lhs(), lhsShape)))
- return failure();
- if (failed(getShapeVec(op.rhs(), rhsShape)))
- return failure();
- if (!OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape))
- return failure();
-
- rewriter.replaceOpWithNewOp<ConstWitnessOp>(op.getOperation(), true);
- return success();
- }
-};
-
-// Scalars are always broadcastable.
-struct CstrBroadcastableScalar : public OpRewritePattern<CstrBroadcastableOp> {
- using OpRewritePattern<CstrBroadcastableOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(CstrBroadcastableOp op,
- PatternRewriter &rewriter) const override {
- SmallVector<int64_t, 6> shape;
- if (failed(getShapeVec(op.lhs(), shape)) || shape.size() > 0)
- return failure();
- if (failed(getShapeVec(op.rhs(), shape)) || shape.size() > 0)
- return failure();
-
- rewriter.replaceOpWithNewOp<ConstWitnessOp>(op.getOperation(), true);
- return success();
- }
-};
-
} // namespace
void CstrBroadcastableOp::getCanonicalizationPatterns(
// Canonicalization patterns have overlap with the considerations during
// folding in case additional shape information is inferred at some point that
// does not result in folding.
- patterns.insert<CstrBroadcastableEqOps, CstrBroadcastablePartialInfo,
- CstrBroadcastableScalar>(context);
+ patterns.insert<CstrBroadcastableEqOps>(context);
}
OpFoldResult CstrBroadcastableOp::fold(ArrayRef<Attribute> operands) {