[mlir][sparse] Introduce sparse_tensor.storage operator to create a sparse tensor...
authorPeiming Liu <peiming@google.com>
Fri, 2 Sep 2022 20:31:47 +0000 (20:31 +0000)
committerPeiming Liu <peiming@google.com>
Sat, 3 Sep 2022 00:08:29 +0000 (00:08 +0000)
Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D133231

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/test/Dialect/SparseTensor/invalid.mlir
mlir/test/Dialect/SparseTensor/roundtrip.mlir

index 3e1564f..9272e2e 100644 (file)
@@ -629,6 +629,29 @@ def SparseTensor_YieldOp : SparseTensor_Op<"yield", [NoSideEffect, Terminator]>,
 // sparse tensor codegen to progressively lower sparse tensors.
 //===----------------------------------------------------------------------===//
 
+def SparseTensor_StorageNewOp : SparseTensor_Op<"storage", []>,
+    Arguments<(ins Variadic<AnyType>:$inputs)>,
+    Results<(outs AnyTuple:$result)> {
+  let summary = "Pack a list of value into one sparse tensor storage value";
+  let description = [{
+     Pack a list of value into one sparse tensor storage value (represented as
+     a tuple) at the given index.
+
+     The result tuple elements' type should match the corresponding type in the
+     input array.
+
+     Example:
+
+     ```mlir
+     %0 = sparse_tensor.storage(%1, %2): memref<?xf64>, memref<?xf64>
+                                to tuple<memref<?xf64>, memref<?xf64>>
+     ```
+   }];
+
+  let assemblyFormat = " attr-dict `(` $inputs `)``:` type($inputs) `to` type($result)";
+  let hasVerifier = 1;
+}
+
 def SparseTensor_StorageGetOp : SparseTensor_Op<"storage_get", []>,
     Arguments<(ins AnyTuple:$storage,
                    IndexAttr:$idx)>,
index 1c76f7e..22cf768 100644 (file)
@@ -486,6 +486,23 @@ LogicalResult YieldOp::verify() {
 // Sparse Tensor Storage Operation.
 //===----------------------------------------------------------------------===//
 
+LogicalResult StorageNewOp::verify() {
+  auto retTypes = getResult().getType().getTypes();
+  if (retTypes.size() != getInputs().size())
+    return emitError("The number of inputs is inconsistent with output tuple");
+
+  for (auto pair : llvm::zip(getInputs(), retTypes)) {
+    auto input = std::get<0>(pair);
+    auto retTy = std::get<1>(pair);
+
+    if (input.getType() != retTy)
+      return emitError(llvm::formatv("Type mismatch between input (type={0}) "
+                                     "and output tuple element (type={1})",
+                                     input.getType(), retTy));
+  }
+  return success();
+}
+
 LogicalResult StorageGetOp::verify() {
   uint64_t extractIdx = getIdx().getZExtValue();
   auto innerTypeArray = getStorage().getType().getTypes();
index 805f959..b9555e8 100644 (file)
@@ -445,6 +445,26 @@ func.func @invalid_concat_size_mismatch(%arg0: tensor<2x4xf64, #DC>,
 
 // -----
 
+func.func @sparse_storage_new(%arg0: memref<?xf64>, %arg1: memref<?xf64>, %arg2: f64) ->
+                               tuple<memref<?xf64>, memref<?xf64>> {
+  // expected-error@+1{{The number of inputs is inconsistent with output}}
+  %0 = sparse_tensor.storage(%arg0, %arg1, %arg2)
+       : memref<?xf64>, memref<?xf64>, f64 to tuple<memref<?xf64>, memref<?xf64>>
+  return %0 : tuple<memref<?xf64>, memref<?xf64>>
+}
+
+// -----
+
+func.func @sparse_storage_new(%arg0: memref<?xf64>, %arg1: memref<?xf64>, %arg2: f64) ->
+                               tuple<memref<?xi64>, memref<?xf64>, f64> {
+  // expected-error@+1{{Type mismatch between}}
+  %0 = sparse_tensor.storage(%arg0, %arg1, %arg2)
+       : memref<?xf64>, memref<?xf64>, f64 to tuple<memref<?xi64>, memref<?xf64>, f64>
+  return %0 : tuple<memref<?xi64>, memref<?xf64>, f64>
+}
+
+// -----
+
 func.func @sparse_storage_get(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>) -> memref<?xf64> {
   // expected-error@+1{{Out-of-bound access}}
   %0 = sparse_tensor.storage_get %arg0[3]
index 4b97277..c37b4e7 100644 (file)
@@ -317,6 +317,22 @@ func.func @concat_sparse_sparse(%arg0: tensor<2x4xf64, #SparseMatrix>,
 
 // -----
 
+
+// CHECK: func @sparse_storage_new(
+//  CHECK-SAME: %[[A0:.*0]]: memref<?xf64>,
+//  CHECK-SAME: %[[A1:.*1]]: memref<?xf64>,
+//  CHECK-SAME: %[[A2:.*]]: f64
+//       CHECK: %[[TMP_0:.*]] = sparse_tensor.storage(%[[A0]], %[[A1]], %[[A2]])
+//       CHECK: return %[[TMP_0]] : tuple<memref<?xf64>, memref<?xf64>, f64>
+func.func @sparse_storage_new(%arg0: memref<?xf64>, %arg1: memref<?xf64>, %arg2: f64) ->
+                               tuple<memref<?xf64>, memref<?xf64>, f64> {
+  %0 = sparse_tensor.storage(%arg0, %arg1, %arg2)
+       : memref<?xf64>, memref<?xf64>, f64 to tuple<memref<?xf64>, memref<?xf64>, f64>
+  return %0 : tuple<memref<?xf64>, memref<?xf64>, f64>
+}
+
+// -----
+
 // CHECK-LABEL: func @sparse_storage_get(
 //  CHECK-SAME:   %[[A0:.*]]: tuple<memref<?xf64>, memref<?xf64>, f64>
 //       CHECK:   %[[TMP0:.*]] = sparse_tensor.storage_get %[[A0]][0] :
@@ -329,7 +345,7 @@ func.func @sparse_storage_get(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>) -
   return %0 : memref<?xf64>
 }
 
-// ----
+// -----
 
 // CHECK-LABEL: func @sparse_storage_set(
 //  CHECK-SAME:   %[[A0:.*]]: tuple<memref<?xf64>, memref<?xf64>, f64>,