From 95e4b71519e6621a132252b462b9bf9fce63ff61 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Tue, 12 Oct 2021 13:02:29 -0700 Subject: [PATCH] [mlir][tosa] Fix tosa average_pool2d to linalg type issue Average pool assumed the same input/output type. Result type for integers is always an i32, should be updated appropriately. Reviewed By: GMNGeoffrey Differential Revision: https://reviews.llvm.org/D111590 --- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 2 ++ mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp | 15 +++++++------ mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 20 +++++++++++++++++ .../Conversion/TosaToLinalg/tosa-to-linalg.mlir | 9 ++++---- mlir/test/Dialect/Tosa/ops.mlir | 25 ++++++++++++++++++++-- 5 files changed, 58 insertions(+), 13 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index cc8c8fa..b57e8b2 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -82,6 +82,8 @@ def Tosa_AvgPool2dOp : Tosa_Op<"avg_pool2d", [ ); let builders = [Tosa_AvgPool2dOpQuantInfoBuilder]; + + let verifier = [{ return verifyAveragePoolOp(*this); }]; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index f24a849..bd1769d 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -2796,7 +2796,7 @@ public: Type inElementTy = inputTy.getElementType(); ShapedType resultTy = op.getType().template cast(); - Type resultETy = inputTy.getElementType(); + Type resultETy = op.getType().cast().getElementType(); Type accETy = inElementTy.isa() ? rewriter.getI32Type() : inElementTy; @@ -2810,9 +2810,10 @@ public: pad.resize(2, 0); getValuesFromIntArrayAttribute(op.pad(), pad); pad.resize(pad.size() + 2, 0); - Attribute initialAttr = rewriter.getZeroAttr(accETy); - Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter); + Attribute padAttr = rewriter.getZeroAttr(inElementTy); + Value paddedInput = applyPad(loc, input, pad, padAttr, rewriter); + Attribute initialAttr = rewriter.getZeroAttr(accETy); Value initialValue = rewriter.create(loc, initialAttr); SmallVector kernel, stride; @@ -2909,8 +2910,7 @@ public: // to be applied. Value poolVal = args[0]; if (accETy.isa()) { - auto countF = - rewriter.create(loc, inElementTy, countI); + auto countF = rewriter.create(loc, accETy, countI); poolVal = rewriter.create(loc, poolVal, countF)->getResult(0); } else { @@ -2974,8 +2974,11 @@ public: auto clamp = clampHelper( loc, scaled, min, max, CmpIPredicate::slt, rewriter); + poolVal = clamp; // Convert type. - poolVal = rewriter.create(loc, resultETy, clamp); + if (resultETy != clamp.getType()) { + poolVal = rewriter.create(loc, resultETy, poolVal); + } } // Cast to output type. diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 3a02543..9c8d4ac 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -342,6 +342,26 @@ static LogicalResult verifyConvOp(T op) { return success(); } +static LogicalResult verifyAveragePoolOp(tosa::AvgPool2dOp op) { + auto inputETy = op.input().getType().cast().getElementType(); + auto resultETy = op.getType().cast().getElementType(); + + if (auto quantType = inputETy.dyn_cast()) + inputETy = quantType.getStorageType(); + + if (auto quantType = resultETy.dyn_cast()) + resultETy = quantType.getStorageType(); + + if (inputETy.isF32() && resultETy.isF32()) + return success(); + if (inputETy.isInteger(8) && resultETy.isInteger(32)) + return success(); + if (inputETy.isInteger(16) && resultETy.isInteger(32)) + return success(); + + return op.emitOpError("input/output element types are incompatible."); +} + //===----------------------------------------------------------------------===// // TOSA Operator Quantization Builders. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index 1c81a2a..df66772 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -1465,15 +1465,14 @@ func @avg_pool_i8(%arg0 : tensor<1x128x128x2xi8>) -> () { // CHECK: %[[SCALE:.+]] = "tosa.apply_scale"(%{{.+}}, %[[MULTIPLIER]], %[[SHIFT]]) {double_round = false} // CHECK: %[[OUTZP:.+]] = constant -128 // CHECK: %[[OUT:.+]] = addi %[[SCALE]], %[[OUTZP]] - // CHECK: %[[MIN:.+]] = constant -128 - // CHECK: %[[MAX:.+]] = constant 127 + // CHECK: %[[MIN:.+]] = constant -2147483648 + // CHECK: %[[MAX:.+]] = constant 2147483647 // CHECK: %[[CMP_MIN:.+]] = cmpi slt, %[[OUT]], %[[MIN]] // CHECK: %[[CLMP_MIN:.+]] = select %[[CMP_MIN]], %[[MIN]], %[[OUT]] // CHECK: %[[CMP_MAX:.+]] = cmpi slt, %[[MAX]], %[[OUT]] // CHECK: %[[CLMP_MAX:.+]] = select %[[CMP_MAX]], %[[MAX]], %[[CLMP_MIN]] - // CHECK: %[[TRUNC:.+]] = trunci %[[CLMP_MAX]] - // CHECK: linalg.yield %[[TRUNC]] - %0 = "tosa.avg_pool2d"(%arg0) {kernel = [4, 4], pad = [0, 0, 0, 0], quantization_info = {input_zp = -128 : i32, output_zp = -128 : i32}, stride = [4, 4]} : (tensor<1x128x128x2xi8>) -> tensor<1x32x32x2xi8> + // CHECK: linalg.yield %[[CLMP_MAX]] + %0 = "tosa.avg_pool2d"(%arg0) {kernel = [4, 4], pad = [0, 0, 0, 0], quantization_info = {input_zp = -128 : i32, output_zp = -128 : i32}, stride = [4, 4]} : (tensor<1x128x128x2xi8>) -> tensor<1x32x32x2xi32> return } diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index ec169d0..df79ebb 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -10,13 +10,34 @@ func @test_argmax(%arg0: tensor<14x19xf32>) -> tensor<14xi32> { } // ----- -// CHECK-LABEL: avg_pool2d -func @test_avg_pool2d(%arg0: tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> { +// CHECK-LABEL: avg_pool2d_f32 +func @test_avg_pool2d_f32(%arg0: tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> { %0 = "tosa.avg_pool2d"(%arg0) {kernel = [2, 2], pad = [0, 1, 0, 1], stride = [1, 1]} : (tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> return %0 : tensor<1x7x7x9xf32> } // ----- +// CHECK-LABEL: avg_pool2d_i8 +func @test_avg_pool2d_i8(%arg0: tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi32> { + %0 = "tosa.avg_pool2d"(%arg0) {kernel = [2, 2], pad = [0, 1, 0, 1], stride = [1, 1]} : (tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi32> + return %0 : tensor<1x7x7x9xi32> +} + +// ----- +// CHECK-LABEL: avg_pool2d_i16 +func @test_avg_pool2d_i16(%arg0: tensor<1x7x7x9xi16>) -> tensor<1x7x7x9xi32> { + %0 = "tosa.avg_pool2d"(%arg0) {kernel = [2, 2], pad = [0, 1, 0, 1], stride = [1, 1]} : (tensor<1x7x7x9xi16>) -> tensor<1x7x7x9xi32> + return %0 : tensor<1x7x7x9xi32> +} + +// ----- +// CHECK-LABEL: avg_pool2d_q8 +func @test_avg_pool2d_q8(%arg0: tensor<1x7x7x9x!quant.uniform>) -> tensor<1x7x7x9x!quant.uniform> { + %0 = "tosa.avg_pool2d"(%arg0) {kernel = [2, 2], pad = [0, 1, 0, 1], stride = [1, 1]} : (tensor<1x7x7x9x!quant.uniform>) -> tensor<1x7x7x9x!quant.uniform> + return %0 : tensor<1x7x7x9x!quant.uniform> +} + +// ----- // CHECK-LABEL: conv2d func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> { %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32> -- 2.7.4