Quantize attribute values by per axis quantization parameters
authorFeng Liu <fengliuai@google.com>
Thu, 19 Sep 2019 21:11:35 +0000 (14:11 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 19 Sep 2019 21:12:08 +0000 (14:12 -0700)
A new converter with per axis quantization parameters is added to quantize a
dense elements attribute. For each slice along the quantization axis, it
creates an uniform quantized value converter, with different scale and zero
point, and quantizes the values in the slice.

The current implementation doesn't handle sparse elements attributes.

PiperOrigin-RevId: 270121986

mlir/include/mlir/Dialect/QuantOps/UniformSupport.h
mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp
mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp
mlir/test/Dialect/QuantOps/convert-const.mlir

index 4236684..0ce76b1 100644 (file)
@@ -67,7 +67,7 @@ struct ExpressedToQuantizedConverter {
 /// placeholder.
 class UniformQuantizedValueConverter {
 public:
-  UniformQuantizedValueConverter(UniformQuantizedType uniformType)
+  explicit UniformQuantizedValueConverter(UniformQuantizedType uniformType)
       : scale(uniformType.getScale()),
         zeroPoint(static_cast<double>(uniformType.getZeroPoint())),
         clampMin(static_cast<double>(uniformType.getStorageTypeMin())),
@@ -78,6 +78,13 @@ public:
     assert(uniformType.getStorageType().isa<IntegerType>());
   }
 
+  UniformQuantizedValueConverter(double scale, double zeroPoint,
+                                 APFloat clampMin, APFloat clampMax,
+                                 uint32_t storageBitWidth, bool isSigned)
+      : scale(scale), zeroPoint(zeroPoint), clampMin(clampMin),
+        clampMax(clampMax), storageBitWidth(storageBitWidth),
+        isSigned(isSigned) {}
+
   virtual APInt quantizeFloatToInt(APFloat expressedValue) const {
     bool lossy;
     expressedValue.convert(scale.getSemantics(), APFloat::rmNearestTiesToEven,
@@ -112,6 +119,52 @@ private:
   const bool isSigned;
 };
 
+/// An utility class to quantize an attribute by the per-axis quantization
+/// parameters. The size of the quantization dim in the converted elements
+/// attribute should matche the size of of scales/zeroPoints vectors in the
+/// quantization parameters.
+class UniformQuantizedPerAxisValueConverter {
+public:
+  explicit UniformQuantizedPerAxisValueConverter(
+      UniformQuantizedPerAxisType uniformType)
+      : scales(uniformType.getScales()),
+        zeroPoints(uniformType.getZeroPoints()),
+        clampMin(static_cast<double>(uniformType.getStorageTypeMin())),
+        clampMax(static_cast<double>(uniformType.getStorageTypeMax())),
+        storageBitWidth(uniformType.getStorageTypeIntegralWidth()),
+        isSigned(uniformType.isSigned()),
+        quantizationDim(uniformType.getQuantizedDimension()) {
+    assert(uniformType.getExpressedType().isa<FloatType>());
+    assert(uniformType.getStorageType().isa<IntegerType>());
+    assert(scales.size() == zeroPoints.size());
+  }
+
+  /// Quantize an Attribute by the quantization parameters. Return nullptr if
+  /// the conversion fails or the input array isn't an ElementsAttr.
+  ElementsAttr convert(Attribute realValue);
+
+private:
+  /// Quantize an DenseFPElementsAttr by the quantization parameters.
+  DenseElementsAttr convert(DenseFPElementsAttr attr);
+
+  /// Get a uniform converter for the index-th chunk along the quantizationDim.
+  /// All the elements in this chunk is quantized by the returned converter.
+  UniformQuantizedValueConverter getPerChunkConverter(int index) const {
+    UniformQuantizedValueConverter converter(scales[index], zeroPoints[index],
+                                             clampMin, clampMax,
+                                             storageBitWidth, isSigned);
+    return converter;
+  }
+
+  const ArrayRef<double> scales;
+  const ArrayRef<int64_t> zeroPoints;
+  const APFloat clampMin;
+  const APFloat clampMax;
+  const uint32_t storageBitWidth;
+  const bool isSigned;
+  int32_t quantizationDim;
+};
+
 } // namespace quant
 } // namespace mlir
 
index 4733e56..e7a1df9 100644 (file)
@@ -135,16 +135,24 @@ Attribute quantizeAttrUniform(Attribute realValue,
 /// On success, stores the converted type in outConvertedType.
 Attribute quantizeAttr(Attribute realValue, QuantizedType quantizedElementType,
                        Type &outConvertedType) {
-  // Hard-coded to just support UniformQuantizedType. This will need to
-  // be generalized when there is more than one.
-  auto uniformQuantizedType =
-      quantizedElementType.dyn_cast<UniformQuantizedType>();
-  if (!uniformQuantizedType) {
+  if (auto uniformQuantized =
+          quantizedElementType.dyn_cast<UniformQuantizedType>()) {
+    UniformQuantizedValueConverter converter(uniformQuantized);
+    return quantizeAttrUniform(realValue, uniformQuantized, converter,
+                               outConvertedType);
+
+  } else if (auto uniformQuantizedPerAxis =
+                 quantizedElementType.dyn_cast<UniformQuantizedPerAxisType>()) {
+    UniformQuantizedPerAxisValueConverter converter(uniformQuantizedPerAxis);
+    auto converted = converter.convert(realValue);
+    // TODO(fengliuai): why we need this outConvertedType? remove it?
+    if (converted) {
+      outConvertedType = converted.getType();
+    }
+    return converted;
+  } else {
     return nullptr;
   }
-  UniformQuantizedValueConverter converter(uniformQuantizedType);
-  return quantizeAttrUniform(realValue, uniformQuantizedType, converter,
-                             outConvertedType);
 }
 
 } // namespace quant
index aec45d4..34e767d 100644 (file)
@@ -17,6 +17,7 @@
 
 #include "mlir/Dialect/QuantOps/UniformSupport.h"
 #include "mlir/IR/StandardTypes.h"
+#include <numeric>
 
 using namespace mlir;
 using namespace mlir::quant;
@@ -70,3 +71,41 @@ Type ExpressedToQuantizedConverter::convert(QuantizedType elementalType) const {
                            elementalType);
   }
 }
