Add `axis` attribute to the quant.stats op
authorFeng Liu <fengliuai@google.com>
Fri, 4 Oct 2019 03:28:40 +0000 (20:28 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 4 Oct 2019 03:29:08 +0000 (20:29 -0700)
The first dim length of the axisStats attribute should equals to the slice size
of the input argument when splitted by the axis dimension.

PiperOrigin-RevId: 272798042

mlir/include/mlir/Dialect/QuantOps/QuantOps.td
mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp
mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp
mlir/test/Dialect/QuantOps/parse-ops-invalid.mlir
mlir/test/Dialect/QuantOps/parse-ops.mlir

index d95b452..761f6ce 100644 (file)
@@ -197,18 +197,23 @@ def quant_StatisticsOp : quant_Op<"stats", [SameOperandsAndResultType]> {
     Currently, only dim=2 is supported, which is interpreted as [min, max].
 
     `layerStats` must be a rank 1 tensor: [2]
-    `axisStats` must be a rank 2 tensor: [N, 2], where N=the rank of `arg`.
+    `axisStats` must be a rank 2 tensor: [N, 2], where N=the slice size
+      splitted by the `axis` dimension. For example:
+      <?x?x3x2>, axis=3 => N=2
+      <?x?x3x2>, axis=2 => N=6
   }];
 
   let arguments = (ins
     quant_RealValueType:$arg,
     ElementsAttr:$layerStats,
-    OptionalAttr<ElementsAttr>:$axisStats);
+    OptionalAttr<ElementsAttr>:$axisStats,
+    OptionalAttr<I64Attr>:$axis);
   let results = (outs quant_RealValueType);
 
   let verifier = [{
     auto tensorArg = arg()->getType().dyn_cast<TensorType>();
-    auto argRank = tensorArg ? tensorArg.getRank() : 0;
+    if (!tensorArg) return emitOpError("arg needs to be tensor type.");
+
     // Verify layerStats attribute.
     {
       auto layerStatsType = layerStats().getType();
@@ -222,15 +227,21 @@ def quant_StatisticsOp : quant_Op<"stats", [SameOperandsAndResultType]> {
     }
     // Verify axisStats (optional) attribute.
     if (axisStats()) {
+      if (!axis()) return emitOpError("axis must be specified for axisStats");
+
+      auto shape = tensorArg.getShape();
+      auto argSliceSize = std::accumulate(std::next(shape.begin(),
+        axis()->getSExtValue()), shape.end(), 1, std::multiplies<int64_t>());
+
       auto axisStatsType = axisStats()->getType();
       if (!axisStatsType.getElementType().isa<FloatType>()) {
         return emitOpError("axisStats must have a floating point element type");
       }
       if (axisStatsType.getRank() != 2 ||
           axisStatsType.getDimSize(1) != 2 ||
-          axisStatsType.getDimSize(0) != argRank) {
+          axisStatsType.getDimSize(0) != argSliceSize) {
         return emitOpError("axisStats must have shape [N,2] "
-                           "where N = the argument rank");
+                           "where N = the slice size defined by the axis dim");
       }
     }
     return success();
index 3bd49d4..b618ac0 100644 (file)
@@ -26,6 +26,7 @@
 #include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/Twine.h"
 #include "llvm/Support/MathExtras.h"
+#include <numeric>
 
 using namespace mlir;
 using namespace mlir::quant;
index 696c1e2..a82a288 100644 (file)
@@ -82,8 +82,8 @@ void AddDefaultStatsPass::runWithConfig(SolverContext &solverContext,
     APFloat maxValue(1.0f);
     ElementsAttr layerStats = DenseFPElementsAttr::get(
         b.getTensorType({2}, b.getF32Type()), {minValue, maxValue});
-    auto statsOp =
-        b.create<StatisticsOp>(func.getLoc(), arg, layerStats, nullptr);
+    auto statsOp = b.create<StatisticsOp>(func.getLoc(), arg, layerStats,
+                                          nullptr, nullptr);
     arg->replaceAllUsesWith(statsOp);
 
     // StatsOp contained a use to 'arg' so make sure to reset it after replacing
@@ -109,7 +109,7 @@ void AddDefaultStatsPass::runWithConfig(SolverContext &solverContext,
     ElementsAttr layerStats = DenseFPElementsAttr::get(
         b.getTensorType({2}, b.getF32Type()), {minValue, maxValue});
     auto statsOp = b.create<StatisticsOp>(op->getLoc(), op->getResult(0),
-                                          layerStats, nullptr);
+                                          layerStats, nullptr, nullptr);
     originalResult->replaceAllUsesWith(statsOp);
 
     // StatsOp contained a use to 'op' so make sure to reset it after replacing
index 7a9b96b..272c530 100644 (file)
@@ -40,15 +40,15 @@ func @invalidStatisticsMismatchedAxisType(%arg0: tensor<8x4x3xf32>) -> tensor<8x
       [-1, 1],
       [-8, 8],
       [-1, 0]
-    ]> : tensor<3x2xi8>
+    ]> : tensor<3x2xi8>, axis = 3 : i64
   } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
   return %0 : tensor<8x4x3xf32>
 }
 
 // -----
