Add an "is_signed" attribute to the quant_ConstFakeQuant op
authorFeng Liu <fengliuai@google.com>
Thu, 18 Jul 2019 18:25:53 +0000 (11:25 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Fri, 19 Jul 2019 18:39:54 +0000 (11:39 -0700)
Some TensorFlow simulated quantize ops such as QuantizeAndDequantizeV2Op have
attribute for the sign of the quantization, so quant_ConstFakeQuant should be
able to represent it with the new attribute is added.

The method for converting these attributes to an QuantizedType is updated to
handle this new argument.

PiperOrigin-RevId: 258810290

mlir/include/mlir/Dialect/QuantOps/QuantOps.td
mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp
mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp
mlir/test/Dialect/QuantOps/convert-fakequant.mlir

index c76ad2d..394d3a1 100644 (file)
@@ -122,7 +122,9 @@ def quant_ConstFakeQuant : quant_Op<"const_fake_quant",
     // The bitwidth of the quantization; between 2 and 16, inclusive.
     I64Attr:$num_bits,
     // Quantization range starts from 0 or 1; starts from 1 if true.
-    DefaultValuedAttr<BoolAttr, "false">:$narrow_range
+    DefaultValuedAttr<BoolAttr, "false">:$narrow_range,
+    // The sign of the quantization.
+    DefaultValuedAttr<BoolAttr, "false">:$is_signed
   );
 
   let results = (outs
index 0c93146..32d8c8a 100644 (file)
@@ -70,7 +70,7 @@ public:
     UniformQuantizedType uniformElementType = fakeQuantAttrsToType(
         fqOp.getLoc(), fqOp.num_bits().getSExtValue(),
         fqOp.min().convertToFloat(), fqOp.max().convertToFloat(),
-        fqOp.narrow_range(), converter.expressedType);
+        fqOp.narrow_range(), converter.expressedType, fqOp.is_signed());
 
     if (!uniformElementType) {
       // Note that the fakeQuantAttrsToType will have emitted the error.
index 13c622e..2667da9 100644 (file)
@@ -45,9 +45,15 @@ mlir::quant::fakeQuantAttrsToType(Location loc, unsigned numBits, double rmin,
     }
   } else if (numBits <= 16) {
     storageType = IntegerType::get(16, ctx);
-    flags = QuantizationFlags::Signed;
-    qmin = -32768;
-    qmax = 32767;
+    if (isSigned) {
+      flags = QuantizationFlags::Signed;
+      qmin = -32768;
+      qmax = 32767;
+    } else {
+      flags = 0;
+      qmin = 0;
+      qmax = 65535;
+    }
   } else {
     emitError(loc, "unsupported FakeQuant number of bits: ") << numBits;
     return nullptr;
index bd28c2c..61561c5 100644 (file)
@@ -46,6 +46,51 @@ func @fakeQuantArgs_Quint8_SymmetricRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32
 }
 
 // -----
+// Verifies a qint8 asymmetric 0..1 range.
+// CHECK-LABEL: fakeQuantArgs_Qint8_0_1
+func @fakeQuantArgs_Qint8_0_1(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
+^bb0(%arg0: tensor<8x4x3xf32>):
+  // CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
+  // CHECK-SAME: -> tensor<8x4x3x!quant.uniform<i8:f32, 0.0039215686274509803:-128>>
+  // CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<8x4x3x!quant.uniform<i8:f32, 0.0039215686274509803:-128>>)
+  // CHECK-SAME: -> tensor<8x4x3xf32>
+  %0 = "quant.const_fake_quant"(%arg0) {
+    min = 0.0 : f32, max = 1.0 : f32, num_bits = 8, is_signed = true
+  } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
+  return %0 : tensor<8x4x3xf32>
+}
+
+// -----
+// Verifies a qint8 asymmetric 0..1 range (with narrow_range = true).
+// CHECK_LABEL: fakeQuantArgs_Qint8_NarrowRange
+func @fakeQuantArgs_Qint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
+^bb0(%arg0: tensor<8x4x3xf32>):
+  // CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
+  // CHECK-SAME: -> tensor<8x4x3x!quant.uniform<i8<-127:127>:f32, 0.003937007874015748:-127>>
+  // CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<8x4x3x!quant.uniform<i8<-127:127>:f32, 0.003937007874015748:-127>>)
+  // CHECK-SAME: -> tensor<8x4x3xf32>
+  %0 = "quant.const_fake_quant"(%arg0) {
+    min = 0.0 : f32, max = 1.0 : f32, num_bits = 8, narrow_range = true, is_signed = true
+  } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
+  return %0 : tensor<8x4x3xf32>
+}
+
+// -----
+// Verifies a qint8 symmetric range of -1..127/128.
+// CHECK_LABEL: fakeQuantArgs_Qint8_SymmetricRange
+func @fakeQuantArgs_Qint8_SymmetricRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
+^bb0(%arg0: tensor<8x4x3xf32>):
+  // CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
+  // CHECK-SAME: -> tensor<8x4x3x!quant.uniform<i8:f32, 7.812500e-03>>
+  // CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<8x4x3x!quant.uniform<i8:f32, 7.812500e-03>>)
+  // CHECK-SAME: -> tensor<8x4x3xf32>
+  %0 = "quant.const_fake_quant"(%arg0) {
+    min = -1.0 : f32, max = 0.9921875 : f32, num_bits = 8, narrow_range = false, is_signed = true
+  } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
+  return %0 : tensor<8x4x3xf32>
+}
+
+// -----
 // Verifies a commonly used -1..1 symmetric 16bit range with a zero point of
 // 0 and range -1.0 .. 32767/32768.
 // CHECK-LABEL: fakeQuantArgs_Qint16_Symmetric
@@ -56,7 +101,7 @@ func @fakeQuantArgs_Qint16_Symmetric(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
   // CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<8x4x3x!quant.uniform<i16:f32, 3.0517578125E-5>>)
   // CHECK-SAME: -> tensor<8x4x3xf32>
   %0 = "quant.const_fake_quant"(%arg0) {
-    min = -1.0 : f32, max = 0.999969482 : f32, num_bits = 16
+    min = -1.0 : f32, max = 0.999969482 : f32, num_bits = 16, is_signed = true
   } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
   return %0 : tensor<8x4x3xf32>
 }