// sparse tensor codegen to progressively lower sparse tensors.
//===----------------------------------------------------------------------===//
+def SparseTensor_StorageNewOp : SparseTensor_Op<"storage", []>,
+ Arguments<(ins Variadic<AnyType>:$inputs)>,
+ Results<(outs AnyTuple:$result)> {
+ let summary = "Pack a list of value into one sparse tensor storage value";
+ let description = [{
+ Pack a list of value into one sparse tensor storage value (represented as
+ a tuple) at the given index.
+
+ The result tuple elements' type should match the corresponding type in the
+ input array.
+
+ Example:
+
+ ```mlir
+ %0 = sparse_tensor.storage(%1, %2): memref<?xf64>, memref<?xf64>
+ to tuple<memref<?xf64>, memref<?xf64>>
+ ```
+ }];
+
+ let assemblyFormat = " attr-dict `(` $inputs `)``:` type($inputs) `to` type($result)";
+ let hasVerifier = 1;
+}
+
def SparseTensor_StorageGetOp : SparseTensor_Op<"storage_get", []>,
Arguments<(ins AnyTuple:$storage,
IndexAttr:$idx)>,
// Sparse Tensor Storage Operation.
//===----------------------------------------------------------------------===//
+LogicalResult StorageNewOp::verify() {
+ auto retTypes = getResult().getType().getTypes();
+ if (retTypes.size() != getInputs().size())
+ return emitError("The number of inputs is inconsistent with output tuple");
+
+ for (auto pair : llvm::zip(getInputs(), retTypes)) {
+ auto input = std::get<0>(pair);
+ auto retTy = std::get<1>(pair);
+
+ if (input.getType() != retTy)
+ return emitError(llvm::formatv("Type mismatch between input (type={0}) "
+ "and output tuple element (type={1})",
+ input.getType(), retTy));
+ }
+ return success();
+}
+
LogicalResult StorageGetOp::verify() {
uint64_t extractIdx = getIdx().getZExtValue();
auto innerTypeArray = getStorage().getType().getTypes();
// -----
+func.func @sparse_storage_new(%arg0: memref<?xf64>, %arg1: memref<?xf64>, %arg2: f64) ->
+ tuple<memref<?xf64>, memref<?xf64>> {
+ // expected-error@+1{{The number of inputs is inconsistent with output}}
+ %0 = sparse_tensor.storage(%arg0, %arg1, %arg2)
+ : memref<?xf64>, memref<?xf64>, f64 to tuple<memref<?xf64>, memref<?xf64>>
+ return %0 : tuple<memref<?xf64>, memref<?xf64>>
+}
+
+// -----
+
+func.func @sparse_storage_new(%arg0: memref<?xf64>, %arg1: memref<?xf64>, %arg2: f64) ->
+ tuple<memref<?xi64>, memref<?xf64>, f64> {
+ // expected-error@+1{{Type mismatch between}}
+ %0 = sparse_tensor.storage(%arg0, %arg1, %arg2)
+ : memref<?xf64>, memref<?xf64>, f64 to tuple<memref<?xi64>, memref<?xf64>, f64>
+ return %0 : tuple<memref<?xi64>, memref<?xf64>, f64>
+}
+
+// -----
+
func.func @sparse_storage_get(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>) -> memref<?xf64> {
// expected-error@+1{{Out-of-bound access}}
%0 = sparse_tensor.storage_get %arg0[3]
// -----
+
+// CHECK: func @sparse_storage_new(
+// CHECK-SAME: %[[A0:.*0]]: memref<?xf64>,
+// CHECK-SAME: %[[A1:.*1]]: memref<?xf64>,
+// CHECK-SAME: %[[A2:.*]]: f64
+// CHECK: %[[TMP_0:.*]] = sparse_tensor.storage(%[[A0]], %[[A1]], %[[A2]])
+// CHECK: return %[[TMP_0]] : tuple<memref<?xf64>, memref<?xf64>, f64>
+func.func @sparse_storage_new(%arg0: memref<?xf64>, %arg1: memref<?xf64>, %arg2: f64) ->
+ tuple<memref<?xf64>, memref<?xf64>, f64> {
+ %0 = sparse_tensor.storage(%arg0, %arg1, %arg2)
+ : memref<?xf64>, memref<?xf64>, f64 to tuple<memref<?xf64>, memref<?xf64>, f64>
+ return %0 : tuple<memref<?xf64>, memref<?xf64>, f64>
+}
+
+// -----
+
// CHECK-LABEL: func @sparse_storage_get(
// CHECK-SAME: %[[A0:.*]]: tuple<memref<?xf64>, memref<?xf64>, f64>
// CHECK: %[[TMP0:.*]] = sparse_tensor.storage_get %[[A0]][0] :
return %0 : memref<?xf64>
}
-// ----
+// -----
// CHECK-LABEL: func @sparse_storage_set(
// CHECK-SAME: %[[A0:.*]]: tuple<memref<?xf64>, memref<?xf64>, f64>,