-func @invalidStatisticsMismatchedAxisRank(%arg0: tensor<8x4x3xf32>) ->
+func @invalidStatisticsMismatchedAxisSize(%arg0: tensor<8x4x3xf32>) ->
     tensor<8x4x3xf32> {
-  // expected-error@+1 {{axisStats must have shape [N,2] where N = the argument rank}}
+  // expected-error@+1 {{axisStats must have shape [N,2] where N = the slice size defined by the axis dim}}
   %0 = "quant.stats"(%arg0) {
     layerStats = dense<[-1.0, 1.0]> : tensor<2xf32>,
     axisStats = dense<[
@@ -56,7 +56,7 @@ func @invalidStatisticsMismatchedAxisRank(%arg0: tensor<8x4x3xf32>) ->
       [-8.0, 8.0],
       [-0.5, 0.5],
       [-2.0, 3.5]
-    ]> : tensor<4x2xf32>
+    ]> : tensor<4x2xf32>, axis = 3 : i64
   } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
   return %0 : tensor<8x4x3xf32>
 }
@@ -64,14 +64,30 @@ func @invalidStatisticsMismatchedAxisRank(%arg0: tensor<8x4x3xf32>) ->
 // -----
 func @invalidStatisticsMismatchedAxisShape(%arg0: tensor<8x4x3xf32>) ->
     tensor<8x4x3xf32> {
-  // expected-error@+1 {{axisStats must have shape [N,2] where N = the argument rank}}
+  // expected-error@+1 {{axisStats must have shape [N,2] where N = the slice size defined by the axis dim}}
   %0 = "quant.stats"(%arg0) {
     layerStats = dense<[-1.0, 1.0]> : tensor<2xf32>,
     axisStats = dense<[
       [-1.0, 1.0, 1.0],
       [-8.0, 8.0, 1.0],
       [-0.5, 0.5, 1.0]
-    ]> : tensor<3x3xf32>
+    ]> : tensor<3x3xf32>, axis = 3 : i64
   } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
   return %0 : tensor<8x4x3xf32>
 }
+
+// -----
+func @axisIsRequiredForAxisStats(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
+  // expected-error@+1 {{axis must be specified for axisStats}}
+  %1 = "quant.stats"(%arg0) {
+    layerStats = dense<[-1.0, 1.0]> : tensor<2xf32>,
+    axisStats = dense<[
+      [-1.0, 1.0],
+      [-8.0, 8.0],
+      [-0.5, 0.5]
+    ]> : tensor<3x2xf32>
+  } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
+  return %1 : tensor<8x4x3xf32>
+}
+
+// -----
index 7d6d1ab..bdcd751 100644 (file)
@@ -50,7 +50,7 @@ func @validStatistics(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
       [-1.0, 1.0],
       [-8.0, 8.0],
       [-0.5, 0.5]
-    ]> : tensor<3x2xf32>
+    ]> : tensor<3x2xf32>, axis = 2 : i64
   } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
   return %1 : tensor<8x4x3xf32>
 }