From 701266c47abab7180d36ae174f19d76a113a77a4 Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Thu, 18 Jul 2019 11:25:53 -0700 Subject: [PATCH] Add an "is_signed" attribute to the quant_ConstFakeQuant op Some TensorFlow simulated quantize ops such as QuantizeAndDequantizeV2Op have attribute for the sign of the quantization, so quant_ConstFakeQuant should be able to represent it with the new attribute is added. The method for converting these attributes to an QuantizedType is updated to handle this new argument. PiperOrigin-RevId: 258810290 --- mlir/include/mlir/Dialect/QuantOps/QuantOps.td | 4 +- .../QuantOps/Transforms/ConvertSimQuant.cpp | 2 +- .../Dialect/QuantOps/Utils/FakeQuantSupport.cpp | 12 ++++-- mlir/test/Dialect/QuantOps/convert-fakequant.mlir | 47 +++++++++++++++++++++- 4 files changed, 59 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir/Dialect/QuantOps/QuantOps.td b/mlir/include/mlir/Dialect/QuantOps/QuantOps.td index c76ad2d..394d3a1 100644 --- a/mlir/include/mlir/Dialect/QuantOps/QuantOps.td +++ b/mlir/include/mlir/Dialect/QuantOps/QuantOps.td @@ -122,7 +122,9 @@ def quant_ConstFakeQuant : quant_Op<"const_fake_quant", // The bitwidth of the quantization; between 2 and 16, inclusive. I64Attr:$num_bits, // Quantization range starts from 0 or 1; starts from 1 if true. - DefaultValuedAttr:$narrow_range + DefaultValuedAttr:$narrow_range, + // The sign of the quantization. + DefaultValuedAttr:$is_signed ); let results = (outs diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp index 0c93146..32d8c8a 100644 --- a/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp +++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp @@ -70,7 +70,7 @@ public: UniformQuantizedType uniformElementType = fakeQuantAttrsToType( fqOp.getLoc(), fqOp.num_bits().getSExtValue(), fqOp.min().convertToFloat(), fqOp.max().convertToFloat(), - fqOp.narrow_range(), converter.expressedType); + fqOp.narrow_range(), converter.expressedType, fqOp.is_signed()); if (!uniformElementType) { // Note that the fakeQuantAttrsToType will have emitted the error. diff --git a/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp b/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp index 13c622e..2667da9 100644 --- a/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp +++ b/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp @@ -45,9 +45,15 @@ mlir::quant::fakeQuantAttrsToType(Location loc, unsigned numBits, double rmin, } } else if (numBits <= 16) { storageType = IntegerType::get(16, ctx); - flags = QuantizationFlags::Signed; - qmin = -32768; - qmax = 32767; + if (isSigned) { + flags = QuantizationFlags::Signed; + qmin = -32768; + qmax = 32767; + } else { + flags = 0; + qmin = 0; + qmax = 65535; + } } else { emitError(loc, "unsupported FakeQuant number of bits: ") << numBits; return nullptr; diff --git a/mlir/test/Dialect/QuantOps/convert-fakequant.mlir b/mlir/test/Dialect/QuantOps/convert-fakequant.mlir index bd28c2c..61561c5 100644 --- a/mlir/test/Dialect/QuantOps/convert-fakequant.mlir +++ b/mlir/test/Dialect/QuantOps/convert-fakequant.mlir @@ -46,6 +46,51 @@ func @fakeQuantArgs_Quint8_SymmetricRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32 } // ----- +// Verifies a qint8 asymmetric 0..1 range. +// CHECK-LABEL: fakeQuantArgs_Qint8_0_1 +func @fakeQuantArgs_Qint8_0_1(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { +^bb0(%arg0: tensor<8x4x3xf32>): + // CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) + // CHECK-SAME: -> tensor<8x4x3x!quant.uniform> + // CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<8x4x3x!quant.uniform>) + // CHECK-SAME: -> tensor<8x4x3xf32> + %0 = "quant.const_fake_quant"(%arg0) { + min = 0.0 : f32, max = 1.0 : f32, num_bits = 8, is_signed = true + } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> + return %0 : tensor<8x4x3xf32> +} + +// ----- +// Verifies a qint8 asymmetric 0..1 range (with narrow_range = true). +// CHECK_LABEL: fakeQuantArgs_Qint8_NarrowRange +func @fakeQuantArgs_Qint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { +^bb0(%arg0: tensor<8x4x3xf32>): + // CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) + // CHECK-SAME: -> tensor<8x4x3x!quant.uniform:f32, 0.003937007874015748:-127>> + // CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<8x4x3x!quant.uniform:f32, 0.003937007874015748:-127>>) + // CHECK-SAME: -> tensor<8x4x3xf32> + %0 = "quant.const_fake_quant"(%arg0) { + min = 0.0 : f32, max = 1.0 : f32, num_bits = 8, narrow_range = true, is_signed = true + } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> + return %0 : tensor<8x4x3xf32> +} + +// ----- +// Verifies a qint8 symmetric range of -1..127/128. +// CHECK_LABEL: fakeQuantArgs_Qint8_SymmetricRange +func @fakeQuantArgs_Qint8_SymmetricRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { +^bb0(%arg0: tensor<8x4x3xf32>): + // CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) + // CHECK-SAME: -> tensor<8x4x3x!quant.uniform> + // CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<8x4x3x!quant.uniform>) + // CHECK-SAME: -> tensor<8x4x3xf32> + %0 = "quant.const_fake_quant"(%arg0) { + min = -1.0 : f32, max = 0.9921875 : f32, num_bits = 8, narrow_range = false, is_signed = true + } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> + return %0 : tensor<8x4x3xf32> +} + +// ----- // Verifies a commonly used -1..1 symmetric 16bit range with a zero point of // 0 and range -1.0 .. 32767/32768. // CHECK-LABEL: fakeQuantArgs_Qint16_Symmetric @@ -56,7 +101,7 @@ func @fakeQuantArgs_Qint16_Symmetric(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { // CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<8x4x3x!quant.uniform>) // CHECK-SAME: -> tensor<8x4x3xf32> %0 = "quant.const_fake_quant"(%arg0) { - min = -1.0 : f32, max = 0.999969482 : f32, num_bits = 16 + min = -1.0 : f32, max = 0.999969482 : f32, num_bits = 16, is_signed = true } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> return %0 : tensor<8x4x3xf32> } -- 2.7.4