[mlir][sparse] Add operator sparse_tensor.indices_buffer.
authorbixia1 <bixia@google.com>
Thu, 5 Jan 2023 00:01:35 +0000 (16:01 -0800)
committerbixia1 <bixia@google.com>
Thu, 5 Jan 2023 17:35:55 +0000 (09:35 -0800)
Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D140762

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/test/Dialect/SparseTensor/invalid.mlir
mlir/test/Dialect/SparseTensor/roundtrip.mlir

index 771e97f..6404dc1 100644 (file)
@@ -154,6 +154,34 @@ def SparseTensor_ToIndicesOp : SparseTensor_Op<"indices", [Pure]>,
   let hasVerifier = 1;
 }
 
+def SparseTensor_ToIndicesBufferOp : SparseTensor_Op<"indices_buffer", [Pure]>,
+    Arguments<(ins AnySparseTensor:$tensor)>,
+    Results<(outs AnyStridedMemRefOfRank<1>:$result)> {
+  let summary = "Extracts the linear indices array from a tensor";
+  let description = [{
+    Returns the linear indices array for a sparse tensor with a trailing COO
+    region with at least two dimensions. It is an error if the tensor doesn't
+    contain such a COO region. This is similar to the `bufferization.to_memref`
+    operation in the sense that it provides a bridge between a tensor world view
+    and a bufferized world view. Unlike the `bufferization.to_memref` operation,
+    however, this sparse operation actually lowers into code that extracts the
+    linear indices array from the sparse storage scheme that stores the indices
+    for the COO region as an array of structures. For example, a 2D COO sparse
+    tensor with two non-zero elements at coordinates (1, 3) and (4, 6) are
+    stored in a linear buffer as (1, 4, 3, 6) instead of two buffer as (1, 4)
+    and (3, 6).
+
+    Example:
+
+    ```mlir
+    %1 = sparse_tensor.indices_buffer %0
+    : tensor<64x64xf64, #COO> to memref<?xindex>
+    ```
+  }];
+  let assemblyFormat = "$tensor attr-dict `:` type($tensor) `to` type($result)";
+  let hasVerifier = 1;
+}
+
 def SparseTensor_ToValuesOp : SparseTensor_Op<"values", [Pure]>,
     Arguments<(ins AnySparseTensor:$tensor)>,
     Results<(outs AnyStridedMemRefOfRank<1>:$result)> {
index 962353e..4d2c4fb 100644 (file)
@@ -496,6 +496,13 @@ LogicalResult ToIndicesOp::verify() {
   return success();
 }
 
+LogicalResult ToIndicesBufferOp::verify() {
+  auto e = getSparseTensorEncoding(getTensor().getType());
+  if (getCOOStart(e) >= e.getDimLevelType().size())
+    return emitError("expected sparse tensor with a COO region");
+  return success();
+}
+
 LogicalResult ToValuesOp::verify() {
   RankedTensorType ttp = getTensor().getType().cast<RankedTensorType>();
   MemRefType mtp = getResult().getType().cast<MemRefType>();
index 4482cf2..531a987 100644 (file)
@@ -90,6 +90,24 @@ func.func @invalid_values_dense(%arg0: tensor<1024xf32>) -> memref<?xf32> {
 
 #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
 
+func.func @indices_buffer_noncoo(%arg0: tensor<128xf64, #SparseVector>) -> memref<?xindex> {
+  // expected-error@+1 {{expected sparse tensor with a COO region}}
+  %0 = sparse_tensor.indices_buffer %arg0 : tensor<128xf64, #SparseVector> to memref<?xindex>
+  return %0 : memref<?xindex>
+}
+
+// -----
+
+func.func @indices_buffer_dense(%arg0: tensor<1024xf32>) -> memref<?xindex> {
+  // expected-error@+1 {{must be sparse tensor of any type values}}
+  %0 = sparse_tensor.indices_buffer %arg0 : tensor<1024xf32> to memref<?xindex>
+  return %0 : memref<?xindex>
+}
+
+// -----
+
+#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
+
 func.func @mismatch_values_types(%arg0: tensor<?xf64, #SparseVector>) -> memref<?xf32> {
   // expected-error@+1 {{unexpected mismatch in element types}}
   %0 = sparse_tensor.values %arg0 : tensor<?xf64, #SparseVector> to memref<?xf32>
index 67fefa2..58375d6 100644 (file)
@@ -78,6 +78,19 @@ func.func @sparse_pointers(%arg0: tensor<128xf64, #SparseVector>) -> memref<?xin
 
 // -----
 
+#COO = #sparse_tensor.encoding<{dimLevelType = ["compressed-nu", "singleton"]}>
+
+// CHECK-LABEL: func @sparse_indices_buffer(
+//  CHECK-SAME: %[[A:.*]]: tensor<?x?xf64, #{{.*}}>)
+//       CHECK: %[[T:.*]] = sparse_tensor.indices_buffer %[[A]] : tensor<?x?xf64, #{{.*}}> to memref<?xindex>
+//       CHECK: return %[[T]] : memref<?xindex>
+func.func @sparse_indices_buffer(%arg0: tensor<?x?xf64, #COO>) -> memref<?xindex> {
+  %0 = sparse_tensor.indices_buffer %arg0 : tensor<?x?xf64, #COO> to memref<?xindex>
+  return %0 : memref<?xindex>
+}
+
+// -----
+
 #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
 
 // CHECK-LABEL: func @sparse_indices(