[mlir][sparse] Moving/renaming genBuffer to allocaBuffer
authorwren romano <2998727+wrengr@users.noreply.github.com>
Wed, 14 Dec 2022 20:58:50 +0000 (12:58 -0800)
committerwren romano <2998727+wrengr@users.noreply.github.com>
Wed, 14 Dec 2022 22:29:36 +0000 (14:29 -0800)
This allows allocaBuffer to be used outside of SparseTensorConversion.cpp, which will be helpful for a some future commits.

Reviewed By: aartbik, Peiming

Differential Revision: https://reviews.llvm.org/D140047

mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp

index ede9c56..b5c82bd 100644 (file)
@@ -1151,6 +1151,18 @@ Value mlir::sparse_tensor::genAllocaScalar(OpBuilder &builder, Location loc,
   return builder.create<memref::AllocaOp>(loc, MemRefType::get({}, tp));
 }
 
+Value mlir::sparse_tensor::allocaBuffer(OpBuilder &builder, Location loc,
+                                        ValueRange values) {
+  const unsigned sz = values.size();
+  assert(sz >= 1);
+  Value buffer = genAlloca(builder, loc, sz, values[0].getType());
+  for (unsigned i = 0; i < sz; i++) {
+    Value idx = constantIndex(builder, loc, i);
+    builder.create<memref::StoreOp>(loc, values[i], buffer, idx);
+  }
+  return buffer;
+}
+
 Value mlir::sparse_tensor::allocDenseTensor(OpBuilder &builder, Location loc,
                                             RankedTensorType tensorTp,
                                             ValueRange sizes) {
index 4d5805e..a121522 100644 (file)
@@ -136,6 +136,11 @@ Value genAlloca(OpBuilder &builder, Location loc, unsigned sz, Type tp);
 /// of the given type, and returns the `memref<$tp>`.
 Value genAllocaScalar(OpBuilder &builder, Location loc, Type tp);
 
+/// Generates a temporary buffer, initializes it with the given contents,
+/// and returns it as type `memref<? x $tp>` (rather than specifying the
+/// size of the buffer).
+Value allocaBuffer(OpBuilder &builder, Location loc, ValueRange values);
+
 /// Generates code to allocate a buffer of the given type, and zero
 /// initialize it.  If the buffer type has any dynamic sizes, then the
 /// `sizes` parameter should be as filled by sizesFromPtr(); that way
index 2603b8d..4f4dbe4 100644 (file)
@@ -207,18 +207,6 @@ static Value genAlloc(RewriterBase &rewriter, Location loc, Value sz, Type tp) {
   return rewriter.create<memref::AllocOp>(loc, memTp, ValueRange{sz});
 }
 
-/// Generates a temporary buffer of the given type and given contents.
-static Value genBuffer(OpBuilder &builder, Location loc, ValueRange values) {
-  unsigned sz = values.size();
-  assert(sz >= 1);
-  Value buffer = genAlloca(builder, loc, sz, values[0].getType());
-  for (unsigned i = 0; i < sz; i++) {
-    Value idx = constantIndex(builder, loc, i);
-    builder.create<memref::StoreOp>(loc, values[i], buffer, idx);
-  }
-  return buffer;
-}
-
 /// Generates a temporary buffer for the level-types of the given encoding.
 static Value genLvlTypesBuffer(OpBuilder &builder, Location loc,
                                SparseTensorEncodingAttr enc) {
@@ -227,7 +215,7 @@ static Value genLvlTypesBuffer(OpBuilder &builder, Location loc,
   lvlTypes.reserve(dlts.size());
   for (auto dlt : dlts)
     lvlTypes.push_back(constantDimLevelTypeEncoding(builder, loc, dlt));
-  return genBuffer(builder, loc, lvlTypes);
+  return allocaBuffer(builder, loc, lvlTypes);
 }
 
 /// This class abstracts over the API of `_mlir_ciface_newSparseTensor`:
@@ -329,7 +317,7 @@ NewCallParams &NewCallParams::genBuffers(SparseTensorEncodingAttr enc,
   // Dimension-sizes array of the enveloping tensor.  Useful for either
   // verification of external data, or for construction of internal data.
   assert(dimSizes.size() == dimRank && "Dimension-rank mismatch");
-  params[kParamDimSizes] = genBuffer(builder, loc, dimSizes);
+  params[kParamDimSizes] = allocaBuffer(builder, loc, dimSizes);
   // The level-sizes array must be passed as well, since for arbitrary
   // dim2lvl mappings it cannot be trivially reconstructed at runtime.
   // For now however, since we're still assuming permutations, we will
@@ -358,10 +346,10 @@ NewCallParams &NewCallParams::genBuffers(SparseTensorEncodingAttr enc,
       lvlSizes[i] = dimSizes[i];
     }
   }
-  params[kParamLvlSizes] = genBuffer(builder, loc, lvlSizes);
-  params[kParamLvl2Dim] = genBuffer(builder, loc, lvl2dim);
+  params[kParamLvlSizes] = allocaBuffer(builder, loc, lvlSizes);
+  params[kParamLvl2Dim] = allocaBuffer(builder, loc, lvl2dim);
   params[kParamDim2Lvl] =
-      dimOrder ? genBuffer(builder, loc, dim2lvl) : params[kParamLvl2Dim];
+      dimOrder ? allocaBuffer(builder, loc, dim2lvl) : params[kParamLvl2Dim];
   // Secondary and primary types encoding.
   setTemplateTypes(enc, stp);
   // Finally, make note that initialization is complete.
@@ -780,7 +768,7 @@ public:
     // Construct the dimShape.
     const auto dimShape = stp.getShape();
     SmallVector<Value> dimShapeValues = getDimShape(rewriter, loc, stp);
-    Value dimShapeBuffer = genBuffer(rewriter, loc, dimShapeValues);
+    Value dimShapeBuffer = allocaBuffer(rewriter, loc, dimShapeValues);
     // Allocate `SparseTensorReader` and perform all initial setup that
     // does not depend on lvlSizes (nor dim2lvl, lvl2dim, etc).
     Type opaqueTp = getOpaquePointerType(rewriter);
@@ -833,9 +821,9 @@ public:
                 ? rewriter.create<memref::LoadOp>(loc, dimSizesBuffer, dim)
                 : dimShapeValues[d];
       }
-      lvlSizesBuffer = genBuffer(rewriter, loc, lvlSizeValues);
-      lvl2dimBuffer = genBuffer(rewriter, loc, lvl2dimValues);
-      dim2lvlBuffer = genBuffer(rewriter, loc, dim2lvlValues);
+      lvlSizesBuffer = allocaBuffer(rewriter, loc, lvlSizeValues);
+      lvl2dimBuffer = allocaBuffer(rewriter, loc, lvl2dimValues);
+      dim2lvlBuffer = allocaBuffer(rewriter, loc, dim2lvlValues);
     } else {
       assert(dimRank == lvlRank && "Rank mismatch");
       SmallVector<Value> iotaValues;
@@ -843,7 +831,7 @@ public:
       for (unsigned i = 0; i < lvlRank; i++)
         iotaValues.push_back(constantIndex(rewriter, loc, i));
       lvlSizesBuffer = dimSizesBuffer ? dimSizesBuffer : dimShapeBuffer;
-      dim2lvlBuffer = lvl2dimBuffer = genBuffer(rewriter, loc, iotaValues);
+      dim2lvlBuffer = lvl2dimBuffer = allocaBuffer(rewriter, loc, iotaValues);
     }
     // Use the `reader` to parse the file.
     SmallVector<Value, 8> params{