+
+ElementsAttr
+UniformQuantizedPerAxisValueConverter::convert(Attribute realValue) {
+  if (auto attr = realValue.dyn_cast<DenseFPElementsAttr>()) {
+    return convert(attr);
+  }
+  // TODO(fengliuai): handles sparse elements attribute
+  return nullptr;
+}
+
+DenseElementsAttr
+UniformQuantizedPerAxisValueConverter::convert(DenseFPElementsAttr attr) {
+  // Creates the converter for each chunk. Normally the size of the
+  // quantization dim is 3, so we can cache all the converters.
+  ShapedType type = attr.getType();
+  size_t dimSize = type.getDimSize(quantizationDim);
+  if (dimSize != scales.size()) {
+    return {};
+  }
+  SmallVector<UniformQuantizedValueConverter, 4> converters;
+  converters.reserve(dimSize);
+  for (int i = 0, e = dimSize; i != e; ++i) {
+    converters.push_back(getPerChunkConverter(i));
+  }
+
+  // Scan the elements of the dense elements attributes and quantize them by
+  // using the right quantization parameters.
+  int64_t flattenIndex = 0;
+  auto shape = type.getShape();
+  int64_t chunkSize =
+      std::accumulate(std::next(shape.begin(), quantizationDim + 1),
+                      shape.end(), 1, std::multiplies<int64_t>());
+  Type newElementType = IntegerType::get(storageBitWidth, attr.getContext());
+  return attr.mapValues(newElementType, [&](const APFloat &old) {
+    int chunkIndex = (flattenIndex++) / chunkSize;
+    return converters[chunkIndex % dimSize].quantizeFloatToInt(old);
+  });
+}
index 87619df..36b6cb1 100644 (file)
@@ -171,3 +171,23 @@ func @zero_tensors_to_zero_points() -> (tensor<7xf32>, tensor<7xf32>, tensor<7xf
   return %2, %4, %6, %8 : tensor<7xf32>, tensor<7xf32>, tensor<7xf32>, tensor<7xf32>
 }
 
+// -----
+// Verifies per-axis quantization results for dense.
+// CHECK-LABEL: per_axis_dense_quantization
+func @per_axis_dense_quantization() -> (tensor<2x3xf32>, tensor<2x3xf32>) {
+
+// CHECK-NEXT: %[[cst:.*]] = constant dense<{{\[}}[-128, 64, 127], [0, 1, 2]]> : tensor<2x3xi8>
+// CHECK-NEXT: %[[cst0:.*]] = constant dense<{{\[}}[-128, 0, 1], [127, 1, 3]]> : tensor<2x3xi8>
+// CHECK: "quant.scast"(%[[cst]]) : (tensor<2x3xi8>) -> tensor<2x3x!quant.uniform<i8:f32:0, {7.812500e-03:128,1.000000e+00}>>
+// CHECK: "quant.scast"(%cst_0) : (tensor<2x3xi8>) -> tensor<2x3x!quant.uniform<i8:f32:1, {7.812500e-03:128,1.000000e+00,1.000000e+00:1}>>
+
+  %cst = constant dense<[[-2.0, -0.5, 0.0], [0.0, 1.0, 2.0]]> : tensor<2x3xf32>
+  %1 = "quant.qcast"(%cst) : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform<i8:f32:0, {7.812500e-03:128, 1.0}>>
+  %2 = "quant.dcast"(%1) : (tensor<2x3x!quant.uniform<i8:f32:0, {7.812500e-03:128, 1.0}>>) -> (tensor<2x3xf32>)
+
+  %cst0 = constant dense<[[-2.0, -0.5, 0.0], [0.0, 1.0, 2.0]]> : tensor<2x3xf32>
+  %3 = "quant.qcast"(%cst0) : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform<i8:f32:1, {7.812500e-03:128, 1.0, 1.0:1}>>
+  %4 = "quant.dcast"(%3) : (tensor<2x3x!quant.uniform<i8:f32:1, {7.812500e-03:128, 1.0, 1.0:1}>>) -> (tensor<2x3xf32>)
+
+  return %2, %4 : tensor<2x3xf32>, tensor<2x3xf32>
+}