From f2832c2295c6076b51a35d0d7b304c08e1b41c29 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Thu, 8 Jul 2021 15:17:15 -0700 Subject: [PATCH] [mlir][tosa] Added shape propagation for TOSA pool operations. Pool operations perform the same shape propagation. Included the shape propagation and tests for these avg_pool2d and max_pool2d. Differential Revision: https://reviews.llvm.org/D105665 --- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 10 ++++- mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 56 +++++++++++++++++++++++++++ mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 48 +++++++++++++++++++++++ 3 files changed, 112 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 76cd66a..eafce2c 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -56,7 +56,10 @@ def Tosa_ArgMaxOp : Tosa_Op<"argmax", [ //===----------------------------------------------------------------------===// // Operator: avg_pool2d //===----------------------------------------------------------------------===// -def Tosa_AvgPool2dOp : Tosa_Op<"avg_pool2d", [NoSideEffect]> { +def Tosa_AvgPool2dOp : Tosa_Op<"avg_pool2d", [ + DeclareOpInterfaceMethods, + NoSideEffect]> { let summary = "Performs max pooling on the input."; let description = [{ @@ -233,7 +236,10 @@ def Tosa_MatMulOp : Tosa_Op<"matmul", [ //===----------------------------------------------------------------------===// // Operator: max_pool2d //===----------------------------------------------------------------------===// -def Tosa_MaxPool2dOp : Tosa_Op<"max_pool2d", [NoSideEffect]> { +def Tosa_MaxPool2dOp : Tosa_Op<"max_pool2d", [ + DeclareOpInterfaceMethods, + NoSideEffect]> { let summary = "Performs max pooling on the input."; let description = [{ diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 9126f17..75f26f6 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -845,6 +845,62 @@ NARY_SHAPE_INFER(tosa::TanhOp) NARY_SHAPE_INFER(tosa::SigmoidOp) #undef PRED_SHAPE_INFER +static LogicalResult poolingInferReturnTypes( + ValueRange operands, DictionaryAttr attributes, + SmallVectorImpl &inferredReturnShapes) { + RankedTensorType inputTy = operands[0].getType().dyn_cast(); + llvm::SmallVector outputShape; + outputShape.resize(4, -1); + + // We only know the rank if the input type is unranked. + if (!inputTy) { + inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); + return success(); + } + + // Batch and number of channels are identical for pooling layer. + outputShape[0] = inputTy.getDimSize(0); + outputShape[3] = inputTy.getDimSize(3); + + int32_t height = inputTy.getDimSize(1); + int32_t width = inputTy.getDimSize(2); + + llvm::SmallVector kernel; + llvm::SmallVector stride; + llvm::SmallVector pad; + + getI64Values(attributes.get("kernel").cast(), kernel); + getI64Values(attributes.get("stride").cast(), stride); + getI64Values(attributes.get("pad").cast(), pad); + + if (height != -1) { + int32_t padded = height + pad[0] + pad[1] - kernel[0]; + outputShape[1] = padded / stride[0] + 1; + } + + if (width != -1) { + int32_t padded = width + pad[2] + pad[3] - kernel[1]; + outputShape[2] = padded / stride[1] + 1; + } + + inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); + return success(); +} + +LogicalResult AvgPool2dOp::inferReturnTypeComponents( + MLIRContext *context, ::llvm::Optional location, + ValueRange operands, DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnShapes) { + return poolingInferReturnTypes(operands, attributes, inferredReturnShapes); +} + +LogicalResult MaxPool2dOp::inferReturnTypeComponents( + MLIRContext *context, ::llvm::Optional location, + ValueRange operands, DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnShapes) { + return poolingInferReturnTypes(operands, attributes, inferredReturnShapes); +} + //===----------------------------------------------------------------------===// // TOSA Operator Definitions. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir index bfbbe07..a5134ac 100644 --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -660,3 +660,51 @@ func @scatter_minimum_static(%arg0 : tensor, %arg1 : tensor<3x?xi32>, %0 = "tosa.scatter"(%arg0, %arg1, %arg2) : (tensor, tensor<3x?xi32>, tensor) -> (tensor) return } + +// ----- + +// CHECK-LABEL: @test_pool_static +func @test_pool_static(%arg0: tensor<3x5x6x7xf32>) { + // CHECK: -> tensor<3x2x4x7xf32> + %0 = "tosa.avg_pool2d"(%arg0) {kernel = [4, 3], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<3x5x6x7xf32>) -> tensor + + // CHECK: -> tensor<3x2x4x7xf32> + %1 = "tosa.max_pool2d"(%arg0) {kernel = [4, 3], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<3x5x6x7xf32>) -> tensor + return +} + +// ----- + +// CHECK-LABEL: @test_pool_dynamic_input +func @test_pool_dynamic_input(%arg0: tensor) { + // CHECK: -> tensor + %0 = "tosa.avg_pool2d"(%arg0) {kernel = [4, 3], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor) -> tensor + + // CHECK: -> tensor + %1 = "tosa.max_pool2d"(%arg0) {kernel = [4, 3], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor) -> tensor + return +} + +// ----- + +// CHECK-LABEL: @test_pool_padded +func @test_pool_padded(%arg0: tensor<3x5x6x7xf32>) { + // CHECK: -> tensor<3x5x11x7xf32> + %0 = "tosa.avg_pool2d"(%arg0) {kernel = [4, 3], pad = [1, 2, 3, 4], stride = [1, 1]} : (tensor<3x5x6x7xf32>) -> tensor + + // CHECK: -> tensor<3x5x11x7xf32> + %1 = "tosa.max_pool2d"(%arg0) {kernel = [4, 3], pad = [1, 2, 3, 4], stride = [1, 1]} : (tensor<3x5x6x7xf32>) -> tensor + return +} + +// ----- + +// CHECK-LABEL: @test_pool_stride +func @test_pool_stride(%arg0: tensor<3x11x12x7xf32>) { + // CHECK: -> tensor<3x4x4x7xf32> + %0 = "tosa.avg_pool2d"(%arg0) {kernel = [4, 3], pad = [0, 0, 0, 0], stride = [2, 3]} : (tensor<3x11x12x7xf32>) -> tensor + + // CHECK: -> tensor<3x4x4x7xf32> + %1 = "tosa.max_pool2d"(%arg0) {kernel = [4, 3], pad = [0, 0, 0, 0], stride = [2, 3]} : (tensor<3x11x12x7xf32>) -> tensor + return +} -- 2.7.4