[mlir][sparse] factoring out getRankedTensorType helper function
authorwren romano <2998727+wrengr@users.noreply.github.com>
Thu, 19 Jan 2023 03:11:48 +0000 (19:11 -0800)
committerwren romano <2998727+wrengr@users.noreply.github.com>
Sat, 21 Jan 2023 03:36:01 +0000 (19:36 -0800)
Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D142074

mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

index f47d304..a73d627 100644 (file)
@@ -565,7 +565,7 @@ Value sparse_tensor::reshapeValuesToLevels(
 
 Value sparse_tensor::genToPointers(OpBuilder &builder, Location loc,
                                    Value tensor, uint64_t d) {
-  RankedTensorType srcTp = tensor.getType().cast<RankedTensorType>();
+  RankedTensorType srcTp = getRankedTensorType(tensor);
   SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(srcTp);
   Type ptrTp = get1DMemRefType(getPointerOverheadType(builder, encSrc),
                                /*withLayout=*/false);
@@ -575,7 +575,7 @@ Value sparse_tensor::genToPointers(OpBuilder &builder, Location loc,
 
 Value sparse_tensor::genToIndices(OpBuilder &builder, Location loc,
                                   Value tensor, uint64_t d, uint64_t cooStart) {
-  RankedTensorType srcTp = tensor.getType().cast<RankedTensorType>();
+  RankedTensorType srcTp = getRankedTensorType(tensor);
   SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(srcTp);
   Type indTp = get1DMemRefType(getIndexOverheadType(builder, encSrc),
                                /*withLayout=*/d >= cooStart);
@@ -585,7 +585,7 @@ Value sparse_tensor::genToIndices(OpBuilder &builder, Location loc,
 
 Value sparse_tensor::genToValues(OpBuilder &builder, Location loc,
                                  Value tensor) {
-  RankedTensorType srcTp = tensor.getType().cast<RankedTensorType>();
+  RankedTensorType srcTp = getRankedTensorType(tensor);
   Type valTp = get1DMemRefType(srcTp.getElementType(),
                                /*withLayout=*/false);
   return builder.create<ToValuesOp>(loc, valTp, tensor);
index b07991e..8d8b0f8 100644 (file)
@@ -78,6 +78,11 @@ StringRef primaryTypeFunctionSuffix(Type elemTp);
 // Misc code generators and utilities.
 //===----------------------------------------------------------------------===//
 
+template <typename T>
+inline RankedTensorType getRankedTensorType(T t) {
+  return t.getType().template cast<RankedTensorType>();
+}
+
 /// Generates a 1-valued attribute of the given type.  This supports
 /// all the same types as `getZeroAttr`; however, unlike `getZeroAttr`,
 /// for unsupported types we raise `llvm_unreachable` rather than
index 41a4c05..88981fc 100644 (file)
@@ -82,7 +82,7 @@ void LoopEmitter::initialize(ValueRange tensors, StringAttr loopTag,
     // a scalar or 0-dimension tensors
     if (isZeroRankedTensorOrScalar(t.getType()))
       continue;
-    auto rtp = t.getType().cast<RankedTensorType>();
+    auto rtp = getRankedTensorType(t);
     auto rank = static_cast<size_t>(rtp.getRank());
     auto enc = getSparseTensorEncoding(rtp);
     // We always treat sparse output tensor as dense so that we always iterate
index 4a1a0c9..2ce29e5 100644 (file)
@@ -756,8 +756,7 @@ public:
       return failure();
     Location loc = op->getLoc();
     auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
-    RankedTensorType srcType =
-        op.getTensor().getType().cast<RankedTensorType>();
+    auto srcType = getRankedTensorType(op.getTensor());
     Type eltType = srcType.getElementType();
     Type boolType = rewriter.getIntegerType(1);
     Type idxType = rewriter.getIndexType();
index c0b2caa..22ec479 100644 (file)
@@ -268,7 +268,7 @@ public:
         !isAlloc(op.getDpsInitOperand(0), /*isZero=*/false) ||
         !isZeroYield(op) || !op.getDpsInitOperand(0)->get().hasOneUse())
       return failure();
-    auto outputType = op.getResult(0).getType().cast<RankedTensorType>();
+    auto outputType = getRankedTensorType(op.getResult(0));
     // Yielding zero on newly allocated (all-zero) sparse tensors can be
     // optimized out directly (regardless of dynamic or static size).
     if (getSparseTensorEncoding(outputType)) {
@@ -405,8 +405,8 @@ public:
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
     Value srcTensor = op.getSrc();
-    auto srcTp = srcTensor.getType().template cast<RankedTensorType>();
-    auto dstTp = op.getResult().getType().template cast<RankedTensorType>();
+    auto srcTp = getRankedTensorType(srcTensor);
+    auto dstTp = getRankedTensorType(op.getResult());
     SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(srcTp);
     SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp);
     if (!encDst || !encSrc) {
@@ -483,8 +483,7 @@ public:
       return failure();
     }
     if (encSrc) {
-      RankedTensorType rtp =
-          op.getSrc().getType().template cast<RankedTensorType>();
+      auto rtp = getRankedTensorType(op.getSrc());
       auto denseTp =
           RankedTensorType::get(rtp.getShape(), rtp.getElementType());
       auto convert = rewriter.create<ConvertOp>(loc, denseTp, op.getSrc());
@@ -492,8 +491,7 @@ public:
       return success();
     }
     if (encDst) {
-      RankedTensorType rtp =
-          op.getResult().getType().template cast<RankedTensorType>();
+      auto rtp = getRankedTensorType(op.getResult());
       auto denseTp =
           RankedTensorType::get(rtp.getShape(), rtp.getElementType());
       auto reshape = rewriter.create<ReshapeOp>(loc, denseTp, op.getSrc(),
@@ -511,7 +509,7 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
   LogicalResult matchAndRewrite(ConcatenateOp op,
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
-    auto dstTp = op.getType().cast<RankedTensorType>();
+    auto dstTp = getRankedTensorType(op);
     uint64_t conDim = op.getDimension().getZExtValue();
     SmallVector<Value> sizes;
     concatSizesFromInputs(rewriter, sizes, loc, dstTp, op.getInputs(), conDim);
@@ -547,7 +545,7 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
       // CSC matrices along column).
       if (!allDense && conDim == 0 && encDst.hasIdDimOrdering()) {
         for (auto i : op.getInputs()) {
-          auto rtp = i.getType().cast<RankedTensorType>();
+          auto rtp = getRankedTensorType(i);
           auto srcEnc = getSparseTensorEncoding(rtp);
           if (isAllDimOrdered(rtp) && (!srcEnc || srcEnc.hasIdDimOrdering())) {
             allOrdered = true;
@@ -623,7 +621,7 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
       // Accumulates the offset. Note that only static-shaped inputs are allowed
       // by concatenate op verifier, which saves us from computing the offset
       // dynamically.
-      int64_t d = input.getType().cast<RankedTensorType>().getShape()[conDim];
+      int64_t d = getRankedTensorType(input).getShape()[conDim];
       assert(!ShapedType::isDynamic(d));
       offset = rewriter.create<arith::AddIOp>(loc, offset,
                                               constantIndex(rewriter, loc, d));
@@ -699,7 +697,7 @@ private:
                                     PatternRewriter &rewriter) const {
     Location loc = op.getLoc();
     Value src = op.getSource();
-    RankedTensorType dstTp = op.getType().cast<RankedTensorType>();
+    auto dstTp = getRankedTensorType(op);
     SmallVector<Value> sizes;
     sizesFromSrc(rewriter, sizes, loc, src);
     SmallVector<Value> dynSizes;
@@ -769,9 +767,9 @@ private:
   LogicalResult sparse2DenseRewrite(ConvertOp op,
                                     PatternRewriter &rewriter) const {
     Location loc = op->getLoc();
-    RankedTensorType dstTp = op.getType().cast<RankedTensorType>();
+    RankedTensorType dstTp = getRankedTensorType(op);
     Value src = op.getSource();
-    RankedTensorType srcTp = src.getType().cast<RankedTensorType>();
+    RankedTensorType srcTp = getRankedTensorType(src);
 
     SmallVector<Value> sizes;
     sizesForTensor(rewriter, sizes, loc, srcTp, src);
@@ -808,8 +806,8 @@ private:
                                      PatternRewriter &rewriter) const {
     Location loc = op->getLoc();
     Value src = op.getSource();
-    RankedTensorType srcTp = src.getType().cast<RankedTensorType>();
-    RankedTensorType dstTp = op.getType().cast<RankedTensorType>();
+    RankedTensorType srcTp = getRankedTensorType(src);
+    RankedTensorType dstTp = getRankedTensorType(op);
     SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp);
     int64_t rank = dstTp.getRank();
 
@@ -928,7 +926,7 @@ public:
     auto loc = op.getLoc();
     Value input = op.getTensor();
     SmallVector<Value> reduc = op.getInitArgs();
-    auto rtp = input.getType().cast<RankedTensorType>();
+    auto rtp = getRankedTensorType(input);
     int64_t rank = rtp.getRank();
 
     // Special-case: for each over a sparse constant uses its own rewriting
@@ -1015,7 +1013,7 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
   LogicalResult matchAndRewrite(NewOp op,
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
-    auto dstTp = op.getResult().getType().template cast<RankedTensorType>();
+    auto dstTp = getRankedTensorType(op.getResult());
     SparseTensorEncodingAttr encDst = getSparseTensorEncoding(dstTp);
     if (!encDst)
       return failure();
@@ -1138,7 +1136,7 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
     Value nnz = rewriter.create<NumberOfEntriesOp>(loc, src);
 
     // Allocate a temporary buffer for storing dimension sizes and indices.
-    auto srcTp = src.getType().template cast<RankedTensorType>();
+    auto srcTp = getRankedTensorType(src);
     uint64_t rank = srcTp.getRank();
     Type indexTp = rewriter.getIndexType();
     Value dimSizes = genAlloca(rewriter, loc, rank, indexTp);
index 719b1c6..c31f20e 100644 (file)
@@ -1589,7 +1589,7 @@ private:
       // TODO: investigate fusing the conversion with computation,
       //       especially if it is a direct yield!
       //
-      auto srcTp = tval.getType().cast<RankedTensorType>();
+      auto srcTp = getRankedTensorType(tval);
       auto dstEnc = SparseTensorEncodingAttr::get(
           getContext(), srcEnc.getDimLevelType(),
           permute(env, env.op().getMatchingIndexingMap(t)), // new order