From 1f7adf8cb1d77ba35b8fa322c93d0d88a4cdc1f0 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Tue, 4 May 2021 14:21:51 -0700 Subject: [PATCH] [mlir][tosa] Fix tosa.concat by inserting linalg.fill after linalg.init All linalg.init operations must be fed into a linalg operation before subtensor. The inserted linalg.fill guarantees it executes correctly. Reviewed By: mravishankar Differential Revision: https://reviews.llvm.org/D101848 --- mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp | 7 ++++++- mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir | 8 ++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index ee4f29c..fb34ff5 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1591,9 +1591,14 @@ struct ConcatConverter : public OpConversionPattern { } sizes[axis] = resultDimSize; - Value result = rewriter.create( + Value init = rewriter.create( loc, resultType.getShape(), resultType.getElementType()); + Value zeroVal = rewriter.create( + loc, rewriter.getZeroAttr(resultType.getElementType())); + Value result = + rewriter.create(loc, init, zeroVal).getResult(0); + for (auto arg : args) { sizes[axis] = rewriter.create(loc, arg, axisValue); result = rewriter.create(loc, arg, result, offsets, diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index 9bd03f1..a9c91c2 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -637,8 +637,10 @@ func @concat(%arg0: tensor<5x1xf32>, %arg1: tensor<6x1xf32>) -> () { // CHECK: [[ARG1_AXIS:%.+]] = memref.dim %arg1, [[AXIS]] // CHECK: [[RESULT_AXIS:%.+]] = addi [[ARG0_DIM0]], [[ARG1_AXIS]] // CHECK: [[INIT:%.+]] = linalg.init_tensor [11, 1] + // CHECK: [[CST:%.+]] = constant 0.0 + // CHECK: [[FILL:%.+]] = linalg.fill([[INIT]], [[CST]]) // CHECK: [[ARG0_DIM0:%.+]] = memref.dim %arg0, [[AXIS]] - // CHECK: [[INSERT0:%.+]] = subtensor_insert %arg0 into [[INIT]]{{\[}}[[OFFSET]], [[OFFSET]]] {{\[}}[[ARG0_DIM0]], [[ARG0_DIM1]]] {{\[}}[[STRIDE]], [[STRIDE]]] + // CHECK: [[INSERT0:%.+]] = subtensor_insert %arg0 into [[FILL]]{{\[}}[[OFFSET]], [[OFFSET]]] {{\[}}[[ARG0_DIM0]], [[ARG0_DIM1]]] {{\[}}[[STRIDE]], [[STRIDE]]] // CHECK: [[NEW_OFFSET:%.+]] = addi [[OFFSET]], [[ARG0_DIM0]] // CHECK: [[ARG1_DIM0:%.+]] = memref.dim %arg1, [[AXIS]] // CHECK: [[INSERT1:%.+]] = subtensor_insert %arg1 into [[INSERT0]]{{\[}}[[NEW_OFFSET]], [[OFFSET]]] {{\[}}[[ARG1_DIM0]], [[ARG0_DIM1]]] {{\[}}[[STRIDE]], [[STRIDE]]] @@ -654,8 +656,10 @@ func @concat(%arg0: tensor<5x1xf32>, %arg1: tensor<6x1xf32>) -> () { // CHECK: [[ARG1_AXIS:%.+]] = memref.dim %arg0, [[AXIS]] // CHECK: [[RESULT_AXIS:%.+]] = addi [[ARG0_DIM1]], [[ARG1_AXIS]] // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 2] + // CHECK: [[CST:%.+]] = constant 0.0 + // CHECK: [[FILL:%.+]] = linalg.fill([[INIT]], [[CST]]) // CHECK: [[ARG0_DIM1:%.+]] = memref.dim %arg0, [[AXIS]] - // CHECK: [[INSERT0:%.+]] = subtensor_insert %arg0 into [[INIT]]{{\[}}[[OFFSET]], [[OFFSET]]] {{\[}}[[ARG0_DIM0]], [[ARG0_DIM1]]] {{\[}}[[STRIDE]], [[STRIDE]]] + // CHECK: [[INSERT0:%.+]] = subtensor_insert %arg0 into [[FILL]]{{\[}}[[OFFSET]], [[OFFSET]]] {{\[}}[[ARG0_DIM0]], [[ARG0_DIM1]]] {{\[}}[[STRIDE]], [[STRIDE]]] // CHECK: [[NEW_OFFSET:%.+]] = addi [[OFFSET]], [[ARG0_DIM1]] // CHECK: [[ARG1_DIM1:%.+]] = memref.dim %arg0, [[AXIS]] // CHECK: [[INSERT1:%.+]] = subtensor_insert %arg0 into [[INSERT0]]{{\[}}[[OFFSET]], [[NEW_OFFSET]]] {{\[}}[[ARG0_DIM0]], [[ARG1_DIM1]]] {{\[}}[[STRIDE]], [[STRIDE]]] -- 2.7.4