```
}];
let extraClassDeclaration = [{
- VectorType getVectorType() {
+ VectorType getSourceVectorType() {
return getVector().getType().cast<VectorType>();
}
}];
}];
let extraClassDeclaration = [{
Type getSourceType() { return getSource().getType(); }
- VectorType getVectorType() {
+ VectorType getResultVectorType() {
return getVector().getType().cast<VectorType>();
}
/// `value`, `dstShape` and `broadcastedDims` must be properly specified or
/// the helper will assert. This means:
/// 1. `dstShape` must not be empty.
- /// 2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)]
+ /// 2. `broadcastedDims` must be confined to [0 .. rank(value.getResultVectorType)]
/// 2. `dstShape` trimmed of the dimensions specified in `broadcastedDims`
// must match the `value` shape.
static Value createOrFoldBroadcastOp(
VectorType getV2VectorType() {
return getV2().getType().cast<VectorType>();
}
- VectorType getVectorType() {
+ VectorType getResultVectorType() {
return getVector().getType().cast<VectorType>();
}
}];
OpBuilder<(ins "Value":$source)>,
];
let extraClassDeclaration = [{
- VectorType getVectorType() {
+ VectorType getSourceVectorType() {
return getVector().getType().cast<VectorType>();
}
}];
];
let extraClassDeclaration = [{
static StringRef getPositionAttrStrName() { return "position"; }
- VectorType getVectorType() {
+ VectorType getSourceVectorType() {
return getVector().getType().cast<VectorType>();
}
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
? VectorType()
: (*getAcc().begin()).getType().cast<VectorType>();
}
- VectorType getVectorType() {
+ VectorType getResultVectorType() {
return getResult().getType().cast<VectorType>();
}
static constexpr StringRef getKindAttrStrName() {
static StringRef getOffsetsAttrStrName() { return "offsets"; }
static StringRef getSizesAttrStrName() { return "sizes"; }
static StringRef getStridesAttrStrName() { return "strides"; }
- VectorType getVectorType(){ return getVector().getType().cast<VectorType>(); }
+ VectorType getSourceVectorType() {
+ return getVector().getType().cast<VectorType>();
+ }
void getOffsets(SmallVectorImpl<int64_t> &results);
bool hasNonUnitStrides() {
return llvm::any_of(getStrides(), [](Attribute attr) {
OpBuilder<(ins "Value":$vector, "ArrayRef<int64_t>":$transp)>
];
let extraClassDeclaration = [{
- VectorType getVectorType() {
+ VectorType getSourceVectorType() {
return getVector().getType().cast<VectorType>();
}
- VectorType getResultType() {
+ VectorType getResultVectorType() {
return getResult().getType().cast<VectorType>();
}
void getTransp(SmallVectorImpl<int64_t> &results);
/// Return true if this is a broadcast from scalar to a 2D vector.
static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {
- return broadcastOp.getVectorType().getRank() == 2;
+ return broadcastOp.getResultVectorType().getRank() == 2;
}
/// Return true if this integer extend op can be folded into a contract op.
SmallVector<int64_t> sizes;
populateFromInt64AttrArray(op.getSizes(), sizes);
- ArrayRef<int64_t> warpVectorShape = op.getVectorType().getShape();
+ ArrayRef<int64_t> warpVectorShape = op.getSourceVectorType().getShape();
// Compute offset in vector registers. Note that the mma.sync vector registers
// are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector
assert(broadcastSupportsMMAMatrixType(op));
const char *fragType = inferFragType(op);
- auto vecType = op.getVectorType();
+ auto vecType = op.getResultVectorType();
gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(
auto loc = shuffleOp->getLoc();
auto v1Type = shuffleOp.getV1VectorType();
auto v2Type = shuffleOp.getV2VectorType();
- auto vectorType = shuffleOp.getVectorType();
+ auto vectorType = shuffleOp.getResultVectorType();
Type llvmType = typeConverter->convertType(vectorType);
auto maskArrayAttr = shuffleOp.getMask();
LogicalResult
matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto vectorType = extractEltOp.getVectorType();
+ auto vectorType = extractEltOp.getSourceVectorType();
auto llvmType = typeConverter->convertType(vectorType.getElementType());
// Bail if result type cannot be lowered.
LogicalResult
matchAndRewrite(vector::BroadcastOp castOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Type resultType = getTypeConverter()->convertType(castOp.getVectorType());
+ Type resultType =
+ getTypeConverter()->convertType(castOp.getResultVectorType());
if (!resultType)
return failure();
return success();
}
- SmallVector<Value, 4> source(castOp.getVectorType().getNumElements(),
+ SmallVector<Value, 4> source(castOp.getResultVectorType().getNumElements(),
adaptor.getSource());
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
- castOp, castOp.getVectorType(), source);
+ castOp, castOp.getResultVectorType(), source);
return success();
}
};
LogicalResult
matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto oldResultType = shuffleOp.getVectorType();
+ auto oldResultType = shuffleOp.getResultVectorType();
if (!spirv::CompositeType::isValid(oldResultType))
return failure();
Type newResultType = getTypeConverter()->convertType(oldResultType);
LogicalResult ReductionOp::verify() {
// Verify for 0-D and 1-D vector.
- int64_t rank = getVectorType().getRank();
+ int64_t rank = getSourceVectorType().getRank();
if (rank > 1)
return emitOpError("unsupported reduction rank: ") << rank;
/// Returns the mask type expected by this operation.
Type ReductionOp::getExpectedMaskType() {
- auto vecType = getVectorType();
+ auto vecType = getSourceVectorType();
return vecType.cloneWith(std::nullopt,
IntegerType::get(vecType.getContext(), /*width=*/1));
}
}
std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
- return llvm::to_vector<4>(getVectorType().getShape());
+ return llvm::to_vector<4>(getSourceVectorType().getShape());
}
namespace {
if (maskableOp.isMasked())
return failure();
- auto vectorType = reductionOp.getVectorType();
+ auto vectorType = reductionOp.getSourceVectorType();
if (vectorType.getRank() != 0 && vectorType.getDimSize(0) != 1)
return failure();
}
LogicalResult vector::ExtractElementOp::verify() {
- VectorType vectorType = getVectorType();
+ VectorType vectorType = getSourceVectorType();
if (vectorType.getRank() == 0) {
if (getPosition())
return emitOpError("expected position to be empty with 0-D vector");
LogicalResult vector::ExtractOp::verify() {
auto positionAttr = getPosition().getValue();
- if (positionAttr.size() > static_cast<unsigned>(getVectorType().getRank()))
+ if (positionAttr.size() >
+ static_cast<unsigned>(getSourceVectorType().getRank()))
return emitOpError(
"expected position attribute of rank smaller than vector rank");
for (const auto &en : llvm::enumerate(positionAttr)) {
auto attr = en.value().dyn_cast<IntegerAttr>();
if (!attr || attr.getInt() < 0 ||
- attr.getInt() >= getVectorType().getDimSize(en.index()))
+ attr.getInt() >= getSourceVectorType().getDimSize(en.index()))
return emitOpError("expected position attribute #")
<< (en.index() + 1)
<< " to be a non-negative integer smaller than the corresponding "
ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
ExtractOp e)
- : extractOp(e), vectorRank(extractOp.getVectorType().getRank()),
+ : extractOp(e), vectorRank(extractOp.getSourceVectorType().getRank()),
extractedRank(extractOp.getPosition().size()) {
assert(vectorRank >= extractedRank && "extracted pos overflow");
sentinels.reserve(vectorRank - extractedRank);
int64_t stride = 1;
for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
strides.push_back(stride);
- stride *= getDimReverse(extractOp.getVectorType(), i + destinationRank);
+ stride *=
+ getDimReverse(extractOp.getSourceVectorType(), i + destinationRank);
}
int64_t position = linearize(extractedPos, strides);
size_t lastOffset = sliceOffsets.size() - 1;
if (sliceOffsets.back() != 0 ||
extractStridedSliceOp.getType().getDimSize(lastOffset) !=
- extractStridedSliceOp.getVectorType().getDimSize(lastOffset))
+ extractStridedSliceOp.getSourceVectorType().getDimSize(lastOffset))
break;
sliceOffsets.pop_back();
}
destinationRank = vecType.getRank();
// The dimensions of the result need to be untouched by the
// extractStridedSlice op.
- if (destinationRank >
- extractStridedSliceOp.getVectorType().getRank() - sliceOffsets.size())
+ if (destinationRank > extractStridedSliceOp.getSourceVectorType().getRank() -
+ sliceOffsets.size())
return Value();
auto extractedPos = extractVector<int64_t>(extractOp.getPosition());
assert(extractedPos.size() >= sliceOffsets.size());
if (!srcVectorType)
return {};
return ::computeBroadcastedUnitDims(srcVectorType.getShape(),
- getVectorType().getShape());
+ getResultVectorType().getShape());
}
/// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
LogicalResult BroadcastOp::verify() {
std::pair<int, int> mismatchingDims;
- BroadcastableToResult res =
- isBroadcastableTo(getSourceType(), getVectorType(), &mismatchingDims);
+ BroadcastableToResult res = isBroadcastableTo(
+ getSourceType(), getResultVectorType(), &mismatchingDims);
if (res == BroadcastableToResult::Success)
return success();
if (res == BroadcastableToResult::SourceRankHigher)
}
OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
- if (getSourceType() == getVectorType())
+ if (getSourceType() == getResultVectorType())
return getSource();
if (!adaptor.getSource())
return {};
- auto vectorType = getVectorType();
+ auto vectorType = getResultVectorType();
if (adaptor.getSource().isa<IntegerAttr, FloatAttr>())
return DenseElementsAttr::get(vectorType, adaptor.getSource());
if (auto attr = adaptor.getSource().dyn_cast<SplatElementsAttr>())
auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
if (!srcBroadcast)
return failure();
- rewriter.replaceOpWithNewOp<BroadcastOp>(
- broadcastOp, broadcastOp.getVectorType(), srcBroadcast.getSource());
+ rewriter.replaceOpWithNewOp<BroadcastOp>(broadcastOp,
+ broadcastOp.getResultVectorType(),
+ srcBroadcast.getSource());
return success();
}
};
}
LogicalResult ShuffleOp::verify() {
- VectorType resultType = getVectorType();
+ VectorType resultType = getResultVectorType();
VectorType v1Type = getV1VectorType();
VectorType v2Type = getV2VectorType();
// Verify ranks.
}
}
- return DenseElementsAttr::get(getVectorType(), results);
+ return DenseElementsAttr::get(getResultVectorType(), results);
}
namespace {
Type tRHS = getOperandTypeRHS();
VectorType vLHS = getOperandVectorTypeLHS(),
vRHS = tRHS.dyn_cast<VectorType>(),
- vACC = getOperandVectorTypeACC(), vRES = getVectorType();
+ vACC = getOperandVectorTypeACC(), vRES = getResultVectorType();
if (vLHS.getRank() != 1)
return emitOpError("expected 1-d vector for operand #1");
/// Returns the mask type expected by this operation. Mostly used for
/// verification purposes. It requires the operation to be vectorized."
Type OuterProductOp::getExpectedMaskType() {
- auto vecType = this->getVectorType();
+ auto vecType = this->getResultVectorType();
return VectorType::get(vecType.getShape(),
IntegerType::get(vecType.getContext(), /*width=*/1));
}
}
LogicalResult ExtractStridedSliceOp::verify() {
- auto type = getVectorType();
+ auto type = getSourceVectorType();
auto offsets = getOffsetsAttr();
auto sizes = getSizesAttr();
auto strides = getStridesAttr();
/*halfOpen=*/false)))
return failure();
- auto resultType =
- inferStridedSliceOpResultType(getVectorType(), offsets, sizes, strides);
+ auto resultType = inferStridedSliceOpResultType(getSourceVectorType(),
+ offsets, sizes, strides);
if (getResult().getType() != resultType)
return emitOpError("expected result type to be ") << resultType;
ArrayAttr extractSizes = op.getSizes();
auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
while (insertOp) {
- if (op.getVectorType().getRank() !=
+ if (op.getSourceVectorType().getRank() !=
insertOp.getSourceVectorType().getRank())
return failure();
ArrayAttr insertOffsets = insertOp.getOffsets();
}
OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
- if (getVectorType() == getResult().getType())
+ if (getSourceVectorType() == getResult().getType())
return getVector();
if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
return getResult();
// Eliminate splat constant transpose ops.
if (auto attr = adaptor.getVector().dyn_cast_or_null<DenseElementsAttr>())
if (attr.isSplat())
- return attr.reshape(getResultType());
+ return attr.reshape(getResultVectorType());
// Eliminate identity transpose ops. This happens when the dimensions of the
// input vector remain in their original order after the transpose operation.
}
LogicalResult vector::TransposeOp::verify() {
- VectorType vectorType = getVectorType();
- VectorType resultType = getResultType();
+ VectorType vectorType = getSourceVectorType();
+ VectorType resultType = getResultVectorType();
int64_t rank = resultType.getRank();
if (vectorType.getRank() != rank)
return emitOpError("vector result rank mismatch: ") << rank;
}
std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
- return llvm::to_vector<4>(getResultType().getShape());
+ return llvm::to_vector<4>(getResultVectorType().getShape());
}
namespace {
auto srcVectorType = bcastOp.getSourceType().dyn_cast<VectorType>();
if (!srcVectorType || srcVectorType.getNumElements() == 1) {
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
- transposeOp, transposeOp.getResultType(), bcastOp.getSource());
+ transposeOp, transposeOp.getResultVectorType(), bcastOp.getSource());
return success();
}
return failure();
rewriter.replaceOpWithNewOp<vector::SplatOp>(
- transposeOp, transposeOp.getResultType(), splatOp.getInput());
+ transposeOp, transposeOp.getResultVectorType(), splatOp.getInput());
return success();
}
};
return failure();
unsigned int operandNumber = operand->getOperandNumber();
auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
- VectorType extractSrcType = extractOp.getVectorType();
+ VectorType extractSrcType = extractOp.getSourceVectorType();
Location loc = extractOp.getLoc();
// "vector.extract %v[] : vector<f32>" is an invalid op.
SmallVector<size_t> newRetIndices;
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
rewriter, warpOp, {extractOp.getVector()},
- {extractOp.getVectorType()}, newRetIndices);
+ {extractOp.getSourceVectorType()}, newRetIndices);
rewriter.setInsertionPointAfter(newWarpOp);
Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
// Extract from distributed vector.
return failure();
unsigned int operandNumber = operand->getOperandNumber();
auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
- VectorType extractSrcType = extractOp.getVectorType();
+ VectorType extractSrcType = extractOp.getSourceVectorType();
bool is0dOrVec1Extract = extractSrcType.getNumElements() == 1;
Type elType = extractSrcType.getElementType();
VectorType distributedVecType;
// vector.extract_strided_slice requires the input and output vector to have
// the same rank. Here we drop leading one dimensions from the input vector
// type to make sure we don't cause mismatch.
- VectorType oldSrcType = extractOp.getVectorType();
+ VectorType oldSrcType = extractOp.getSourceVectorType();
VectorType newSrcType = trimLeadingOneDims(oldSrcType);
if (newSrcType.getRank() == oldSrcType.getRank())
LogicalResult matchAndRewrite(vector::BroadcastOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
- VectorType dstType = op.getVectorType();
+ VectorType dstType = op.getResultVectorType();
VectorType srcType = op.getSourceType().dyn_cast<VectorType>();
Type eltType = dstType.getElementType();
auto loc = op.getLoc();
Value input = op.getVector();
- VectorType inputType = op.getVectorType();
- VectorType resType = op.getResultType();
+ VectorType inputType = op.getSourceVectorType();
+ VectorType resType = op.getResultVectorType();
// Set up convenience transposition table.
SmallVector<int64_t> transp;
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
- VectorType srcType = op.getVectorType();
+ VectorType srcType = op.getSourceVectorType();
if (srcType.getRank() != 2)
return rewriter.notifyMatchFailure(op, "Not a 2D transpose");
Value shuffled =
rewriter.create<vector::ShuffleOp>(loc, casted, casted, mask);
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(),
- shuffled);
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
+ op, op.getResultVectorType(), shuffled);
return success();
}
VectorType lhsType = op.getOperandVectorTypeLHS();
VectorType rhsType = op.getOperandTypeRHS().dyn_cast<VectorType>();
- VectorType resType = op.getVectorType();
+ VectorType resType = op.getResultVectorType();
Type eltType = resType.getElementType();
bool isInt = eltType.isa<IntegerType, IndexType>();
Value acc = (op.getAcc().empty()) ? nullptr : op.getAcc()[0];
continue;
// contractionOp can only take vector as operands.
auto srcType = broadcast.getSourceType().dyn_cast<VectorType>();
- if (!srcType || srcType.getRank() == broadcast.getVectorType().getRank())
+ if (!srcType ||
+ srcType.getRank() == broadcast.getResultVectorType().getRank())
continue;
int64_t rankDiff =
- broadcast.getVectorType().getRank() - srcType.getRank();
+ broadcast.getResultVectorType().getRank() - srcType.getRank();
bool innerDimBroadcast = false;
SmallVector<AffineExpr> originalDims;
for (const auto &dim : llvm::enumerate(srcType.getShape())) {
- if (dim.value() !=
- broadcast.getVectorType().getDimSize(rankDiff + dim.index())) {
+ if (dim.value() != broadcast.getResultVectorType().getDimSize(
+ rankDiff + dim.index())) {
innerDimBroadcast = true;
break;
}
// of non-unit size.
bool nonUnitDimReductionBroadcast = false;
for (int64_t i = 0; i < rankDiff; ++i) {
- if (broadcast.getVectorType().getDimSize(i) != 1 &&
+ if (broadcast.getResultVectorType().getDimSize(i) != 1 &&
isReductionIterator(contractOp.getIteratorTypes()
.getValue()[map.getDimPosition(i)])) {
nonUnitDimReductionBroadcast = true;
continue;
AffineMap broadcastMap =
- AffineMap::get(broadcast.getVectorType().getRank(), 0, originalDims,
- contractOp.getContext());
+ AffineMap::get(broadcast.getResultVectorType().getRank(), 0,
+ originalDims, contractOp.getContext());
map = broadcastMap.compose(map);
*operand = broadcast.getSource();
changed = true;
auto transposeOp = operand.getDefiningOp<vector::TransposeOp>();
if (transposeOp) {
transposeMaps.push_back(transposeOp.getTransp());
- srcType = transposeOp.getVectorType();
+ srcType = transposeOp.getSourceVectorType();
} else if (!matchPattern(operand, m_Constant())) {
return failure();
}
LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
PatternRewriter &rewriter) const override {
// Only support extracting scalars for now.
- if (extractOp.getVectorType().getRank() != 1)
+ if (extractOp.getSourceVectorType().getRank() != 1)
return failure();
auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
[](const APInt &val) { return !val.isOneValue(); }))
return failure();
- unsigned rank = extractOp.getVectorType().getRank();
+ unsigned rank = extractOp.getSourceVectorType().getRank();
assert(castDstLastDim % castSrcLastDim == 0);
int64_t expandRatio = castDstLastDim / castSrcLastDim;
LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
PatternRewriter &rewriter) const override {
- if (transposeOp.getResultType().getRank() == 0)
+ if (transposeOp.getResultVectorType().getRank() == 0)
return failure();
auto targetShape = getTargetShape(options, transposeOp);
if (!targetShape)
return failure();
- auto originalVectorType = transposeOp.getResultType();
+ auto originalVectorType = transposeOp.getResultVectorType();
SmallVector<int64_t> strides(targetShape->size(), 1);
Location loc = transposeOp.getLoc();
ArrayRef<int64_t> originalSize = originalVectorType.getShape();
// Check if the source vector type is supported. AVX2 patterns can only be
// applied to f32 vector types with two dimensions greater than one.
- VectorType srcType = op.getVectorType();
+ VectorType srcType = op.getSourceVectorType();
if (!srcType.getElementType().isF32())
return rewriter.notifyMatchFailure(op, "Unsupported vector element type");
// Reshape the n-D input vector with only two dimensions greater than one
// to a 2-D vector.
auto flattenedType =
- VectorType::get({n * m}, op.getVectorType().getElementType());
+ VectorType::get({n * m}, op.getSourceVectorType().getElementType());
auto reshInputType = VectorType::get({m, n}, srcType.getElementType());
auto reshInput =
ib.create<vector::ShapeCastOp>(flattenedType, op.getVector());
// We have to transpose their dimensions and retrieve its original rank
// (e.g., 1x8x1x4x1).
res = ib.create<vector::ShapeCastOp>(flattenedType, res);
- res = ib.create<vector::ShapeCastOp>(op.getResultType(), res);
+ res = ib.create<vector::ShapeCastOp>(op.getResultVectorType(), res);
rewriter.replaceOp(op, res);
return success();
};