void runOnFunction() override;
};
+struct LowerUniformCastsPass : public FunctionPass<LowerUniformCastsPass> {
+ void runOnFunction() override;
+};
+
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// Dequantize
+//===----------------------------------------------------------------------===//
+
+static Value *emitUniformPerLayerDequantize(Location loc, Value *input,
+ UniformQuantizedType elementType,
+ PatternRewriter &rewriter) {
+ // Pre-conditions.
+ if (!elementType.isSigned()) {
+ // TODO: Support unsigned storage type.
+ return rewriter.getContext()->emitDiagnostic(
+ loc, "unimplemented: dequantize signed uniform",
+ MLIRContext::DiagnosticKind::Warning),
+ nullptr;
+ }
+
+ Type storageType = elementType.castToStorageType(input->getType());
+ Type realType = elementType.castToExpressedType(input->getType());
+ Type intermediateType =
+ castElementType(storageType, IntegerType::get(32, rewriter.getContext()));
+ assert(storageType && "cannot cast to storage type");
+ assert(realType && "cannot cast to expressed type");
+
+ // Cast to storage type.
+ input = rewriter.create<StorageCastOp>(loc, storageType, input);
+
+ // Promote to intermediate type.
+ input = rewriter.create<ConvertISOp>(loc, intermediateType, input);
+
+ // Apply zero-point offset.
+ if (elementType.getZeroPoint() != 0) {
+ Value *negZeroPointConst = rewriter.create<ConstantOp>(
+ loc, broadcastScalarConstIntValue(intermediateType,
+ -elementType.getZeroPoint()));
+ input = rewriter.create<AddIOp>(loc, input, negZeroPointConst);
+ }
+
+ // Convert to float.
+ input = rewriter.create<ConvertISToFOp>(loc, realType, input);
+
+ // Mul by scale.
+ Value *scaleConst = rewriter.create<ConstantOp>(
+ loc, broadcastScalarConstFloatValue(realType,
+ APFloat(elementType.getScale())));
+ return rewriter.create<MulFOp>(loc, input, scaleConst);
+}
+
+static Value *
+emitUniformPerAxisDequantize(Location loc, Value *input,
+ UniformQuantizedPerAxisType elementType,
+ PatternRewriter &rewriter) {
+ // TODO: Support per-axis dequantize.
+ return rewriter.getContext()->emitDiagnostic(
+ loc, "unimplemented: per-axis uniform dequantization",
+ MLIRContext::DiagnosticKind::Warning),
+ nullptr;
+
+ return input->getDefiningOp()->emitWarning(
+ "unimplemented: per-axis uniform dequantization"),
+ nullptr;
+}
+
+static Value *emitDequantize(Location loc, Value *input,
+ PatternRewriter &rewriter) {
+ Type inputType = input->getType();
+ QuantizedType qElementType =
+ QuantizedType::getQuantizedElementType(inputType);
+ if (auto uperLayerElementType =
+ qElementType.dyn_cast_or_null<UniformQuantizedType>()) {
+ return emitUniformPerLayerDequantize(loc, input, uperLayerElementType,
+ rewriter);
+ } else if (auto uperAxisElementType =
+ qElementType.dyn_cast_or_null<UniformQuantizedPerAxisType>()) {
+ return emitUniformPerAxisDequantize(loc, input, uperAxisElementType,
+ rewriter);
+ } else {
+ return nullptr;
+ }
+}
+
+namespace {
+
+struct UniformDequantizePattern : public RewritePattern {
+ UniformDequantizePattern(MLIRContext *context)
+ : RewritePattern(DequantizeCastOp::getOperationName(), 1, context) {}
+
+ PatternMatchResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const {
+ auto dcastOp = op->cast<DequantizeCastOp>();
+ Type inputType = dcastOp.arg()->getType();
+ Type outputType = dcastOp.getResult()->getType();
+
+ QuantizedType inputElementType =
+ QuantizedType::getQuantizedElementType(inputType);
+ Type expressedOutputType = inputElementType.castToExpressedType(inputType);
+ if (expressedOutputType != outputType) {
+ // Not a valid uniform cast.
+ return matchFailure();
+ }
+
+ Value *dequantizedValue =
+ emitDequantize(dcastOp.getLoc(), dcastOp.arg(), rewriter);
+ if (!dequantizedValue) {
+ return matchFailure();
+ }
+
+ rewriter.replaceOp(op, dequantizedValue);
+ return matchSuccess();
+ }
+};
+
} // end anonymous namespace
//===----------------------------------------------------------------------===//
} // end anonymous namespace
+//===----------------------------------------------------------------------===//
+// LowerUniformRealMath pass
+//===----------------------------------------------------------------------===//
+
void LowerUniformRealMathPass::runOnFunction() {
auto &fn = getFunction();
OwningRewritePatternList patterns;
return new LowerUniformRealMathPass();
}
-static PassRegistration<LowerUniformRealMathPass>
- pass("fxpmath-lower-uniform-real-math",
- "Lowers uniform-quantized real math ops to integer arithmetic.");
+static PassRegistration<LowerUniformRealMathPass> lowerUniformRealMathPass(
+ "fxpmath-lower-uniform-real-math",
+ "Lowers uniform-quantized real math ops to integer arithmetic.");
+
+//===----------------------------------------------------------------------===//
+// LowerUniformCasts pass
+//===----------------------------------------------------------------------===//
+
+void LowerUniformCastsPass::runOnFunction() {
+ auto &fn = getFunction();
+ OwningRewritePatternList patterns;
+ auto *context = &getContext();
+ patterns.push_back(llvm::make_unique<UniformDequantizePattern>(context));
+ applyPatternsGreedily(fn, std::move(patterns));
+}
+
+FunctionPassBase *createLowerUniformCastsPass() {
+ return new LowerUniformCastsPass();
+}
+
+static PassRegistration<LowerUniformCastsPass>
+ lowerUniformCastsPass("fxpmath-lower-uniform-casts",
+ "Lowers uniform-quantized casts.");
return IntegerAttr::get(integerType, value);
}
+/// Given an APFloat, converts it to the float semantics that matches the
+/// given FloatType, silently ignoring inexact conversions.
+inline APFloat convertFloatToType(FloatType ft, APFloat value) {
+ bool losesInfo;
+ auto status = value.convert(ft.getFloatSemantics(),
+ APFloat::rmNearestTiesToEven, &losesInfo);
+ assert((status & (APFloat::opDivByZero | APFloat::opInvalidOp)) == 0 &&
+ "could not convert to float const");
+ return value;
+}
+
+/// Creates an IntegerAttr with a type that matches the shape of 't' (which can
+/// be a primitive/vector/tensor).
+inline Attribute broadcastScalarConstFloatValue(Type t, APFloat value) {
+ if (auto vt = t.dyn_cast<VectorOrTensorType>()) {
+ FloatType floatElementType = vt.getElementType().dyn_cast<FloatType>();
+ assert(floatElementType &&
+ "float broadcast element type must be float like");
+ APFloat apValue = convertFloatToType(floatElementType, value);
+ return SplatElementsAttr::get(vt,
+ FloatAttr::get(vt.getElementType(), apValue));
+ } else {
+ auto floatType = t.dyn_cast<FloatType>();
+ assert(floatType && "float broadcast must be of float type");
+ APFloat apValue = convertFloatToType(floatType, value);
+ return FloatAttr::get(floatType, apValue);
+ }
+}
+
} // namespace detail
} // namespace fxpmath
} // namespace mlir
--- /dev/null
+// RUN: mlir-opt %s -split-input-file -fxpmath-lower-uniform-casts | FileCheck %s --dump-input=always
+
+// -----
+// CHECK-LABEL: dequantize_per_layer_fixedpoint
+!type_input = type tensor<4x!quant.uniform<i8:f32, 6.25e-2>>
+!type_result = type tensor<4xf32>
+func @dequantize_per_layer_fixedpoint(%arg0 : !type_input) -> !type_result {
+ // CHECK: %cst = constant splat<tensor<4xf32>, 6.250000e-02> : tensor<4xf32>
+ // CHECK-NEXT: %0 = "quant.scast"(%arg0) : (tensor<4x!quant.uniform<i8:f32, 6.250000e-02>>) -> tensor<4xi8>
+ // CHECK-NEXT: %1 = "fxpmath.convertis"(%0) : (tensor<4xi8>) -> tensor<4xi32>
+ // CHECK-NEXT: %2 = "fxpmath.convertistof"(%1) : (tensor<4xi32>) -> tensor<4xf32>
+ // CHECK-NEXT: %3 = mulf %2, %cst : tensor<4xf32>
+ // CHECK-NEXT: return %3 : tensor<4xf32>
+ %0 = "quant.dcast"(%arg0) : (!type_input) -> (!type_result)
+ return %0 : !type_result
+}
+
+// -----
+// CHECK-LABEL: dequantize_per_layer_affine
+!type_input = type tensor<4x!quant.uniform<i8:f32, 6.25e-2:-36>>
+!type_result = type tensor<4xf32>
+func @dequantize_per_layer_affine(%arg0 : !type_input) -> !type_result {
+ // CHECK: %cst = constant splat<tensor<4xi32>, 36> : tensor<4xi32>
+ // CHECK-NEXT: %cst_0 = constant splat<tensor<4xf32>, 6.250000e-02> : tensor<4xf32>
+ // CHECK-NEXT: %0 = "quant.scast"(%arg0) : (tensor<4x!quant.uniform<i8:f32, 6.250000e-02:-36>>) -> tensor<4xi8>
+ // CHECK-NEXT: %1 = "fxpmath.convertis"(%0) : (tensor<4xi8>) -> tensor<4xi32>
+ // CHECK-NEXT: %2 = addi %1, %cst : tensor<4xi32>
+ // CHECK-NEXT: %3 = "fxpmath.convertistof"(%2) : (tensor<4xi32>) -> tensor<4xf32>
+ // CHECK-NEXT: %4 = mulf %3, %cst_0 : tensor<4xf32>
+ // CHECK-NEXT: return %4 : tensor<4xf32>
+ %0 = "quant.dcast"(%arg0) : (!type_input) -> (!type_result)
+ return %0 : !type_result
+}
+
+// -----
+// CHECK-LABEL: dequantize_per_axis_fixedpoint
+!type_input = type tensor<4x!quant.uniform<i8:f32:0, {6.25e-2,3.26e-2,4.25e-2,1.23e-2}>>
+!type_result = type tensor<4xf32>
+func @dequantize_per_axis_fixedpoint(%arg0 : !type_input) -> !type_result {
+ // expected-warning@+1 {{unimplemented: per-axis uniform dequantization}}
+ %0 = "quant.dcast"(%arg0) : (!type_input) -> (!type_result)
+ return %0 : !type_result
+}
+
+// -----
+// CHECK-LABEL: dequantize_per_axis_affine
+!type_input = type tensor<4x!quant.uniform<i8:f32:0, {6.25e-2,3.26e-2,4.25e-2,1.23e-2}>>
+!type_result = type tensor<4xf32>
+func @dequantize_per_axis_affine(%arg0 : !type_input) -> !type_result {
+ // expected-warning@+1 {{unimplemented: per-axis uniform dequantization}}
+ %0 = "quant.dcast"(%arg0) : (!type_input) -> (!type_result)
+ return %0 : !type_result
+}
+
+// -----
+// Noop dequantize should be skipped (will be canonicalized away later).
+// CHECK-LABEL: dequantize_noop
+!type_input = type tensor<4x!quant.uniform<i8:f32, 6.25e-2:-36>>
+!type_result = type tensor<4x!quant.uniform<i8:f32, 6.25e-2:-36>>
+func @dequantize_noop(%arg0 : !type_input) -> !type_result {
+ // CHECK: %0 = "quant.dcast"(%arg0)
+ %0 = "quant.dcast"(%arg0) : (!type_input) -> (!type_result)
+ return %0 : !type_result
+}