[mlir][sparse] Factoring out allocaIndices()
authorwren romano <2998727+wrengr@users.noreply.github.com>
Fri, 1 Oct 2021 00:47:43 +0000 (17:47 -0700)
committerwren romano <2998727+wrengr@users.noreply.github.com>
Fri, 1 Oct 2021 21:18:56 +0000 (14:18 -0700)
This is preliminary work towards D110790. Depends On D110882.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D110883

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp

index b501593..154fe3f 100644 (file)
@@ -295,6 +295,18 @@ static Value genIndexAndValueForSparse(ConversionPatternRewriter &rewriter,
   return rewriter.create<tensor::ExtractOp>(loc, values, ivs[0]);
 }
 
+/// Generates code to stack-allocate a `memref<?xindex>` where the `?`
+/// is the given `rank`.  This array is intended to serve as a reusable
+/// buffer for storing the indices of a single tensor element, to avoid
+/// allocation in the body of loops.
+static Value allocaIndices(ConversionPatternRewriter &rewriter, Location loc,
+                           int64_t rank) {
+  auto indexTp = rewriter.getIndexType();
+  auto memTp = MemRefType::get({ShapedType::kDynamicSize}, indexTp);
+  Value arg = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(rank));
+  return rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{arg});
+}
+
 //===----------------------------------------------------------------------===//
 // Conversion rules.
 //===----------------------------------------------------------------------===//
@@ -413,13 +425,9 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
     // loop is generated by genAddElt().
     Location loc = op->getLoc();
     ShapedType shape = resType.cast<ShapedType>();
-    auto memTp =
-        MemRefType::get({ShapedType::kDynamicSize}, rewriter.getIndexType());
     Value perm;
     Value ptr = genNewCall(rewriter, op, encDst, 2, perm);
-    Value arg = rewriter.create<ConstantOp>(
-        loc, rewriter.getIndexAttr(shape.getRank()));
-    Value ind = rewriter.create<memref::AllocaOp>(loc, memTp, ValueRange{arg});
+    Value ind = allocaIndices(rewriter, loc, shape.getRank());
     SmallVector<Value> lo;
     SmallVector<Value> hi;
     SmallVector<Value> st;