[mlir][sparse] Added new SparseTensorEncodingAttr::withoutOrdering factory
authorwren romano <2998727+wrengr@users.noreply.github.com>
Thu, 15 Dec 2022 22:32:58 +0000 (14:32 -0800)
committerwren romano <2998727+wrengr@users.noreply.github.com>
Fri, 16 Dec 2022 02:14:54 +0000 (18:14 -0800)
Reviewed By: aartbik, Peiming

Differential Revision: https://reviews.llvm.org/D140171

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp

index 47993b0..70c74bf 100644 (file)
@@ -169,6 +169,10 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
 
     /// Returns the type for index storage based on indexBitWidth
     Type getIndexType() const;
+
+    /// Constructs a new encoding with the dimOrdering and higherOrdering
+    /// reset to the default/identity.
+    SparseTensorEncodingAttr withoutOrdering() const;
   }];
 
   let genVerifyDecl = 1;
index 06a451f..aecde7d 100644 (file)
@@ -57,6 +57,12 @@ Type SparseTensorEncodingAttr::getIndexType() const {
   return idxWidth ? IntegerType::get(getContext(), idxWidth) : indexType;
 }
 
+SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutOrdering() const {
+  return SparseTensorEncodingAttr::get(
+      getContext(), getDimLevelType(), AffineMap(), AffineMap(),
+      getPointerBitWidth(), getIndexBitWidth());
+}
+
 Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
   if (failed(parser.parseLess()))
     return {};
index 4f4dbe4..b4d1491 100644 (file)
@@ -529,9 +529,7 @@ genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor,
   assert(elemTp == dstTp.getElementType() &&
          "reshape should not change element type");
   // Start an iterator over the source tensor (in original index order).
-  auto noPerm = SparseTensorEncodingAttr::get(
-      op->getContext(), encSrc.getDimLevelType(), AffineMap(), AffineMap(),
-      encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
+  const auto noPerm = encSrc.withoutOrdering();
   SmallVector<Value> srcDimSizes =
       getDimSizes(rewriter, loc, encSrc, srcTp, adaptor.getSrc());
   NewCallParams params(rewriter, loc);
@@ -596,9 +594,7 @@ static void genSparseCOOIterationLoop(
   Type elemTp = tensorTp.getElementType();
 
   // Start an iterator over the tensor (in original index order).
-  auto noPerm = SparseTensorEncodingAttr::get(
-      rewriter.getContext(), enc.getDimLevelType(), AffineMap(), AffineMap(),
-      enc.getPointerBitWidth(), enc.getIndexBitWidth());
+  const auto noPerm = enc.withoutOrdering();
   SmallVector<Value> dimSizes = getDimSizes(rewriter, loc, noPerm, tensorTp, t);
   Value iter = NewCallParams(rewriter, loc)
                    .genBuffers(noPerm, dimSizes, tensorTp)
@@ -1485,9 +1481,7 @@ public:
     auto encSrc = getSparseTensorEncoding(srcType);
     SmallVector<Value> dimSizes =
         getDimSizes(rewriter, loc, encSrc, srcType, src);
-    auto enc = SparseTensorEncodingAttr::get(
-        op->getContext(), encSrc.getDimLevelType(), AffineMap(), AffineMap(),
-        encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
+    const auto enc = encSrc.withoutOrdering();
     Value coo = NewCallParams(rewriter, loc)
                     .genBuffers(enc, dimSizes, srcType)
                     .genNewCall(Action::kToCOO, src);