[mlir][sparse] Add sparse_tensor.select operation
authorJim Kitchen <jim22k@gmail.com>
Tue, 13 Sep 2022 20:22:53 +0000 (15:22 -0500)
committerJim Kitchen <jim22k@gmail.com>
Tue, 13 Sep 2022 20:22:53 +0000 (15:22 -0500)
The new select operation allows filtering of sparse tensors
by conditionally keeping or removing each element. This
can be used to remove negative values or select the upper
triangle of a matrix.

The select op has a single region which operates on a single
value and must return a boolean True to keep or False to drop.

Reviewed by: aartbik

Differential Revision: https://reviews.llvm.org/D133569

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/test/Dialect/SparseTensor/invalid.mlir
mlir/test/Dialect/SparseTensor/roundtrip.mlir

index 28401da..ed1943f 100644 (file)
@@ -604,11 +604,72 @@ def SparseTensor_ReduceOp : SparseTensor_Op<"reduce", [NoSideEffect, SameOperand
   let hasVerifier = 1;
 }
 
+def SparseTensor_SelectOp : SparseTensor_Op<"select", [NoSideEffect, SameOperandsAndResultType]>,
+    Arguments<(ins AnyType:$x)>,
+    Results<(outs AnyType:$output)> {
+  let summary = "Select operation utilized within linalg.generic";
+  let description = [{
+      Defines an evaluation within a `linalg.generic` operation that takes a single
+      operand and decides whether or not to keep that operand in the output.
+
+      A single region must contain exactly one block taking one argument. The block
+      must end with a sparse_tensor.yield and the output type must be boolean.
+
+      Value threshold is an obvious usage of the select operation. However, by using
+      `linalg.index`, other useful selection can be achieved, such as selecting the
+      upper triangle of a matrix.
+
+      Example of selecting A >= 4.0:
+
+      ```mlir
+      %C = bufferization.alloc_tensor...
+      %0 = linalg.generic #trait
+         ins(%A: tensor<?xf64, #SparseVector>)
+        outs(%C: tensor<?xf64, #SparseVector>) {
+        ^bb0(%a: f64, %c: f64) :
+          %result = sparse_tensor.select %a : f64 {
+              ^bb0(%arg0: f64):
+                %cf4 = arith.constant 4.0 : f64
+                %keep = arith.cmpf "uge", %arg0, %cf4 : f64
+                sparse_tensor.yield %keep : i1
+            }
+          linalg.yield %result : f64
+      } -> tensor<?xf64, #SparseVector>
+      ```
+
+      Example of selecting lower triangle of a matrix:
+
+      ```mlir
+      %C = bufferization.alloc_tensor...
+      %0 = linalg.generic #trait
+         ins(%A: tensor<?x?xf64, #CSR>)
+        outs(%C: tensor<?x?xf64, #CSR>) {
+        ^bb0(%a: f64, %c: f64) :
+          %row = linalg.index 0 : index
+          %col = linalg.index 1 : index
+          %result = sparse_tensor.select %a : f64 {
+              ^bb0(%arg0: f64):
+                %keep = arith.cmpf "olt", %col, %row : f64
+                sparse_tensor.yield %keep : i1
+            }
+          linalg.yield %result : f64
+      } -> tensor<?x?xf64, #CSR>
+      ```
+  }];
+
+  let regions = (region SizedRegion<1>:$region);
+  let assemblyFormat = [{
+         $x attr-dict `:` type($x) $region
+  }];
+  let hasVerifier = 1;
+}
+
 def SparseTensor_YieldOp : SparseTensor_Op<"yield", [NoSideEffect, Terminator]>,
     Arguments<(ins AnyType:$result)> {
   let summary = "Yield from sparse_tensor set-like operations";
   let description = [{
-      Yields a value from within a `binary` or `unary` block.
+      Yields a value from within a `binary`, `unary`, `reduce`,
+      or `select` block.
 
       Example:
 
index 3c55364..c647b0b 100644 (file)
@@ -458,12 +458,27 @@ LogicalResult ReduceOp::verify() {
 
   // Check correct number of block arguments and return type.
   Region &formula = getRegion();
-  if (!formula.empty()) {
-    regionResult = verifyNumBlockArgs(
-        this, formula, "reduce", TypeRange{inputType, inputType}, inputType);
-    if (failed(regionResult))
-      return regionResult;
-  }
+  regionResult = verifyNumBlockArgs(this, formula, "reduce",
+                                    TypeRange{inputType, inputType}, inputType);
+  if (failed(regionResult))
+    return regionResult;
+
+  return success();
+}
+
+LogicalResult SelectOp::verify() {
+  Builder b(getContext());
+
+  Type inputType = getX().getType();
+  Type boolType = b.getI1Type();
+  LogicalResult regionResult = success();
+
+  // Check correct number of block arguments and return type.
+  Region &formula = getRegion();
+  regionResult = verifyNumBlockArgs(this, formula, "select",
+                                    TypeRange{inputType}, boolType);
+  if (failed(regionResult))
+    return regionResult;
 
   return success();
 }
@@ -472,11 +487,11 @@ LogicalResult YieldOp::verify() {
   // Check for compatible parent.
   auto *parentOp = (*this)->getParentOp();
   if (isa<BinaryOp>(parentOp) || isa<UnaryOp>(parentOp) ||
-      isa<ReduceOp>(parentOp))
+      isa<ReduceOp>(parentOp) || isa<SelectOp>(parentOp))
     return success();
 
-  return emitOpError(
-      "expected parent op to be sparse_tensor unary, binary, or reduce");
+  return emitOpError("expected parent op to be sparse_tensor unary, binary, "
+                     "reduce, or select");
 }
 
 //===----------------------------------------------------------------------===//
index c8fd0dd..c607dd2 100644 (file)
@@ -355,6 +355,40 @@ func.func @invalid_reduce_wrong_yield(%arg0: f64, %arg1: f64) -> f64 {
 
 // -----
 
+func.func @invalid_select_num_args_mismatch(%arg0: f64) -> f64 {
+  // expected-error@+1 {{select region must have exactly 1 arguments}}
+  %r = sparse_tensor.select %arg0 : f64 {
+      ^bb0(%x: f64, %y: f64):
+        %ret = arith.constant 1 : i1
+        sparse_tensor.yield %ret : i1
+    }
+  return %r : f64
+}
+
+// -----
+
+func.func @invalid_select_return_type_mismatch(%arg0: f64) -> f64 {
+  // expected-error@+1 {{select region yield type mismatch}}
+  %r = sparse_tensor.select %arg0 : f64 {
+      ^bb0(%x: f64):
+        sparse_tensor.yield %x : f64
+    }
+  return %r : f64
+}
+
+// -----
+
+func.func @invalid_select_wrong_yield(%arg0: f64) -> f64 {
+  // expected-error@+1 {{select region must end with sparse_tensor.yield}}
+  %r = sparse_tensor.select %arg0 : f64 {
+      ^bb0(%x: f64):
+        tensor.yield %x : f64
+    }
+  return %r : f64
+}
+
+// -----
+
 #DC = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}>
 func.func @invalid_concat_less_inputs(%arg: tensor<9x4xf64, #DC>) -> tensor<9x4xf64, #DC> {
   // expected-error@+1 {{Need at least two tensors to concatenate.}}
index b795f54..5c22ffb 100644 (file)
@@ -291,6 +291,30 @@ func.func @sparse_reduce_2d_to_1d(%arg0: f64, %arg1: f64) -> f64 {
 
 #SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
 
+// CHECK-LABEL: func @sparse_select(
+//  CHECK-SAME:   %[[A:.*]]: f64) -> f64 {
+//       CHECK:   %[[Z:.*]] = arith.constant 0.000000e+00 : f64
+//       CHECK:   %[[C1:.*]] = sparse_tensor.select %[[A]] : f64 {
+//       CHECK:       ^bb0(%[[A1:.*]]: f64):
+//       CHECK:         %[[B1:.*]] = arith.cmpf ogt, %[[A1]], %[[Z]] : f64
+//       CHECK:         sparse_tensor.yield %[[B1]] : i1
+//       CHECK:     }
+//       CHECK:   return %[[C1]] : f64
+//       CHECK: }
+func.func @sparse_select(%arg0: f64) -> f64 {
+  %cf0 = arith.constant 0.0 : f64
+  %r = sparse_tensor.select %arg0 : f64 {
+      ^bb0(%x: f64):
+        %cmp = arith.cmpf "ogt", %x, %cf0 : f64
+        sparse_tensor.yield %cmp : i1
+    }
+  return %r : f64
+}
+
+// -----
+
+#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
+
 // CHECK-LABEL: func @concat_sparse_sparse(
 //  CHECK-SAME:   %[[A0:.*]]: tensor<2x4xf64
 //  CHECK-SAME:   %[[A1:.*]]: tensor<3x4xf64