From f4ae4762bf7d64d7ca46d05206955c9b44cedc49 Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Mon, 9 Sep 2019 15:42:07 -0700 Subject: [PATCH] Add quant.const_fake_quant_per_axis op Comparing to the existing quant.const_fake_quant op, the min and max attributes of this new op is for each channel of last dimension of the input. PiperOrigin-RevId: 268093722 --- mlir/include/mlir/Dialect/QuantOps/QuantOps.td | 32 ++++++++++++++++++++++++++ mlir/test/Dialect/QuantOps/parse-ops.mlir | 15 ++++++++++++ 2 files changed, 47 insertions(+) diff --git a/mlir/include/mlir/Dialect/QuantOps/QuantOps.td b/mlir/include/mlir/Dialect/QuantOps/QuantOps.td index 394d3a1..d95b452 100644 --- a/mlir/include/mlir/Dialect/QuantOps/QuantOps.td +++ b/mlir/include/mlir/Dialect/QuantOps/QuantOps.td @@ -132,6 +132,38 @@ def quant_ConstFakeQuant : quant_Op<"const_fake_quant", ); } +def quant_ConstFakeQuantPerAxis : quant_Op<"const_fake_quant_per_axis", + [SameOperandsAndResultType, NoSideEffect]> { + let summary = + "Simulates the effect of per axis uniform quantization with const range."; + + let description = [{ + Given a const min, max, num_bits and narrow_range attribute, applies the + same per axis uniform quantization simulation as is done by the TensorFlow + fake_quant_with_min_max_vars_per_channel op. See the fakeQuantAttrsToType() + utility method and the quant-convert-simulated-quantization pass for futher + details. + }]; + + let arguments = (ins + F32Tensor:$inputs, + F32ArrayAttr:$min, + F32ArrayAttr:$max, + // The quantized dimension of the inputs tensor. + I64Attr:$axis, + // The bitwidth of the quantization; between 2 and 16, inclusive. + I64Attr:$num_bits, + // Quantization range starts from 0 or 1; starts from 1 if true. + DefaultValuedAttr:$narrow_range, + // The sign of the quantization. + DefaultValuedAttr:$is_signed + ); + + let results = (outs + F32Tensor:$outputs + ); +} + def quant_StatisticsRefOp : quant_Op<"stats_ref", [SameOperandsAndResultType]> { let summary = "Indicates that statistics are resolved by reference."; diff --git a/mlir/test/Dialect/QuantOps/parse-ops.mlir b/mlir/test/Dialect/QuantOps/parse-ops.mlir index 77968f8..7d6d1ab 100644 --- a/mlir/test/Dialect/QuantOps/parse-ops.mlir +++ b/mlir/test/Dialect/QuantOps/parse-ops.mlir @@ -16,6 +16,21 @@ func @validConstFakeQuant(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { } // ----- +// CHECK-LABEL: validConstFakeQuantPerAxis +func @validConstFakeQuantPerAxis(%arg0: tensor<8x4x2xf32>) -> tensor<8x4x2xf32> { + %0 = "quant.const_fake_quant_per_axis"(%arg0) { + min = [0.0 : f32, 1.0 : f32], max = [2.0 : f32, 3.0 : f32], axis = 2, num_bits = 8, narrow_range = true + } : (tensor<8x4x2xf32>) -> tensor<8x4x2xf32> + %1 = "quant.const_fake_quant_per_axis"(%0) { + min = [0.0 : f32, 1.0 : f32], max = [2.0 : f32, 3.0 : f32], axis = 2, num_bits = 8, narrow_range = false + } : (tensor<8x4x2xf32>) -> tensor<8x4x2xf32> + %2 = "quant.const_fake_quant_per_axis"(%1) { + min = [0.0 : f32, 1.0 : f32], max = [2.0 : f32, 3.0 : f32], axis = 2, num_bits = 8 + } : (tensor<8x4x2xf32>) -> tensor<8x4x2xf32> + return %2 : tensor<8x4x2xf32> +} + +// ----- // CHECK-LABEL: validStatisticsRef func @validStatisticsRef(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { %0 = "quant.stats_ref"(%arg0) { statsKey = "foobar" } : -- 2.7.4