[mlir][sparse] Factoring out NewCallParams class in SparseTensorConversion.cpp
authorwren romano <2998727+wrengr@users.noreply.github.com>
Wed, 9 Nov 2022 00:43:44 +0000 (16:43 -0800)
committerwren romano <2998727+wrengr@users.noreply.github.com>
Wed, 9 Nov 2022 01:19:54 +0000 (17:19 -0800)
The new class helps encapsulate the arguments to `_mlir_ciface_newSparseTensor` so that client code doesn't depend on the details of the API.  (This makes way for the next differential which significantly alters the API.)

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D137680

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp

index f41a579..6ca6cfc 100644 (file)
@@ -70,16 +70,6 @@ static Value genDimSizeCall(OpBuilder &builder, Location loc,
       .getResult(0);
 }
 
-/// Generates a call into the "swiss army knife" method of the sparse runtime
-/// support library for materializing sparse tensors into the computation.
-static Value genNewCall(OpBuilder &builder, Location loc,
-                        ArrayRef<Value> params) {
-  StringRef name = "newSparseTensor";
-  Type pTp = getOpaquePointerType(builder);
-  return createFuncCall(builder, loc, name, pTp, params, EmitCInterface::On)
-      .getResult(0);
-}
-
 /// Compute the size from type (for static sizes) or from an already-converted
 /// opaque pointer source (for dynamic sizes) at the given dimension.
 static Value sizeFromPtrAtDim(OpBuilder &builder, Location loc,
@@ -168,41 +158,132 @@ static Value genBuffer(OpBuilder &builder, Location loc, ValueRange values) {
   return buffer;
 }
 
-/// Populates parameters required to call the "swiss army knife" method of the
-/// sparse runtime support library for materializing sparse tensors into the
-/// computation.
-static void newParams(OpBuilder &builder, SmallVector<Value, 8> &params,
-                      Location loc, ShapedType stp,
-                      SparseTensorEncodingAttr &enc, Action action,
-                      ValueRange szs, Value ptr = Value()) {
-  ArrayRef<DimLevelType> dlt = enc.getDimLevelType();
-  unsigned sz = dlt.size();
+/// This class abstracts over the API of `_mlir_ciface_newSparseTensor`:
+/// the "swiss army knife" method of the sparse runtime support library
+/// for materializing sparse tensors into the computation.  This abstraction
+/// reduces the need to make modifications to client code whenever that
+/// API changes.
+class NewCallParams final {
+public:
+  /// Allocates the `ValueRange` for the `func::CallOp` parameters,
+  /// but does not initialize them.
+  NewCallParams(OpBuilder &builder, Location loc)
+      : builder(builder), loc(loc), pTp(getOpaquePointerType(builder)) {}
+
+  /// Initializes all static parameters (i.e., those which indicate
+  /// type-level information such as the encoding and sizes), generating
+  /// MLIR buffers as needed, and returning `this` for method chaining.
+  /// This method does not set the action and pointer arguments, since
+  /// those are handled by `genNewCall` instead.
+  NewCallParams &genBuffers(SparseTensorEncodingAttr enc, ValueRange sizes,
+                            ShapedType stp);
+
+  /// (Re)sets the C++ template type parameters, and returns `this`
+  /// for method chaining.  This is already done as part of `genBuffers`,
+  /// but is factored out so that it can also be called independently
+  /// whenever subsequent `genNewCall` calls want to reuse the same
+  /// buffers but different type parameters.
+  //
+  // TODO: This is only ever used by sparse2sparse-viaCOO `ConvertOp`;
+  // is there a better way to handle that than this one-off setter method?
+  NewCallParams &setTemplateTypes(SparseTensorEncodingAttr enc,
+                                  ShapedType stp) {
+    params[kParamPtrTp] = constantPointerTypeEncoding(builder, loc, enc);
+    params[kParamIndTp] = constantIndexTypeEncoding(builder, loc, enc);
+    params[kParamValTp] =
+        constantPrimaryTypeEncoding(builder, loc, stp.getElementType());
+    return *this;
+  }
+
+  /// Checks whether all the static parameters have been initialized.
+  bool isInitialized() const {
+    for (unsigned i = 0; i < kNumStaticParams; ++i)
+      if (!params[i])
+        return false;
+    return true;
+  }
+
+  /// Gets the dimension-to-level mapping.
+  //
+  // TODO: This is only ever used for passing into `genAddEltCall`;
+  // is there a better way to encapsulate that pattern (both to avoid
+  // this one-off getter, and to avoid potential mixups)?
+  Value getDim2LvlMap() const {
+    assert(isInitialized() && "Must initialize before getDim2LvlMap");
+    return params[kParamDim2Lvl];
+  }
+
+  /// Generates a function call, with the current static parameters
+  /// and the given dynamic arguments.
+  Value genNewCall(Action action, Value ptr = Value()) {
+    assert(isInitialized() && "Must initialize before genNewCall");
+    StringRef name = "newSparseTensor";
+    params[kParamAction] = constantAction(builder, loc, action);
+    params[kParamPtr] = ptr ? ptr : builder.create<LLVM::NullOp>(loc, pTp);
+    return createFuncCall(builder, loc, name, pTp, params, EmitCInterface::On)
+        .getResult(0);
+  }
+
+private:
+  static constexpr unsigned kNumStaticParams = 6;
+  static constexpr unsigned kNumDynamicParams = 2;
+  static constexpr unsigned kNumParams = kNumStaticParams + kNumDynamicParams;
+  static constexpr unsigned kParamLvlTypes = 0;
+  static constexpr unsigned kParamDimSizes = 1;
+  static constexpr unsigned kParamDim2Lvl = 2;
+  static constexpr unsigned kParamPtrTp = 3;
+  static constexpr unsigned kParamIndTp = 4;
+  static constexpr unsigned kParamValTp = 5;
+  static constexpr unsigned kParamAction = 6;
+  static constexpr unsigned kParamPtr = 7;
+
+  OpBuilder &builder;
+  Location loc;
+  Type pTp;
+  Value params[kNumParams];
+};
+
+// TODO: see the note at `_mlir_ciface_newSparseTensor` about how
+// the meaning of the various arguments (e.g., "sizes" vs "shapes")
+// is inconsistent between the different actions.
+NewCallParams &NewCallParams::genBuffers(SparseTensorEncodingAttr enc,
+                                         ValueRange dimSizes, ShapedType stp) {
+  const unsigned lvlRank = enc.getDimLevelType().size();
+  const unsigned dimRank = stp.getRank();
   // Sparsity annotations.
-  SmallVector<Value, 4> attrs;
-  for (unsigned i = 0; i < sz; i++)
-    attrs.push_back(constantDimLevelTypeEncoding(builder, loc, dlt[i]));
-  params.push_back(genBuffer(builder, loc, attrs));
-  // Dimension sizes array of the enveloping tensor. Useful for either
+  SmallVector<Value, 4> lvlTypes;
+  for (auto dlt : enc.getDimLevelType())
+    lvlTypes.push_back(constantDimLevelTypeEncoding(builder, loc, dlt));
+  assert(lvlTypes.size() == lvlRank && "Level-rank mismatch");
+  params[kParamLvlTypes] = genBuffer(builder, loc, lvlTypes);
+  // Dimension-sizes array of the enveloping tensor.  Useful for either
   // verification of external data, or for construction of internal data.
-  params.push_back(genBuffer(builder, loc, szs));
-  // Dimension order permutation array. This is the "identity" permutation by
-  // default, or otherwise the "reverse" permutation of a given ordering, so
-  // that indices can be mapped quickly to the right position.
-  SmallVector<Value, 4> rev(sz);
-  for (unsigned i = 0; i < sz; i++)
-    rev[toOrigDim(enc, i)] = constantIndex(builder, loc, i);
-  params.push_back(genBuffer(builder, loc, rev));
+  assert(dimSizes.size() == dimRank && "Dimension-rank mismatch");
+  params[kParamDimSizes] = genBuffer(builder, loc, dimSizes);
+  // The dimension-to-level mapping.  We must preinitialize `dim2lvl`
+  // so that the true branch below can perform random-access `operator[]`
+  // assignment.
+  SmallVector<Value, 4> dim2lvl(dimRank);
+  auto dimOrder = enc.getDimOrdering();
+  if (dimOrder) {
+    assert(dimOrder.isPermutation());
+    for (unsigned l = 0; l < lvlRank; l++) {
+      // The `d`th source variable occurs in the `l`th result position.
+      uint64_t d = dimOrder.getDimPosition(l);
+      dim2lvl[d] = constantIndex(builder, loc, l);
+    }
+  } else {
+    assert(dimRank == lvlRank && "Rank mismatch");
+    for (unsigned i = 0; i < lvlRank; i++)
+      dim2lvl[i] = constantIndex(builder, loc, i);
+  }
+  params[kParamDim2Lvl] = genBuffer(builder, loc, dim2lvl);
   // Secondary and primary types encoding.
-  Type elemTp = stp.getElementType();
-  params.push_back(constantPointerTypeEncoding(builder, loc, enc));
-  params.push_back(constantIndexTypeEncoding(builder, loc, enc));
-  params.push_back(constantPrimaryTypeEncoding(builder, loc, elemTp));
-  // User action.
-  params.push_back(constantAction(builder, loc, action));
-  // Payload pointer.
-  if (!ptr)
-    ptr = builder.create<LLVM::NullOp>(loc, getOpaquePointerType(builder));
-  params.push_back(ptr);
+  setTemplateTypes(enc, stp);
+  // Finally, make note that initialization is complete.
+  assert(isInitialized() && "Initialization failed");
+  // And return `this` for method chaining.
+  return *this;
 }
 
 /// Generates a call to obtain the values array.
@@ -387,14 +468,12 @@ genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor,
       op->getContext(), encSrc.getDimLevelType(), AffineMap(), AffineMap(),
       encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
   SmallVector<Value, 4> srcSizes;
-  SmallVector<Value, 8> params;
   sizesFromPtr(rewriter, srcSizes, loc, encSrc, srcTp, adaptor.getSrc());
-  newParams(rewriter, params, loc, srcTp, noPerm, Action::kToIterator, srcSizes,
-            adaptor.getSrc());
-  Value iter = genNewCall(rewriter, loc, params);
+  NewCallParams params(rewriter, loc);
+  Value iter = params.genBuffers(noPerm, srcSizes, srcTp)
+                   .genNewCall(Action::kToIterator, adaptor.getSrc());
   // Start a new COO for the destination tensor.
   SmallVector<Value, 4> dstSizes;
-  params.clear();
   if (dstTp.hasStaticShape()) {
     sizesFromType(rewriter, dstSizes, loc, dstTp);
   } else {
@@ -402,9 +481,9 @@ genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor,
     genReshapeDstShape(loc, rewriter, dstSizes, srcSizes, dstShape,
                        op.getReassociationIndices());
   }
-  newParams(rewriter, params, loc, dstTp, encDst, Action::kEmptyCOO, dstSizes);
-  Value coo = genNewCall(rewriter, loc, params);
-  Value dstPerm = params[2];
+  Value coo =
+      params.genBuffers(encDst, dstSizes, dstTp).genNewCall(Action::kEmptyCOO);
+  Value dstPerm = params.getDim2LvlMap();
   // Construct a while loop over the iterator.
   Value srcIdx = genAlloca(rewriter, loc, srcRank, rewriter.getIndexType());
   Value dstIdx = genAlloca(rewriter, loc, dstRank, rewriter.getIndexType());
@@ -426,9 +505,7 @@ genSparse2SparseReshape(ReshapeOp op, typename ReshapeOp::Adaptor adaptor,
   rewriter.create<scf::YieldOp>(loc);
   // Final call to construct sparse tensor storage and free temporary resources.
   rewriter.setInsertionPointAfter(whileOp);
-  params[6] = constantAction(rewriter, loc, Action::kFromCOO);
-  params[7] = coo;
-  Value dst = genNewCall(rewriter, loc, params);
+  Value dst = params.genNewCall(Action::kFromCOO, coo);
   genDelCOOCall(rewriter, loc, elemTp, coo);
   genDelIteratorCall(rewriter, loc, elemTp, iter);
   rewriter.replaceOp(op, dst);
@@ -458,11 +535,10 @@ static void genSparseCOOIterationLoop(
       rewriter.getContext(), enc.getDimLevelType(), AffineMap(), AffineMap(),
       enc.getPointerBitWidth(), enc.getIndexBitWidth());
   SmallVector<Value, 4> sizes;
-  SmallVector<Value, 8> params;
   sizesFromPtr(rewriter, sizes, loc, noPerm, tensorTp, t);
-  newParams(rewriter, params, loc, tensorTp, noPerm, Action::kToIterator, sizes,
-            t);
-  Value iter = genNewCall(rewriter, loc, params);
+  Value iter = NewCallParams(rewriter, loc)
+                   .genBuffers(noPerm, sizes, tensorTp)
+                   .genNewCall(Action::kToIterator, t);
 
   // Construct a while loop over the iterator.
   Value srcIdx = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
@@ -611,12 +687,12 @@ public:
     // Generate the call to construct tensor from ptr. The sizes are
     // inferred from the result type of the new operator.
     SmallVector<Value, 4> sizes;
-    SmallVector<Value, 8> params;
     ShapedType stp = resType.cast<ShapedType>();
     sizesFromType(rewriter, sizes, loc, stp);
     Value ptr = adaptor.getOperands()[0];
-    newParams(rewriter, params, loc, stp, enc, Action::kFromFile, sizes, ptr);
-    rewriter.replaceOp(op, genNewCall(rewriter, loc, params));
+    rewriter.replaceOp(op, NewCallParams(rewriter, loc)
+                               .genBuffers(enc, sizes, stp)
+                               .genNewCall(Action::kFromFile, ptr));
     return success();
   }
 };
@@ -650,10 +726,10 @@ public:
     }
     // Generate the call to construct empty tensor. The sizes are
     // explicitly defined by the arguments to the alloc operator.
