}))
return failure();
- auto castOp = sliceOp.getSource().getDefiningOp<tensor::CastOp>();
+ auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
if (!castOp)
return failure();
return failure();
/// Deduce the type of the result to use for the canonicalized operation.
+ Location loc = sliceOp.getLoc();
+ auto sliceOpType = sliceOp.getType();
RankedTensorType resultType =
ExtractSliceOp::inferCanonicalRankReducedResultType(
- sliceOp.getType().getRank(), sliceOp.getSourceType(),
+ sliceOpType.getRank(), sliceOp.getSourceType(),
sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(),
sliceOp.getMixedStrides());
- Value newSlice = rewriter.create<ExtractSliceOp>(
- sliceOp.getLoc(), resultType, castOp.getSource(), sliceOp.getOffsets(),
+ Value newResult = rewriter.create<ExtractSliceOp>(
+ loc, resultType, castOp.getSource(), sliceOp.getOffsets(),
sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
- rewriter.replaceOpWithNewOp<tensor::CastOp>(sliceOp, sliceOp.getType(),
- newSlice);
+ if (newResult.getType() != sliceOpType)
+ newResult = rewriter.create<CastOp>(loc, sliceOpType, newResult);
+ rewriter.replaceOp(sliceOp, newResult);
return success();
}
};