From 6c01b5cdaddce8df325020659d73cd7728778392 Mon Sep 17 00:00:00 2001 From: bixia1 Date: Sun, 20 Nov 2022 21:57:30 -0800 Subject: [PATCH] [mlir][sparse] Fix a bug in concatenate operator rewriting. When calculating the dynamic dimensions for the concatenate result, we shouldn't accumulate the sizes for the non-concatenating dimensions. Reviewed By: aartbik, Peiming Differential Revision: https://reviews.llvm.org/D138436 --- .../Transforms/SparseTensorRewriting.cpp | 9 ++- .../SparseTensor/sparse_concat_codegen.mlir | 89 +++++++++++++++++++++- 2 files changed, 94 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp index 1472b67..7b9eba9 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -469,9 +469,12 @@ struct ConcatenateRewriter : public OpRewritePattern { Value v = createOrFoldDimOp(rewriter, loc, op.getOperand(0), d.index()); rewriter.create(loc, op.getOperand(0), d.index()); - for (const auto &opnd : op.getOperands().drop_front()) { - Value t = createOrFoldDimOp(rewriter, loc, opnd, d.index()); - v = rewriter.create(loc, v, t); + if (conDim == d.index()) { + // Adding the size of the concatenating dimension. + for (const auto &opnd : op.getOperands().drop_front()) { + Value t = createOrFoldDimOp(rewriter, loc, opnd, d.index()); + v = rewriter.create(loc, v, t); + } } dynSizes.push_back(v); } diff --git a/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir b/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir index c2a1fd8..9d26f3b 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_concat_codegen.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-convert=false" \ -// RUN: --sparsification | FileCheck %s +// RUN: | FileCheck %s #DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}> @@ -87,3 +87,90 @@ func.func @concat_sparse_sparse(%arg0: tensor<2x4xf64, #DCSR>, tensor<4x4xf64, #DCSR> to tensor<9x4xf64, #DCSR> return %0 : tensor<9x4xf64, #DCSR> } + +// CHECK-LABEL: @concat_sparse_sparse_dynamic( +// CHECK-SAME: %[[TMP_arg0:.*]]: tensor<2x4xf64, #sparse_tensor +// CHECK-SAME: %[[TMP_arg1:.*]]: tensor<3x4xf64, #sparse_tensor +// CHECK-SAME: %[[TMP_arg2:.*]]: tensor<4x4xf64, #sparse_tensor +// CHECK-DAG: %[[TMP_c0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[TMP_c1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[TMP_c5:.*]] = arith.constant 5 : index +// CHECK-DAG: %[[TMP_c2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[TMP_c9:.*]] = arith.constant 9 : index +// CHECK-DAG: %[[TMP_c4:.*]] = arith.constant 4 : index +// CHECK: %[[TMP_0:.*]] = bufferization.alloc_tensor(%[[TMP_c9]], %[[TMP_c4]]) : tensor +// CHECK: %[[TMP_7:.*]] = memref.load %[[TMP_1]][%[[TMP_c1]]] : memref +// CHECK: %[[RET_1:.*]] = scf.for %[[TMP_arg3:.*]] = %[[TMP_6]] to %[[TMP_7]] step %[[TMP_c1]] iter_args(%[[A0:.*]] = %[[TMP_0]]) +// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_2]][%[[TMP_arg3]]] : memref +// CHECK-DAG: %[[TMP_25:.*]] = memref.load %[[TMP_3]][%[[TMP_arg3]]] : memref +// CHECK-DAG: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index +// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_3]][%[[TMP_24]]] : memref +// CHECK: %[[RET_4:.*]] = scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]] iter_args(%[[A1:.*]] = %[[A0]]) +// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_4]][%[[TMP_arg4]]] : memref +// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_5]][%[[TMP_arg4]]] : memref +// CHECK: %[[NEW_1:.*]] = sparse_tensor.insert %[[TMP_28]] into %[[A1]][%[[TMP_23]], %[[TMP_27]]] : tensor +// CHECK: %[[TMP_14:.*]] = memref.load %[[TMP_8]][%[[TMP_c1]]] : memref +// CHECK: %[[RET_2:.*]] = scf.for %[[TMP_arg3:.*]] = %[[TMP_13]] to %[[TMP_14]] step %[[TMP_c1]] iter_args(%[[A2:.*]] = %[[RET_1]]) +// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_9]][%[[TMP_arg3]]] : memref +// CHECK-DAG: %[[TMP_25:.*]] = memref.load %[[TMP_10]][%[[TMP_arg3]]] : memref +// CHECK-DAG: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index +// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_10]][%[[TMP_24]]] : memref +// CHECK: %[[RET_5:.*]] = scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]] iter_args(%[[A3:.*]] = %[[A2]]) +// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_11]][%[[TMP_arg4]]] : memref +// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_12]][%[[TMP_arg4]]] : memref +// CHECK: %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c2]] : index +// CHECK: %[[NEW_2:.*]] = sparse_tensor.insert %[[TMP_28]] into %[[A3]][%[[TMP_29]], %[[TMP_27]]] : tensor +// CHECK: %[[TMP_21:.*]] = memref.load %[[TMP_15]][%[[TMP_c1]]] : memref +// CHECK: %[[RET_3:.*]] = scf.for %[[TMP_arg3:.*]] = %[[TMP_20]] to %[[TMP_21]] step %[[TMP_c1]] iter_args(%[[A4:.*]] = %[[RET_2]]) +// CHECK: %[[TMP_23:.*]] = memref.load %[[TMP_16]][%[[TMP_arg3]]] : memref +// CHECK: %[[TMP_25:.*]] = memref.load %[[TMP_17]][%[[TMP_arg3]]] : memref +// CHECK: %[[TMP_24:.*]] = arith.addi %[[TMP_arg3]], %[[TMP_c1]] : index +// CHECK: %[[TMP_26:.*]] = memref.load %[[TMP_17]][%[[TMP_24]]] : memref +// CHECK: %[[RET_6:.*]] = scf.for %[[TMP_arg4:.*]] = %[[TMP_25]] to %[[TMP_26]] step %[[TMP_c1]] iter_args(%[[A5:.*]] = %[[A4]]) +// CHECK: %[[TMP_27:.*]] = memref.load %[[TMP_18]][%[[TMP_arg4]]] : memref +// CHECK: %[[TMP_28:.*]] = memref.load %[[TMP_19]][%[[TMP_arg4]]] : memref +// CHECK: %[[TMP_29:.*]] = arith.addi %[[TMP_23]], %[[TMP_c5]] : index +// CHECK: %[[NEW_3:.*]] = sparse_tensor.insert %[[TMP_28]] into %[[A5]][%[[TMP_29]], %[[TMP_27]]] : tensor, + %arg1: tensor<3x4xf64, #DCSR>, + %arg2: tensor<4x4xf64, #DCSR>) + -> tensor { + %0 = sparse_tensor.concatenate %arg0, %arg1, %arg2 {dimension = 0 : index} + : tensor<2x4xf64, #DCSR>, + tensor<3x4xf64, #DCSR>, + tensor<4x4xf64, #DCSR> to tensor + return %0 : tensor +} \ No newline at end of file -- 2.7.4