-    SmallVector<Value, 8> params;
-    ShapedType stp = resType.cast<ShapedType>();
-    newParams(rewriter, params, loc, stp, enc, Action::kEmpty, sizes);
-    rewriter.replaceOp(op, genNewCall(rewriter, loc, params));
+    rewriter.replaceOp(op,
+                       NewCallParams(rewriter, loc)
+                           .genBuffers(enc, sizes, resType.cast<ShapedType>())
+                           .genNewCall(Action::kEmpty));
     return success();
   }
 };
@@ -690,7 +766,7 @@ public:
         return success();
       }
       SmallVector<Value, 4> sizes;
-      SmallVector<Value, 8> params;
+      NewCallParams params(rewriter, loc);
       ShapedType stp = srcType.cast<ShapedType>();
       sizesFromPtr(rewriter, sizes, loc, encSrc, stp, src);
       bool useDirectConversion;
@@ -708,9 +784,8 @@ public:
         break;
       }
       if (useDirectConversion) {
-        newParams(rewriter, params, loc, stp, encDst, Action::kSparseToSparse,
-                  sizes, src);
-        rewriter.replaceOp(op, genNewCall(rewriter, loc, params));
+        rewriter.replaceOp(op, params.genBuffers(encDst, sizes, stp)
+                                   .genNewCall(Action::kSparseToSparse, src));
       } else { // use via-COO conversion.
         // Set up encoding with right mix of src and dst so that the two
         // method calls can share most parameters, while still providing
@@ -719,13 +794,13 @@ public:
             op->getContext(), encDst.getDimLevelType(), encDst.getDimOrdering(),
             encDst.getHigherOrdering(), encSrc.getPointerBitWidth(),
             encSrc.getIndexBitWidth());
