static FailureOr<Operation *> vectorizeConvolution(RewriterBase &rewriter,
LinalgOp convOp);
-/// Return the given vector type if `elementType` is valid.
-static FailureOr<VectorType> getVectorType(ArrayRef<int64_t> shape,
- Type elementType) {
- if (!VectorType::isValidElementType(elementType)) {
- return failure();
- }
- return VectorType::get(shape, elementType);
-}
-
-/// Cast the given type to a vector type if its element type is valid.
-static FailureOr<VectorType> getVectorType(ShapedType type) {
- return getVectorType(type.getShape(), type.getElementType());
-}
-
/// Return the unique instance of OpType in `block` if it is indeed unique.
/// Return null if none or more than 1 instances exist.
template <typename OpType>
vector::BroadcastableToResult::Success)
return value;
Location loc = b.getInsertionPoint()->getLoc();
- return b.createOrFold<vector::BroadcastOp>(loc, targetVectorType, value);
+ return b.createOrFold<vector::BroadcastOp>(loc, targetVectorType,
+ value);
}
/// Create MultiDimReductionOp to compute the reduction for `reductionOp`. This
}
auto readType =
- getVectorType(readVecShape, getElementTypeOrSelf(opOperand->get()));
- if (!succeeded(readType))
- return failure();
+ VectorType::get(readVecShape, getElementTypeOrSelf(opOperand->get()));
SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
Operation *read = rewriter.create<vector::TransferReadOp>(
- loc, *readType, opOperand->get(), indices, readMap);
+ loc, readType, opOperand->get(), indices, readMap);
read = state.maskOperation(rewriter, read, linalgOp, maskingMap);
Value readValue = read->getResult(0);
// 3.b. If masked, set in-bounds to true. Masking guarantees that the access
// will be in-bounds.
if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
- SmallVector<bool> inBounds(readType->getRank(), true);
+ SmallVector<bool> inBounds(readType.getRank(), true);
cast<vector::TransferReadOp>(maskOp.getMaskableOp())
.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
}
if (!srcType.hasStaticShape() || !dstType.hasStaticShape())
return failure();
- auto readType = getVectorType(srcType);
- auto writeType = getVectorType(dstType);
- if (!(succeeded(readType) && succeeded(writeType)))
- return failure();
+ auto readType =
+ VectorType::get(srcType.getShape(), getElementTypeOrSelf(srcType));
+ auto writeType =
+ VectorType::get(dstType.getShape(), getElementTypeOrSelf(dstType));
Location loc = copyOp->getLoc();
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
SmallVector<Value> indices(srcType.getRank(), zero);
Value readValue = rewriter.create<vector::TransferReadOp>(
- loc, *readType, copyOp.getSource(), indices,
+ loc, readType, copyOp.getSource(), indices,
rewriter.getMultiDimIdentityMap(srcType.getRank()));
if (readValue.getType().cast<VectorType>().getRank() == 0) {
readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
- readValue =
- rewriter.create<vector::BroadcastOp>(loc, *writeType, readValue);
+ readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue);
}
Operation *writeValue = rewriter.create<vector::TransferWriteOp>(
loc, readValue, copyOp.getTarget(), indices,
auto sourceType = padOp.getSourceType();
auto resultType = padOp.getResultType();
- // Complex is not a valid vector element type.
- if (!VectorType::isValidElementType(sourceType.getElementType()))
- return failure();
-
// Copy cannot be vectorized if pad value is non-constant and source shape
// is dynamic. In case of a dynamic source shape, padding must be appended
// by TransferReadOp, but TransferReadOp supports only constant padding.
if (insertOp.getDest() == padOp.getResult())
return failure();
- auto vecType = getVectorType(padOp.getType());
- if (!succeeded(vecType))
- return failure();
- unsigned vecRank = vecType->getRank();
+ auto vecType = VectorType::get(padOp.getType().getShape(),
+ padOp.getType().getElementType());
+ unsigned vecRank = vecType.getRank();
unsigned tensorRank = insertOp.getType().getRank();
// Check if sizes match: Insert the entire tensor into most minor dims.
// (No permutations allowed.)
SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
- expectedSizes.append(vecType->getShape().begin(),
- vecType->getShape().end());
+ expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
if (!llvm::all_of(
llvm::zip(insertOp.getMixedSizes(), expectedSizes), [](auto it) {
return getConstantIntValue(std::get<0>(it)) == std::get<1>(it);
SmallVector<Value> readIndices(
vecRank, rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0));
auto read = rewriter.create<vector::TransferReadOp>(
- padOp.getLoc(), *vecType, padOp.getSource(), readIndices, padValue);
+ padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue);
// Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
// specified offsets. Write is fully in-bounds because a InsertSliceOp's
auto rhsRank = rhsShapedType.getRank();
switch (oper) {
case Conv:
- if (rhsRank != 2 && rhsRank != 3)
+ if (rhsRank != 2 && rhsRank!= 3)
return;
break;
case Pool:
Type lhsEltType = lhsShapedType.getElementType();
Type rhsEltType = rhsShapedType.getElementType();
Type resEltType = resShapedType.getElementType();
- auto lhsType = getVectorType(lhsShape, lhsEltType);
- auto rhsType = getVectorType(rhsShape, rhsEltType);
- auto resType = getVectorType(resShape, resEltType);
- if (!(succeeded(lhsType) && succeeded(rhsType) && succeeded(resType)))
- return failure();
+ auto lhsType = VectorType::get(lhsShape, lhsEltType);
+ auto rhsType = VectorType::get(rhsShape, rhsEltType);
+ auto resType = VectorType::get(resShape, resEltType);
// Read lhs slice of size {w * strideW + kw * dilationW, c, f} @ [0, 0,
// 0].
Value lhs = rewriter.create<vector::TransferReadOp>(
- loc, *lhsType, lhsShaped, ValueRange{zero, zero, zero});
+ loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
// Read rhs slice of size {kw, c, f} @ [0, 0, 0].
// This is needed only for Conv.
Value rhs = nullptr;
if (oper == Conv)
rhs = rewriter.create<vector::TransferReadOp>(
- loc, *rhsType, rhsShaped, ValueRange{zero, zero, zero});
+ loc, rhsType, rhsShaped, ValueRange{zero, zero, zero});
// Read res slice of size {n, w, f} @ [0, 0, 0].
Value res = rewriter.create<vector::TransferReadOp>(
- loc, *resType, resShaped, ValueRange{zero, zero, zero});
+ loc, resType, resShaped, ValueRange{zero, zero, zero});
// The base vectorization case is input: {n,w,c}, weight: {kw,c,f}, output:
// {n,w,f}. To reuse the base pattern vectorization case, we do pre
Type lhsEltType = lhsShapedType.getElementType();
Type rhsEltType = rhsShapedType.getElementType();
Type resEltType = resShapedType.getElementType();
- auto lhsType = getVectorType(
+ VectorType lhsType = VectorType::get(
{nSize,
// iw = ow * sw + kw * dw - 1
// (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
cSize},
lhsEltType);
- auto rhsType = getVectorType({kwSize, cSize}, rhsEltType);
- auto resType = getVectorType({nSize, wSize, cSize}, resEltType);
- if (!(succeeded(lhsType) && succeeded(rhsType) && succeeded(resType)))
- return failure();
+ VectorType rhsType = VectorType::get({kwSize, cSize}, rhsEltType);
+ VectorType resType = VectorType::get({nSize, wSize, cSize}, resEltType);
// Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
// 0].
Value lhs = rewriter.create<vector::TransferReadOp>(
- loc, *lhsType, lhsShaped, ValueRange{zero, zero, zero});
+ loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
// Read rhs slice of size {kw, c} @ [0, 0].
- Value rhs = rewriter.create<vector::TransferReadOp>(
- loc, *rhsType, rhsShaped, ValueRange{zero, zero});
+ Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
+ ValueRange{zero, zero});
// Read res slice of size {n, w, c} @ [0, 0, 0].
Value res = rewriter.create<vector::TransferReadOp>(
- loc, *resType, resShaped, ValueRange{zero, zero, zero});
+ loc, resType, resShaped, ValueRange{zero, zero, zero});
//===------------------------------------------------------------------===//
// Begin vector-only rewrite part