From a2e08eb384b4d045fade6ef6f8bf6049924ef21b Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 16 Apr 2019 18:36:24 -0700 Subject: [PATCH] Bring naming of some quant ops in alignment with docs and introduce a few necessary additional ops (stats_ref, stats, coupled_ref). -- PiperOrigin-RevId: 243919195 --- mlir/g3doc/Quantization.md | 17 +-- mlir/include/mlir/Quantization/QuantOps.td | 148 +++++++++++++++++---- mlir/lib/Quantization/Transforms/ConvertConst.cpp | 4 +- .../Quantization/Transforms/ConvertSimQuant.cpp | 8 +- mlir/test/Quantization/convert-const.mlir | 44 +++--- mlir/test/Quantization/convert-fakequant.mlir | 20 +-- mlir/test/Quantization/parse-ops-invalid.mlir | 77 +++++++++++ mlir/test/Quantization/parse-ops.mlir | 49 +++++++ 8 files changed, 294 insertions(+), 73 deletions(-) create mode 100644 mlir/test/Quantization/parse-ops-invalid.mlir create mode 100644 mlir/test/Quantization/parse-ops.mlir diff --git a/mlir/g3doc/Quantization.md b/mlir/g3doc/Quantization.md index eeb202f..531f6b9 100644 --- a/mlir/g3doc/Quantization.md +++ b/mlir/g3doc/Quantization.md @@ -226,14 +226,15 @@ TODO : Flesh this section out. ### Instrumentation and constraint ops -TODO : These ops are not defined yet - -* instrument_stats : Assigns a unique id and signals that statistics should be - collected by the runtime when execution passes through this op. -* constrain_uniform : Constrains that for uniform quantization, the solver - should choose a type with certain characteristics such as the number of - fixed-point values, underlying storage type, or whether to constrain to - power of two scales. +* const_fake_quant : Emulates the logic of the historic TensorFlow + fake_quant_with_min_max_args op. +* stats_ref : Declares that statistics should be gathered at this point with a + unique key and made available to future passes of the solver. +* stats : Declares inline statistics (per layer and per axis) for the point in + the computation. stats_ref ops are generally converted to stats ops once + trial runs have been performed. +* coupled_ref : Declares points in the computation to be coupled from a type + inference perspective based on a unique key. ## Integration with simulated quantization at training time diff --git a/mlir/include/mlir/Quantization/QuantOps.td b/mlir/include/mlir/Quantization/QuantOps.td index 09a1a65..13fa6ca 100644 --- a/mlir/include/mlir/Quantization/QuantOps.td +++ b/mlir/include/mlir/Quantization/QuantOps.td @@ -36,45 +36,47 @@ class quant_Op traits> : Op; //===----------------------------------------------------------------------===// -// Quantization barriers +// Quantization casts //===----------------------------------------------------------------------===// -class quant_BarrierOp traits> : - quant_Op, Arguments<(ins quant_RealValueType:$arg)>, - Results<(outs quant_RealValueType)>; - -// A QuantizeBarrier (qbarrier) represents a potential type shift from a -// quantizable type to a quantized type. +// A QuantizeCast (qcast) represents a potential type shift from a quantizable +// type to a quantized type. // -// At runtime, a qbarrier will apply the transformation expressed by its +// At runtime, a qcast will apply the transformation expressed by its // operand and result type. For flexibility during transformation, it is also -// possible to have a qbarrier that performs no transformation (both its +// possible to have a qcast that performs no transformation (both its // operand and result type are quantizable). // -// A qbarrier will typically originate from either: +// A qcast will typically originate from either: // a) An expressed or implied constraint in the source dialect which signals // that a certain level of quantization is possible or required. // b) An inference made by a quantization algorithm indicating that a // quantized representation may be acceptable. // // Especially early in transformation, it is common to have pairs of -// qbarrier/dbarrier at points where a transition to a quantized type is -// required. In addition, it is also common to have an identity qbarrier +// qcast/dcast at points where a transition to a quantized type is +// required. In addition, it is also common to have an identity qcast // (where the operand and result type are not quantized) at all points where // it is legal to use a quantized representation (but is not known to be // acceptable). -def quant_QuantizeBarrierOp : quant_BarrierOp<"qbarrier", [NoSideEffect]>; +def quant_QuantizeCastOp : quant_Op<"qcast", [NoSideEffect]> { + let arguments = (ins quant_RealValueType:$arg); + let results = (outs quant_RealValueType); +} -// A DequantizeBarrier (dbarrier) represents the inverse of a qbarrier, +// A DequantizeCast op (dcast) represents the inverse of a qcast, // converting back from a quantized to quantizable (expressed) type. // -// Like qbarriers, a dbarrier is allowed to have both its operand and result +// Like qcasts, a dcast is allowed to have both its operand and result // as non quantized types. This facilitates transformations and marks edges // where the computation must be carried out in the expressed type. // -// Especially early in transformation, it is common to have dbarriers on +// Especially early in transformation, it is common to have dcasts on // all operands to ops that must operate with the expressed type (typically // math ops prior to lowering to target-specific, quantized kernels). -def quant_DequantizeBarrierOp : quant_BarrierOp<"dbarrier", [NoSideEffect]>; +def quant_DequantizeCastOp : quant_Op<"dcast", [NoSideEffect]> { + let arguments = (ins quant_RealValueType:$arg); + let results = (outs quant_RealValueType); +} // A StorageCast (scast) represents a cast from or to a type based on the // storage type and a type based on a corresponding quantized type. @@ -87,13 +89,13 @@ def quant_DequantizeBarrierOp : quant_BarrierOp<"dbarrier", [NoSideEffect]>; // i8 -> !quant<"uniform[i8:f32]{1.0}"> // tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> // vector<4xi8> -> vector<4x!quant<"uniform[i8:f32]{1.0}">> -def quant_StorageCastOp : - quant_Op<"scast", [NoSideEffect]>, - Arguments<(ins quant_RealOrStorageValueType:$arg)>, - Results<(outs quant_RealOrStorageValueType)>; +def quant_StorageCastOp : quant_Op<"scast", [NoSideEffect]> { + let arguments = (ins quant_RealOrStorageValueType:$arg); + let results = (outs quant_RealOrStorageValueType); +} //===----------------------------------------------------------------------===// -// Training integration ops +// Training integration and instrumentation ops //===----------------------------------------------------------------------===// def quant_ConstFakeQuant : quant_Op<"const_fake_quant", @@ -102,11 +104,11 @@ def quant_ConstFakeQuant : quant_Op<"const_fake_quant", "Simulates the effect of uniform quantization with const range."; let description = [{ -Given a const min, max, num_bits and narrow_range attribute, applies the same -uniform quantization simulation as is done by the TensorFlow -fake_quant_with_min_max_args op. See the fakeQuantAttrsToType() utility -method and the quant-convert-simulated-quantization pass for futher details. -}]; + Given a const min, max, num_bits and narrow_range attribute, applies the + same uniform quantization simulation as is done by the TensorFlow + fake_quant_with_min_max_args op. See the fakeQuantAttrsToType() utility + method and the quant-convert-simulated-quantization pass for futher details. + }]; let arguments = (ins F32Tensor:$inputs, @@ -123,4 +125,96 @@ method and the quant-convert-simulated-quantization pass for futher details. ); } +def quant_StatisticsRefOp : quant_Op<"stats_ref", []> { + let summary = + "Indicates that statistics are resolved by reference."; + + let description = [{ + This op acts as an identity that, when encountered at runtime, should result + in statistics being collected about about the value of its operand/result. + Such statistics will be stored with the provided key, allowing this node + to later be converted to a 'stats' op if statistics with that key have been + encountered. + }]; + + let arguments = (ins + quant_RealValueType:$arg, + StrAttr:$statsKey + ); + let results = (outs quant_RealValueType); +} + +def quant_StatisticsOp : quant_Op<"stats", []> { + let summary = + "Identity op which associates statistics with the value."; + + let description = [{ + Associates statistics about the runtime ranges of values observed for + evaluations of this node. + + Statistics about the entire type are reported in the 'layerStats' attribute + and those for each axis, in the (optional) `axisStats` attribute. The + interpretation of each is determined by the last dimension of its shape. + Currently, only dim=2 is supported, which is interpreted as [min, max]. + + `layerStats` must be a rank 1 tensor: [2] + `axisStats` must be a rank 2 tensor: [N, 2], where N=the rank of `arg`. + }]; + + let arguments = (ins + quant_RealValueType:$arg, + ElementsAttr:$layerStats, + OptionalAttr:$axisStats); + let results = (outs quant_RealValueType); + + let verifier = [{ + auto tensorArg = arg()->getType().dyn_cast(); + auto argRank = tensorArg ? tensorArg.getRank() : 0; + // Verify layerStats attribute. + { + auto layerStatsType = layerStats().getType(); + if (!layerStatsType.getElementType().isa()) { + return emitOpError( + "layerStats must have a floating point element type"); + } + if (layerStatsType.getRank() != 1 || layerStatsType.getDimSize(0) != 2) { + return emitOpError("layerStats must have shape [2]"); + } + } + // Verify axisStats (optional) attribute. + if (axisStats()) { + auto axisStatsType = axisStats()->getType(); + if (!axisStatsType.getElementType().isa()) { + return emitOpError("axisStats must have a floating point element type"); + } + if (axisStatsType.getRank() != 2 || + axisStatsType.getDimSize(1) != 2 || + axisStatsType.getDimSize(0) != argRank) { + return emitOpError("axisStats must have shape [N,2] " + "where N = the argument rank"); + } + } + return success(); + }]; +} + +def quant_CoupledRefOp : quant_Op<"coupled_ref", []> { + let summary = + "Indicates that one point of the computation is coupled to another."; + + let description = [{ + Ordinarily, relationships between ops for the purposes of determining + compatible quantized types is explicit based on the use-def chain. However, + in some situations, a use may be separated from its def by arbitrary + external connections. In such a case, during analysis, all coupled_ref + nodes in a module which share a coupledKey will be considered to be + directly connected as via an identity op for the purpose of type inference. + }]; + + let arguments = (ins + quant_RealValueType:$arg, + StrAttr:$coupledKey); + let results = (outs quant_RealValueType); +} + #endif // QUANT_OPS diff --git a/mlir/lib/Quantization/Transforms/ConvertConst.cpp b/mlir/lib/Quantization/Transforms/ConvertConst.cpp index ec947f2..f0501c2 100644 --- a/mlir/lib/Quantization/Transforms/ConvertConst.cpp +++ b/mlir/lib/Quantization/Transforms/ConvertConst.cpp @@ -44,7 +44,7 @@ public: }; QuantizedConstRewrite(MLIRContext *context) - : RewritePattern(QuantizeBarrierOp::getOperationName(), 1, context) {} + : RewritePattern(QuantizeCastOp::getOperationName(), 1, context) {} PatternMatchResult match(Operation *op) const override; void rewrite(Operation *op, std::unique_ptr baseState, @@ -59,7 +59,7 @@ PatternMatchResult QuantizedConstRewrite::match(Operation *op) const { State state; // Is the operand a constant? - auto qbarrier = op->cast(); + auto qbarrier = op->cast(); if (!matchPattern(qbarrier.arg(), m_Constant(&state.value))) { return matchFailure(); } diff --git a/mlir/lib/Quantization/Transforms/ConvertSimQuant.cpp b/mlir/lib/Quantization/Transforms/ConvertSimQuant.cpp index 2d007e2b..7137424 100644 --- a/mlir/lib/Quantization/Transforms/ConvertSimQuant.cpp +++ b/mlir/lib/Quantization/Transforms/ConvertSimQuant.cpp @@ -82,10 +82,10 @@ public: // TODO: Map to a qbarrier with an attribute like [Forced] to signal that // this is a forced/hard-coded constraint. - auto qbarrier = rewriter.create( - op->getLoc(), quantizedType, fqOp.inputs()); - rewriter.replaceOpWithNewOp(op, converter.inputType, - qbarrier.getResult()); + auto qbarrier = rewriter.create(op->getLoc(), quantizedType, + fqOp.inputs()); + rewriter.replaceOpWithNewOp(op, converter.inputType, + qbarrier.getResult()); return false; } diff --git a/mlir/test/Quantization/convert-const.mlir b/mlir/test/Quantization/convert-const.mlir index d0ac5d7..21aa66d 100644 --- a/mlir/test/Quantization/convert-const.mlir +++ b/mlir/test/Quantization/convert-const.mlir @@ -14,8 +14,8 @@ func @constant_splat_tensor_u8_affine() -> tensor<4xf32> { // CHECK: %cst = constant splat, -64> : tensor<4xi8> // CHECK-NEXT: %0 = "quant.scast"(%cst) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[u8:f32]{7.812500e-03:128}">> %cst = constant splat, 0.5> : tensor<4xf32> - %1 = "quant.qbarrier"(%cst) : (tensor<4xf32>) -> tensor<4x!quant<"uniform[u8:f32]{7.812500e-03:128}">> - %2 = "quant.dbarrier"(%1) : (tensor<4x!quant<"uniform[u8:f32]{7.812500e-03:128}">>) -> (tensor<4xf32>) + %1 = "quant.qcast"(%cst) : (tensor<4xf32>) -> tensor<4x!quant<"uniform[u8:f32]{7.812500e-03:128}">> + %2 = "quant.dcast"(%1) : (tensor<4x!quant<"uniform[u8:f32]{7.812500e-03:128}">>) -> (tensor<4xf32>) return %2 : tensor<4xf32> } @@ -26,8 +26,8 @@ func @constant_splat_tensor_i8_affine() -> tensor<4xf32> { // CHECK: %cst = constant splat, 63> : tensor<4xi8> // CHECK-NEXT: %0 = "quant.scast"(%cst) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{7.812500e-03:-1}">> %cst = constant splat, 0.5> : tensor<4xf32> - %1 = "quant.qbarrier"(%cst) : (tensor<4xf32>) -> tensor<4x!quant<"uniform[i8:f32]{7.812500e-03:-1}">> - %2 = "quant.dbarrier"(%1) : (tensor<4x!quant<"uniform[i8:f32]{7.812500e-03:-1}">>) -> (tensor<4xf32>) + %1 = "quant.qcast"(%cst) : (tensor<4xf32>) -> tensor<4x!quant<"uniform[i8:f32]{7.812500e-03:-1}">> + %2 = "quant.dcast"(%1) : (tensor<4x!quant<"uniform[i8:f32]{7.812500e-03:-1}">>) -> (tensor<4xf32>) return %2 : tensor<4xf32> } @@ -38,8 +38,8 @@ func @const_splat_tensor_i8_fixedpoint() -> tensor<4xf32> { // CHECK: %cst = constant splat, 64> : tensor<4xi8> // CHECK-NEXT: %0 = "quant.scast"(%cst) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">> %cst = constant splat, 0.5> : tensor<4xf32> - %1 = "quant.qbarrier"(%cst) : (tensor<4xf32>) -> tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">> - %2 = "quant.dbarrier"(%1) : (tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> (tensor<4xf32>) + %1 = "quant.qcast"(%cst) : (tensor<4xf32>) -> tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">> + %2 = "quant.dcast"(%1) : (tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> (tensor<4xf32>) return %2 : tensor<4xf32> } @@ -49,8 +49,8 @@ func @const_splat_tensor_i8_fixedpoint() -> tensor<4xf32> { func @const_splat_tensor_i8_fixedpoint_neg() -> tensor<4xf32> { // CHECK: %cst = constant splat, -64> : tensor<4xi8> %cst = constant splat, -0.5> : tensor<4xf32> - %1 = "quant.qbarrier"(%cst) : (tensor<4xf32>) -> tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">> - %2 = "quant.dbarrier"(%1) : (tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> (tensor<4xf32>) + %1 = "quant.qcast"(%cst) : (tensor<4xf32>) -> tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">> + %2 = "quant.dcast"(%1) : (tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> (tensor<4xf32>) return %2 : tensor<4xf32> } @@ -60,8 +60,8 @@ func @const_splat_tensor_i8_fixedpoint_neg() -> tensor<4xf32> { func @const_dense_tensor_i8_fixedpoint() -> tensor<7xf32> { // CHECK: %cst = constant dense, [-128, -128, -64, 0, 64, 127, 127]> : tensor<7xi8> %cst = constant dense, [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<7xf32> - %1 = "quant.qbarrier"(%cst) : (tensor<7xf32>) -> tensor<7x!quant<"uniform[i8:f32]{7.812500e-03}">> - %2 = "quant.dbarrier"(%1) : (tensor<7x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> (tensor<7xf32>) + %1 = "quant.qcast"(%cst) : (tensor<7xf32>) -> tensor<7x!quant<"uniform[i8:f32]{7.812500e-03}">> + %2 = "quant.dcast"(%1) : (tensor<7x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> (tensor<7xf32>) return %2 : tensor<7xf32> } @@ -74,8 +74,8 @@ func @const_sparse_tensor_i8_fixedpoint() -> tensor<7x2xf32> { %cst = constant sparse, [[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6]], [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<7x2xf32> - %1 = "quant.qbarrier"(%cst) : (tensor<7x2xf32>) -> tensor<7x2x!quant<"uniform[i8:f32]{7.812500e-03}">> - %2 = "quant.dbarrier"(%1) : (tensor<7x2x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> (tensor<7x2xf32>) + %1 = "quant.qcast"(%cst) : (tensor<7x2xf32>) -> tensor<7x2x!quant<"uniform[i8:f32]{7.812500e-03}">> + %2 = "quant.dcast"(%1) : (tensor<7x2x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> (tensor<7x2xf32>) return %2 : tensor<7x2xf32> } @@ -86,8 +86,8 @@ func @const_primitive_float_i8_fixedpoint() -> f32 { // CHECK: %c64_i8 = constant 64 : i8 // CHECK-NEXT: %0 = "quant.scast"(%c64_i8) : (i8) -> !quant<"uniform[i8:f32]{7.812500e-03}"> %cst = constant 0.5 : f32 - %1 = "quant.qbarrier"(%cst) : (f32) -> !quant<"uniform[i8:f32]{7.812500e-03}"> - %2 = "quant.dbarrier"(%1) : (!quant<"uniform[i8:f32]{7.812500e-03}">) -> (f32) + %1 = "quant.qcast"(%cst) : (f32) -> !quant<"uniform[i8:f32]{7.812500e-03}"> + %2 = "quant.dcast"(%1) : (!quant<"uniform[i8:f32]{7.812500e-03}">) -> (f32) return %2 : f32 } @@ -98,8 +98,8 @@ func @const_dense_tensor_u4_affine() -> tensor<7xf32> { // NOTE: Unsigned quantities printed by MLIR as signed. // CHECK: %cst = constant dense, [0, 0, 4, -8, -4, -1, -1]> : tensor<7xi4> %cst = constant dense, [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<7xf32> - %1 = "quant.qbarrier"(%cst) : (tensor<7xf32>) -> tensor<7x!quant<"uniform[u4:f32]{1.250000e-01:8}">> - %2 = "quant.dbarrier"(%1) : (tensor<7x!quant<"uniform[u4:f32]{1.250000e-01:8}">>) -> (tensor<7xf32>) + %1 = "quant.qcast"(%cst) : (tensor<7xf32>) -> tensor<7x!quant<"uniform[u4:f32]{1.250000e-01:8}">> + %2 = "quant.dcast"(%1) : (tensor<7x!quant<"uniform[u4:f32]{1.250000e-01:8}">>) -> (tensor<7xf32>) return %2 : tensor<7xf32> } @@ -110,8 +110,8 @@ func @const_dense_tensor_i4_affine() -> tensor<7xf32> { // NOTE: Unsigned quantities printed by MLIR as signed. // CHECK: %cst = constant dense, [-8, -8, -5, -1, 3, 7, 7]> : tensor<7xi4> %cst = constant dense, [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<7xf32> - %1 = "quant.qbarrier"(%cst) : (tensor<7xf32>) -> tensor<7x!quant<"uniform[i4:f32]{1.250000e-01:-1}">> - %2 = "quant.dbarrier"(%1) : (tensor<7x!quant<"uniform[i4:f32]{1.250000e-01:-1}">>) -> (tensor<7xf32>) + %1 = "quant.qcast"(%cst) : (tensor<7xf32>) -> tensor<7x!quant<"uniform[i4:f32]{1.250000e-01:-1}">> + %2 = "quant.dcast"(%1) : (tensor<7x!quant<"uniform[i4:f32]{1.250000e-01:-1}">>) -> (tensor<7xf32>) return %2 : tensor<7xf32> } @@ -121,8 +121,8 @@ func @const_dense_tensor_i4_affine() -> tensor<7xf32> { func @const_dense_tensor_i4_fixedpoint() -> tensor<7xf32> { // CHECK: %cst = constant dense, [-8, -8, -4, 0, 4, 7, 7]> : tensor<7xi4> %cst = constant dense, [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<7xf32> - %1 = "quant.qbarrier"(%cst) : (tensor<7xf32>) -> tensor<7x!quant<"uniform[i4:f32]{1.250000e-01}">> - %2 = "quant.dbarrier"(%1) : (tensor<7x!quant<"uniform[i4:f32]{1.250000e-01}">>) -> (tensor<7xf32>) + %1 = "quant.qcast"(%cst) : (tensor<7xf32>) -> tensor<7x!quant<"uniform[i4:f32]{1.250000e-01}">> + %2 = "quant.dcast"(%1) : (tensor<7x!quant<"uniform[i4:f32]{1.250000e-01}">>) -> (tensor<7xf32>) return %2 : tensor<7xf32> } @@ -134,7 +134,7 @@ func @const_dense_tensor_i4_fixedpoint() -> tensor<7xf32> { func @const_custom_storage_range_i8_fixedpoint() -> tensor<7xf32> { // CHECK: %cst = constant dense, [-100, -100, -64, 0, 64, 100, 100]> : tensor<7xi8> %cst = constant dense, [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<7xf32> - %1 = "quant.qbarrier"(%cst) : (tensor<7xf32>) -> tensor<7x!quant<"uniform[i8(-100:100):f32]{7.812500e-03}">> - %2 = "quant.dbarrier"(%1) : (tensor<7x!quant<"uniform[i8(-100:100):f32]{7.812500e-03}">>) -> (tensor<7xf32>) + %1 = "quant.qcast"(%cst) : (tensor<7xf32>) -> tensor<7x!quant<"uniform[i8(-100:100):f32]{7.812500e-03}">> + %2 = "quant.dcast"(%1) : (tensor<7x!quant<"uniform[i8(-100:100):f32]{7.812500e-03}">>) -> (tensor<7xf32>) return %2 : tensor<7xf32> } diff --git a/mlir/test/Quantization/convert-fakequant.mlir b/mlir/test/Quantization/convert-fakequant.mlir index 8acabd5..38b2a6c 100644 --- a/mlir/test/Quantization/convert-fakequant.mlir +++ b/mlir/test/Quantization/convert-fakequant.mlir @@ -5,9 +5,9 @@ // CHECK-LABEL: fakeQuantArgs_Quint8_0_1 func @fakeQuantArgs_Quint8_0_1(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { ^bb0(%arg0: tensor<8x4x3xf32>): - // CHECK: %0 = "quant.qbarrier"(%arg0) : (tensor<8x4x3xf32>) + // CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) // CHECK-SAME: -> tensor<8x4x3x!quant<"uniform[u8:f32]{0.0039215686274509803}">> - // CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor<8x4x3x!quant<"uniform[u8:f32]{0.0039215686274509803}">>) + // CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<8x4x3x!quant<"uniform[u8:f32]{0.0039215686274509803}">>) // CHECK-SAME: -> tensor<8x4x3xf32> %0 = "quant.const_fake_quant"(%arg0) { min: 0.0 : f32, max: 1.0 : f32, num_bits: 8 @@ -20,9 +20,9 @@ func @fakeQuantArgs_Quint8_0_1(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { // CHECK_LABEL: fakeQuantArgs_Quint8_NarrowRange func @fakeQuantArgs_Quint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { ^bb0(%arg0: tensor<8x4x3xf32>): - // CHECK: %0 = "quant.qbarrier"(%arg0) : (tensor<8x4x3xf32>) + // CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) // CHECK-SAME: -> tensor<8x4x3x!quant<"uniform[u8(1:255):f32]{0.003937007874015748:1}">> - // CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor<8x4x3x!quant<"uniform[u8(1:255):f32]{0.003937007874015748:1}">>) + // CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<8x4x3x!quant<"uniform[u8(1:255):f32]{0.003937007874015748:1}">>) // CHECK-SAME: -> tensor<8x4x3xf32> %0 = "quant.const_fake_quant"(%arg0) { min: 0.0 : f32, max: 1.0 : f32, num_bits: 8, narrow_range: true @@ -35,9 +35,9 @@ func @fakeQuantArgs_Quint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { // CHECK_LABEL: fakeQuantArgs_Quint8_SymmetricRange func @fakeQuantArgs_Quint8_SymmetricRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { ^bb0(%arg0: tensor<8x4x3xf32>): - // CHECK: %0 = "quant.qbarrier"(%arg0) : (tensor<8x4x3xf32>) + // CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) // CHECK-SAME: -> tensor<8x4x3x!quant<"uniform[u8:f32]{7.812500e-03:128}">> - // CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor<8x4x3x!quant<"uniform[u8:f32]{7.812500e-03:128}">>) + // CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<8x4x3x!quant<"uniform[u8:f32]{7.812500e-03:128}">>) // CHECK-SAME: -> tensor<8x4x3xf32> %0 = "quant.const_fake_quant"(%arg0) { min: -1.0 : f32, max: 0.9921875 : f32, num_bits: 8, narrow_range: false @@ -51,9 +51,9 @@ func @fakeQuantArgs_Quint8_SymmetricRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32 // CHECK-LABEL: fakeQuantArgs_Qint16_Symmetric func @fakeQuantArgs_Qint16_Symmetric(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { ^bb0(%arg0: tensor<8x4x3xf32>): - // CHECK: %0 = "quant.qbarrier"(%arg0) : (tensor<8x4x3xf32>) + // CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) // CHECK-SAME: -> tensor<8x4x3x!quant<"uniform[i16:f32]{3.0517578125E-5}">> - // CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor<8x4x3x!quant<"uniform[i16:f32]{3.0517578125E-5}">>) + // CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<8x4x3x!quant<"uniform[i16:f32]{3.0517578125E-5}">>) // CHECK-SAME: -> tensor<8x4x3xf32> %0 = "quant.const_fake_quant"(%arg0) { min: -1.0 : f32, max: 0.999969482 : f32, num_bits: 16 @@ -66,9 +66,9 @@ func @fakeQuantArgs_Qint16_Symmetric(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { // CHECK-LABEL: fakeQuantArgs_UnrankedTensor func @fakeQuantArgs_UnrankedTensor(tensor) -> tensor { ^bb0(%arg0: tensor): - // CHECK: %0 = "quant.qbarrier"(%arg0) : (tensor) + // CHECK: %0 = "quant.qcast"(%arg0) : (tensor) // CHECK-SAME: -> tensor> - // CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor>) + // CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor>) // CHECK-SAME: -> tensor %0 = "quant.const_fake_quant"(%arg0) { min: 0.0 : f32, max: 1.0 : f32, num_bits: 8 diff --git a/mlir/test/Quantization/parse-ops-invalid.mlir b/mlir/test/Quantization/parse-ops-invalid.mlir new file mode 100644 index 0000000..d7c15c8 --- /dev/null +++ b/mlir/test/Quantization/parse-ops-invalid.mlir @@ -0,0 +1,77 @@ +// RUN: mlir-opt %s -split-input-file -verify + +// ----- +func @invalidStatisticsMismatchedLayerType(%arg0: tensor<8x4x3xf32>) -> + tensor<8x4x3xf32> { + // expected-error@+1 {{layerStats must have a floating point element type}} + %0 = "quant.stats"(%arg0) { + layerStats: dense, [-1, 1]> + } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> + return %0 : tensor<8x4x3xf32> +} + +// ----- +func @invalidStatisticsMismatchedLayerRank(%arg0: tensor<8x4x3xf32>) -> + tensor<8x4x3xf32> { + // expected-error@+1 {{layerStats must have shape [2]}} + %0 = "quant.stats"(%arg0) { + layerStats: dense, [[-1.0, 1.0]]> + } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> + return %0 : tensor<8x4x3xf32> +} + +// ----- +func @invalidStatisticsMismatchedLayerShape(%arg0: tensor<8x4x3xf32>) -> + tensor<8x4x3xf32> { + // expected-error@+1 {{layerStats must have shape [2]}} + %0 = "quant.stats"(%arg0) { + layerStats: dense, [-1.0, 1.0, 2.0]> + } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> + return %0 : tensor<8x4x3xf32> +} + +// ----- +// CHECK-LABEL: validStatistics +func @invalidStatisticsMismatchedAxisType(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { + // expected-error@+1 {{axisStats must have a floating point element type}} + %0 = "quant.stats"(%0) { + layerStats: dense, [-1.0, 1.0]>, + axisStats: dense, [ + [-1, 1], + [-8, 8], + [-1, 0] + ]> + } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> + return %0 : tensor<8x4x3xf32> +} + +// ----- +func @invalidStatisticsMismatchedAxisRank(%arg0: tensor<8x4x3xf32>) -> + tensor<8x4x3xf32> { + // expected-error@+1 {{axisStats must have shape [N,2] where N = the argument rank}} + %0 = "quant.stats"(%arg0) { + layerStats: dense, [-1.0, 1.0]>, + axisStats: dense, [ + [-1.0, 1.0], + [-8.0, 8.0], + [-0.5, 0.5], + [-2.0, 3.5] + ]> + } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> + return %0 : tensor<8x4x3xf32> +} + +// ----- +func @invalidStatisticsMismatchedAxisShape(%arg0: tensor<8x4x3xf32>) -> + tensor<8x4x3xf32> { + // expected-error@+1 {{axisStats must have shape [N,2] where N = the argument rank}} + %0 = "quant.stats"(%arg0) { + layerStats: dense, [-1.0, 1.0]>, + axisStats: dense, [ + [-1.0, 1.0, 1.0], + [-8.0, 8.0, 1.0], + [-0.5, 0.5, 1.0] + ]> + } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> + return %0 : tensor<8x4x3xf32> +} diff --git a/mlir/test/Quantization/parse-ops.mlir b/mlir/test/Quantization/parse-ops.mlir new file mode 100644 index 0000000..87534fa --- /dev/null +++ b/mlir/test/Quantization/parse-ops.mlir @@ -0,0 +1,49 @@ +// RUN: mlir-opt %s -split-input-file | FileCheck %s + +// ----- +// CHECK-LABEL: validConstFakeQuant +func @validConstFakeQuant(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { + %0 = "quant.const_fake_quant"(%arg0) { + min: 0.0 : f32, max: 1.0 : f32, num_bits: 8, narrow_range: true + } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> + %1 = "quant.const_fake_quant"(%0) { + min: 0.0 : f32, max: 1.0 : f32, num_bits: 8, narrow_range: false + } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> + %2 = "quant.const_fake_quant"(%1) { + min: 0.0 : f32, max: 1.0 : f32, num_bits: 8 + } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> + return %2 : tensor<8x4x3xf32> +} + +// ----- +// CHECK-LABEL: validStatisticsRef +func @validStatisticsRef(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { + %0 = "quant.stats_ref"(%arg0) { statsKey: "foobar" } : + (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> + return %0 : tensor<8x4x3xf32> +} + +// ----- +// CHECK-LABEL: validStatistics +func @validStatistics(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { + %0 = "quant.stats"(%arg0) { + layerStats: dense, [-1.0, 1.0]> + } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> + %1 = "quant.stats"(%0) { + layerStats: dense, [-1.0, 1.0]>, + axisStats: dense, [ + [-1.0, 1.0], + [-8.0, 8.0], + [-0.5, 0.5] + ]> + } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> + return %1 : tensor<8x4x3xf32> +} + +// ----- +// CHECK-LABEL: validCoupledRef +func @validCoupledRef(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { + %0 = "quant.coupled_ref"(%arg0) { coupledKey: "foobar" } : + (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> + return %0 : tensor<8x4x3xf32> +} -- 2.7.4