[mlir][sparse] Refactor the code that reshapes the values buffer for annotated all...
authorbixia1 <bixia@google.com>
Wed, 11 Jan 2023 16:59:00 +0000 (08:59 -0800)
committerbixia1 <bixia@google.com>
Thu, 12 Jan 2023 00:02:46 +0000 (16:02 -0800)
Move the functionality to codegen utils for sharing with the codegen path.

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D141514

mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp

index ff53640..03d9e6b 100644 (file)
@@ -526,6 +526,38 @@ void sparse_tensor::foreachInSparseConstant(
   }
 }
 
+void sparse_tensor::storeIndices(OpBuilder &builder, Location loc,
+                                 unsigned rank, Value ind, ValueRange ivs,
+                                 unsigned offsetDim, Value offset) {
+  for (unsigned i = 0; i < rank; i++) {
+    Value idx = ivs[i];
+    if (offsetDim == i && offset)
+      idx = builder.create<arith::AddIOp>(loc, idx, offset);
+    builder.create<memref::StoreOp>(loc, idx, ind,
+                                    constantIndex(builder, loc, i));
+  }
+}
+
+Value sparse_tensor::reshapeValuesToLevels(
+    OpBuilder &builder, Location loc, SparseTensorEncodingAttr enc,
+    const SmallVectorImpl<Value> &dimSizes, Value valuesBuffer,
+    Value idxBuffer) {
+  // Use the dstIdx to store the level sizes.
+  unsigned rank = enc.getDimLevelType().size();
+  SmallVector<Value> lvlSizes;
+  for (unsigned i = 0; i < dimSizes.size(); i++)
+    lvlSizes.push_back(dimSizes[toOrigDim(enc, i)]);
+  storeIndices(builder, loc, rank, idxBuffer, lvlSizes);
+  // The memref ReshapeOp requires the sizes buffer to have a static
+  // shape.
+  idxBuffer = builder.create<memref::CastOp>(
+      loc, MemRefType::get({rank}, builder.getIndexType()), idxBuffer);
+  SmallVector<int64_t> shape(rank, ShapedType::kDynamic);
+  Type elemTp = valuesBuffer.getType().cast<MemRefType>().getElementType();
+  return builder.create<memref::ReshapeOp>(loc, MemRefType::get(shape, elemTp),
+                                           valuesBuffer, idxBuffer);
+}
+
 Value sparse_tensor::genToPointers(OpBuilder &builder, Location loc,
                                    Value tensor, uint64_t d) {
   RankedTensorType srcTp = tensor.getType().cast<RankedTensorType>();
index 1c8cad5..2ff60eb 100644 (file)
@@ -217,6 +217,20 @@ void foreachInSparseConstant(
     Location loc, RewriterBase &rewriter, SparseElementsAttr attr,
     function_ref<void(ArrayRef<Value>, Value)> callback);
 
+/// Converts the vector indices and store it into the memory pointed by
+/// `ind`, apply (optional) `offset` on `offsetDim`.
+void storeIndices(OpBuilder &builder, Location loc, unsigned rank, Value ind,
+                  ValueRange ivs, unsigned offsetDim = 0,
+                  Value offset = Value());
+
+/// Reshapes the linear values buffer for an annotated all dense sparse tensor
+/// to match the shape of the corresponding dense tensor to support direct
+/// access of the buffer through indices.
+Value reshapeValuesToLevels(OpBuilder &builder, Location loc,
+                            SparseTensorEncodingAttr enc,
+                            const SmallVectorImpl<Value> &dimSizes,
+                            Value valuesBuffer, Value idxBuffer);
+
 //===----------------------------------------------------------------------===//
 // Inlined constant generators.
 //
index bf61164..aaeb041 100644 (file)
@@ -428,20 +428,6 @@ static SmallVector<Value> loadIndices(OpBuilder &builder, Location loc,
   return ivs;
 }
 
-/// Converts the vector indices and store it into the memory pointed by
-/// `ind`, apply (optional) `offset` on `offsetDim`.
-static void storeIndices(OpBuilder &builder, Location loc, unsigned rank,
-                         Value ind, ValueRange ivs, unsigned offsetDim = 0,
-                         Value offset = Value()) {
-  for (unsigned i = 0; i < rank; i++) {
-    Value idx = ivs[i];
-    if (offsetDim == i && offset)
-      idx = builder.create<arith::AddIOp>(loc, idx, offset);
-    builder.create<memref::StoreOp>(loc, idx, ind,
-                                    constantIndex(builder, loc, i));
-  }
-}
-
 /// Inserts a value stored in `elemPtr` into a dense tensor created by
 /// allocDenseTensor().
 static void insertScalarIntoDenseTensor(OpBuilder &builder, Location loc,
@@ -1375,19 +1361,8 @@ public:
         dst = genValuesCall(rewriter, loc,
                             MemRefType::get({ShapedType::kDynamic}, elemTp),
                             {dst});
-
         // Use the dstIdx to store the level sizes.
-        SmallVector<Value> lvlSizes;
-        for (unsigned i = 0; i < sizes.size(); i++)
-          lvlSizes.push_back(sizes[toOrigDim(encDst, i)]);
-        storeIndices(rewriter, loc, rank, dstIdx, lvlSizes);
-        // The memref ReshapeOp requires the sizes buffer to have a static
-        // shape.
-        Value typedBuffer = rewriter.create<memref::CastOp>(
-            loc, MemRefType::get({rank}, rewriter.getIndexType()), dstIdx);
-        SmallVector<int64_t> shape(rank, ShapedType::kDynamic);
-        dst = rewriter.create<memref::ReshapeOp>(
-            loc, MemRefType::get(shape, elemTp), dst, typedBuffer);
+        dst = reshapeValuesToLevels(rewriter, loc, encDst, sizes, dst, dstIdx);
       } else {
         dstPerm = params.getDim2LvlMap();
         elemPtr = genAllocaScalar(rewriter, loc, elemTp);