let hasVerifier = 1;
}
+def SparseTensor_SelectOp : SparseTensor_Op<"select", [NoSideEffect, SameOperandsAndResultType]>,
+ Arguments<(ins AnyType:$x)>,
+ Results<(outs AnyType:$output)> {
+ let summary = "Select operation utilized within linalg.generic";
+ let description = [{
+ Defines an evaluation within a `linalg.generic` operation that takes a single
+ operand and decides whether or not to keep that operand in the output.
+
+ A single region must contain exactly one block taking one argument. The block
+ must end with a sparse_tensor.yield and the output type must be boolean.
+
+ Value threshold is an obvious usage of the select operation. However, by using
+ `linalg.index`, other useful selection can be achieved, such as selecting the
+ upper triangle of a matrix.
+
+ Example of selecting A >= 4.0:
+
+ ```mlir
+ %C = bufferization.alloc_tensor...
+ %0 = linalg.generic #trait
+ ins(%A: tensor<?xf64, #SparseVector>)
+ outs(%C: tensor<?xf64, #SparseVector>) {
+ ^bb0(%a: f64, %c: f64) :
+ %result = sparse_tensor.select %a : f64 {
+ ^bb0(%arg0: f64):
+ %cf4 = arith.constant 4.0 : f64
+ %keep = arith.cmpf "uge", %arg0, %cf4 : f64
+ sparse_tensor.yield %keep : i1
+ }
+ linalg.yield %result : f64
+ } -> tensor<?xf64, #SparseVector>
+ ```
+
+ Example of selecting lower triangle of a matrix:
+
+ ```mlir
+ %C = bufferization.alloc_tensor...
+ %0 = linalg.generic #trait
+ ins(%A: tensor<?x?xf64, #CSR>)
+ outs(%C: tensor<?x?xf64, #CSR>) {
+ ^bb0(%a: f64, %c: f64) :
+ %row = linalg.index 0 : index
+ %col = linalg.index 1 : index
+ %result = sparse_tensor.select %a : f64 {
+ ^bb0(%arg0: f64):
+ %keep = arith.cmpf "olt", %col, %row : f64
+ sparse_tensor.yield %keep : i1
+ }
+ linalg.yield %result : f64
+ } -> tensor<?x?xf64, #CSR>
+ ```
+ }];
+
+ let regions = (region SizedRegion<1>:$region);
+ let assemblyFormat = [{
+ $x attr-dict `:` type($x) $region
+ }];
+ let hasVerifier = 1;
+}
+
def SparseTensor_YieldOp : SparseTensor_Op<"yield", [NoSideEffect, Terminator]>,
Arguments<(ins AnyType:$result)> {
let summary = "Yield from sparse_tensor set-like operations";
let description = [{
- Yields a value from within a `binary` or `unary` block.
+ Yields a value from within a `binary`, `unary`, `reduce`,
+ or `select` block.
Example:
// Check correct number of block arguments and return type.
Region &formula = getRegion();
- if (!formula.empty()) {
- regionResult = verifyNumBlockArgs(
- this, formula, "reduce", TypeRange{inputType, inputType}, inputType);
- if (failed(regionResult))
- return regionResult;
- }
+ regionResult = verifyNumBlockArgs(this, formula, "reduce",
+ TypeRange{inputType, inputType}, inputType);
+ if (failed(regionResult))
+ return regionResult;
+
+ return success();
+}
+
+LogicalResult SelectOp::verify() {
+ Builder b(getContext());
+
+ Type inputType = getX().getType();
+ Type boolType = b.getI1Type();
+ LogicalResult regionResult = success();
+
+ // Check correct number of block arguments and return type.
+ Region &formula = getRegion();
+ regionResult = verifyNumBlockArgs(this, formula, "select",
+ TypeRange{inputType}, boolType);
+ if (failed(regionResult))
+ return regionResult;
return success();
}
// Check for compatible parent.
auto *parentOp = (*this)->getParentOp();
if (isa<BinaryOp>(parentOp) || isa<UnaryOp>(parentOp) ||
- isa<ReduceOp>(parentOp))
+ isa<ReduceOp>(parentOp) || isa<SelectOp>(parentOp))
return success();
- return emitOpError(
- "expected parent op to be sparse_tensor unary, binary, or reduce");
+ return emitOpError("expected parent op to be sparse_tensor unary, binary, "
+ "reduce, or select");
}
//===----------------------------------------------------------------------===//
// -----
+func.func @invalid_select_num_args_mismatch(%arg0: f64) -> f64 {
+ // expected-error@+1 {{select region must have exactly 1 arguments}}
+ %r = sparse_tensor.select %arg0 : f64 {
+ ^bb0(%x: f64, %y: f64):
+ %ret = arith.constant 1 : i1
+ sparse_tensor.yield %ret : i1
+ }
+ return %r : f64
+}
+
+// -----
+
+func.func @invalid_select_return_type_mismatch(%arg0: f64) -> f64 {
+ // expected-error@+1 {{select region yield type mismatch}}
+ %r = sparse_tensor.select %arg0 : f64 {
+ ^bb0(%x: f64):
+ sparse_tensor.yield %x : f64
+ }
+ return %r : f64
+}
+
+// -----
+
+func.func @invalid_select_wrong_yield(%arg0: f64) -> f64 {
+ // expected-error@+1 {{select region must end with sparse_tensor.yield}}
+ %r = sparse_tensor.select %arg0 : f64 {
+ ^bb0(%x: f64):
+ tensor.yield %x : f64
+ }
+ return %r : f64
+}
+
+// -----
+
#DC = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}>
func.func @invalid_concat_less_inputs(%arg: tensor<9x4xf64, #DC>) -> tensor<9x4xf64, #DC> {
// expected-error@+1 {{Need at least two tensors to concatenate.}}
#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
+// CHECK-LABEL: func @sparse_select(
+// CHECK-SAME: %[[A:.*]]: f64) -> f64 {
+// CHECK: %[[Z:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK: %[[C1:.*]] = sparse_tensor.select %[[A]] : f64 {
+// CHECK: ^bb0(%[[A1:.*]]: f64):
+// CHECK: %[[B1:.*]] = arith.cmpf ogt, %[[A1]], %[[Z]] : f64
+// CHECK: sparse_tensor.yield %[[B1]] : i1
+// CHECK: }
+// CHECK: return %[[C1]] : f64
+// CHECK: }
+func.func @sparse_select(%arg0: f64) -> f64 {
+ %cf0 = arith.constant 0.0 : f64
+ %r = sparse_tensor.select %arg0 : f64 {
+ ^bb0(%x: f64):
+ %cmp = arith.cmpf "ogt", %x, %cf0 : f64
+ sparse_tensor.yield %cmp : i1
+ }
+ return %r : f64
+}
+
+// -----
+
+#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
+
// CHECK-LABEL: func @concat_sparse_sparse(
// CHECK-SAME: %[[A0:.*]]: tensor<2x4xf64
// CHECK-SAME: %[[A1:.*]]: tensor<3x4xf64