[mlir][sparse] add support for direct prod/and/min/max reductions
authorAart Bik <ajcbik@google.com>
Fri, 9 Jun 2023 19:48:45 +0000 (12:48 -0700)
committerAart Bik <ajcbik@google.com>
Mon, 12 Jun 2023 16:27:47 +0000 (09:27 -0700)
We recently fixed a bug in "sparsifying" such reductions, since
it incorrectly changed this into reductions over stored elements
only , which only works for add/sub/or/xor. However, we still want
to be able to "sparsify" the reductions even in the general case,
and this is a first step by rewriting them into a custom reduction
that feeds in the implicit zeros. NOTE HOWEVER, that in the long run
we want to do this better and feed in any implicit zero only ONCE
for efficiency.

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D152580

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reductions_min.mlir [new file with mode: 0644]

index de0f2f7..2c5289a 100644 (file)
@@ -387,6 +387,90 @@ public:
   }
 };
 
+/// Rewrites a sparse reduction that would not sparsify directly since
+/// doing so would only iterate over the stored elements, ignoring the
+/// implicit zeros, into a semi-ring. Applies to all prod/and/min/max
+/// (note that reductions like add/sub/or/xor can directly be sparsified
+/// since the implicit zeros do not contribute to the final result).
+/// Note that prod/and are still included since, even though they often
+/// are nullified in sparse data, they may still occur for special
+/// situations in which e.g. some rows in a sparse matrix are fully
+/// dense. For min/max, including the implicit zeros is a much more
+/// common situation.
+///
+/// TODO: this essentially "densifies" the operation; we want to implement
+///       this much more efficiently by performing the reduction over the
+///       stored values, and feed in the zero once if there were *any*
+///       implicit zeros as well; but for now, at least we provide
+///       the functionality
+///
+struct GenSemiRingReduction : public OpRewritePattern<GenericOp> {
+public:
+  using OpRewritePattern<GenericOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(GenericOp op,
+                                PatternRewriter &rewriter) const override {
+    // Reject non-reductions.
+    if (!op.hasTensorSemantics() || op.getNumDpsInputs() != 1 ||
+        op.getNumReductionLoops() == 0 || op.getNumResults() != 1)
+      return failure();
+    auto inp = op.getDpsInputOperand(0);
+    auto init = op.getDpsInitOperand(0);
+    if (!isSparseTensor(inp))
+      return failure();
+    // Look for direct x = x OP y for semi-ring ready reductions.
+    auto red = cast<linalg::YieldOp>(op.getRegion().front().getTerminator())
+                   .getOperand(0)
+                   .getDefiningOp();
+    if (!isa<arith::AndIOp, arith::MulIOp, arith::MulFOp, arith::MinFOp,
+             arith::MinSIOp, arith::MinUIOp, arith::MaxFOp, arith::MaxSIOp,
+             arith::MaxUIOp>(red))
+      return failure();
+    Value s0 = op.getBlock()->getArgument(0);
+    Value s1 = op.getBlock()->getArgument(1);
+    if ((red->getOperand(0) != s0 || red->getOperand(1) != s1) &&
+        (red->getOperand(0) != s1 || red->getOperand(1) != s0))
+      return failure();
+    // Identity.
+    Location loc = op.getLoc();
+    Value identity =
+        rewriter.create<tensor::ExtractOp>(loc, init->get(), ValueRange());
+    // Unary {
+    //    present -> value
+    //    absent  -> zero.
+    // }
+    Type rtp = s0.getType();
+    rewriter.setInsertionPointToStart(&op.getRegion().front());
+    auto semiring = rewriter.create<sparse_tensor::UnaryOp>(loc, rtp, s0);
+    Block *present =
+        rewriter.createBlock(&semiring.getPresentRegion(), {}, rtp, loc);
+    rewriter.setInsertionPointToStart(&semiring.getPresentRegion().front());
+    rewriter.create<sparse_tensor::YieldOp>(loc, present->getArgument(0));
+    rewriter.createBlock(&semiring.getAbsentRegion(), {}, {}, {});
+    rewriter.setInsertionPointToStart(&semiring.getAbsentRegion().front());
+    auto zero =
+        rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(rtp));
+    rewriter.create<sparse_tensor::YieldOp>(loc, zero);
+    rewriter.setInsertionPointAfter(semiring);
+    // CustomReduce {
+    //    x = x REDUC y, identity
+    // }
+    auto custom = rewriter.create<sparse_tensor::ReduceOp>(
+        loc, rtp, semiring.getResult(), s1, identity);
+    Block *region =
+        rewriter.createBlock(&custom.getRegion(), {}, {rtp, rtp}, {loc, loc});
+    rewriter.setInsertionPointToStart(&custom.getRegion().front());
+    IRMapping irMap;
+    irMap.map(red->getOperand(0), region->getArgument(0));
+    irMap.map(red->getOperand(1), region->getArgument(1));
+    auto cloned = rewriter.clone(*red, irMap);
+    rewriter.create<sparse_tensor::YieldOp>(loc, cloned->getResult(0));
+    rewriter.setInsertionPointAfter(custom);
+    rewriter.replaceOp(red, custom.getResult());
+    return success();
+  }
+};
+
 /// Sparse rewriting rule for sparse-to-sparse reshape operator.
 struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
 public:
