From 648dfdfc2481bf0205181991f6eb9be13a3d9174 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Mon, 19 Apr 2021 13:40:33 -0700 Subject: [PATCH] [mlir][tosa] Add tosa.avg_pool2d lowering Added the float lowerings for avg pool with corresponding tests. Differential Revision: https://reviews.llvm.org/D100793 --- mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp | 42 +++++++++++++++------- .../Conversion/TosaToLinalg/tosa-to-linalg.mlir | 15 ++++++++ 2 files changed, 45 insertions(+), 12 deletions(-) diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index de27fec..8ef186e 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1626,18 +1626,19 @@ public: } }; -class MaxPool2dConverter : public OpRewritePattern { +template +class Pool2dConverter : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, + LogicalResult matchAndRewrite(SrcOp op, PatternRewriter &rewriter) const final { Location loc = op.getLoc(); Value input = op.input(); ShapedType inputTy = input.getType().cast(); Type inElementTy = inputTy.getElementType(); - ShapedType resultTy = op.getType().cast(); + ShapedType resultTy = op.getType().template cast(); Type outElementTy = inputTy.getElementType(); int64_t rank = inputTy.getRank(); @@ -1646,17 +1647,20 @@ public: // Determine what the initial value needs to be for the max pool op. Attribute initialAttr; - if (outElementTy.isF32()) + if (isa(op) && outElementTy.isF32()) initialAttr = rewriter.getFloatAttr( outElementTy, APFloat::getLargest( outElementTy.cast().getFloatSemantics(), true)); - if (outElementTy.isa()) + if (isa(op) && outElementTy.isa()) initialAttr = rewriter.getIntegerAttr( outElementTy, APInt::getSignedMinValue(outElementTy.getIntOrFloatBitWidth())); + if (isa(op) && outElementTy.isa()) + initialAttr = rewriter.getZeroAttr(outElementTy); + if (!initialAttr) return rewriter.notifyMatchFailure( op, "Unsupported initial value for tosa.maxpool_2d op"); @@ -1670,6 +1674,7 @@ public: Attribute strideAttr = rewriter.getI64VectorAttr(stride); Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1}); + int64_t kernelSize = kernel[0] * kernel[1]; // If non-zero padding we need to pad the input if (llvm::any_of(pad, [](int64_t v) { return v != 0; })) { @@ -1716,34 +1721,46 @@ public: .getOperation()); }; - if (inElementTy.isF32()) { + if (isa(op) && inElementTy.isF32()) { linalg::LinalgOp poolingOp = createOp(static_cast(nullptr)); rewriter.replaceOp(op, poolingOp->getResult(0)); return success(); } - if (inElementTy.isInteger(8)) { + if (isa(op) && inElementTy.isInteger(8)) { linalg::LinalgOp poolingOp = createOp(static_cast(nullptr)); rewriter.replaceOp(op, poolingOp->getResult(0)); return success(); } - if (inElementTy.isInteger(16)) { + if (isa(op) && inElementTy.isInteger(16)) { linalg::LinalgOp poolingOp = createOp(static_cast(nullptr)); rewriter.replaceOp(op, poolingOp->getResult(0)); return success(); } - if (inElementTy.isInteger(32)) { + if (isa(op) && inElementTy.isInteger(32)) { linalg::LinalgOp poolingOp = createOp(static_cast(nullptr)); rewriter.replaceOp(op, poolingOp->getResult(0)); return success(); } + if (isa(op) && inElementTy.isF32()) { + linalg::LinalgOp poolingOp = + createOp(static_cast(nullptr)); + auto constAttr = DenseElementsAttr::get( + resultTy, static_cast(1.0 / kernelSize)); + auto constant = rewriter.create(loc, constAttr); + auto mul = rewriter.create( + loc, resultTy, poolingOp->getResult(0), constant, 0); + rewriter.replaceOp(op, mul.output()); + return success(); + } + return failure(); } }; @@ -1805,7 +1822,8 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns( TileConverter, TransposeConverter, MatMulConverter, - MaxPool2dConverter, + Pool2dConverter, + Pool2dConverter, FullyConnectedConverter>(patterns->getContext()); - // clang-format on + // clang-format on } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index 119ceea..09392e6 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -923,6 +923,21 @@ func @max_pool_i32(%arg0: tensor<1x6x34x62xi32>) -> () { %0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 0], kernel = [3, 3], stride = [1, 1]} : (tensor<1x6x34x62xi32>) -> (tensor<1x4x32x62xi32>) return } +// ----- + +// CHECK-LABEL: @avg_pool +func @avg_pool(%arg0: tensor<1x6x34x62xf32>) -> () { + // CHECK-DAG: [[CONST:%.+]] = constant 0 + // CHECK-DAG: [[INIT:%.+]] = linalg.init_tensor [1, 3, 31, 62] + // CHECK-DAG: [[FILL:%.+]] = linalg.fill([[INIT]], [[CONST]]) + // CHECK-DAG: [[KERNEL:%.+]] = linalg.init_tensor [4, 4] + // CHECK: linalg.pooling_nhwc_sum {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%arg0, [[KERNEL]] : tensor<1x6x34x62xf32>, tensor<4x4xf32>) outs([[FILL]] : tensor<1x3x31x62xf32>) + // CHECK: constant dense<6.250000e-02> + // CHECK: linalg.generic + // CHECK: mulf + %0 = "tosa.avg_pool2d"(%arg0) {pad = [0, 0, 0, 0], kernel = [4, 4], stride = [1, 1]} : (tensor<1x6x34x62xf32>) -> (tensor<1x3x31x62xf32>) + return +} // ----- -- 2.7.4