[mlir][sparse] Add a helper class to help lowering operations with/without function...
authorPeiming Liu <peiming@google.com>
Fri, 12 May 2023 20:33:49 +0000 (20:33 +0000)
committerPeiming Liu <peiming@google.com>
Tue, 16 May 2023 17:22:44 +0000 (17:22 +0000)
Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D150477

mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp

index 9e76289..caa549e 100644 (file)
@@ -72,6 +72,72 @@ StringRef primaryTypeFunctionSuffix(Type elemTp);
 // Misc code generators and utilities.
 //===----------------------------------------------------------------------===//
 
+/// A helper class to simplify lowering operations with/without function calls.
+template <class SubClass>
+class FuncCallOrInlineGenerator {
+public:
+  FuncCallOrInlineGenerator(TypeRange retTypes, ValueRange params, bool genCall)
+      : retTypes(retTypes), params(params), genCall(genCall) {}
+
+  // The main API invoked by clients, which abstracts away the details of
+  // creating function calls from clients.
+  SmallVector<Value> genCallOrInline(OpBuilder &builder, Location loc) {
+    if (!genCall)
+      return genImplementation(retTypes, params, builder, loc);
+
+    // Looks up the function.
+    std::string funcName = getMangledFuncName();
+    ModuleOp module = getParentOpOf<ModuleOp>(builder);
+    MLIRContext *context = module.getContext();
+    auto result = SymbolRefAttr::get(context, funcName);
+    auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
+
+    if (!func) {
+      // Create the function if not already exist.
+      OpBuilder::InsertionGuard insertionGuard(builder);
+      builder.setInsertionPoint(getParentOpOf<func::FuncOp>(builder));
+      func = builder.create<func::FuncOp>(
+          loc, funcName,
+          FunctionType::get(context, params.getTypes(), retTypes));
+      func.setPrivate();
+      // Set the insertion point to the body of the function.
+      Block *entryBB = func.addEntryBlock();
+      builder.setInsertionPointToStart(entryBB);
+      ValueRange args = entryBB->getArguments();
+      // Delegates to user to generate the actually implementation.
+      SmallVector<Value> result =
+          genImplementation(retTypes, args, builder, loc);
+      builder.create<func::ReturnOp>(loc, result);
+    }
+    // Returns the CallOp result.
+    func::CallOp call = builder.create<func::CallOp>(loc, func, params);
+    return call.getResults();
+  }
+
+private:
+  template <class OpTp>
+  OpTp getParentOpOf(OpBuilder &builder) {
+    return builder.getInsertionBlock()->getParent()->getParentOfType<OpTp>();
+  }
+
+  // CRTP: get the mangled function name (only called when genCall=true).
+  std::string getMangledFuncName() {
+    return static_cast<SubClass *>(this)->getMangledFuncName();
+  }
+
+  // CRTP: Client implementation.
+  SmallVector<Value> genImplementation(TypeRange retTypes, ValueRange params,
+                                       OpBuilder &builder, Location loc) {
+    return static_cast<SubClass *>(this)->genImplementation(retTypes, params,
+                                                            builder, loc);
+  }
+
+private:
+  TypeRange retTypes; // The types of all returned results
+  ValueRange params;  // The values of all input parameters
+  bool genCall;       // Should the implemetantion be wrapped in a function
+};
+
 /// Add type casting between arith and index types when needed.
 Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy);
 
