[mlir][sparse] proper in-place SDDMM with spy function
authorAart Bik <ajcbik@google.com>
Wed, 14 Jun 2023 22:18:00 +0000 (15:18 -0700)
committerAart Bik <ajcbik@google.com>
Thu, 15 Jun 2023 20:59:38 +0000 (13:59 -0700)
This specific operation matches the cuSPARSE SDDMM semantics exactly.

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D152969

mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/test/Dialect/SparseTensor/spy_sddmm.mlir [new file with mode: 0755]

index 428bc49..881e02e 100644 (file)
@@ -1136,10 +1136,26 @@ inline static Value genInvariantValue(CodegenEnv &env, ExprId exp) {
 /// inlined cloned code.
 static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
                           Value e, LoopId ldx) {
-  if (Operation *def = e.getDefiningOp()) {
+  if (auto arg = dyn_cast<BlockArgument>(e)) {
+    // Direct arguments of the original linalg op must be converted
+    // into dense tensor loads. Note that we should not encounter
+    // anything else. This needs to be verified by semi-ring ops.
+    linalg::GenericOp op = env.op();
+    if (arg.getOwner()->getParentOp() == op) {
+      const TensorId tid = env.makeTensorId(arg.getArgNumber());
+      OpOperand *t = &op->getOpOperand(tid);
+      assert(!getSparseTensorType(t->get()).hasEncoding()); // dense!
+      SmallVector<Value> args;
+      Value ptr = genSubscript(env, rewriter, t, args);
+      return rewriter.create<memref::LoadOp>(op.getLoc(), ptr, args);
+    }
+  } else if (Operation *def = e.getDefiningOp()) {
+    // Handle index computation.
     if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
       return env.getLoopVar(env.makeLoopId(indexOp.getDim()));
+    // When still defined in new body, recurse into operands.
     if (def->getBlock() == block) {
+      rewriter.setInsertionPoint(def);
       for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) {
         rewriter.updateRootInPlace(def, [&]() {
           def->setOperand(
@@ -1195,8 +1211,10 @@ static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e,
     if (ee &&
         (kind == TensorExp::Kind::kUnary || kind == TensorExp::Kind::kBinary ||
          kind == TensorExp::Kind::kBinaryBranch ||
-         kind == TensorExp::Kind::kReduce || kind == TensorExp::Kind::kSelect))
+         kind == TensorExp::Kind::kReduce || kind == TensorExp::Kind::kSelect)) {
+      OpBuilder::InsertionGuard guard(rewriter);
       ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee, ldx);
+    }
   }
 
   if (kind == TensorExp::Kind::kReduce)
diff --git a/mlir/test/Dialect/SparseTensor/spy_sddmm.mlir b/mlir/test/Dialect/SparseTensor/spy_sddmm.mlir
new file mode 100755 (executable)
index 0000000..8bc405a
--- /dev/null
@@ -0,0 +1,73 @@
+// RUN: mlir-opt %s -sparsification | FileCheck %s
+
+//
+// A SDDMM implementation with "spy" function and
+// in-place update of the sampling sparse matrix.
+//
+
+#SM = #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>
+
+#trait_sampled_dense_dense = {
+  indexing_maps = [
+    affine_map<(i,j,k) -> (i,k)>,  // A
+    affine_map<(i,j,k) -> (k,j)>,  // B
+    affine_map<(i,j,k) -> (i,j)>   // S
+  ],
+  iterator_types = ["parallel", "parallel", "reduction"],
+  doc = "S(i,j) += spy[S(i,j)] x SUM_k A(i,k) B(k,j)"
+}
+
+// CHECK-LABEL: func.func @sparse_sampled_dd(
+// CHECK-SAME:    %[[VAL_0:.*0]]: tensor<8x8xf64>,
+// CHECK-SAME:    %[[VAL_1:.*1]]: tensor<8x8xf64>,
+// CHECK-SAME:    %[[VAL_2:.*2]]: tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> {
+// CHECK-DAG:     %[[VAL_3:.*]] = arith.constant 8 : index
+// CHECK-DAG:     %[[VAL_4:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[VAL_6:.*]] = bufferization.to_memref %[[VAL_0]] : memref<8x8xf64>
+// CHECK-DAG:     %[[VAL_7:.*]] = bufferization.to_memref %[[VAL_1]] : memref<8x8xf64>
+// CHECK-DAG:     %[[VAL_8:.*]] = sparse_tensor.positions %[[VAL_2]] {level = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK-DAG:     %[[VAL_9:.*]] = sparse_tensor.coordinates %[[VAL_2]] {level = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK-DAG:     %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_2]] : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>
+// CHECK:         scf.for %[[VAL_11:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
+// CHECK:           scf.for %[[VAL_12:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] {
+// CHECK:             %[[VAL_13:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_11]]] : memref<?xindex>
+// CHECK:             %[[VAL_14:.*]] = arith.addi %[[VAL_11]], %[[VAL_5]] : index
+// CHECK:             %[[VAL_15:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_14]]] : memref<?xindex>
+// CHECK:             scf.for %[[VAL_16:.*]] = %[[VAL_13]] to %[[VAL_15]] step %[[VAL_5]] {
+// CHECK:               %[[VAL_17:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_16]]] : memref<?xindex>
+// CHECK:               %[[VAL_18:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_16]]] : memref<?xf64>
+// CHECK:               %[[VAL_19:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]], %[[VAL_12]]] : memref<8x8xf64>
+// CHECK:               %[[VAL_20:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_12]], %[[VAL_17]]] : memref<8x8xf64>
+// CHECK:               %[[VAL_21:.*]] = arith.mulf %[[VAL_19]], %[[VAL_20]] : f64
+// CHECK:               %[[VAL_22:.*]] = arith.addf %[[VAL_18]], %[[VAL_21]] : f64
+// CHECK:               memref.store %[[VAL_22]], %[[VAL_10]]{{\[}}%[[VAL_16]]] : memref<?xf64>
+// CHECK:             } {"Emitted from" = "linalg.generic"}
+// CHECK:           } {"Emitted from" = "linalg.generic"}
+// CHECK:         } {"Emitted from" = "linalg.generic"}
+// CHECK:         %[[VAL_23:.*]] = sparse_tensor.load %[[VAL_2]] : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:           return %[[VAL_23]] : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK:       }
+func.func @sparse_sampled_dd(%argA: tensor<8x8xf64>,
+                             %argB: tensor<8x8xf64>,
+                             %argS: tensor<8x8xf64, #SM>) -> tensor<8x8xf64, #SM> {
+  %f0 = arith.constant 0.0 : f64
+  %result = linalg.generic #trait_sampled_dense_dense
+    ins(%argA, %argB: tensor<8x8xf64>, tensor<8x8xf64>) outs(%argS: tensor<8x8xf64, #SM>) {
+      ^bb(%a: f64, %b: f64, %s: f64):
+         %u = sparse_tensor.unary %s : f64 to f64
+             present={
+                ^bb0(%p: f64):
+                  %mul = arith.mulf %a, %b : f64
+                  sparse_tensor.yield %mul : f64
+             }
+             absent={}
+         %r = sparse_tensor.reduce %s, %u, %f0 : f64 {
+              ^bb0(%p: f64, %q: f64):
+                %add = arith.addf %p, %q : f64
+                sparse_tensor.yield %add : f64
+            }
+         linalg.yield %r : f64
+  } -> tensor<8x8xf64, #SM>
+  return %result : tensor<8x8xf64, #SM>
+}