-        newParams(rewriter, params, loc, stp, enc, Action::kToCOO, sizes, src);
-        Value coo = genNewCall(rewriter, loc, params);
-        params[3] = constantPointerTypeEncoding(rewriter, loc, encDst);
-        params[4] = constantIndexTypeEncoding(rewriter, loc, encDst);
-        params[6] = constantAction(rewriter, loc, Action::kFromCOO);
-        params[7] = coo;
-        Value dst = genNewCall(rewriter, loc, params);
+        // TODO: This is the only place where `kToCOO` (or `kToIterator`)
+        // is called with a non-identity permutation.  Is there any clean
+        // way to push the permutation over to the `kFromCOO` side instead?
+        Value coo =
+            params.genBuffers(enc, sizes, stp).genNewCall(Action::kToCOO, src);
+        Value dst = params.setTemplateTypes(encDst, stp)
+                        .genNewCall(Action::kFromCOO, coo);
         genDelCOOCall(rewriter, loc, stp.getElementType(), coo);
         rewriter.replaceOp(op, dst);
       }
@@ -743,7 +818,7 @@ public:
       RankedTensorType srcTensorTp = srcType.cast<RankedTensorType>();
       unsigned rank = dstTensorTp.getRank();
       Type elemTp = dstTensorTp.getElementType();
