[mlir][sparse] Implement sparse_tensor.select
authorJim Kitchen <jim22k@gmail.com>
Mon, 3 Oct 2022 19:34:53 +0000 (14:34 -0500)
committerJim Kitchen <jim22k@gmail.com>
Mon, 3 Oct 2022 19:39:26 +0000 (14:39 -0500)
The region within sparse_tensor.select is used as the runtime criteria
for whether to keep the existing value in the sparse tensor.

While the sparse element is provided to the comparison, indices may also
be used to decide on whether to keep the original value. This allows, for
example, to only keep the upper triangle of a matrix.

Reviewed by: aartbik

Differential Revision: https://reviews.llvm.org/D134761

mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_select.mlir [new file with mode: 0644]
mlir/unittests/Dialect/SparseTensor/MergerTest.cpp

index f9be476..a376d9a 100644 (file)
@@ -76,6 +76,7 @@ enum Kind {
   kBitCast,
   kBinaryBranch, // semiring unary branch created from a binary op
   kUnary,        // semiring unary op
+  kSelect,       // custom selection criteria
   // Binary operations.
   kMulF,
   kMulC,
@@ -129,8 +130,8 @@ struct TensorExp {
   /// this field may be used to cache "hoisted" loop invariant tensor loads.
   Value val;
 
-  /// Code blocks used by semirings. For the case of kUnary, kBinary, and
-  /// kReduce, this holds the original operation with all regions. For
+  /// Code blocks used by semirings. For the case of kUnary, kBinary, kReduce,
+  /// and kSelect, this holds the original operation with all regions. For
   /// kBinaryBranch, this holds the YieldOp for the left or right half
   /// to be merged into a nested scf loop.
   Operation *op;
index 88bd885..7cd9f7f 100644 (file)
@@ -878,6 +878,15 @@ static void genTensorStore(Merger &merger, CodeGen &codegen, OpBuilder &builder,
       // Only unary and binary are allowed to return uninitialized rhs
       // to indicate missing output.
       assert(merger.exp(exp).kind == kUnary || merger.exp(exp).kind == kBinary);
+    } else if (merger.exp(exp).kind == kSelect) {
+      scf::IfOp ifOp = builder.create<scf::IfOp>(loc, rhs);
+      builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+      // Existing value was preserved to be used here.
+      assert(merger.exp(exp).val);
+      Value v0 = merger.exp(exp).val;
+      genInsertionStore(codegen, builder, op, t, v0);
+      merger.exp(exp).val = Value();
+      builder.setInsertionPointAfter(ifOp);
     } else {
       genInsertionStore(codegen, builder, op, t, rhs);
     }
@@ -1037,9 +1046,15 @@ static Value genExp(Merger &merger, CodeGen &codegen, RewriterBase &rewriter,
   if (ee && (merger.exp(exp).kind == Kind::kUnary ||
              merger.exp(exp).kind == Kind::kBinary ||
              merger.exp(exp).kind == Kind::kBinaryBranch ||
-             merger.exp(exp).kind == Kind::kReduce))
+             merger.exp(exp).kind == Kind::kReduce ||
+             merger.exp(exp).kind == Kind::kSelect))
     ee = relinkBranch(codegen, rewriter, ee.getParentBlock(), ee, ldx);
 
+  if (merger.exp(exp).kind == kSelect) {
+    assert(!merger.exp(exp).val);
+    merger.exp(exp).val = v0; // Preserve value for later use.
+  }
+
   if (merger.exp(exp).kind == Kind::kReduce) {
     assert(codegen.redCustom != -1u);
     codegen.redCustom = -1u;
index 7f132f9..3791971 100644 (file)
@@ -78,6 +78,7 @@ TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
     children.e1 = y;
     break;
   case kBinaryBranch:
+  case kSelect:
     assert(x != -1u && y == -1u && !v && o);
     children.e0 = x;
     children.e1 = y;
@@ -212,7 +213,7 @@ unsigned Merger::takeCombi(Kind kind, unsigned s0, unsigned s1, Operation *orig,
 }
 
 unsigned Merger::mapSet(Kind kind, unsigned s0, Value v, Operation *op) {
-  assert(kAbsF <= kind && kind <= kUnary);
+  assert(kAbsF <= kind && kind <= kSelect);
   unsigned s = addSet();
   for (unsigned p : latSets[s0]) {
     unsigned e = addExp(kind, latPoints[p].exp, v, op);
@@ -265,9 +266,8 @@ BitVector Merger::simplifyCond(unsigned s0, unsigned p0) {
   BitVector simple = latPoints[p0].bits;
   bool reset = isSingleton && hasAnySparse(simple);
   for (unsigned b = 0, be = simple.size(); b < be; b++) {
-    if (simple[b] &&
-        (!isDimLevelType(b, DimLvlType::kCompressed) &&
-         !isDimLevelType(b, DimLvlType::kSingleton))) {
+    if (simple[b] && (!isDimLevelType(b, DimLvlType::kCompressed) &&
+                      !isDimLevelType(b, DimLvlType::kSingleton))) {
       if (reset)
         simple.reset(b);
       reset = true;
@@ -338,6 +338,7 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
     return isSingleCondition(t, tensorExps[e].children.e0);
   case kBinaryBranch:
   case kUnary:
+  case kSelect:
     return false;
   // Binary operations.
   case kDivF: // note: x / c only
@@ -449,6 +450,8 @@ static const char *kindToOpSymbol(Kind kind) {
     return "binary_branch";
   case kUnary:
     return "unary";
+  case kSelect:
+    return "select";
   // Binary operations.
   case kMulF:
   case kMulC:
@@ -537,6 +540,7 @@ void Merger::dumpExp(unsigned e) const {
   case kBitCast:
   case kBinaryBranch:
   case kUnary:
+  case kSelect:
     llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " ";
     dumpExp(tensorExps[e].children.e0);
     break;
@@ -684,6 +688,7 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
     return mapSet(kind, buildLattices(tensorExps[e].children.e0, i),
                   tensorExps[e].val);
   case kBinaryBranch:
+  case kSelect:
     // The left or right half of a binary operation which has already
     // been split into separate operations for each region.
     return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), Value(),
@@ -978,6 +983,10 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
             isAdmissableBranch(unop, unop.getAbsentRegion()))
           return addExp(kUnary, e, Value(), def);
       }
+      if (auto selop = dyn_cast<sparse_tensor::SelectOp>(def)) {
+        if (isAdmissableBranch(selop, selop.getRegion()))
+          return addExp(kSelect, e, Value(), def);
+      }
     }
   }
   // Construct binary operations if subexpressions can be built.
@@ -1228,6 +1237,9 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
                          *tensorExps[e].op->getBlock()->getParent(), {v0});
   case kUnary:
     return buildUnaryPresent(rewriter, loc, tensorExps[e].op, v0);
+  case kSelect:
+    return insertYieldOp(rewriter, loc,
+                         cast<SelectOp>(tensorExps[e].op).getRegion(), {v0});
   case kBinary:
     return buildBinaryOverlap(rewriter, loc, tensorExps[e].op, v0, v1);
   case kReduce: {
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_select.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_select.mlir
new file mode 100644 (file)
index 0000000..bbe94e9
--- /dev/null
@@ -0,0 +1,148 @@
+// RUN: mlir-opt %s --sparse-compiler | \
+// RUN: mlir-cpu-runner \
+// RUN:  -e entry -entry-point-result=void  \
+// RUN:  -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
+// RUN: FileCheck %s
+
+#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
+#CSR = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}>
+#CSC = #sparse_tensor.encoding<{
+  dimLevelType = [ "dense", "compressed" ],
+  dimOrdering = affine_map<(i,j) -> (j,i)>
+}>
+
+//
+// Traits for tensor operations.
+//
+#trait_vec_select = {
+  indexing_maps = [
+    affine_map<(i) -> (i)>, // A
+    affine_map<(i) -> (i)>  // C (out)
+  ],
+  iterator_types = ["parallel"]
+}
+
+#trait_mat_select = {
+  indexing_maps = [
+    affine_map<(i,j) -> (i,j)>,  // A (in)
+    affine_map<(i,j) -> (i,j)>   // X (out)
+  ],
+  iterator_types = ["parallel", "parallel"]
+}
+
+module {
+  func.func @vecSelect(%arga: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector> {
+    %c0 = arith.constant 0 : index
+    %cf1 = arith.constant 1.0 : f64
+    %d0 = tensor.dim %arga, %c0 : tensor<?xf64, #SparseVector>
+    %xv = bufferization.alloc_tensor(%d0): tensor<?xf64, #SparseVector>
+    %0 = linalg.generic #trait_vec_select
+      ins(%arga: tensor<?xf64, #SparseVector>)
+      outs(%xv: tensor<?xf64, #SparseVector>) {
+        ^bb(%a: f64, %b: f64):
+          %1 = sparse_tensor.select %a : f64 {
+              ^bb0(%x: f64):
+                %keep = arith.cmpf "oge", %x, %cf1 : f64
+                sparse_tensor.yield %keep : i1
+            }
+          linalg.yield %1 : f64
+    } -> tensor<?xf64, #SparseVector>
+    return %0 : tensor<?xf64, #SparseVector>
+  }
+
+  func.func @matUpperTriangle(%arga: tensor<?x?xf64, #CSR>) -> tensor<?x?xf64, #CSR> {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %d0 = tensor.dim %arga, %c0 : tensor<?x?xf64, #CSR>
+    %d1 = tensor.dim %arga, %c1 : tensor<?x?xf64, #CSR>
+    %xv = bufferization.alloc_tensor(%d0, %d1): tensor<?x?xf64, #CSR>
+    %0 = linalg.generic #trait_mat_select
+      ins(%arga: tensor<?x?xf64, #CSR>)
+      outs(%xv: tensor<?x?xf64, #CSR>) {
+        ^bb(%a: f64, %b: f64):
+          %row = linalg.index 0 : index
+          %col = linalg.index 1 : index
+          %1 = sparse_tensor.select %a : f64 {
+              ^bb0(%x: f64):
+                %keep = arith.cmpi "ugt", %col, %row : index
+                sparse_tensor.yield %keep : i1
+            }
+          linalg.yield %1 : f64
+    } -> tensor<?x?xf64, #CSR>
+    return %0 : tensor<?x?xf64, #CSR>
+  }
+
+  // Dumps a sparse vector of type f64.
+  func.func @dump_vec(%arg0: tensor<?xf64, #SparseVector>) {
+    // Dump the values array to verify only sparse contents are stored.
+    %c0 = arith.constant 0 : index
+    %d0 = arith.constant -1.0 : f64
+    %0 = sparse_tensor.values %arg0 : tensor<?xf64, #SparseVector> to memref<?xf64>
+    %1 = vector.transfer_read %0[%c0], %d0: memref<?xf64>, vector<8xf64>
+    vector.print %1 : vector<8xf64>
+    // Dump the dense vector to verify structure is correct.
+    %dv = sparse_tensor.convert %arg0 : tensor<?xf64, #SparseVector> to tensor<?xf64>
+    %2 = vector.transfer_read %dv[%c0], %d0: tensor<?xf64>, vector<16xf64>
+    vector.print %2 : vector<16xf64>
+    return
+  }
+
+  // Dump a sparse matrix.
+  func.func @dump_mat(%arg0: tensor<?x?xf64, #CSR>) {
+    // Dump the values array to verify only sparse contents are stored.
+    %c0 = arith.constant 0 : index
+    %d0 = arith.constant -1.0 : f64
+    %0 = sparse_tensor.values %arg0 : tensor<?x?xf64, #CSR> to memref<?xf64>
+    %1 = vector.transfer_read %0[%c0], %d0: memref<?xf64>, vector<16xf64>
+    vector.print %1 : vector<16xf64>
+    %dm = sparse_tensor.convert %arg0 : tensor<?x?xf64, #CSR> to tensor<?x?xf64>
+    %2 = vector.transfer_read %dm[%c0, %c0], %d0: tensor<?x?xf64>, vector<5x5xf64>
+    vector.print %2 : vector<5x5xf64>
+    return
+  }
+
+  // Driver method to call and verify vector kernels.
+  func.func @entry() {
+    %c0 = arith.constant 0 : index
+
+    // Setup sparse matrices.
+    %v1 = arith.constant sparse<
+        [ [1], [3], [5], [7], [9] ],
+        [ 1.0, 2.0, -4.0, 0.0, 5.0 ]
+    > : tensor<10xf64>
+    %m1 = arith.constant sparse<
+        [ [0, 3], [1, 4], [2, 1], [2, 3], [3, 3], [3, 4], [4, 2] ],
+        [ 1., 2., 3., 4., 5., 6., 7.]
+    > : tensor<5x5xf64>
+    %sv1 = sparse_tensor.convert %v1 : tensor<10xf64> to tensor<?xf64, #SparseVector>
+    %sm1 = sparse_tensor.convert %m1 : tensor<5x5xf64> to tensor<?x?xf64, #CSR>
+
+    // Call sparse matrix kernels.
+    %1 = call @vecSelect(%sv1) : (tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector>
+    %2 = call @matUpperTriangle(%sm1) : (tensor<?x?xf64, #CSR>) -> tensor<?x?xf64, #CSR>
+
+    //
+    // Verify the results.
+    //
+    // CHECK:      ( 1, 2, -4, 0, 5, -1, -1, -1 )
+    // CHECK-NEXT: ( 0, 1, 0, 2, 0, -4, 0, 0, 0, 5, -1, -1, -1, -1, -1, -1 )
+    // CHECK-NEXT: ( 1, 2, 3, 4, 5, 6, 7, -1, -1, -1, -1, -1, -1, -1, -1, -1 )
+    // CHECK-NEXT: ( ( 0, 0, 0, 1, 0 ), ( 0, 0, 0, 0, 2 ), ( 0, 3, 0, 4, 0 ), ( 0, 0, 0, 5, 6 ), ( 0, 0, 7, 0, 0 ) )
+    // CHECK-NEXT: ( 1, 2, 5, -1, -1, -1, -1, -1 )
+    // CHECK-NEXT: ( 0, 1, 0, 2, 0, 0, 0, 0, 0, 5, -1, -1, -1, -1, -1, -1 )
+    // CHECK-NEXT: ( 1, 2, 4, 6, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 )
+    // CHECK-NEXT: ( ( 0, 0, 0, 1, 0 ), ( 0, 0, 0, 0, 2 ), ( 0, 0, 0, 4, 0 ), ( 0, 0, 0, 0, 6 ), ( 0, 0, 0, 0, 0 ) )
+    //
+    call @dump_vec(%sv1) : (tensor<?xf64, #SparseVector>) -> ()
+    call @dump_mat(%sm1) : (tensor<?x?xf64, #CSR>) -> ()
+    call @dump_vec(%1) : (tensor<?xf64, #SparseVector>) -> ()
+    call @dump_mat(%2) : (tensor<?x?xf64, #CSR>) -> ()
+
+    // Release the resources.
+    bufferization.dealloc_tensor %sv1 : tensor<?xf64, #SparseVector>
+    bufferization.dealloc_tensor %sm1 : tensor<?x?xf64, #CSR>
+    bufferization.dealloc_tensor %1 : tensor<?xf64, #SparseVector>
+    bufferization.dealloc_tensor %2 : tensor<?x?xf64, #CSR>
+    return
+  }
+}
index 8d41558..c0e75dc 100644 (file)
@@ -259,6 +259,7 @@ protected:
     case kCIm:
     case kCRe:
     case kBitCast:
+    case kSelect:
     case kBinaryBranch:
     case kUnary:
       return compareExpression(tensorExp.children.e0, pattern->e0);