index 88f79bf..e729f72 100644 (file)
@@ -42,8 +42,6 @@ namespace {
 using FuncGeneratorType =
     function_ref<void(OpBuilder &, ModuleOp, func::FuncOp, RankedTensorType)>;
 
-static constexpr const char kInsertFuncNamePrefix[] = "_insert_";
-
 //===----------------------------------------------------------------------===//
 // Helper methods.
 //===----------------------------------------------------------------------===//
@@ -396,134 +394,102 @@ static Value genCompressed(OpBuilder &builder, Location loc,
   return ifOp2.getResult(o);
 }
 
-/// Generates code along an insertion path without the need for a "cursor".
-/// This current insertion strategy comes at the expense of some testing
-/// overhead for each insertion. The strategy will be optimized later for
-/// common insertion patterns. The current insertion strategy also assumes
-/// insertions occur in "a reasonable order" that enables building the
-/// storage scheme in an appending/inserting kind of fashion (i.e. no
-/// in-between insertions that need data movement). The implementation
-/// relies on CSE/DCE to clean up all bookkeeping that is not needed.
-///
-/// TODO: better unord/not-unique; also generalize, optimize, specialize!
-///
-static void genInsertBody(OpBuilder &builder, ModuleOp module,
-                          func::FuncOp func, RankedTensorType rtp) {
-  const OpBuilder::InsertionGuard insertionGuard(builder);
-  Block *const entryBlock = func.addEntryBlock();
-  builder.setInsertionPointToStart(entryBlock);
-  const ValueRange args = entryBlock->getArguments();
-  const Location loc = func.getLoc();
-  const SparseTensorType stt(rtp);
-  const Level lvlRank = stt.getLvlRank();
-
-  // Extract fields and coordinates from args.
-  SmallVector<Value> fields = llvm::to_vector(args.drop_back(lvlRank + 1));
-  MutSparseTensorDescriptor desc(rtp, fields);
-  const SmallVector<Value> coords =
-      llvm::to_vector(args.take_back(lvlRank + 1).drop_back());
-  Value value = args.back();
-  Value parentPos = constantZero(builder, loc, builder.getIndexType());
-  // Generate code for every level.
-  for (Level l = 0; l < lvlRank; l++) {
-    const auto dlt = stt.getLvlType(l);
-    if (isCompressedDLT(dlt)) {
-      // Create:
-      //   if (!present) {
-      //     coordinates[l].push_back(coords[l])
-      //     <update positions and prepare level l + 1>
-      //   }
-      //   positions[l] = coordinates.size() - 1
-      //   <insert @ positions[l] at next level l + 1>
-      parentPos =
-          genCompressed(builder, loc, desc, coords, value, parentPos, l);
-    } else if (isSingletonDLT(dlt)) {
-      // Create:
-      //   coordinates[l].push_back(coords[l])
-      //   positions[l] = positions[l-1]
-      //   <insert @ positions[l] at next level l + 1>
-      createPushback(builder, loc, desc, SparseTensorFieldKind::CrdMemRef, l,
-                     coords[l]);
-    } else {
-      assert(isDenseDLT(dlt));
-      // Construct the new position as:
-      //   positions[l] = size * positions[l-1] + coords[l]
-      //   <insert @ positions[l] at next level l + 1>
-      Value size = sizeFromTensorAtLvl(builder, loc, desc, l);
-      Value mult = builder.create<arith::MulIOp>(loc, size, parentPos);
-      parentPos = builder.create<arith::AddIOp>(loc, mult, coords[l]);
+/// Helper class to help lowering sparse_tensor.insert operation.
+class SparseInsertGenerator
+    : public FuncCallOrInlineGenerator<SparseInsertGenerator> {
+public:
+  SparseInsertGenerator(TensorType rtp, TypeRange retTypes, ValueRange params,
+                        bool genCall)
+      : FuncCallOrInlineGenerator(retTypes, params, genCall), rtp(rtp){};
+
+  /// Generates code along an insertion path without the need for a "cursor".
+  /// This current insertion strategy comes at the expense of some testing
+  /// overhead for each insertion. The strategy will be optimized later for
+  /// common insertion patterns. The current insertion strategy also assumes
+  /// insertions occur in "a reasonable order" that enables building the
+  /// storage scheme in an appending/inserting kind of fashion (i.e. no
+  /// in-between insertions that need data movement). The implementation
+  /// relies on CSE/DCE to clean up all bookkeeping that is not needed.
+  ///
+  /// TODO: better unord/not-unique; also generalize, optimize, specialize!
+  SmallVector<Value> genImplementation(TypeRange retTypes, ValueRange args,
+                                       OpBuilder &builder, Location loc) {
+    const SparseTensorType stt(rtp.cast<RankedTensorType>());
+    const Level lvlRank = stt.getLvlRank();
+    // Extract fields and coordinates from args.
+    SmallVector<Value> fields = llvm::to_vector(args.drop_back(lvlRank + 1));
+    MutSparseTensorDescriptor desc(stt, fields);
+    const SmallVector<Value> coords =
+        llvm::to_vector(args.take_back(lvlRank + 1).drop_back());
+    Value value = args.back();
+    Value parentPos = constantZero(builder, loc, builder.getIndexType());
+    // Generate code for every level.
+    for (Level l = 0; l < lvlRank; l++) {
+      const auto dlt = stt.getLvlType(l);
+      if (isCompressedDLT(dlt)) {
+        // Create:
+        //   if (!present) {
+        //     coordinates[l].push_back(coords[l])
+        //     <update positions and prepare level l + 1>
+        //   }
+        //   positions[l] = coordinates.size() - 1
+        //   <insert @ positions[l] at next level l + 1>
+        parentPos =
+            genCompressed(builder, loc, desc, coords, value, parentPos, l);
+      } else if (isSingletonDLT(dlt)) {
+        // Create:
+        //   coordinates[l].push_back(coords[l])
+        //   positions[l] = positions[l-1]
+        //   <insert @ positions[l] at next level l + 1>
+        createPushback(builder, loc, desc, SparseTensorFieldKind::CrdMemRef, l,
+                       coords[l]);
+      } else {
+        assert(isDenseDLT(dlt));
+        // Construct the new position as:
+        //   positions[l] = size * positions[l-1] + coords[l]
+        //   <insert @ positions[l] at next level l + 1>
+        Value size = sizeFromTensorAtLvl(builder, loc, desc, l);
+        Value mult = builder.create<arith::MulIOp>(loc, size, parentPos);
+        parentPos = builder.create<arith::AddIOp>(loc, mult, coords[l]);
+      }
     }
+    // Reached the actual value append/insert.
+    if (!stt.isDenseLvl(lvlRank - 1))
+      createPushback(builder, loc, desc, SparseTensorFieldKind::ValMemRef,
+                     std::nullopt, value);
+    else
+      genStore(builder, loc, value, desc.getValMemRef(), parentPos);
+    return fields;
   }
-  // Reached the actual value append/insert.
-  if (!stt.isDenseLvl(lvlRank - 1))
-    createPushback(builder, loc, desc, SparseTensorFieldKind::ValMemRef,
-                   std::nullopt, value);
-  else
-    genStore(builder, loc, value, desc.getValMemRef(), parentPos);
-  builder.create<func::ReturnOp>(loc, fields);
-}
 
-/// Generates a call to a function to perform an insertion operation. If the
-/// function doesn't exist yet, call `createFunc` to generate the function.
-static void genInsertionCallHelper(OpBuilder &builder,
-                                   MutSparseTensorDescriptor desc,
-                                   SmallVectorImpl<Value> &lcvs, Value value,
-                                   func::FuncOp insertPoint,
-                                   StringRef namePrefix,
-                                   FuncGeneratorType createFunc) {
-  // The mangled name of the function has this format:
-  //   <namePrefix>_<DLT>_<shape>_<ordering>_<eltType>_<crdWidth>_<posWidth>
-  const SparseTensorType stt(desc.getRankedTensorType());
-  SmallString<32> nameBuffer;
-  llvm::raw_svector_ostream nameOstream(nameBuffer);
-  nameOstream << namePrefix;
-  const Level lvlRank = stt.getLvlRank();
-  assert(lcvs.size() == static_cast<size_t>(lvlRank));
-  for (Level l = 0; l < lvlRank; l++)
-    nameOstream << toMLIRString(stt.getLvlType(l)) << "_";
-  // Static dim sizes are used in the generated code while dynamic sizes are
-  // loaded from the dimSizes buffer. This is the reason for adding the shape
-  // to the function name.
-  for (const auto sh : stt.getDimShape())
-    nameOstream << sh << "_";
-  // Permutation information is also used in generating insertion.
-  if (!stt.isIdentity())
-    nameOstream << stt.getDimToLvlMap() << "_";
-  nameOstream << stt.getElementType() << "_";
-  nameOstream << stt.getCrdWidth() << "_" << stt.getPosWidth();
-
-  // Look up the function.
-  ModuleOp module = insertPoint->getParentOfType<ModuleOp>();
-  MLIRContext *context = module.getContext();
-  auto result = SymbolRefAttr::get(context, nameOstream.str());
-  auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
-
-  // Construct operands: fields, coords, and value.
-  SmallVector<Value> operands = llvm::to_vector(desc.getFields());
-  operands.append(lcvs);
-  operands.push_back(value);
-  Location loc = insertPoint.getLoc();
-
-  if (!func) {
-    // Create the function.
-    OpBuilder::InsertionGuard insertionGuard(builder);
-    builder.setInsertionPoint(insertPoint);
-
-    func = builder.create<func::FuncOp>(
-        loc, nameOstream.str(),
-        FunctionType::get(context, ValueRange(operands).getTypes(),
-                          ValueRange(desc.getFields()).getTypes()));
-    func.setPrivate();
-    createFunc(builder, module, func, stt);
+  std::string getMangledFuncName() {
+    // The mangled name of the function has this format:
+    //   <namePrefix>_<DLT>_<shape>_<ordering>_<eltType>_<crdWidth>_<posWidth>
+    constexpr const char kInsertFuncNamePrefix[] = "_insert_";
+    const SparseTensorType stt(rtp.cast<RankedTensorType>());
+
+    SmallString<32> nameBuffer;
+    llvm::raw_svector_ostream nameOstream(nameBuffer);
+    nameOstream << kInsertFuncNamePrefix;
+    const Level lvlRank = stt.getLvlRank();
+    for (Level l = 0; l < lvlRank; l++)
+      nameOstream << toMLIRString(stt.getLvlType(l)) << "_";
+    // Static dim sizes are used in the generated code while dynamic sizes are
+    // loaded from the dimSizes buffer. This is the reason for adding the shape
+    // to the function name.
+    for (const auto sh : stt.getDimShape())
+      nameOstream << sh << "_";
+    // Permutation information is also used in generating insertion.
+    if (!stt.isIdentity())
+      nameOstream << stt.getDimToLvlMap() << "_";
+    nameOstream << stt.getElementType() << "_";
+    nameOstream << stt.getCrdWidth() << "_" << stt.getPosWidth();
+    return nameOstream.str().str();
   }
 
-  // Generate a call to perform the insertion and update `fields` with values
-  // returned from the call.
-  func::CallOp call = builder.create<func::CallOp>(loc, func, operands);
-  for (size_t i = 0, e = desc.getNumFields(); i < e; i++) {
-    desc.getFields()[i] = call.getResult(i);
-  }
-}
+private:
+  TensorType rtp;
+};
 
 /// Generations insertion finalization code.
 static void genEndInsert(OpBuilder &builder, Location loc,
@@ -936,8 +902,7 @@ public:
     Value count = adaptor.getCount();
     const SparseTensorType dstType(desc.getRankedTensorType());
     Type eltType = dstType.getElementType();
-    // Prepare level-coords.
-    SmallVector<Value> lcvs(adaptor.getLvlCoords());
+
     // If the innermost level is ordered, we need to sort the coordinates
     // in the "added" array prior to applying the compression.
     if (dstType.isOrderedLvl(dstType.getLvlRank() - 1))
@@ -960,16 +925,22 @@ public:
     //    }
     scf::ForOp loop = createFor(rewriter, loc, count, desc.getFields());
     Value i = loop.getInductionVar();
+
     Value crd = genLoad(rewriter, loc, added, i);
     Value value = genLoad(rewriter, loc, values, crd);
-    lcvs.push_back(crd);
-    // TODO: faster for subsequent insertions?
-    auto insertPoint = op->template getParentOfType<func::FuncOp>();
-    genInsertionCallHelper(rewriter, desc, lcvs, value, insertPoint,
-                           kInsertFuncNamePrefix, genInsertBody);
+    SmallVector<Value> params(desc.getFields().begin(), desc.getFields().end());
+    SmallVector<Type> flatSpTensorTps = llvm::to_vector(
+        llvm::map_range(desc.getFields(), [](Value v) { return v.getType(); }));
+    params.append(adaptor.getLvlCoords().begin(), adaptor.getLvlCoords().end());
+    params.push_back(crd);
+    params.push_back(value);
+    SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps,
+                                    params, /*genCall=*/true);
+    SmallVector<Value> insertRet = insertGen.genCallOrInline(rewriter, loc);
     genStore(rewriter, loc, constantZero(rewriter, loc, eltType), values, crd);
     genStore(rewriter, loc, constantI1(rewriter, loc, false), filled, crd);
-    rewriter.create<scf::YieldOp>(loc, desc.getFields());
+    rewriter.create<scf::YieldOp>(loc, insertRet);
+
     rewriter.setInsertionPointAfter(loop);
     Value result = genTuple(rewriter, loc, dstType, loop->getResults());
     // Deallocate the buffers on exit of the full loop nest.
@@ -991,17 +962,18 @@ public:
   LogicalResult
   matchAndRewrite(InsertOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    SmallVector<Value> fields;
-    auto desc = getMutDescriptorFromTensorTuple(adaptor.getTensor(), fields);
-    SmallVector<Value> lcvs(adaptor.getLvlCoords());
-    // Generate insertion.
-    Value value = adaptor.getValue();
-    auto insertPoint = op->template getParentOfType<func::FuncOp>();
-    genInsertionCallHelper(rewriter, desc, lcvs, value, insertPoint,
-                           kInsertFuncNamePrefix, genInsertBody);
-
+    Location loc = op.getLoc();
+    auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+    TypeRange flatSpTensorTps = desc.getFields().getTypes();
+    SmallVector<Value> params = llvm::to_vector(desc.getFields());
+    params.append(adaptor.getLvlCoords().begin(), adaptor.getLvlCoords().end());
+    params.push_back(adaptor.getValue());
+    SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps,
+                                    params, /*genCall=*/true);
+    SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc);
     // Replace operation with resulting memrefs.
-    rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), desc));
+    rewriter.replaceOp(op,
+                       genTuple(rewriter, loc, op.getTensor().getType(), ret));
     return success();
   }
 };