return getKindForOp(combinerOps[0]);
}
-/// If `value` of assumed VectorType has a shape different than `shape`, try to
-/// build and return a new vector.broadcast to `shape`.
-/// Otherwise, just return `value`.
-// TODO: this is best effort atm and there is currently no guarantee of
-// correctness for the broadcast semantics.
+/// Broadcast `value` to a vector of `shape` if possible. Return value
+/// otherwise.
static Value broadcastIfNeeded(OpBuilder &b, Value value,
ArrayRef<int64_t> shape) {
- unsigned numDimsGtOne = std::count_if(shape.begin(), shape.end(),
- [](int64_t val) { return val > 1; });
- auto vecType = value.getType().dyn_cast<VectorType>();
- if (shape.empty() ||
- (vecType != nullptr &&
- (vecType.getShape() == shape || vecType.getRank() > numDimsGtOne)))
+ // If no shape to broadcast to, just return `value`.
+ if (shape.empty())
+ return value;
+ VectorType targetVectorType =
+ VectorType::get(shape, getElementTypeOrSelf(value));
+ if (vector::isBroadcastableTo(value.getType(), targetVectorType) !=
+ vector::BroadcastableToResult::Success)
return value;
- auto newVecType = VectorType::get(shape, vecType ? vecType.getElementType()
- : value.getType());
- return b.create<vector::BroadcastOp>(b.getInsertionPoint()->getLoc(),
- newVecType, value);
+ Location loc = b.getInsertionPoint()->getLoc();
+ return b.createOrFold<vector::BroadcastOp>(loc, targetVectorType, value);
}
/// If value of assumed VectorType has a shape different than `shape`, build and
// by TransferReadOp, but TransferReadOp supports only constant padding.
auto padValue = padOp.getConstantPaddingValue();
if (!padValue) {
- if (!sourceType.hasStaticShape()) return failure();
+ if (!sourceType.hasStaticShape())
+ return failure();
// Create dummy padding value.
auto elemType = sourceType.getElementType();
padValue = rewriter.create<ConstantOp>(padOp.getLoc(), elemType,
// If `dest` is a FillOp and the TransferWriteOp would overwrite the entire
// tensor, write directly to the FillOp's operand.
- if (llvm::equal(vecShape, resultType.getShape())
- && llvm::all_of(writeInBounds, [](bool b) { return b; }))
+ if (llvm::equal(vecShape, resultType.getShape()) &&
+ llvm::all_of(writeInBounds, [](bool b) { return b; }))
if (auto fill = dest.getDefiningOp<FillOp>())
dest = fill.output();
// Generate TransferWriteOp.
- auto writeIndices = ofrToIndexValues(
- rewriter, padOp.getLoc(), padOp.getMixedLowPad());
+ auto writeIndices =
+ ofrToIndexValues(rewriter, padOp.getLoc(), padOp.getMixedLowPad());
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
padOp, read, dest, writeIndices, writeInBounds);
return success(changed);
}
- protected:
- virtual LogicalResult rewriteUser(
- PatternRewriter &rewriter, PadTensorOp padOp, OpTy op) const = 0;
+protected:
+ virtual LogicalResult rewriteUser(PatternRewriter &rewriter,
+ PadTensorOp padOp, OpTy op) const = 0;
};
/// Rewrite use of PadTensorOp result in TransferReadOp. E.g.:
/// - Single, scalar padding value.
struct PadTensorOpVectorizationWithTransferReadPattern
: public VectorizePadTensorOpUserPattern<vector::TransferReadOp> {
- using VectorizePadTensorOpUserPattern<vector::TransferReadOp>
- ::VectorizePadTensorOpUserPattern;
+ using VectorizePadTensorOpUserPattern<
+ vector::TransferReadOp>::VectorizePadTensorOpUserPattern;
LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp,
vector::TransferReadOp xferOp) const override {
// Low padding must be static 0.
- if (!padOp.hasZeroLowPad()) return failure();
+ if (!padOp.hasZeroLowPad())
+ return failure();
// Pad value must be a constant.
auto padValue = padOp.getConstantPaddingValue();
- if (!padValue) return failure();
+ if (!padValue)
+ return failure();
// Padding value of existing `xferOp` is unused.
- if (xferOp.hasOutOfBoundsDim() || xferOp.mask()) return failure();
+ if (xferOp.hasOutOfBoundsDim() || xferOp.mask())
+ return failure();
rewriter.updateRootInPlace(xferOp, [&]() {
SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
/// - Single, scalar padding value.
struct PadTensorOpVectorizationWithTransferWritePattern
: public VectorizePadTensorOpUserPattern<vector::TransferWriteOp> {
- using VectorizePadTensorOpUserPattern<vector::TransferWriteOp>
- ::VectorizePadTensorOpUserPattern;
+ using VectorizePadTensorOpUserPattern<
+ vector::TransferWriteOp>::VectorizePadTensorOpUserPattern;
LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp,
vector::TransferWriteOp xferOp) const override {
// Low padding must be static 0.
- if (!padOp.hasZeroLowPad()) return failure();
+ if (!padOp.hasZeroLowPad())
+ return failure();
// Pad value must be a constant.
auto padValue = padOp.getConstantPaddingValue();
- if (!padValue) return failure();
+ if (!padValue)
+ return failure();
// TransferWriteOp result must be directly consumed by an ExtractSliceOp.
- if (!xferOp->hasOneUse()) return failure();
+ if (!xferOp->hasOneUse())
+ return failure();
auto trimPadding = dyn_cast<tensor::ExtractSliceOp>(*xferOp->user_begin());
- if (!trimPadding) return failure();
+ if (!trimPadding)
+ return failure();
// Only static zero offsets supported when trimming padding.
- if (!trimPadding.hasZeroOffset()) return failure();
+ if (!trimPadding.hasZeroOffset())
+ return failure();
// trimPadding must remove the amount of padding that was added earlier.
- if (!hasSameTensorSize(padOp.source(), trimPadding)) return failure();
+ if (!hasSameTensorSize(padOp.source(), trimPadding))
+ return failure();
// Insert the new TransferWriteOp at position of the old TransferWriteOp.
rewriter.setInsertionPoint(xferOp);
// If the input to PadTensorOp is a CastOp, try with with both CastOp result
// and CastOp operand.
if (auto castOp = beforePadding.getDefiningOp<tensor::CastOp>())
- if (hasSameTensorSize(castOp.source(), afterTrimming)) return true;
+ if (hasSameTensorSize(castOp.source(), afterTrimming))
+ return true;
auto t1 = beforePadding.getType().dyn_cast<RankedTensorType>();
auto t2 = afterTrimming.getType().dyn_cast<RankedTensorType>();
// Only RankedTensorType supported.
- if (!t1 || !t2) return false;
+ if (!t1 || !t2)
+ return false;
// Rank of both values must be the same.
- if (t1.getRank() != t2.getRank()) return false;
+ if (t1.getRank() != t2.getRank())
+ return false;
// All static dimensions must be the same. Mixed cases (e.g., dimension
// static in `t1` but dynamic in `t2`) are not supported.
}
// Nothing more to check if all dimensions are static.
- if (t1.getNumDynamicDims() == 0) return true;
+ if (t1.getNumDynamicDims() == 0)
+ return true;
// All dynamic sizes must be the same. The only supported case at the moment
// is when `beforePadding` is an ExtractSliceOp (or a cast thereof).
assert(static_cast<size_t>(t1.getRank()) ==
beforeSlice.getMixedSizes().size());
- assert(static_cast<size_t>(t2.getRank())
- == afterTrimming.getMixedSizes().size());
+ assert(static_cast<size_t>(t2.getRank()) ==
+ afterTrimming.getMixedSizes().size());
for (unsigned i = 0; i < t1.getRank(); ++i) {
// Skip static dimensions.
- if (!t1.isDynamicDim(i)) continue;
+ if (!t1.isDynamicDim(i))
+ continue;
auto size1 = beforeSlice.getMixedSizes()[i];
auto size2 = afterTrimming.getMixedSizes()[i];
// Case 1: Same value or same constant int.
- if (isEqualConstantIntOrValue(size1, size2)) continue;
+ if (isEqualConstantIntOrValue(size1, size2))
+ continue;
// Other cases: Take a deeper look at defining ops of values.
auto v1 = size1.dyn_cast<Value>();
auto v2 = size2.dyn_cast<Value>();
- if (!v1 || !v2) return false;
+ if (!v1 || !v2)
+ return false;
// Case 2: Both values are identical AffineMinOps. (Should not happen if
// CSE is run.)
auto minOp1 = v1.getDefiningOp<AffineMinOp>();
auto minOp2 = v2.getDefiningOp<AffineMinOp>();
- if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap()
- && minOp1.operands() == minOp2.operands()) continue;
+ if (minOp1 && minOp2 && minOp1.getAffineMap() == minOp2.getAffineMap() &&
+ minOp1.operands() == minOp2.operands())
+ continue;
// Add additional cases as needed.
}
LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp,
tensor::InsertSliceOp insertOp) const override {
// Low padding must be static 0.
- if (!padOp.hasZeroLowPad()) return failure();
+ if (!padOp.hasZeroLowPad())
+ return failure();
// Only unit stride supported.
- if (!insertOp.hasUnitStride()) return failure();
+ if (!insertOp.hasUnitStride())
+ return failure();
// Pad value must be a constant.
auto padValue = padOp.getConstantPaddingValue();
if (!padValue)
void mlir::linalg::populatePadTensorOpVectorizationPatterns(
RewritePatternSet &patterns, PatternBenefit baseBenefit) {
- patterns.add<GenericPadTensorOpVectorizationPattern>(
- patterns.getContext(), baseBenefit);
+ patterns.add<GenericPadTensorOpVectorizationPattern>(patterns.getContext(),
+ baseBenefit);
// Try these specialized patterns first before resorting to the generic one.
patterns.add<PadTensorOpVectorizationWithTransferReadPattern,
PadTensorOpVectorizationWithTransferWritePattern,
// BroadcastOp
//===----------------------------------------------------------------------===//
-static LogicalResult verify(BroadcastOp op) {
- VectorType srcVectorType = op.getSourceType().dyn_cast<VectorType>();
- VectorType dstVectorType = op.getVectorType();
- // Scalar to vector broadcast is always valid. A vector
- // to vector broadcast needs some additional checking.
- if (srcVectorType) {
- int64_t srcRank = srcVectorType.getRank();
- int64_t dstRank = dstVectorType.getRank();
- if (srcRank > dstRank)
- return op.emitOpError("source rank higher than destination rank");
- // Source has an exact match or singleton value for all trailing dimensions
- // (all leading dimensions are simply duplicated).
- int64_t lead = dstRank - srcRank;
- for (int64_t r = 0; r < srcRank; ++r) {
- int64_t srcDim = srcVectorType.getDimSize(r);
- int64_t dstDim = dstVectorType.getDimSize(lead + r);
- if (srcDim != 1 && srcDim != dstDim)
- return op.emitOpError("dimension mismatch (")
- << srcDim << " vs. " << dstDim << ")";
+BroadcastableToResult
+mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
+ std::pair<int, int> *mismatchingDims) {
+ // Broadcast scalar to vector of the same element type.
+ if (srcType.isIntOrIndexOrFloat() && dstVectorType &&
+ getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType))
+ return BroadcastableToResult::Success;
+ // From now on, only vectors broadcast.
+ VectorType srcVectorType = srcType.dyn_cast<VectorType>();
+ if (!srcVectorType)
+ return BroadcastableToResult::SourceTypeNotAVector;
+
+ int64_t srcRank = srcVectorType.getRank();
+ int64_t dstRank = dstVectorType.getRank();
+ if (srcRank > dstRank)
+ return BroadcastableToResult::SourceRankHigher;
+ // Source has an exact match or singleton value for all trailing dimensions
+ // (all leading dimensions are simply duplicated).
+ int64_t lead = dstRank - srcRank;
+ for (int64_t r = 0; r < srcRank; ++r) {
+ int64_t srcDim = srcVectorType.getDimSize(r);
+ int64_t dstDim = dstVectorType.getDimSize(lead + r);
+ if (srcDim != 1 && srcDim != dstDim) {
+ if (mismatchingDims) {
+ mismatchingDims->first = srcDim;
+ mismatchingDims->second = dstDim;
+ }
+ return BroadcastableToResult::DimensionMismatch;
}
}
- return success();
+
+ return BroadcastableToResult::Success;
+}
+
+static LogicalResult verify(BroadcastOp op) {
+ std::pair<int, int> mismatchingDims;
+ BroadcastableToResult res = isBroadcastableTo(
+ op.getSourceType(), op.getVectorType(), &mismatchingDims);
+ if (res == BroadcastableToResult::Success)
+ return success();
+ if (res == BroadcastableToResult::SourceRankHigher)
+ return op.emitOpError("source rank higher than destination rank");
+ if (res == BroadcastableToResult::DimensionMismatch)
+ return op.emitOpError("dimension mismatch (")
+ << mismatchingDims.first << " vs. " << mismatchingDims.second << ")";
+ if (res == BroadcastableToResult::SourceTypeNotAVector)
+ return op.emitOpError("source type is not a vector");
+ llvm_unreachable("unexpected vector.broadcast op error");
}
OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
+ if (getSourceType() == getVectorType())
+ return source();
if (!operands[0])
return {};
auto vectorType = getVectorType();