Value sparse_tensor::genToPointers(OpBuilder &builder, Location loc,
Value tensor, uint64_t d) {
- RankedTensorType srcTp = tensor.getType().cast<RankedTensorType>();
+ RankedTensorType srcTp = getRankedTensorType(tensor);
SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(srcTp);
Type ptrTp = get1DMemRefType(getPointerOverheadType(builder, encSrc),
/*withLayout=*/false);
Value sparse_tensor::genToIndices(OpBuilder &builder, Location loc,
Value tensor, uint64_t d, uint64_t cooStart) {
- RankedTensorType srcTp = tensor.getType().cast<RankedTensorType>();
+ RankedTensorType srcTp = getRankedTensorType(tensor);
SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(srcTp);
Type indTp = get1DMemRefType(getIndexOverheadType(builder, encSrc),
/*withLayout=*/d >= cooStart);
Value sparse_tensor::genToValues(OpBuilder &builder, Location loc,
Value tensor) {
- RankedTensorType srcTp = tensor.getType().cast<RankedTensorType>();
+ RankedTensorType srcTp = getRankedTensorType(tensor);
Type valTp = get1DMemRefType(srcTp.getElementType(),
/*withLayout=*/false);
return builder.create<ToValuesOp>(loc, valTp, tensor);
!isAlloc(op.getDpsInitOperand(0), /*isZero=*/false) ||
!isZeroYield(op) || !op.getDpsInitOperand(0)->get().hasOneUse())
return failure();
- auto outputType = op.getResult(0).getType().cast<RankedTensorType>();
+ auto outputType = getRankedTensorType(op.getResult(0));
// Yielding zero on newly allocated (all-zero) sparse tensors can be
// optimized out directly (regardless of dynamic or static size).
if (getSparseTensorEncoding(outputType)) {
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value srcTensor = op.getSrc();
- auto srcTp = srcTensor.getType().template cast<RankedTensorType>();
- auto dstTp = op.getResult().getType().template cast<RankedTensorType>();
+ auto srcTp = getRankedTensorType(srcTensor);
+ auto dstTp = getRankedTensorType(op.getResult());
SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(srcTp);
SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp);
if (!encDst || !encSrc) {
return failure();
}
if (encSrc) {
- RankedTensorType rtp =
- op.getSrc().getType().template cast<RankedTensorType>();
+ auto rtp = getRankedTensorType(op.getSrc());
auto denseTp =
RankedTensorType::get(rtp.getShape(), rtp.getElementType());
auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
return success();
}
if (encDst) {
- RankedTensorType rtp =
- op.getResult().getType().template cast<RankedTensorType>();
+ auto rtp = getRankedTensorType(op.getResult());
auto denseTp =
RankedTensorType::get(rtp.getShape(), rtp.getElementType());
auto reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(),
LogicalResult matchAndRewrite(ConcatenateOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- auto dstTp = op.getType().cast<RankedTensorType>();
+ auto dstTp = getRankedTensorType(op);
uint64_t conDim = op.getDimension().getZExtValue();
SmallVector<Value> sizes;
concatSizesFromInputs(rewriter, sizes, loc, dstTp, op.getInputs(), conDim);
// CSC matrices along column).
if (!allDense && conDim == 0 && encDst.hasIdDimOrdering()) {
for (auto i : op.getInputs()) {
- auto rtp = i.getType().cast<RankedTensorType>();
+ auto rtp = getRankedTensorType(i);
auto srcEnc = getSparseTensorEncoding(rtp);
if (isAllDimOrdered(rtp) && (!srcEnc || srcEnc.hasIdDimOrdering())) {
allOrdered = true;
// Accumulates the offset. Note that only static-shaped inputs are allowed
// by concatenate op verifier, which saves us from computing the offset
// dynamically.
- int64_t d = input.getType().cast<RankedTensorType>().getShape()[conDim];
+ int64_t d = getRankedTensorType(input).getShape()[conDim];
assert(!ShapedType::isDynamic(d));
offset = rewriter.create<arith::AddIOp>(loc, offset,
constantIndex(rewriter, loc, d));
PatternRewriter &rewriter) const {
Location loc = op.getLoc();
Value src = op.getSource();
- RankedTensorType dstTp = op.getType().cast<RankedTensorType>();
+ auto dstTp = getRankedTensorType(op);
SmallVector<Value> sizes;
sizesFromSrc(rewriter, sizes, loc, src);
SmallVector<Value> dynSizes;
LogicalResult sparse2DenseRewrite(ConvertOp op,
PatternRewriter &rewriter) const {
Location loc = op->getLoc();
- RankedTensorType dstTp = op.getType().cast<RankedTensorType>();
+ RankedTensorType dstTp = getRankedTensorType(op);
Value src = op.getSource();
- RankedTensorType srcTp = src.getType().cast<RankedTensorType>();
+ RankedTensorType srcTp = getRankedTensorType(src);
SmallVector<Value> sizes;
sizesForTensor(rewriter, sizes, loc, srcTp, src);
PatternRewriter &rewriter) const {
Location loc = op->getLoc();
Value src = op.getSource();
- RankedTensorType srcTp = src.getType().cast<RankedTensorType>();
- RankedTensorType dstTp = op.getType().cast<RankedTensorType>();
+ RankedTensorType srcTp = getRankedTensorType(src);
+ RankedTensorType dstTp = getRankedTensorType(op);
SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp);
int64_t rank = dstTp.getRank();
auto loc = op.getLoc();
Value input = op.getTensor();
SmallVector<Value> reduc = op.getInitArgs();
- auto rtp = input.getType().cast<RankedTensorType>();
+ auto rtp = getRankedTensorType(input);
int64_t rank = rtp.getRank();
// Special-case: for each over a sparse constant uses its own rewriting
LogicalResult matchAndRewrite(NewOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- auto dstTp = op.getResult().getType().template cast<RankedTensorType>();
+ auto dstTp = getRankedTensorType(op.getResult());
SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp);
if (!encDst)
return failure();
Value nnz = rewriter.create<NumberOfEntriesOp>(loc, src);
// Allocate a temporary buffer for storing dimension sizes and indices.
- auto srcTp = src.getType().template cast<RankedTensorType>();
+ auto srcTp = getRankedTensorType(src);
uint64_t rank = srcTp.getRank();
Type indexTp = rewriter.getIndexType();
Value dimSizes = genAlloca(rewriter, loc, rank, indexTp);