@@ -1262,8 +1346,8 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
 //===---------------------------------------------------------------------===//
 
 void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
-  patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast>(
-      patterns.getContext());
+  patterns.add<FoldInvariantYield, FuseSparseMultiplyOverAdd, FuseTensorCast,
+               GenSemiRingReduction>(patterns.getContext());
 }
 
 void mlir::populatePostSparsificationRewriting(RewritePatternSet &patterns,
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reductions_min.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_reductions_min.mlir
new file mode 100644 (file)
index 0000000..d624f15
--- /dev/null
@@ -0,0 +1,128 @@
+// DEFINE: %{option} = enable-runtime-library=true
+// DEFINE: %{command} = mlir-opt %s --sparse-compiler=%{option} | \
+// DEFINE: mlir-cpu-runner \
+// DEFINE:  -e entry -entry-point-result=void  \
+// DEFINE:  -shared-libs=%mlir_c_runner_utils | \
+// DEFINE: FileCheck %s
+//
+// RUN: %{command}
+//
+// Do the same run, but now with direct IR generation.
+// REDEFINE: %{option} = enable-runtime-library=false
+// RUN: %{command}
+//
+// Do the same run, but now with direct IR generation and vectorization.
+// REDEFINE: %{option} = "enable-runtime-library=false vl=2 reassociate-fp-reductions=true enable-index-optimizations=true"
+// RUN: %{command}
+
+#SV = #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ] }>
+
+#trait_reduction = {
+  indexing_maps = [
+    affine_map<(i) -> (i)>,  // a
+    affine_map<(i) -> ()>    // x (scalar out)
+  ],
+  iterator_types = ["reduction"],
+  doc = "x += MIN_i a(i)"
+}
+
+// Examples of sparse vector MIN reductions.
+module {
+
+  // Custom MIN reduction: stored i32 elements only.
+  func.func @min1(%arga: tensor<32xi32, #SV>, %argx: tensor<i32>) -> tensor<i32> {
+    %c = tensor.extract %argx[] : tensor<i32>
+    %0 = linalg.generic #trait_reduction
+      ins(%arga: tensor<32xi32, #SV>)
+      outs(%argx: tensor<i32>) {
+        ^bb(%a: i32, %b: i32):
+          %1 = sparse_tensor.reduce %a, %b, %c : i32 {
+            ^bb0(%x: i32, %y: i32):
+             %m = arith.minsi %x, %y : i32
+              sparse_tensor.yield %m : i32
+          }
+          linalg.yield %1 : i32
+    } -> tensor<i32>
+    return %0 : tensor<i32>
+  }
+
+  // Regular MIN reduction: stored i32 elements AND implicit zeros.
+  // Note that dealing with the implicit zeros is taken care of
+  // by the sparse compiler to preserve semantics of the "original".
+  func.func @min2(%arga: tensor<32xi32, #SV>, %argx: tensor<i32>) -> tensor<i32> {
+    %c = tensor.extract %argx[] : tensor<i32>
+    %0 = linalg.generic #trait_reduction
+      ins(%arga: tensor<32xi32, #SV>)
+      outs(%argx: tensor<i32>) {
+        ^bb(%a: i32, %b: i32):
+         %m = arith.minsi %a, %b : i32
+          linalg.yield %m : i32
+    } -> tensor<i32>
+    return %0 : tensor<i32>
+  }
+
+  func.func @dump_i32(%arg0 : tensor<i32>) {
+    %v = tensor.extract %arg0[] : tensor<i32>
+    vector.print %v : i32
+    return
+  }
+
+  func.func @entry() {
+    %ri = arith.constant dense<999> : tensor<i32>
+
+    // Vectors with a few zeros.
+    %c_0_i32 = arith.constant dense<[
+      2, 2, 7, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2,
+      2, 2, 2, 2, 3, 0, 9, 2, 2, 2, 2, 0, 5, 1, 7, 3
+    ]> : tensor<32xi32>
+
+    // Vectors with no zeros.
+    %c_1_i32 = arith.constant dense<[
+      2, 2, 7, 2, 2, 2, 2, 2, 2, 2, 2, 4, 2, 2, 2, 2,
+      2, 2, 2, 2, 3, 2, 7, 2, 2, 2, 2, 2, 2, 1, 7, 3
+    ]> : tensor<32xi32>
+
+    // Convert constants to annotated tensors. Note that this
+    // particular conversion only stores nonzero elements,
+    // so we will have no explicit zeros, only implicit zeros.
+    %sv0 = sparse_tensor.convert %c_0_i32
+      : tensor<32xi32> to tensor<32xi32, #SV>
+    %sv1 = sparse_tensor.convert %c_1_i32
+      : tensor<32xi32> to tensor<32xi32, #SV>
+
+    // Special case, construct a sparse vector with an explicit zero.
+    %v = arith.constant sparse< [ [1], [7] ], [ 0, 22 ] > : tensor<32xi32>
+    %sv2 = sparse_tensor.convert %v: tensor<32xi32> to tensor<32xi32, #SV>
+
+    // Call the kernels.
+    %0 = call @min1(%sv0, %ri) : (tensor<32xi32, #SV>, tensor<i32>) -> tensor<i32>
+    %1 = call @min1(%sv1, %ri) : (tensor<32xi32, #SV>, tensor<i32>) -> tensor<i32>
+    %2 = call @min1(%sv2, %ri) : (tensor<32xi32, #SV>, tensor<i32>) -> tensor<i32>
+    %3 = call @min2(%sv0, %ri) : (tensor<32xi32, #SV>, tensor<i32>) -> tensor<i32>
+    %4 = call @min2(%sv1, %ri) : (tensor<32xi32, #SV>, tensor<i32>) -> tensor<i32>
+    %5 = call @min2(%sv2, %ri) : (tensor<32xi32, #SV>, tensor<i32>) -> tensor<i32>
+
+    // Verify results.
+    //
+    // CHECK: 1
+    // CHECK: 1
+    // CHECK: 0
+    // CHECK: 0
+    // CHECK: 1
+    // CHECK: 0
+    //
+    call @dump_i32(%0) : (tensor<i32>) -> ()
+    call @dump_i32(%1) : (tensor<i32>) -> ()
+    call @dump_i32(%2) : (tensor<i32>) -> ()
+    call @dump_i32(%3) : (tensor<i32>) -> ()
+    call @dump_i32(%4) : (tensor<i32>) -> ()
+    call @dump_i32(%5) : (tensor<i32>) -> ()
+
+    // Release the resources.
+    bufferization.dealloc_tensor %sv0 : tensor<32xi32, #SV>
+    bufferization.dealloc_tensor %sv1 : tensor<32xi32, #SV>
+    bufferization.dealloc_tensor %sv2 : tensor<32xi32, #SV>
+
+    return
+  }
+}