-      // Fabricate a no-permutation encoding for newParams().
+      // Fabricate a no-permutation encoding for NewCallParams
       // The pointer/index types must be those of `src`.
       // The dimLevelTypes aren't actually used by Action::kToIterator.
       encDst = SparseTensorEncodingAttr::get(
@@ -751,11 +826,10 @@ public:
           SmallVector<DimLevelType>(rank, DimLevelType::Dense), AffineMap(),
           AffineMap(), encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
       SmallVector<Value, 4> sizes;
-      SmallVector<Value, 8> params;
       sizesFromPtr(rewriter, sizes, loc, encSrc, srcTensorTp, src);
-      newParams(rewriter, params, loc, dstTensorTp, encDst, Action::kToIterator,
-                sizes, src);
-      Value iter = genNewCall(rewriter, loc, params);
+      Value iter = NewCallParams(rewriter, loc)
+                       .genBuffers(encDst, sizes, dstTensorTp)
+                       .genNewCall(Action::kToIterator, src);
       Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
       Value elemPtr = genAllocaScalar(rewriter, loc, elemTp);
       Block *insertionBlock = rewriter.getInsertionBlock();
@@ -817,12 +891,12 @@ public:
     ShapedType stp = resType.cast<ShapedType>();
     unsigned rank = stp.getRank();
     SmallVector<Value, 4> sizes;
-    SmallVector<Value, 8> params;
     sizesFromSrc(rewriter, sizes, loc, src);
-    newParams(rewriter, params, loc, stp, encDst, Action::kEmptyCOO, sizes);
-    Value coo = genNewCall(rewriter, loc, params);
+    NewCallParams params(rewriter, loc);
+    Value coo =
+        params.genBuffers(encDst, sizes, stp).genNewCall(Action::kEmptyCOO);
     Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
-    Value perm = params[2];
+    Value perm = params.getDim2LvlMap();
     Type eltType = stp.getElementType();
     Value elemPtr = genAllocaScalar(rewriter, loc, eltType);
     genDenseTensorOrSparseConstantIterLoop(
@@ -836,9 +910,7 @@ public:
           genAddEltCall(builder, loc, eltType, coo, elemPtr, ind, perm);
         });
     // Final call to construct sparse tensor storage.
