Currently, only dim=2 is supported, which is interpreted as [min, max].
`layerStats` must be a rank 1 tensor: [2]
- `axisStats` must be a rank 2 tensor: [N, 2], where N=the rank of `arg`.
+ `axisStats` must be a rank 2 tensor: [N, 2], where N=the slice size
+ splitted by the `axis` dimension. For example:
+ <?x?x3x2>, axis=3 => N=2
+ <?x?x3x2>, axis=2 => N=6
}];
let arguments = (ins
quant_RealValueType:$arg,
ElementsAttr:$layerStats,
- OptionalAttr<ElementsAttr>:$axisStats);
+ OptionalAttr<ElementsAttr>:$axisStats,
+ OptionalAttr<I64Attr>:$axis);
let results = (outs quant_RealValueType);
let verifier = [{
auto tensorArg = arg()->getType().dyn_cast<TensorType>();
- auto argRank = tensorArg ? tensorArg.getRank() : 0;
+ if (!tensorArg) return emitOpError("arg needs to be tensor type.");
+
// Verify layerStats attribute.
{
auto layerStatsType = layerStats().getType();
}
// Verify axisStats (optional) attribute.
if (axisStats()) {
+ if (!axis()) return emitOpError("axis must be specified for axisStats");
+
+ auto shape = tensorArg.getShape();
+ auto argSliceSize = std::accumulate(std::next(shape.begin(),
+ axis()->getSExtValue()), shape.end(), 1, std::multiplies<int64_t>());
+
auto axisStatsType = axisStats()->getType();
if (!axisStatsType.getElementType().isa<FloatType>()) {
return emitOpError("axisStats must have a floating point element type");
}
if (axisStatsType.getRank() != 2 ||
axisStatsType.getDimSize(1) != 2 ||
- axisStatsType.getDimSize(0) != argRank) {
+ axisStatsType.getDimSize(0) != argSliceSize) {
return emitOpError("axisStats must have shape [N,2] "
- "where N = the argument rank");
+ "where N = the slice size defined by the axis dim");
}
}
return success();
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/MathExtras.h"
+#include <numeric>
using namespace mlir;
using namespace mlir::quant;
APFloat maxValue(1.0f);
ElementsAttr layerStats = DenseFPElementsAttr::get(
b.getTensorType({2}, b.getF32Type()), {minValue, maxValue});
- auto statsOp =
- b.create<StatisticsOp>(func.getLoc(), arg, layerStats, nullptr);
+ auto statsOp = b.create<StatisticsOp>(func.getLoc(), arg, layerStats,
+ nullptr, nullptr);
arg->replaceAllUsesWith(statsOp);
// StatsOp contained a use to 'arg' so make sure to reset it after replacing
ElementsAttr layerStats = DenseFPElementsAttr::get(
b.getTensorType({2}, b.getF32Type()), {minValue, maxValue});
auto statsOp = b.create<StatisticsOp>(op->getLoc(), op->getResult(0),
- layerStats, nullptr);
+ layerStats, nullptr, nullptr);
originalResult->replaceAllUsesWith(statsOp);
// StatsOp contained a use to 'op' so make sure to reset it after replacing
[-1, 1],
[-8, 8],
[-1, 0]
- ]> : tensor<3x2xi8>
+ ]> : tensor<3x2xi8>, axis = 3 : i64
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
return %0 : tensor<8x4x3xf32>
}
// -----
-func @invalidStatisticsMismatchedAxisRank(%arg0: tensor<8x4x3xf32>) ->
+func @invalidStatisticsMismatchedAxisSize(%arg0: tensor<8x4x3xf32>) ->
tensor<8x4x3xf32> {
- // expected-error@+1 {{axisStats must have shape [N,2] where N = the argument rank}}
+ // expected-error@+1 {{axisStats must have shape [N,2] where N = the slice size defined by the axis dim}}
%0 = "quant.stats"(%arg0) {
layerStats = dense<[-1.0, 1.0]> : tensor<2xf32>,
axisStats = dense<[
[-8.0, 8.0],
[-0.5, 0.5],
[-2.0, 3.5]
- ]> : tensor<4x2xf32>
+ ]> : tensor<4x2xf32>, axis = 3 : i64
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
return %0 : tensor<8x4x3xf32>
}
// -----
func @invalidStatisticsMismatchedAxisShape(%arg0: tensor<8x4x3xf32>) ->
tensor<8x4x3xf32> {
- // expected-error@+1 {{axisStats must have shape [N,2] where N = the argument rank}}
+ // expected-error@+1 {{axisStats must have shape [N,2] where N = the slice size defined by the axis dim}}
%0 = "quant.stats"(%arg0) {
layerStats = dense<[-1.0, 1.0]> : tensor<2xf32>,
axisStats = dense<[
[-1.0, 1.0, 1.0],
[-8.0, 8.0, 1.0],
[-0.5, 0.5, 1.0]
- ]> : tensor<3x3xf32>
+ ]> : tensor<3x3xf32>, axis = 3 : i64
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
return %0 : tensor<8x4x3xf32>
}
+
+// -----
+func @axisIsRequiredForAxisStats(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
+ // expected-error@+1 {{axis must be specified for axisStats}}
+ %1 = "quant.stats"(%arg0) {
+ layerStats = dense<[-1.0, 1.0]> : tensor<2xf32>,
+ axisStats = dense<[
+ [-1.0, 1.0],
+ [-8.0, 8.0],
+ [-0.5, 0.5]
+ ]> : tensor<3x2xf32>
+ } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
+ return %1 : tensor<8x4x3xf32>
+}
+
+// -----
[-1.0, 1.0],
[-8.0, 8.0],
[-0.5, 0.5]
- ]> : tensor<3x2xf32>
+ ]> : tensor<3x2xf32>, axis = 2 : i64
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
return %1 : tensor<8x4x3xf32>
}