[mlir][sparse] move loop boundary method to codegenenv
authorAart Bik <ajcbik@google.com>
Thu, 22 Dec 2022 20:10:03 +0000 (12:10 -0800)
committerAart Bik <ajcbik@google.com>
Thu, 22 Dec 2022 20:40:45 +0000 (12:40 -0800)
Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D140578

mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp
mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

index 0be15d6..1310913 100644 (file)
@@ -37,6 +37,27 @@ void CodegenEnv::startEmit(OpOperand *so, unsigned lv,
   }
 }
 
+Optional<Operation *> CodegenEnv::genLoopBoundary(
+    function_ref<Optional<Operation *>(MutableArrayRef<Value> parameters)>
+        callback) {
+  SmallVector<Value> params;
+  if (isReduc())
+    params.push_back(redVal);
+  if (isExpand())
+    params.push_back(expCount);
+  if (insChain != nullptr)
+    params.push_back(insChain);
+  auto r = callback(params); // may update parameters
+  unsigned i = 0;
+  if (isReduc())
+    updateReduc(params[i++]);
+  if (isExpand())
+    updateExpandCount(params[i++]);
+  if (insChain != nullptr)
+    updateInsertionChain(params[i]);
+  return r;
+}
+
 //===----------------------------------------------------------------------===//
 // Code generation environment topological sort methods
 //===----------------------------------------------------------------------===//
index cb5ba99..47ce70c 100644 (file)
@@ -49,6 +49,12 @@ public:
 
   void startEmit(OpOperand *so, unsigned lv, SparseTensorLoopEmitter *le);
 
+  /// Generates loop boundary statements (entering/exiting loops). The function
+  /// passes and updates the passed-in parameters.
+  Optional<Operation *> genLoopBoundary(
+      function_ref<Optional<Operation *>(MutableArrayRef<Value> parameters)>
+          callback);
+
   //
   // Merger delegates.
   //
index eb71c4c..462dd7d 100644 (file)
@@ -604,34 +604,6 @@ static bool isAdmissibleTensorExp(CodegenEnv &env, unsigned exp,
 // Sparse compiler synthesis methods (statements and expressions).
 //===----------------------------------------------------------------------===//
 
-/// Generates loop boundary statements (entering/exiting loops). The function
-/// passes and updates the reduction value.
-static Optional<Operation *> genLoopBoundary(
-    CodegenEnv &env,
-    function_ref<Optional<Operation *>(MutableArrayRef<Value> reduc)>
-        callback) {
-  SmallVector<Value> reduc;
-  if (env.isReduc())
-    reduc.push_back(env.getReduc());
-  if (env.isExpand())
-    reduc.push_back(env.getExpandCount());
-  if (env.getInsertionChain())
-    reduc.push_back(env.getInsertionChain());
-
-  auto r = callback(reduc);
-
-  // Callback should do in-place update on reduction value vector.
-  unsigned i = 0;
-  if (env.isReduc())
-    env.updateReduc(reduc[i++]);
-  if (env.isExpand())
-    env.updateExpandCount(reduc[i++]);
-  if (env.getInsertionChain())
-    env.updateInsertionChain(reduc[i]);
-
-  return r;
-}
-
 /// Local bufferization of all dense and sparse data structures.
 static void genBuffers(CodegenEnv &env, OpBuilder &builder) {
   linalg::GenericOp op = env.op();
@@ -1066,7 +1038,7 @@ static Operation *genFor(CodegenEnv &env, OpBuilder &builder, bool isOuter,
       isCompressedDLT(env.dlt(tid, idx)) || isSingletonDLT(env.dlt(tid, idx));
   bool isParallel = isParallelFor(env, isOuter, isSparse);
 
-  Operation *loop = *genLoopBoundary(env, [&](MutableArrayRef<Value> reduc) {
+  Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
     if (env.merger().isFilterLoop(idx)) {
       // extraTids/extraDims must be empty because filter loops only
       // corresponding to the one and only sparse tensor level.
@@ -1092,7 +1064,7 @@ static Operation *genWhile(CodegenEnv &env, OpBuilder &builder, unsigned idx,
                            ArrayRef<size_t> condDims,
                            ArrayRef<size_t> extraTids,
                            ArrayRef<size_t> extraDims) {
-  Operation *loop = *genLoopBoundary(env, [&](MutableArrayRef<Value> reduc) {
+  Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
     // Construct the while-loop with a parameter for each
     // index.
     return env.emitter()->enterCoIterationOverTensorsAtDims(
@@ -1425,7 +1397,7 @@ static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop,
     needsUniv = false;
   }
 
-  genLoopBoundary(env, [&](MutableArrayRef<Value> reduc) {
+  env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
     env.emitter()->exitCurrentLoop(rewriter, env.op().getLoc(), reduc);
     return std::nullopt;
   });