From 8189e6ef908e35d90bee4d13972ce9bb6c8d8966 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Mon, 22 Apr 2019 17:35:38 -0700 Subject: [PATCH] Implement lowering of quant.dcast to the fxpmathops and standard dialect. Note that I broke this out as a separate pass because intermediate transformations often produce qcast/dcast ops that are integral to the transformation, and it is typical to want to lower any remaining, unmatched casts at the end of quantization. If this flexibility ends up not being needed, they can be collapsed into the same pass. This is included in the same cpp file because all of the math ops will need to defer to emitting quantize/dequantize logic for cases that they cannot be fully lowered to fixed-point math. Also, the new convertistof op needs to be evaluated for inclusion in StandardOps. -- PiperOrigin-RevId: 244768679 --- mlir/include/mlir/FxpMathOps/FxpMathOps.td | 15 +++ mlir/include/mlir/FxpMathOps/Passes.h | 4 + mlir/include/mlir/Quantization/QuantTypes.h | 3 +- .../FxpMathOps/Transforms/LowerUniformRealMath.cpp | 147 ++++++++++++++++++++- .../lib/FxpMathOps/Transforms/UniformKernelUtils.h | 29 ++++ mlir/test/FxpMathOps/lower-uniform-casts.mlir | 64 +++++++++ 6 files changed, 258 insertions(+), 4 deletions(-) create mode 100644 mlir/test/FxpMathOps/lower-uniform-casts.mlir diff --git a/mlir/include/mlir/FxpMathOps/FxpMathOps.td b/mlir/include/mlir/FxpMathOps/FxpMathOps.td index ae50966..c611111 100644 --- a/mlir/include/mlir/FxpMathOps/FxpMathOps.td +++ b/mlir/include/mlir/FxpMathOps/FxpMathOps.td @@ -119,6 +119,21 @@ def fxpmath_ConvertISOp : let results = (outs IntegerLike); } +def fxpmath_ConvertISToFOp : + fxpmath_Op<"convertistof", + [NoSideEffect, SameValueShape]> { + let summary = + "Does an element-wise conversion from a signed integer to a float"; + let description = [{ + Similar to an element-wise static_cast in C++, from a signed integer + element type to a floating point element type, rounding to the nearest + floating point value. + }]; + let arguments = (ins IntegerLike:$arg); + let results = (outs FloatLike); +} + + def fxpmath_VecScalarSaturatingRoundingDoublingHighMulISOp : fxpmath_Op<"vs_saturating_rounding_doubling_high_mulis", [NoSideEffect, SameValueType]> { diff --git a/mlir/include/mlir/FxpMathOps/Passes.h b/mlir/include/mlir/FxpMathOps/Passes.h index e4df24f..b988f13 100644 --- a/mlir/include/mlir/FxpMathOps/Passes.h +++ b/mlir/include/mlir/FxpMathOps/Passes.h @@ -33,6 +33,10 @@ namespace fxpmath { /// floating point form. FunctionPassBase *createLowerUniformRealMathPass(); +/// Creates a pass that lowers uniform-quantized qcast/dcast ops to equivalent +/// operations that perform quantize/dequantize. +FunctionPassBase *createLowerUniformCastsPass(); + } // namespace fxpmath } // namespace mlir diff --git a/mlir/include/mlir/Quantization/QuantTypes.h b/mlir/include/mlir/Quantization/QuantTypes.h index a8e3b04..6e9e97d 100644 --- a/mlir/include/mlir/Quantization/QuantTypes.h +++ b/mlir/include/mlir/Quantization/QuantTypes.h @@ -80,7 +80,8 @@ public: /// Support method to enable LLVM-style type casting. static bool kindof(unsigned kind) { - return kind == QuantizationTypes::UniformQuantized; + return kind == QuantizationTypes::UniformQuantized || + kind == QuantizationTypes::UniformQuantizedPerAxis; } /// Gets the minimum possible stored by a storageType. storageTypeMin must diff --git a/mlir/lib/FxpMathOps/Transforms/LowerUniformRealMath.cpp b/mlir/lib/FxpMathOps/Transforms/LowerUniformRealMath.cpp index 0eaa22e..d7c0a62 100644 --- a/mlir/lib/FxpMathOps/Transforms/LowerUniformRealMath.cpp +++ b/mlir/lib/FxpMathOps/Transforms/LowerUniformRealMath.cpp @@ -35,6 +35,123 @@ struct LowerUniformRealMathPass void runOnFunction() override; }; +struct LowerUniformCastsPass : public FunctionPass { + void runOnFunction() override; +}; + +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// Dequantize +//===----------------------------------------------------------------------===// + +static Value *emitUniformPerLayerDequantize(Location loc, Value *input, + UniformQuantizedType elementType, + PatternRewriter &rewriter) { + // Pre-conditions. + if (!elementType.isSigned()) { + // TODO: Support unsigned storage type. + return rewriter.getContext()->emitDiagnostic( + loc, "unimplemented: dequantize signed uniform", + MLIRContext::DiagnosticKind::Warning), + nullptr; + } + + Type storageType = elementType.castToStorageType(input->getType()); + Type realType = elementType.castToExpressedType(input->getType()); + Type intermediateType = + castElementType(storageType, IntegerType::get(32, rewriter.getContext())); + assert(storageType && "cannot cast to storage type"); + assert(realType && "cannot cast to expressed type"); + + // Cast to storage type. + input = rewriter.create(loc, storageType, input); + + // Promote to intermediate type. + input = rewriter.create(loc, intermediateType, input); + + // Apply zero-point offset. + if (elementType.getZeroPoint() != 0) { + Value *negZeroPointConst = rewriter.create( + loc, broadcastScalarConstIntValue(intermediateType, + -elementType.getZeroPoint())); + input = rewriter.create(loc, input, negZeroPointConst); + } + + // Convert to float. + input = rewriter.create(loc, realType, input); + + // Mul by scale. + Value *scaleConst = rewriter.create( + loc, broadcastScalarConstFloatValue(realType, + APFloat(elementType.getScale()))); + return rewriter.create(loc, input, scaleConst); +} + +static Value * +emitUniformPerAxisDequantize(Location loc, Value *input, + UniformQuantizedPerAxisType elementType, + PatternRewriter &rewriter) { + // TODO: Support per-axis dequantize. + return rewriter.getContext()->emitDiagnostic( + loc, "unimplemented: per-axis uniform dequantization", + MLIRContext::DiagnosticKind::Warning), + nullptr; + + return input->getDefiningOp()->emitWarning( + "unimplemented: per-axis uniform dequantization"), + nullptr; +} + +static Value *emitDequantize(Location loc, Value *input, + PatternRewriter &rewriter) { + Type inputType = input->getType(); + QuantizedType qElementType = + QuantizedType::getQuantizedElementType(inputType); + if (auto uperLayerElementType = + qElementType.dyn_cast_or_null()) { + return emitUniformPerLayerDequantize(loc, input, uperLayerElementType, + rewriter); + } else if (auto uperAxisElementType = + qElementType.dyn_cast_or_null()) { + return emitUniformPerAxisDequantize(loc, input, uperAxisElementType, + rewriter); + } else { + return nullptr; + } +} + +namespace { + +struct UniformDequantizePattern : public RewritePattern { + UniformDequantizePattern(MLIRContext *context) + : RewritePattern(DequantizeCastOp::getOperationName(), 1, context) {} + + PatternMatchResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const { + auto dcastOp = op->cast(); + Type inputType = dcastOp.arg()->getType(); + Type outputType = dcastOp.getResult()->getType(); + + QuantizedType inputElementType = + QuantizedType::getQuantizedElementType(inputType); + Type expressedOutputType = inputElementType.castToExpressedType(inputType); + if (expressedOutputType != outputType) { + // Not a valid uniform cast. + return matchFailure(); + } + + Value *dequantizedValue = + emitDequantize(dcastOp.getLoc(), dcastOp.arg(), rewriter); + if (!dequantizedValue) { + return matchFailure(); + } + + rewriter.replaceOp(op, dequantizedValue); + return matchSuccess(); + } +}; + } // end anonymous namespace //===----------------------------------------------------------------------===// @@ -256,6 +373,10 @@ struct UniformRealMulEwPattern : public RewritePattern { } // end anonymous namespace +//===----------------------------------------------------------------------===// +// LowerUniformRealMath pass +//===----------------------------------------------------------------------===// + void LowerUniformRealMathPass::runOnFunction() { auto &fn = getFunction(); OwningRewritePatternList patterns; @@ -269,6 +390,26 @@ FunctionPassBase *createLowerUniformRealMathPass() { return new LowerUniformRealMathPass(); } -static PassRegistration - pass("fxpmath-lower-uniform-real-math", - "Lowers uniform-quantized real math ops to integer arithmetic."); +static PassRegistration lowerUniformRealMathPass( + "fxpmath-lower-uniform-real-math", + "Lowers uniform-quantized real math ops to integer arithmetic."); + +//===----------------------------------------------------------------------===// +// LowerUniformCasts pass +//===----------------------------------------------------------------------===// + +void LowerUniformCastsPass::runOnFunction() { + auto &fn = getFunction(); + OwningRewritePatternList patterns; + auto *context = &getContext(); + patterns.push_back(llvm::make_unique(context)); + applyPatternsGreedily(fn, std::move(patterns)); +} + +FunctionPassBase *createLowerUniformCastsPass() { + return new LowerUniformCastsPass(); +} + +static PassRegistration + lowerUniformCastsPass("fxpmath-lower-uniform-casts", + "Lowers uniform-quantized casts."); diff --git a/mlir/lib/FxpMathOps/Transforms/UniformKernelUtils.h b/mlir/lib/FxpMathOps/Transforms/UniformKernelUtils.h index 53aa86e..7b25739 100644 --- a/mlir/lib/FxpMathOps/Transforms/UniformKernelUtils.h +++ b/mlir/lib/FxpMathOps/Transforms/UniformKernelUtils.h @@ -196,6 +196,35 @@ inline Attribute broadcastScalarConstIntValue(Type t, int64_t value) { return IntegerAttr::get(integerType, value); } +/// Given an APFloat, converts it to the float semantics that matches the +/// given FloatType, silently ignoring inexact conversions. +inline APFloat convertFloatToType(FloatType ft, APFloat value) { + bool losesInfo; + auto status = value.convert(ft.getFloatSemantics(), + APFloat::rmNearestTiesToEven, &losesInfo); + assert((status & (APFloat::opDivByZero | APFloat::opInvalidOp)) == 0 && + "could not convert to float const"); + return value; +} + +/// Creates an IntegerAttr with a type that matches the shape of 't' (which can +/// be a primitive/vector/tensor). +inline Attribute broadcastScalarConstFloatValue(Type t, APFloat value) { + if (auto vt = t.dyn_cast()) { + FloatType floatElementType = vt.getElementType().dyn_cast(); + assert(floatElementType && + "float broadcast element type must be float like"); + APFloat apValue = convertFloatToType(floatElementType, value); + return SplatElementsAttr::get(vt, + FloatAttr::get(vt.getElementType(), apValue)); + } else { + auto floatType = t.dyn_cast(); + assert(floatType && "float broadcast must be of float type"); + APFloat apValue = convertFloatToType(floatType, value); + return FloatAttr::get(floatType, apValue); + } +} + } // namespace detail } // namespace fxpmath } // namespace mlir diff --git a/mlir/test/FxpMathOps/lower-uniform-casts.mlir b/mlir/test/FxpMathOps/lower-uniform-casts.mlir new file mode 100644 index 0000000..3bd94a4 --- /dev/null +++ b/mlir/test/FxpMathOps/lower-uniform-casts.mlir @@ -0,0 +1,64 @@ +// RUN: mlir-opt %s -split-input-file -fxpmath-lower-uniform-casts | FileCheck %s --dump-input=always + +// ----- +// CHECK-LABEL: dequantize_per_layer_fixedpoint +!type_input = type tensor<4x!quant.uniform> +!type_result = type tensor<4xf32> +func @dequantize_per_layer_fixedpoint(%arg0 : !type_input) -> !type_result { + // CHECK: %cst = constant splat, 6.250000e-02> : tensor<4xf32> + // CHECK-NEXT: %0 = "quant.scast"(%arg0) : (tensor<4x!quant.uniform>) -> tensor<4xi8> + // CHECK-NEXT: %1 = "fxpmath.convertis"(%0) : (tensor<4xi8>) -> tensor<4xi32> + // CHECK-NEXT: %2 = "fxpmath.convertistof"(%1) : (tensor<4xi32>) -> tensor<4xf32> + // CHECK-NEXT: %3 = mulf %2, %cst : tensor<4xf32> + // CHECK-NEXT: return %3 : tensor<4xf32> + %0 = "quant.dcast"(%arg0) : (!type_input) -> (!type_result) + return %0 : !type_result +} + +// ----- +// CHECK-LABEL: dequantize_per_layer_affine +!type_input = type tensor<4x!quant.uniform> +!type_result = type tensor<4xf32> +func @dequantize_per_layer_affine(%arg0 : !type_input) -> !type_result { + // CHECK: %cst = constant splat, 36> : tensor<4xi32> + // CHECK-NEXT: %cst_0 = constant splat, 6.250000e-02> : tensor<4xf32> + // CHECK-NEXT: %0 = "quant.scast"(%arg0) : (tensor<4x!quant.uniform>) -> tensor<4xi8> + // CHECK-NEXT: %1 = "fxpmath.convertis"(%0) : (tensor<4xi8>) -> tensor<4xi32> + // CHECK-NEXT: %2 = addi %1, %cst : tensor<4xi32> + // CHECK-NEXT: %3 = "fxpmath.convertistof"(%2) : (tensor<4xi32>) -> tensor<4xf32> + // CHECK-NEXT: %4 = mulf %3, %cst_0 : tensor<4xf32> + // CHECK-NEXT: return %4 : tensor<4xf32> + %0 = "quant.dcast"(%arg0) : (!type_input) -> (!type_result) + return %0 : !type_result +} + +// ----- +// CHECK-LABEL: dequantize_per_axis_fixedpoint +!type_input = type tensor<4x!quant.uniform> +!type_result = type tensor<4xf32> +func @dequantize_per_axis_fixedpoint(%arg0 : !type_input) -> !type_result { + // expected-warning@+1 {{unimplemented: per-axis uniform dequantization}} + %0 = "quant.dcast"(%arg0) : (!type_input) -> (!type_result) + return %0 : !type_result +} + +// ----- +// CHECK-LABEL: dequantize_per_axis_affine +!type_input = type tensor<4x!quant.uniform> +!type_result = type tensor<4xf32> +func @dequantize_per_axis_affine(%arg0 : !type_input) -> !type_result { + // expected-warning@+1 {{unimplemented: per-axis uniform dequantization}} + %0 = "quant.dcast"(%arg0) : (!type_input) -> (!type_result) + return %0 : !type_result +} + +// ----- +// Noop dequantize should be skipped (will be canonicalized away later). +// CHECK-LABEL: dequantize_noop +!type_input = type tensor<4x!quant.uniform> +!type_result = type tensor<4x!quant.uniform> +func @dequantize_noop(%arg0 : !type_input) -> !type_result { + // CHECK: %0 = "quant.dcast"(%arg0) + %0 = "quant.dcast"(%arg0) : (!type_input) -> (!type_result) + return %0 : !type_result +} -- 2.7.4