From 7ea643c06d8977045d0cf79507f36d828773378c Mon Sep 17 00:00:00 2001 From: Peiming Liu Date: Wed, 31 Aug 2022 20:44:41 +0000 Subject: [PATCH] [mlir][sparse] Introduce new sparse_tensor.storage_get/set to access memory that stores the handle of a sparse tensor Introduce new sparse_tensor.storage_get/set to access memory that stores the handle of a sparse tensor. The sparse tensor storage are represented as a tuple, these operation will later be eliminated and the tuple will be flattened after sparse tensor codegen Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D133049 --- .../Dialect/SparseTensor/IR/SparseTensorOps.td | 53 ++++++++++++++++++++++ .../SparseTensor/IR/SparseTensorDialect.cpp | 42 +++++++++++++++++ mlir/test/Dialect/SparseTensor/invalid.mlir | 39 ++++++++++++++++ mlir/test/Dialect/SparseTensor/roundtrip.mlir | 31 +++++++++++++ 4 files changed, 165 insertions(+) diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td index 39af4a8..25bc16f 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -623,4 +623,57 @@ def SparseTensor_YieldOp : SparseTensor_Op<"yield", [NoSideEffect, Terminator]>, let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// Sparse Tensor Storage Operation. These operations are used internally by +// sparse tensor codegen to progressively lower sparse tensors. +//===----------------------------------------------------------------------===// + +def SparseTensor_StorageGetOp : SparseTensor_Op<"storage_get", []>, + Arguments<(ins AnyTuple:$storage, + IndexAttr:$idx)>, + Results<(outs AnyType:$result)> { + let summary = "Get the data stored in the sparse tensor storage at the given index"; + let description = [{ + Get the data stored in the sparse tensor storage (represented as a tuple) + at the given index. + + The result type should match the corresponding element type in the tuple. + + Example: + + ```mlir + %0 = sparse_tensor.storage_get %arg0[0] : tuple, memref, f64> to memref + ``` + }]; + + let assemblyFormat = " $storage attr-dict `[`$idx`]` `:` type($storage) `to` type($result)"; + let hasVerifier = 1; +} + +def SparseTensor_StorageSetOp : SparseTensor_Op<"storage_set", []>, + Arguments<(ins AnyTuple:$storage, + AnyType:$value, + IndexAttr:$idx)>, + Results<(outs AnyTuple:$result)> { + let summary = "Set the data stored in the sparse tensor storage at given index"; + let description = [{ + Set the data stored in the sparse tensor storage (represented as a tuple) + at the given index. Return a new SSA value with the corresponding element + updated (others remain unchanged). + + The result type should match the original tuple type with only the updated + element type changed accordingly. + + Example: + + ```mlir + %0 = sparse_tensor.storage_set %arg0, %arg1 at 0 : tuple, memref, f64>, memref to tuple, memref, f64> + ``` + }]; + + let assemblyFormat = " $storage attr-dict `[`$idx`]``,` $value `:` type($storage) `,` type($value) `to` type($result)"; + let hasVerifier = 1; +} + + #endif // SPARSETENSOR_OPS diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index 8691b94..1c76f7e 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -483,6 +483,48 @@ LogicalResult YieldOp::verify() { } //===----------------------------------------------------------------------===// +// Sparse Tensor Storage Operation. +//===----------------------------------------------------------------------===// + +LogicalResult StorageGetOp::verify() { + uint64_t extractIdx = getIdx().getZExtValue(); + auto innerTypeArray = getStorage().getType().getTypes(); + if (extractIdx >= innerTypeArray.size()) + return emitError(llvm::formatv( + "Out-of-bound access with index={0} on tuple with length={1}", + extractIdx, innerTypeArray.size())); + + auto expectedTy = getStorage().getType().getType(extractIdx); + auto returnTy = getResult().getType(); + if (expectedTy != returnTy) + return emitError(llvm::formatv( + "Type mismatch between the returning type (type={0}) and the " + "corresponding element type at index {1} (type={2})", + expectedTy, extractIdx, returnTy)); + return success(); +} + +LogicalResult StorageSetOp::verify() { + uint64_t setIdx = getIdx().getZExtValue(); + SmallVector expectedElemTy(getStorage().getType().getTypes()); + if (setIdx >= expectedElemTy.size()) + return emitError(llvm::formatv( + "Out-of-bound access with index = {0} on tuple with length={1}", setIdx, + expectedElemTy.size())); + + // Updates the element type after storage_set. + expectedElemTy[setIdx] = getValue().getType(); + auto expectedTy = TupleType::get(getContext(), expectedElemTy); + auto returnTy = getResult().getType(); + if (expectedTy != returnTy) + return emitError( + llvm::formatv("Type mismatch between the returning type " + "(type={0}) and the expected type (type={1})", + returnTy, expectedTy)); + return success(); +} + +//===----------------------------------------------------------------------===// // TensorDialect Methods. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir index d9b48fe..805f959 100644 --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -443,3 +443,42 @@ func.func @invalid_concat_size_mismatch(%arg0: tensor<2x4xf64, #DC>, return %0 : tensor<9x4xf64, #DC> } +// ----- + +func.func @sparse_storage_get(%arg0: tuple, memref, f64>) -> memref { + // expected-error@+1{{Out-of-bound access}} + %0 = sparse_tensor.storage_get %arg0[3] + : tuple, memref, f64> to + memref + return %0 : memref +} + +// ----- + +func.func @sparse_storage_get(%arg0: tuple, memref, f64>) -> memref { + // expected-error@+1{{Type mismatch}} + %0 = sparse_tensor.storage_get %arg0[2] + : tuple, memref, f64> to + memref + return %0 : memref +} + +// ----- + +func.func @sparse_storage_set(%arg0: tuple, memref, f64>, %arg1: memref) -> tuple, memref, f64> { + // expected-error@+1{{Out-of-bound access}} + %0 = sparse_tensor.storage_set %arg0[3], %arg1 + : tuple, memref, f64>, memref to + tuple, memref, f64> + return %0 : tuple, memref, f64> +} + +// ----- + +func.func @sparse_storage_set(%arg0: tuple, memref, f64>, %arg1: memref) -> tuple, memref, f64> { + // expected-error@+1{{Type mismatch}} + %0 = sparse_tensor.storage_set %arg0[2], %arg1 + : tuple, memref, f64>, memref to + tuple, memref, f64> + return %0 : tuple, memref, f64> +} diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir index 5edc977..4b97277 100644 --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -314,3 +314,34 @@ func.func @concat_sparse_sparse(%arg0: tensor<2x4xf64, #SparseMatrix>, tensor<4x4xf64, #SparseMatrix> to tensor<9x4xf64, #SparseMatrix> return %0 : tensor<9x4xf64, #SparseMatrix> } + +// ----- + +// CHECK-LABEL: func @sparse_storage_get( +// CHECK-SAME: %[[A0:.*]]: tuple, memref, f64> +// CHECK: %[[TMP0:.*]] = sparse_tensor.storage_get %[[A0]][0] : +// CHECK-SAME: tuple, memref, f64> +// CHECK-SAME: to memref +// CHECK: return %[[TMP0]] : memref +func.func @sparse_storage_get(%arg0: tuple, memref, f64>) -> memref { + %0 = sparse_tensor.storage_get %arg0[0] + : tuple, memref, f64> to memref + return %0 : memref +} + +// ---- + +// CHECK-LABEL: func @sparse_storage_set( +// CHECK-SAME: %[[A0:.*]]: tuple, memref, f64>, +// CHECK-SAME: %[[A1:.*]]: memref +// CHECK: %[[TMP0:.*]] = sparse_tensor.storage_set %[[A0]][0], %[[A1]] : +// CHECK-SAME: tuple, memref, f64>, +// CHECK-SAME: memref +// CHECK-SAME: to tuple, memref, f64> +// CHECK: return %0 : tuple, memref, f64> +func.func @sparse_storage_set(%arg0: tuple, memref, f64>, %arg1: memref) -> tuple, memref, f64> { + %0 = sparse_tensor.storage_set %arg0[0], %arg1 + : tuple, memref, f64>, memref to + tuple, memref, f64> + return %0 : tuple, memref, f64> +} -- 2.7.4