From 90478251c736ce335fe8d45e46a09d9bec889583 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Thu, 26 Aug 2021 11:20:58 -0700 Subject: [PATCH] [mlir][tosa] Tosa reverse to linalg supporting dynamic shapes Needed to switch to extract to support tosa.reverse using dynamic shapes. Reviewed By: NatashaKnk Differential Revision: https://reviews.llvm.org/D108744 --- mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp | 50 +++++++++++-------- .../Conversion/TosaToLinalg/tosa-to-linalg.mlir | 56 +++++++++++++++++----- 2 files changed, 74 insertions(+), 32 deletions(-) diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index beab93b..74239fe 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -2043,40 +2043,48 @@ public: Value input = op.input(); auto inputTy = input.getType().template cast(); auto resultTy = op.getType().template cast(); - auto rank = resultTy.getRank(); auto axis = op.axis(); - if (!inputTy.hasStaticShape()) - return rewriter.notifyMatchFailure( - op, "No initial value found for reduction operation"); + SmallVector dynDims; + for (int i = 0; i < inputTy.getRank(); i++) { + if (inputTy.isDynamicDim(i)) { + dynDims.push_back(rewriter.create(loc, input, i)); + } + } + + Value axisDimSize = rewriter.create(loc, input, axis); // First fill the output buffer with the init value. auto initTensor = rewriter .create( - loc, ArrayRef({}), inputTy.getShape(), - inputTy.getElementType()) + loc, ArrayRef({dynDims}), + inputTy.getShape(), inputTy.getElementType()) .result(); - - SmallVector inputExprs; - inputExprs.resize(resultTy.getRank()); - - for (int i = 0; i < rank; i++) - inputExprs[i] = rewriter.getAffineDimExpr(i); - - inputExprs[axis] = - rewriter.getAffineConstantExpr(inputTy.getDimSize(axis) - 1) - - inputExprs[axis]; - SmallVector affineMaps = { - AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs, - rewriter.getContext()), rewriter.getMultiDimIdentityMap(resultTy.getRank())}; rewriter.replaceOpWithNewOp( - op, resultTy, op.input(), ValueRange{initTensor}, affineMaps, + op, resultTy, ArrayRef({}), ValueRange{initTensor}, affineMaps, getNParallelLoopsAttrs(resultTy.getRank()), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - nestedBuilder.create(op.getLoc(), *args.begin()); + llvm::SmallVector indices; + for (unsigned int i = 0; i < inputTy.getRank(); i++) { + auto index = + rewriter.create(nestedLoc, i).getResult(); + if (i == axis) { + auto one = rewriter.create(nestedLoc, 1); + auto sizeMinusOne = + rewriter.create(nestedLoc, axisDimSize, one); + index = rewriter.create(nestedLoc, sizeMinusOne, index); + } + + indices.push_back(index); + } + + auto extract = nestedBuilder.create( + nestedLoc, input, indices); + nestedBuilder.create(op.getLoc(), + extract.getResult()); }); return success(); } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index 8a3cf62..50e4c78 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -881,28 +881,62 @@ func @rescaleUnnecessaryDoubleRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8>) { // ----- -// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (-d0 + 4, d1)> -// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 3)> +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: @reverse func @reverse(%arg0: tensor<5x4xi32>) -> () { - // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 4] - // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x4xi32>) outs([[INIT]] : tensor<5x4xi32>) { - // CHECK: ^bb0(%arg1: i32, %arg2: i32): - // CHECK: linalg.yield %arg1 : i32 + // CHECK: %[[C0:.+]] = constant 0 + // CHECK: %[[RDIM:.+]] = tensor.dim %arg0, %[[C0]] + // CHECK: %[[INIT:.+]] = linalg.init_tensor [5, 4] + // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]]], iterator_types = ["parallel", "parallel"]} outs(%[[INIT]] : tensor<5x4xi32>) + // CHECK-DAG: %[[I0:.+]] = linalg.index 0 + // CHECK-DAG: %[[I1:.+]] = linalg.index 1 + // CHECK-DAG: %[[SUB1:.+]] = constant 1 + // CHECK-DAG: %[[RDIM_MINUS_C1:.+]] = subi %[[RDIM]], %[[SUB1]] + // CHECK-DAG: %[[READ_DIM:.+]] = subi %[[RDIM_MINUS_C1]], %[[I0]] + // CHECK-DAG: %[[EXTRACT:.+]] = tensor.extract %arg0[%[[READ_DIM]], %[[I1]]] : tensor<5x4xi32> + // CHECK: linalg.yield %[[EXTRACT]] %0 = "tosa.reverse"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<5x4xi32> - // CHECK: [[INIT:%.+]] = linalg.init_tensor [5, 4] - // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x4xi32>) outs([[INIT]] : tensor<5x4xi32>) { - // CHECK: ^bb0(%arg1: i32, %arg2: i32): - // CHECK: linalg.yield %arg1 : i32 + // CHECK: %[[C1:.+]] = constant 1 + // CHECK: %[[RDIM:.+]] = tensor.dim %arg0, %[[C1]] + // CHECK: %[[INIT:.+]] = linalg.init_tensor [5, 4] + // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]]], iterator_types = ["parallel", "parallel"]} outs(%[[INIT]] : tensor<5x4xi32>) + // CHECK-DAG: %[[I0:.+]] = linalg.index 0 + // CHECK-DAG: %[[I1:.+]] = linalg.index 1 + // CHECK-DAG: %[[SUB1:.+]] = constant 1 + // CHECK-DAG: %[[RDIM_MINUS_C1:.+]] = subi %[[RDIM]], %[[SUB1]] + // CHECK-DAG: %[[READ_DIM:.+]] = subi %[[RDIM_MINUS_C1]], %[[I1]] + // CHECK-DAG: %[[EXTRACT:.+]] = tensor.extract %arg0[%[[I0]], %[[READ_DIM]]] : tensor<5x4xi32> + // CHECK: linalg.yield %[[EXTRACT]] %1 = "tosa.reverse"(%arg0) {axis = 1 : i64} : (tensor<5x4xi32>) -> tensor<5x4xi32> return } // ----- +// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)> + +// CHECK-LABEL: @reverse_dyn +func @reverse_dyn(%arg0: tensor) -> () { + // CHECK: %[[C0_1:.+]] = constant 0 + // CHECK: %[[D0_1:.+]] = tensor.dim %arg0, %[[C0_1]] + // CHECK: %[[C0_2:.+]] = constant 0 + // CHECK: %[[D0_2:.+]] = tensor.dim %arg0, %[[C0_2]] + // CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[D0_1]]] + // CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]]], iterator_types = ["parallel"]} outs(%[[INIT]] : tensor) + // CHECK-DAG: %[[I0:.+]] = linalg.index 0 + // CHECK-DAG: %[[SUB1:.+]] = constant 1 + // CHECK-DAG: %[[RDIM_MINUS_C1:.+]] = subi %[[D0_2]], %[[SUB1]] + // CHECK-DAG: %[[READ_DIM:.+]] = subi %[[RDIM_MINUS_C1]], %[[I0]] + // CHECK-DAG: %[[EXTRACT:.+]] = tensor.extract %arg0[%[[READ_DIM]]] : tensor + // CHECK: linalg.yield %[[EXTRACT]] + %0 = "tosa.reverse"(%arg0) {axis = 0 : i64} : (tensor) -> tensor + return +} + +// ----- + // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)> // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -- 2.7.4