-    params[6] = constantAction(rewriter, loc, Action::kFromCOO);
-    params[7] = coo;
-    Value dst = genNewCall(rewriter, loc, params);
+    Value dst = params.genNewCall(Action::kFromCOO, coo);
     genDelCOOCall(rewriter, loc, eltType, coo);
     rewriter.replaceOp(op, dst);
     return success();
@@ -1117,15 +1189,15 @@ public:
     Value offset = constantIndex(rewriter, loc, 0);
 
     SmallVector<Value, 4> sizes;
-    SmallVector<Value, 8> params;
+    NewCallParams params(rewriter, loc);
     concatSizesFromInputs(rewriter, sizes, loc, dstTp, op.getInputs(),
                           concatDim);
 
     if (encDst) {
       // Start a new COO for the destination tensor.
-      newParams(rewriter, params, loc, dstTp, encDst, Action::kEmptyCOO, sizes);
-      dst = genNewCall(rewriter, loc, params);
-      dstPerm = params[2];
+      dst =
+          params.genBuffers(encDst, sizes, dstTp).genNewCall(Action::kEmptyCOO);
+      dstPerm = params.getDim2LvlMap();
       elemPtr = genAllocaScalar(rewriter, loc, elemTp);
       dstIdx = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
     } else {
@@ -1188,11 +1260,9 @@ public:
       offset = rewriter.create<arith::AddIOp>(loc, offset, curDim);
     }
     if (encDst) {
-      params[6] = constantAction(rewriter, loc, Action::kFromCOO);
       // In sparse output case, the destination holds the COO.
       Value coo = dst;
-      params[7] = coo;
-      dst = genNewCall(rewriter, loc, params);
+      dst = params.genNewCall(Action::kFromCOO, coo);
       // Release resources.
       genDelCOOCall(rewriter, loc, elemTp, coo);
       rewriter.replaceOp(op, dst);
@@ -1216,27 +1286,25 @@ public:
     Value src = adaptor.getOperands()[0];
     auto encSrc = getSparseTensorEncoding(srcType);
     SmallVector<Value, 4> sizes;
-    SmallVector<Value, 8> params;
     sizesFromPtr(rewriter, sizes, loc, encSrc, srcType, src);
     auto enc = SparseTensorEncodingAttr::get(
         op->getContext(), encSrc.getDimLevelType(), AffineMap(), AffineMap(),
         encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
-    newParams(rewriter, params, loc, srcType, enc, Action::kToCOO, sizes, src);
-    Value coo = genNewCall(rewriter, loc, params);
+    Value coo = NewCallParams(rewriter, loc)
+                    .genBuffers(enc, sizes, srcType)
+                    .genNewCall(Action::kToCOO, src);
     // Then output the tensor to external file with indices in the externally
     // visible lexicographic index order. A sort is required if the source was
     // not in that order yet (note that the sort can be dropped altogether if
     // external format does not care about the order at all, but here we assume
     // it does).
-    bool sort =
-        encSrc.getDimOrdering() && !encSrc.getDimOrdering().isIdentity();
-    params.clear();
-    params.push_back(coo);
-    params.push_back(adaptor.getOperands()[1]);
-    params.push_back(constantI1(rewriter, loc, sort));
+    Value sort = constantI1(rewriter, loc,
+                            encSrc.getDimOrdering() &&
+                                !encSrc.getDimOrdering().isIdentity());
+    SmallVector<Value, 3> outParams{coo, adaptor.getOperands()[1], sort};
     Type eltType = srcType.getElementType();
     SmallString<18> name{"outSparseTensor", primaryTypeFunctionSuffix(eltType)};
-    createFuncCall(rewriter, loc, name, {}, params, EmitCInterface::Off);
+    createFuncCall(rewriter, loc, name, {}, outParams, EmitCInterface::Off);
     genDelCOOCall(rewriter, loc, eltType, coo);
     rewriter.eraseOp(op);
     return success();