Add tests to verify 0.0 is quantized correctly
authorFeng Liu <fengliuai@google.com>
Thu, 29 Aug 2019 17:08:46 +0000 (10:08 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 29 Aug 2019 17:09:22 +0000 (10:09 -0700)
We should consider both signed and narrow_range cases.

PiperOrigin-RevId: 266167366

mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp
mlir/test/Dialect/QuantOps/convert-const.mlir
mlir/test/Dialect/QuantOps/convert-fakequant.mlir

index eeb69b7..637f6a0 100644 (file)
@@ -76,7 +76,7 @@ mlir::quant::fakeQuantAttrsToType(Location loc, unsigned numBits, double rmin,
   // points and dequantized to 0.0.
   if (std::fabs(rmax - rmin) < std::numeric_limits<double>::epsilon()) {
     return UniformQuantizedType::getChecked(flags, storageType, expressedType,
-                                            1.0, 0, qmin, qmax, loc);
+                                            1.0, qmin, qmin, qmax, loc);
   }
 
   // Determine the scale.
index 0fe2941..87619df 100644 (file)
@@ -138,3 +138,36 @@ func @const_custom_storage_range_i8_fixedpoint() -> tensor<7xf32> {
   %2 = "quant.dcast"(%1) : (tensor<7x!quant.uniform<i8<-100:100>:f32, 7.812500e-03>>) -> (tensor<7xf32>)
   return %2 : tensor<7xf32>
 }
+
+// -----
+// Verifies quantization results of all-0.0 tensors are quantized to zero points.
+// CHECK-LABEL: zero_tensors_to_zero_points
+func @zero_tensors_to_zero_points() -> (tensor<7xf32>, tensor<7xf32>, tensor<7xf32>, tensor<7xf32>) {
+
+// CHECK: %[[cst:.*]] = constant dense<-127> : tensor<7xi8>
+// CHECK: %[[cst0:.*]] = constant dense<0> : tensor<7xi8>
+// CHECK: %[[cst1:.*]] = constant dense<1> : tensor<7xi8>
+// CHECK: "quant.scast"(%[[cst0]]) : (tensor<7xi8>) -> tensor<7x!quant.uniform<i8:f32, 1.000000e+00>>
+// CHECK: "quant.scast"(%[[cst]]) : (tensor<7xi8>) -> tensor<7x!quant.uniform<i8<-127:127>:f32, 1.000000e+00:-127>>
+// CHECK: "quant.scast"(%[[cst0]]) : (tensor<7xi8>) -> tensor<7x!quant.uniform<u8:f32, 1.000000e+00>>
+// CHECK: "quant.scast"(%[[cst1]]) : (tensor<7xi8>) -> tensor<7x!quant.uniform<u8<1:255>:f32, 1.000000e+00:1>>
+
+  %cst = constant dense<0.0> : tensor<7xf32>
+  %1 = "quant.qcast"(%cst) : (tensor<7xf32>) -> tensor<7x!quant.uniform<i8:f32, 1.0>>
+  %2 = "quant.dcast"(%1) : (tensor<7x!quant.uniform<i8:f32, 1.0>>) -> (tensor<7xf32>)
+
+  %cst0 = constant dense<0.0> : tensor<7xf32>
+  %3 = "quant.qcast"(%cst0) : (tensor<7xf32>) -> tensor<7x!quant.uniform<i8<-127:127>:f32, 1.0:-127>>
+  %4 = "quant.dcast"(%3) : (tensor<7x!quant.uniform<i8<-127:127>:f32, 1.0:-127>>) -> (tensor<7xf32>)
+
+  %cst1 = constant dense<0.0> : tensor<7xf32>
+  %5 = "quant.qcast"(%cst1) : (tensor<7xf32>) -> tensor<7x!quant.uniform<u8:f32, 1.0>>
+  %6 = "quant.dcast"(%5) : (tensor<7x!quant.uniform<u8:f32, 1.0>>) -> (tensor<7xf32>)
+
+  %cst2 = constant dense<0.0> : tensor<7xf32>
+  %7 = "quant.qcast"(%cst2) : (tensor<7xf32>) -> tensor<7x!quant.uniform<u8<1:255>:f32, 1.0:1>>
+  %8 = "quant.dcast"(%7) : (tensor<7x!quant.uniform<u8<1:255>:f32, 1.0:1>>) -> (tensor<7xf32>)
+
+  return %2, %4, %6, %8 : tensor<7xf32>, tensor<7xf32>, tensor<7xf32>, tensor<7xf32>
+}
+
index 61561c5..15de088 100644 (file)
@@ -1,6 +1,36 @@
 // RUN: mlir-opt %s -split-input-file -quant-convert-simulated-quantization | FileCheck %s --dump-input=fail
 
 // -----
+// Verifies a quint8 single point.
+// CHECK-LABEL: fakeQuantArgs_Quint8_0
+func @fakeQuantArgs_Quint8_0(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
+^bb0(%arg0: tensor<8x4x3xf32>):
+  // CHECK: %[[qc:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
+  // CHECK-SAME: -> tensor<8x4x3x!quant.uniform<u8:f32, 1.000000e+00>>
+  // CHECK-NEXT: "quant.dcast"(%[[qc]]) : (tensor<8x4x3x!quant.uniform<u8:f32, 1.000000e+00>>)
+  // CHECK-SAME: -> tensor<8x4x3xf32>
+  %0 = "quant.const_fake_quant"(%arg0) {
+    min = 0.0 : f32, max = 0.0 : f32, num_bits = 8
+  } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
+  return %0 : tensor<8x4x3xf32>
+}
+
+// -----
+// Verifies a quint8 single point (with narrow_range = true).
+// CHECK-LABEL: fakeQuantArgs_Quint8_0_NarrowRange
+func @fakeQuantArgs_Quint8_0_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
+^bb0(%arg0: tensor<8x4x3xf32>):
+  // CHECK: %[[qc:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
+  // CHECK-SAME: -> tensor<8x4x3x!quant.uniform<u8<1:255>:f32, 1.000000e+00:1>>
+  // CHECK-NEXT: "quant.dcast"(%[[qc]]) : (tensor<8x4x3x!quant.uniform<u8<1:255>:f32, 1.000000e+00:1>>)
+  // CHECK-SAME: -> tensor<8x4x3xf32>
+  %0 = "quant.const_fake_quant"(%arg0) {
+    min = 0.0 : f32, max = 0.0 : f32, num_bits = 8, narrow_range = true
+  } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
+  return %0 : tensor<8x4x3xf32>
+}
+
+// -----
 // Verifies a quint8 asymmetric 0..1 range.
 // CHECK-LABEL: fakeQuantArgs_Quint8_0_1
 func @fakeQuantArgs_Quint8_0_1(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
@@ -46,6 +76,36 @@ func @fakeQuantArgs_Quint8_SymmetricRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32
 }
 
 // -----
+// Verifies a qint8 single point.
+// CHECK-LABEL: fakeQuantArgs_Qint8_0
+func @fakeQuantArgs_Qint8_0(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
+^bb0(%arg0: tensor<8x4x3xf32>):
+  // CHECK: %[[qc:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
+  // CHECK-SAME: -> tensor<8x4x3x!quant.uniform<i8:f32, 1.000000e+00:-128>>
+  // CHECK-NEXT: "quant.dcast"(%[[qc]]) : (tensor<8x4x3x!quant.uniform<i8:f32, 1.000000e+00:-128>>)
+  // CHECK-SAME: -> tensor<8x4x3xf32>
+  %0 = "quant.const_fake_quant"(%arg0) {
+    min = 0.0 : f32, max = 0.0 : f32, num_bits = 8, is_signed = true
+  } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
+  return %0 : tensor<8x4x3xf32>
+}
+
+// -----
+// Verifies a qint8 single point (with narrow_range = true).
+// CHECK-LABEL: fakeQuantArgs_Qint8_0_NarrowRange
+func @fakeQuantArgs_Qint8_0_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
+^bb0(%arg0: tensor<8x4x3xf32>):
+  // CHECK: %[[qc:.*]]  = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
+  // CHECK-SAME: -> tensor<8x4x3x!quant.uniform<i8<-127:127>:f32, 1.000000e+00:-127>>
+  // CHECK-NEXT: "quant.dcast"(%[[qc]]) : (tensor<8x4x3x!quant.uniform<i8<-127:127>:f32, 1.000000e+00:-127>>)
+  // CHECK-SAME: -> tensor<8x4x3xf32>
+  %0 = "quant.const_fake_quant"(%arg0) {
+    min = 0.0 : f32, max = 0.0 : f32, num_bits = 8, narrow_range = true, is_signed = true
+  } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
+  return %0 : tensor<8x4x3xf32>
+}
+
+// -----
 // Verifies a qint8 asymmetric 0..1 range.
 // CHECK-LABEL: fakeQuantArgs_Qint8_0_1
 func @fakeQuantArgs_Qint8_0_1(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {