From e8d551e2bda03928307460cb0ff151b9b5b312a0 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Thu, 18 Apr 2019 17:06:05 -0700 Subject: [PATCH] Implement lowering of element-wise fixed point add and mul to the standard dialect. This also does the following: - Removes the poc POT add implementation in favor of a version that does not rescale. - Adds a handful of FxpMathOps which are needed (these are for comment and we may want to move them to the StandardOps dialect). - Adds a canonicalizer to the StorageCastOp, which removes some cruft once conversions have been done. - Adds a couple of predicates to OpBase. -- PiperOrigin-RevId: 244287706 --- mlir/include/mlir/FxpMathOps/FxpMathOps.td | 67 +++-- mlir/include/mlir/IR/OpBase.td | 16 +- mlir/include/mlir/Quantization/QuantOps.td | 1 + mlir/include/mlir/StandardOps/Ops.td | 4 + .../FxpMathOps/Transforms/LowerUniformRealMath.cpp | 329 +++++++++++---------- .../lib/FxpMathOps/Transforms/UniformKernelUtils.h | 203 +++++++++++++ mlir/lib/Quantization/IR/QuantOps.cpp | 37 +++ .../FxpMathOps/lower-uniform-real-math-addew.mlir | 86 +++--- .../FxpMathOps/lower-uniform-real-math-mulew.mlir | 94 ++++++ mlir/test/Quantization/canonicalize.mlir | 24 ++ 10 files changed, 633 insertions(+), 228 deletions(-) create mode 100644 mlir/lib/FxpMathOps/Transforms/UniformKernelUtils.h create mode 100644 mlir/test/FxpMathOps/lower-uniform-real-math-mulew.mlir create mode 100644 mlir/test/Quantization/canonicalize.mlir diff --git a/mlir/include/mlir/FxpMathOps/FxpMathOps.td b/mlir/include/mlir/FxpMathOps/FxpMathOps.td index 083f3d9..ae50966 100644 --- a/mlir/include/mlir/FxpMathOps/FxpMathOps.td +++ b/mlir/include/mlir/FxpMathOps/FxpMathOps.td @@ -90,18 +90,60 @@ class fxpmath_Op traits> : //===----------------------------------------------------------------------===// // Fixed-point (fxp) arithmetic ops used by kernels. +// Some of these are temporary pending inclusion into a more core dialect. //===----------------------------------------------------------------------===// -def fxpmath_RoundingDivideByPotFxpOp : - fxpmath_Op<"rounding_divide_by_poti", [NoSideEffect, SameValueType]>, - Arguments<(ins quant_StorageValueType:$x, I32Attr:$exponent)>, - Results<(outs quant_StorageValueType:$y)> { +def fxpmath_ClampISOp : fxpmath_Op<"clampis", [NoSideEffect, SameValueType]> { + let summary = + "Clamps a signed-integer like argument to a min/max range."; + let description = [{ + Element-wise equivalent to: + r = std::min(clamp_max, std::max(e, clamp_min)) + }]; + let arguments = (ins IntegerLike:$arg, + APIntAttr:$clamp_min, + APIntAttr:$clamp_max); + let results = (outs IntegerLike); +} + +def fxpmath_ConvertISOp : + fxpmath_Op<"convertis", + [NoSideEffect, SameValueShape]> { + let summary = + "Does an element-wise conversion from a signed integer to signed integer"; + let description = [{ + Similar to an element-wise static_cast in C++, from a one signed integer + element type to another. + }]; + let arguments = (ins IntegerLike:$arg); + let results = (outs IntegerLike); +} + +def fxpmath_VecScalarSaturatingRoundingDoublingHighMulISOp : + fxpmath_Op<"vs_saturating_rounding_doubling_high_mulis", + [NoSideEffect, SameValueType]> { + let summary = "Implements equivalent functionality to ARMv7 NEON VQRDMULH"; + let description = [{ + Equivalent to the ARMv7 NEON VQRDMULH instruction. + See gemmlowp::SaturatingRoundingDoublingHighMul for a reference + implementation. + }]; + let arguments = (ins IntegerLike:$a, APIntAttr:$b); + let results = (outs IntegerLike); +} + +def fxpmath_RoundingDivideByPotISOp : + fxpmath_Op<"rounding_divide_by_potis", [NoSideEffect, SameValueType]> { + let summary = [{ + Computes a rounding arithmetic right shift. + }]; let description = [{ Computes integer division by a power-of-two, correctly rounded-to-nearest. Also known as a rounding arithmetic right shift. See gemmlowp::RoundingDivideByPOT for a reference implementation. }]; - + let arguments = (ins IntegerLike:$x, APIntAttr:$exponent); + let results = (outs IntegerLike:$y); let verifier = [{ auto verifyExponent = exponent().getSExtValue(); if (verifyExponent < 0 || verifyExponent > 31) { @@ -111,21 +153,6 @@ def fxpmath_RoundingDivideByPotFxpOp : }]; } -def fxpmath_SaturatingAddFxpOp : - fxpmath_Op<"saturating_addi", [NoSideEffect, SameValueType]>, - Arguments<(ins quant_StorageValueType:$x, - quant_StorageValueType:$y, - I32Attr:$clamp_min, - I32Attr:$clamp_max)>, - Results<(outs quant_StorageValueType:$sum)> { - let description = [{ - Computes saturating addition of two operands, saturating to the given min - and max value. The implementation is responsible for choosing an - intermediate register size appropriate to carry out the operation without - overflow. See gemmlowp::SaturatingAdd for a reference implementation. - }]; -} - //===----------------------------------------------------------------------===// // Real math ops. // diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 42c16ed..9c00505 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -336,11 +336,11 @@ class TypedStaticShapeTensor : Type.predicate, IsStaticShapeTensorTypePred ]>, "statically shaped tensor">; -def I1Tensor : TypedTensor; -def I8Tensor : TypedTensor; -def I16Tensor : TypedTensor; -def I32Tensor : TypedTensor; -def I64Tensor : TypedTensor; +def I1Tensor : TypedTensor; +def I8Tensor : TypedTensor; +def I16Tensor : TypedTensor; +def I32Tensor : TypedTensor; +def I64Tensor : TypedTensor; def BF16Tensor : TypedTensor; def F16Tensor : TypedTensor; @@ -503,6 +503,12 @@ class IntegerAttrBase : let returnType = [{ APInt }]; } +def APIntAttr : Attr()">, + "arbitrary integer attribute"> { + let storageType = [{ IntegerAttr }]; + let returnType = [{ APInt }]; +} + def I32Attr : IntegerAttrBase; def I64Attr : IntegerAttrBase; diff --git a/mlir/include/mlir/Quantization/QuantOps.td b/mlir/include/mlir/Quantization/QuantOps.td index 13fa6ca..ec394757 100644 --- a/mlir/include/mlir/Quantization/QuantOps.td +++ b/mlir/include/mlir/Quantization/QuantOps.td @@ -92,6 +92,7 @@ def quant_DequantizeCastOp : quant_Op<"dcast", [NoSideEffect]> { def quant_StorageCastOp : quant_Op<"scast", [NoSideEffect]> { let arguments = (ins quant_RealOrStorageValueType:$arg); let results = (outs quant_RealOrStorageValueType); + let hasCanonicalizer = 0b1; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/StandardOps/Ops.td b/mlir/include/mlir/StandardOps/Ops.td index 9302e14..0b78e14 100644 --- a/mlir/include/mlir/StandardOps/Ops.td +++ b/mlir/include/mlir/StandardOps/Ops.td @@ -131,6 +131,10 @@ def RemIUOp : IntArithmeticOp<"std.remiu"> { let hasConstantFolder = 0b1; } +def ShlISOp : IntArithmeticOp<"std.shlis"> { + let summary = "signed integer shift left"; +} + def SubFOp : FloatArithmeticOp<"std.subf"> { let summary = "floating point subtraction operation"; let hasConstantFolder = 0b1; diff --git a/mlir/lib/FxpMathOps/Transforms/LowerUniformRealMath.cpp b/mlir/lib/FxpMathOps/Transforms/LowerUniformRealMath.cpp index 2ae0902..0eaa22e 100644 --- a/mlir/lib/FxpMathOps/Transforms/LowerUniformRealMath.cpp +++ b/mlir/lib/FxpMathOps/Transforms/LowerUniformRealMath.cpp @@ -15,17 +15,17 @@ // limitations under the License. // ============================================================================= +#include "UniformKernelUtils.h" + #include "mlir/FxpMathOps/FxpMathOps.h" #include "mlir/FxpMathOps/Passes.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" -#include "mlir/Quantization/QuantOps.h" -#include "mlir/Quantization/UniformSupport.h" - -#include +#include "mlir/StandardOps/Ops.h" using namespace mlir; using namespace mlir::fxpmath; +using namespace mlir::fxpmath::detail; using namespace mlir::quant; namespace { @@ -35,186 +35,176 @@ struct LowerUniformRealMathPass void runOnFunction() override; }; -UniformQuantizedType getUniformElementType(Type t) { - return QuantizedType::getQuantizedElementType(t) - .dyn_cast_or_null(); -} - -/// Computes the log2(x), rounded to an integral value. Returns whether 'x' can -/// be considered an exact integral value. -template bool integralLog2(F x, int &log2Result) { - const F xLog2 = std::log(x) * (1.0 / std::log(2.0)); - const F xLog2Rounded = std::round(xLog2); - const F xLog2Frac = xLog2 - xLog2Rounded; - log2Result = static_cast(xLog2Rounded); - // Allow small comparison slop below the level that would make a difference - // for 2^16 levels. - return std::abs(xLog2Frac) < 1e-6; -} +} // end anonymous namespace -/// Helper class for operating on binary operations where all operands -/// and the result are a UniformQuantizedType. -struct RealBinaryOpInfo { - RealBinaryOpInfo(Operation *op, Value *lhs, Value *rhs, - Optional clampMin, Optional clampMax) - : op(op), lhs(lhs), rhs(rhs), clampMin(clampMin), clampMax(clampMax), - lhsType(getUniformElementType(lhs->getType())), - rhsType(getUniformElementType(rhs->getType())), - resultType(getUniformElementType(*op->result_type_begin())), - lhsStorageType(QuantizedType::castToStorageType(lhs->getType())), - rhsStorageType(QuantizedType::castToStorageType(rhs->getType())), - resultStorageType( - QuantizedType::castToStorageType(*op->result_type_begin())) {} - - /// Returns whether this info is valid (all types defined, etc). - bool isValid() const { - return lhsType && rhsType && resultType && lhsStorageType && - rhsStorageType && resultStorageType; - } +//===----------------------------------------------------------------------===// +// Elementwise add +//===----------------------------------------------------------------------===// - /// Returns whether the storage type of all operands is identical. - bool isSameStorageType() const { - return lhsType.getStorageType() == rhsType.getStorageType() && - lhsType.getStorageType() == resultType.getStorageType(); +static LogicalResult +tryRewriteAffineAddEwIsomorphicSigned(const UniformBinaryOpInfo &info, + PatternRewriter &rewriter) { + if (!info.resultType.isSigned() || info.lhsType != info.resultType || + info.rhsType != info.resultType) { + return failure(); } - /// Returns whether all operands and result are considered fixedpoint power - /// of two, setting the lhs, rhs, and result log2 scale references. - bool isFixedPointPOT(int &lhsLog2Scale, int &rhsLog2Scale, - int &resultLog2Scale) const { - if (!lhsType.isFixedPoint() || !rhsType.isFixedPoint() || - !resultType.isFixedPoint()) { - return false; - } + // Choose a byte aligned intermediate width big enough to perform the + // calculation without overflow. + // TODO: This should probably be made just big enough to avoid overflow and + // leave the downstream tooling to decide how to align that to machine + // word sizes. + unsigned intermediateWidth = + info.resultType.getStorageTypeIntegralWidth() <= 8 ? 16 : 32; + IntegerType intermediateElementType = + IntegerType::get(intermediateWidth, rewriter.getContext()); + Type intermediateType = + castElementType(info.resultStorageType, intermediateElementType); - if (!integralLog2(lhsType.getScale(), lhsLog2Scale) || - !integralLog2(rhsType.getScale(), rhsLog2Scale) || - !integralLog2(resultType.getScale(), resultLog2Scale)) { - return false; - } + // Cast operands to storage type. + Value *lhsValue = rewriter + .create(info.op->getLoc(), + info.lhsStorageType, info.lhs) + .getResult(); + Value *rhsValue = rewriter + .create(info.op->getLoc(), + info.rhsStorageType, info.rhs) + .getResult(); + + // Cast to the intermediate sized type. + lhsValue = rewriter.create(info.op->getLoc(), intermediateType, + lhsValue); + rhsValue = rewriter.create(info.op->getLoc(), intermediateType, + rhsValue); - return true; + // Add. + Value *resultValue = + rewriter.create(info.op->getLoc(), lhsValue, rhsValue); + + // Zero point offset adjustment. + // result = (lhs - zp) + (rhs - zp) + zp + // zpOffset = -zp + int zpOffset = -1 * info.resultType.getZeroPoint(); + if (zpOffset != 0) { + Value *zpOffsetConst = rewriter.create( + info.op->getLoc(), + broadcastScalarConstIntValue(intermediateType, zpOffset)); + resultValue = + rewriter.create(info.op->getLoc(), resultValue, zpOffsetConst); } - /// Gets the result integer clamp range given the result quantized type - // and any explicit clamp provided as attributes. - std::pair getClampMinMax() const { - int64_t typeMin = resultType.getStorageTypeMin(); - int64_t typeMax = resultType.getStorageTypeMax(); - - if (clampMin || clampMax) { - UniformQuantizedValueConverter conv(resultType); - if (clampMin) { - typeMin = std::max(typeMin, conv.quantizeFloatToInt64(*clampMin)); - } - if (clampMax) { - typeMax = std::min(typeMax, conv.quantizeFloatToInt64(*clampMax)); - } - } + // Clamp. + auto clampMinMax = info.getClampMinMax(intermediateElementType); + resultValue = rewriter.create( + info.op->getLoc(), resultValue, clampMinMax.first, clampMinMax.second); - // The quantized, integral ops expect clamps as 32bit ints. - return { - IntegerAttr::get(IntegerType::get(32, resultType.getContext()), - typeMin), - IntegerAttr::get(IntegerType::get(32, resultType.getContext()), - typeMax), - }; - } + // Convert back to original type. + resultValue = rewriter.create( + info.op->getLoc(), info.resultStorageType, resultValue); - Operation *op; - Value *lhs; - Value *rhs; - Optional clampMin; - Optional clampMax; - - // Element UniformQuantizedType for operands/result. - UniformQuantizedType lhsType; - UniformQuantizedType rhsType; - UniformQuantizedType resultType; - - // Full storage-based types. - Type lhsStorageType; - Type rhsStorageType; - Type resultStorageType; -}; + // Cast back for new result. + rewriter.replaceOpWithNewOp( + info.op, info.getQuantizedResultType(), resultValue); -} // end anonymous namespace + return success(); +} //===----------------------------------------------------------------------===// -// Elementwise add +// Elementwise mul //===----------------------------------------------------------------------===// -/// Attempts to rewrite a fixed point power-of-two addition of two integers. -/// This supports a limited number of cases, but when supported, represents -/// the simplest computation. -static LogicalResult tryRewriteFixedPOTAddEw(const RealBinaryOpInfo &constInfo, - PatternRewriter &rewriter) { - if (!constInfo.isSameStorageType()) { - return failure(); - } - int lhsLog2Scale; - int rhsLog2Scale; - int resultLog2Scale; - if (!constInfo.isFixedPointPOT(lhsLog2Scale, rhsLog2Scale, resultLog2Scale)) { +static LogicalResult +tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo &info, + PatternRewriter &rewriter) { + if (!info.resultType.isSigned()) { return failure(); } - // Adjust shifts to be relative to the output. - // Left shift of one input scale is supported. The other must match the result - // scale. - int lhsScaleShift = lhsLog2Scale - resultLog2Scale; - int rhsScaleShift = rhsLog2Scale - resultLog2Scale; - if (lhsScaleShift != 0 && rhsScaleShift != 0) { + double outputMultiplierReal = info.lhsType.getScale() * + info.rhsType.getScale() / + info.resultType.getScale(); + if (outputMultiplierReal > 1.0) { + info.op->emitWarning("unimplemented: cannot multiply with multipler > 1.0"); return failure(); } - if (lhsScaleShift > 0 || rhsScaleShift > 0) { - return failure(); + + // TODO: Choose an appropriate intermediate width for muls > 8 bits to + // avoid overflow. + unsigned intermediateWidth = 32; + IntegerType intermediateElementType = + IntegerType::get(intermediateWidth, rewriter.getContext()); + Type intermediateType = + castElementType(info.resultStorageType, intermediateElementType); + + // Cast operands to storage type. + Value *lhsValue = rewriter + .create(info.op->getLoc(), + info.lhsStorageType, info.lhs) + .getResult(); + Value *rhsValue = rewriter + .create(info.op->getLoc(), + info.rhsStorageType, info.rhs) + .getResult(); + + // Cast to the intermediate sized type. + lhsValue = rewriter.create(info.op->getLoc(), intermediateType, + lhsValue); + rhsValue = rewriter.create(info.op->getLoc(), intermediateType, + rhsValue); + + // Apply argument zeroPoints. + if (info.lhsType.getZeroPoint() != 0) { + Value *zpOffsetConst = rewriter.create( + info.op->getLoc(), broadcastScalarConstIntValue( + intermediateType, -info.lhsType.getZeroPoint())); + lhsValue = + rewriter.create(info.op->getLoc(), lhsValue, zpOffsetConst); } - // State accessed by the closure. - Operation *mathOp = constInfo.op; - const auto clampMinMax = constInfo.getClampMinMax(); - Value *lhs = constInfo.lhs; - Value *rhs = constInfo.rhs; - Type lhsStorageType = constInfo.lhsStorageType; - Type rhsStorageType = constInfo.rhsStorageType; - - // If the lhs operand is the one requiring a shift, swap it so that the shift - // happens the rhs operand. - if (lhsScaleShift != 0) { - std::swap(lhs, rhs); - std::swap(lhsStorageType, rhsStorageType); - std::swap(lhsScaleShift, rhsScaleShift); + if (info.rhsType.getZeroPoint() != 0) { + Value *zpOffsetConst = rewriter.create( + info.op->getLoc(), broadcastScalarConstIntValue( + intermediateType, -info.rhsType.getZeroPoint())); + rhsValue = + rewriter.create(info.op->getLoc(), rhsValue, zpOffsetConst); } - int rhsRightShift = -rhsScaleShift; - // Cast operands to storage type. - Value *lhsStorageValue = - rewriter.create(mathOp->getLoc(), lhsStorageType, lhs) - .getResult(); - Value *rhsStorageValue = - rewriter.create(mathOp->getLoc(), rhsStorageType, rhs) - .getResult(); - - // Rescale the rhs operand if needed. - if (rhsRightShift != 0) { - rhsStorageValue = - rewriter - .create( - mathOp->getLoc(), rhsStorageValue, - IntegerAttr::get(IntegerType::get(32, rewriter.getContext()), - rhsRightShift)) - .getResult(); + // Mul. + Value *resultValue = + rewriter.create(info.op->getLoc(), lhsValue, rhsValue); + + // Scale output. + QuantizedMultiplierSmallerThanOneExp outputMultiplier(outputMultiplierReal); + resultValue = rewriter.create( + info.op->getLoc(), resultValue, + IntegerAttr::get(intermediateElementType, outputMultiplier.multiplier)); + resultValue = rewriter.create( + info.op->getLoc(), resultValue, + IntegerAttr::get(intermediateElementType, -outputMultiplier.exponent)); + + // Zero point offset adjustment. + if (info.resultType.getZeroPoint() != 0) { + Value *zpOffsetConst = rewriter.create( + info.op->getLoc(), + broadcastScalarConstIntValue(intermediateType, + info.resultType.getZeroPoint())); + resultValue = + rewriter.create(info.op->getLoc(), resultValue, zpOffsetConst); } - // Add. - Value *sumValue = rewriter.create( - mathOp->getLoc(), lhsStorageValue, rhsStorageValue, clampMinMax.first, - clampMinMax.second); + // Clamp. + auto clampMinMax = info.getClampMinMax(intermediateElementType); + resultValue = rewriter.create( + info.op->getLoc(), resultValue, clampMinMax.first, clampMinMax.second); + + // Convert back to original type. + resultValue = rewriter.create( + info.op->getLoc(), info.resultStorageType, resultValue); // Cast back for new result. rewriter.replaceOpWithNewOp( - mathOp, *mathOp->result_type_begin(), sumValue); + info.op, info.getQuantizedResultType(), resultValue); + return success(); } @@ -227,14 +217,36 @@ struct UniformRealAddEwPattern : public RewritePattern { PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { auto addOp = op->cast(); - const RealBinaryOpInfo info(op, addOp.x(), addOp.y(), addOp.clamp_min(), - addOp.clamp_max()); + const UniformBinaryOpInfo info(op, addOp.x(), addOp.y(), addOp.clamp_min(), + addOp.clamp_max()); + if (!info.isValid()) { + return matchFailure(); + } + + // Try all of the permutations we support. + if (succeeded(tryRewriteAffineAddEwIsomorphicSigned(info, rewriter))) { + return matchSuccess(); + } + + return matchFailure(); + } +}; + +struct UniformRealMulEwPattern : public RewritePattern { + UniformRealMulEwPattern(MLIRContext *context) + : RewritePattern(RealMulEwOp::getOperationName(), 1, context) {} + + PatternMatchResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const { + auto mulOp = op->cast(); + const UniformBinaryOpInfo info(op, mulOp.x(), mulOp.y(), mulOp.clamp_min(), + mulOp.clamp_max()); if (!info.isValid()) { return matchFailure(); } // Try all of the permutations we support. - if (succeeded(tryRewriteFixedPOTAddEw(info, rewriter))) { + if (succeeded(tryRewriteAffineMulEwSigned(info, rewriter))) { return matchSuccess(); } @@ -249,6 +261,7 @@ void LowerUniformRealMathPass::runOnFunction() { OwningRewritePatternList patterns; auto *context = &getContext(); patterns.push_back(llvm::make_unique(context)); + patterns.push_back(llvm::make_unique(context)); applyPatternsGreedily(fn, std::move(patterns)); } diff --git a/mlir/lib/FxpMathOps/Transforms/UniformKernelUtils.h b/mlir/lib/FxpMathOps/Transforms/UniformKernelUtils.h new file mode 100644 index 0000000..53aa86e --- /dev/null +++ b/mlir/lib/FxpMathOps/Transforms/UniformKernelUtils.h @@ -0,0 +1,203 @@ +//===- UniformKernelUtils.h - Utilities for lowering uniform math - C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_ +#define MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_ + +#include "mlir/IR/Operation.h" +#include "mlir/Quantization/QuantOps.h" +#include "mlir/Quantization/QuantTypes.h" +#include "mlir/Quantization/UniformSupport.h" + +#include + +namespace mlir { +namespace fxpmath { +namespace detail { + +inline quant::UniformQuantizedType getUniformElementType(Type t) { + return quant::QuantizedType::getQuantizedElementType(t) + .dyn_cast_or_null(); +} + +inline bool hasStorageBitWidth(quant::QuantizedType t, + llvm::ArrayRef checkWidths) { + unsigned w = t.getStorageType().getIntOrFloatBitWidth(); + for (unsigned checkWidth : checkWidths) { + if (w == checkWidth) + return true; + } + return false; +} + +/// Computes the log2(x), rounded to an integral value. Returns whether 'x' can +/// be considered an exact integral value. +template bool integralLog2(F x, int &log2Result) { + const F xLog2 = std::log(x) * (1.0 / std::log(2.0)); + const F xLog2Rounded = std::round(xLog2); + const F xLog2Frac = xLog2 - xLog2Rounded; + log2Result = static_cast(xLog2Rounded); + // Allow small comparison slop below the level that would make a difference + // for 2^16 levels. + return std::abs(xLog2Frac) < 1e-6; +} + +/// Helper class for operating on binary operations where all operands +/// and the result are a UniformQuantizedType. +struct UniformBinaryOpInfo { + UniformBinaryOpInfo(Operation *op, Value *lhs, Value *rhs, + Optional clampMin, Optional clampMax) + : op(op), lhs(lhs), rhs(rhs), clampMin(clampMin), clampMax(clampMax), + lhsType(getUniformElementType(lhs->getType())), + rhsType(getUniformElementType(rhs->getType())), + resultType(getUniformElementType(*op->result_type_begin())), + lhsStorageType(quant::QuantizedType::castToStorageType(lhs->getType())), + rhsStorageType(quant::QuantizedType::castToStorageType(rhs->getType())), + resultStorageType( + quant::QuantizedType::castToStorageType(*op->result_type_begin())) { + } + + /// Returns whether this info is valid (all types defined, etc). + bool isValid() const { + return lhsType && rhsType && resultType && lhsStorageType && + rhsStorageType && resultStorageType; + } + + /// Gets the final quantized result type of the result. + Type getQuantizedResultType() const { return *op->result_type_begin(); } + + /// Returns whether the storage type of all operands is identical. + bool isSameStorageType() const { + return lhsType.getStorageType() == rhsType.getStorageType() && + lhsType.getStorageType() == resultType.getStorageType(); + } + + /// Returns whether all operands and result are considered fixedpoint power + /// of two, setting the lhs, rhs, and result log2 scale references. + bool isFixedPointPOT(int &lhsLog2Scale, int &rhsLog2Scale, + int &resultLog2Scale) const { + if (!lhsType.isFixedPoint() || !rhsType.isFixedPoint() || + !resultType.isFixedPoint()) { + return false; + } + + if (!integralLog2(lhsType.getScale(), lhsLog2Scale) || + !integralLog2(rhsType.getScale(), rhsLog2Scale) || + !integralLog2(resultType.getScale(), resultLog2Scale)) { + return false; + } + + return true; + } + + /// Gets the result integer clamp range given the result quantized type + // and any explicit clamp provided as attributes. + std::pair getClampMinMax(IntegerType ty) const { + int64_t typeMin = resultType.getStorageTypeMin(); + int64_t typeMax = resultType.getStorageTypeMax(); + + if (clampMin || clampMax) { + quant::UniformQuantizedValueConverter conv(resultType); + if (clampMin) { + typeMin = std::max(typeMin, conv.quantizeFloatToInt64(*clampMin)); + } + if (clampMax) { + typeMax = std::min(typeMax, conv.quantizeFloatToInt64(*clampMax)); + } + } + + // The quantized, integral ops expect clamps as 32bit ints. + return { + IntegerAttr::get(ty, typeMin), + IntegerAttr::get(ty, typeMax), + }; + } + + Operation *op; + Value *lhs; + Value *rhs; + Optional clampMin; + Optional clampMax; + + // Element UniformQuantizedType for operands/result. + quant::UniformQuantizedType lhsType; + quant::UniformQuantizedType rhsType; + quant::UniformQuantizedType resultType; + + // Full storage-based types. + Type lhsStorageType; + Type rhsStorageType; + Type resultStorageType; +}; + +/// Derives a quantized multiplier and shift from a real valued multiplier +/// less than 1. +struct QuantizedMultiplierSmallerThanOneExp { + QuantizedMultiplierSmallerThanOneExp(double realMultiplier) { + assert(realMultiplier < 1.0); + assert(realMultiplier > 0.0); + + const double q = std::frexp(realMultiplier, &exponent); + auto qFixed = static_cast(std::round(q * (1ll << 31))); + assert(qFixed <= (1ll << 31)); + if (qFixed == (1ll << 31)) { + qFixed /= 2; + ++exponent; + } + assert(qFixed <= std::numeric_limits::max()); + multiplier = static_cast(qFixed); + } + + int32_t multiplier; + int exponent; +}; + +/// Casts an integer or floating point based type to a new element type. +inline Type castElementType(Type t, Type newElementType) { + if (auto vt = t.dyn_cast()) { + switch (vt.getKind()) { + case StandardTypes::Kind::Vector: + return VectorType::get(vt.getShape(), newElementType); + case StandardTypes::Kind::RankedTensor: + return RankedTensorType::get(vt.getShape(), newElementType); + case StandardTypes::Kind::UnrankedTensor: + return UnrankedTensorType::get(newElementType); + } + } + assert(t.isIntOrFloat()); + return newElementType; +} + +/// Creates an IntegerAttr with a type that matches the shape of 't' (which can +/// be a primitive/vector/tensor). +inline Attribute broadcastScalarConstIntValue(Type t, int64_t value) { + if (auto vt = t.dyn_cast()) { + assert(vt.getElementType().isa()); + return SplatElementsAttr::get(vt, + IntegerAttr::get(vt.getElementType(), value)); + } + + auto integerType = t.cast(); + assert(t.isa() && "integer broadcast must be of integer type"); + return IntegerAttr::get(integerType, value); +} + +} // namespace detail +} // namespace fxpmath +} // namespace mlir + +#endif // MLIR_FXPMATH_UNIFORM_KERNEL_UTILS_H_ diff --git a/mlir/lib/Quantization/IR/QuantOps.cpp b/mlir/lib/Quantization/IR/QuantOps.cpp index 2e498cd..a183a24 100644 --- a/mlir/lib/Quantization/IR/QuantOps.cpp +++ b/mlir/lib/Quantization/IR/QuantOps.cpp @@ -18,6 +18,8 @@ #include "mlir/Quantization/QuantOps.h" #include "TypeDetail.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" #include "mlir/Quantization/QuantTypes.h" #include "llvm/ADT/StringRef.h" @@ -31,6 +33,41 @@ using namespace mlir::quant::detail; #define GET_OP_CLASSES #include "mlir/Quantization/QuantOps.cpp.inc" +namespace { + +/// Matches x -> [scast -> scast] -> y, replacing the second scast with the +/// value of x if the casts invert each other. +class RemoveRedundantStorageCastsRewrite : public RewritePattern { +public: + RemoveRedundantStorageCastsRewrite(MLIRContext *context) + : RewritePattern(StorageCastOp::getOperationName(), 1, context) {} + + PatternMatchResult match(Operation *op) const override { + auto scastOp = op->cast(); + if (matchPattern(scastOp.arg(), m_Op())) { + auto srcScastOp = scastOp.arg()->getDefiningOp()->cast(); + if (srcScastOp.arg()->getType() == scastOp.getResult()->getType()) { + return matchSuccess(); + } + } + return matchFailure(); + } + + void rewrite(Operation *op, PatternRewriter &rewriter) const override { + auto scastOp = op->cast(); + auto srcScastOp = scastOp.arg()->getDefiningOp()->cast(); + rewriter.replaceOp(op, srcScastOp.arg()); + } +}; + +} // end anonymous namespace + +void StorageCastOp::getCanonicalizationPatterns( + OwningRewritePatternList &patterns, MLIRContext *context) { + patterns.push_back( + llvm::make_unique(context)); +} + QuantizationDialect::QuantizationDialect(MLIRContext *context) : Dialect(/*name=*/"quant", context) { addTypes(); diff --git a/mlir/test/FxpMathOps/lower-uniform-real-math-addew.mlir b/mlir/test/FxpMathOps/lower-uniform-real-math-addew.mlir index 29783f85..4258cf7 100644 --- a/mlir/test/FxpMathOps/lower-uniform-real-math-addew.mlir +++ b/mlir/test/FxpMathOps/lower-uniform-real-math-addew.mlir @@ -1,51 +1,44 @@ -// RUN: mlir-opt %s -split-input-file -fxpmath-lower-uniform-real-math | FileCheck %s --dump-input=fail +// RUN: mlir-opt %s -split-input-file -fxpmath-lower-uniform-real-math -canonicalize | FileCheck %s --dump-input=always // ----- -// Verify lowering when operands and result have the same fixedpoint pot scale. -// CHECK-LABEL: real_addew_fixedpoint_same_scale -// CHECK: %0 = "quant.scast"(%arg0) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8> -// CHECK-NEXT: %1 = "quant.scast"(%arg1) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8> -// CHECK-NEXT: %2 = "fxpmath.saturating_addi"(%0, %1) {clamp_max: 127 : i32, clamp_min: -128 : i32} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8> -// CHECK-NEXT: %3 = "quant.scast"(%2) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">> -// CHECK-NEXT: return %3 : tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">> +// Verify lowering when operands and result have the same fixedpoint scale. +// CHECK-LABEL: real_addew_fixedpoint_isomorphic !type_lhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> !type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> !type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> -func @real_addew_fixedpoint_same_scale(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result { +func @real_addew_fixedpoint_isomorphic(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result { + // CHECK-NEXT: %0 = "quant.scast"(%arg0) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8> + // CHECK-NEXT: %1 = "quant.scast"(%arg1) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8> + // CHECK-NEXT: %2 = "fxpmath.convertis"(%0) : (tensor<4xi8>) -> tensor<4xi16> + // CHECK-NEXT: %3 = "fxpmath.convertis"(%1) : (tensor<4xi8>) -> tensor<4xi16> + // CHECK-NEXT: %4 = addi %2, %3 : tensor<4xi16> + // CHECK-NEXT: %5 = "fxpmath.clampis"(%4) {clamp_max: 127 : i16, clamp_min: -128 : i16} : (tensor<4xi16>) -> tensor<4xi16> + // CHECK-NEXT: %6 = "fxpmath.convertis"(%5) : (tensor<4xi16>) -> tensor<4xi8> + // CHECK-NEXT: %7 = "quant.scast"(%6) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">> + // CHECK-NEXT: return %7 : tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">> %0 = "fxpmath.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result) return %0 : !type_result } // ----- -// Verify lowering when the rhs is a shifted pot scale compared to lhs and result. -// CHECK-LABEL: real_addew_fixedpoint_rhs_shift -// CHECK: %0 = "quant.scast"(%arg0) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8> -// CHECK-NEXT: %1 = "quant.scast"(%arg1) : (tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> tensor<4xi8> -// CHECK-NEXT: %2 = "fxpmath.rounding_divide_by_poti"(%1) {exponent: 3 : i32} : (tensor<4xi8>) -> tensor<4xi8> -// CHECK-NEXT: %3 = "fxpmath.saturating_addi"(%0, %2) {clamp_max: 127 : i32, clamp_min: -128 : i32} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8> -// CHECK-NEXT: %4 = "quant.scast"(%3) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">> -// CHECK-NEXT: return %4 : tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">> -!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> -!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{7.8125e-03}">> -!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> -func @real_addew_fixedpoint_rhs_shift(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result { - %0 = "fxpmath.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result) - return %0 : !type_result -} - -// ----- -// Verify lowering when the lhs is a shifted pot scale compared to lhs and result. -// CHECK-LABEL: real_addew_fixedpoint_lhs_shift -// CHECK: %0 = "quant.scast"(%arg1) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8> -// CHECK-NEXT: %1 = "quant.scast"(%arg0) : (tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> tensor<4xi8> -// CHECK-NEXT: %2 = "fxpmath.rounding_divide_by_poti"(%1) {exponent: 3 : i32} : (tensor<4xi8>) -> tensor<4xi8> -// CHECK-NEXT: %3 = "fxpmath.saturating_addi"(%0, %2) {clamp_max: 127 : i32, clamp_min: -128 : i32} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8> -// CHECK-NEXT: %4 = "quant.scast"(%3) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">> -// CHECK-NEXT: return %4 : tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">> -!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{7.8125e-03}">> -!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> -!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> -func @real_addew_fixedpoint_lhs_shift(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result { +// Verify lowering when operands and result have the same fixedpoint scale +// and non-zero zero points. +// CHECK-LABEL: real_addew_affine_isomorphic +!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2:-5}">> +!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2:-5}">> +!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2:-5}">> +func @real_addew_affine_isomorphic(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result { + // CHECK-NEXT: %cst = constant splat, 5> : tensor<4xi16> + // CHECK-NEXT: %0 = "quant.scast"(%arg0) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02:-5}">>) -> tensor<4xi8> + // CHECK-NEXT: %1 = "quant.scast"(%arg1) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02:-5}">>) -> tensor<4xi8> + // CHECK-NEXT: %2 = "fxpmath.convertis"(%0) : (tensor<4xi8>) -> tensor<4xi16> + // CHECK-NEXT: %3 = "fxpmath.convertis"(%1) : (tensor<4xi8>) -> tensor<4xi16> + // CHECK-NEXT: %4 = addi %2, %3 : tensor<4xi16> + // CHECK-NEXT: %5 = addi %4, %cst : tensor<4xi16> + // CHECK-NEXT: %6 = "fxpmath.clampis"(%5) {clamp_max: 127 : i16, clamp_min: -128 : i16} : (tensor<4xi16>) -> tensor<4xi16> + // CHECK-NEXT: %7 = "fxpmath.convertis"(%6) : (tensor<4xi16>) -> tensor<4xi8> + // CHECK-NEXT: %8 = "quant.scast"(%7) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{6.250000e-02:-5}">> + // CHECK-NEXT: return %8 : tensor<4x!quant<"uniform[i8:f32]{6.250000e-02:-5}">> %0 = "fxpmath.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result) return %0 : !type_result } @@ -54,16 +47,19 @@ func @real_addew_fixedpoint_lhs_shift(%arg0 : !type_lhs, %arg1: !type_rhs) -> !t // The RHS quant parameters proscribe a range of [-8..8) so an explicit clamp // of [-4..4] should result in an integral clamp range of [-64..64]. // CHECK-LABEL: real_addew_fixedpoint_clamp -// CHECK: %0 = "quant.scast"(%arg1) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8> -// CHECK-NEXT: %1 = "quant.scast"(%arg0) : (tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> tensor<4xi8> -// CHECK-NEXT: %2 = "fxpmath.rounding_divide_by_poti"(%1) {exponent: 3 : i32} : (tensor<4xi8>) -> tensor<4xi8> -// CHECK-NEXT: %3 = "fxpmath.saturating_addi"(%0, %2) {clamp_max: 64 : i32, clamp_min: -64 : i32} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8> -// CHECK-NEXT: %4 = "quant.scast"(%3) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">> -// CHECK-NEXT: return %4 : tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">> -!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{7.8125e-03}">> +!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> !type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> !type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> func @real_addew_fixedpoint_clamp(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result { + // CHECK-NEXT: %0 = "quant.scast"(%arg0) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8> + // CHECK-NEXT: %1 = "quant.scast"(%arg1) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8> + // CHECK-NEXT: %2 = "fxpmath.convertis"(%0) : (tensor<4xi8>) -> tensor<4xi16> + // CHECK-NEXT: %3 = "fxpmath.convertis"(%1) : (tensor<4xi8>) -> tensor<4xi16> + // CHECK-NEXT: %4 = addi %2, %3 : tensor<4xi16> + // CHECK-NEXT: %5 = "fxpmath.clampis"(%4) {clamp_max: 64 : i16, clamp_min: -64 : i16} : (tensor<4xi16>) -> tensor<4xi16> + // CHECK-NEXT: %6 = "fxpmath.convertis"(%5) : (tensor<4xi16>) -> tensor<4xi8> + // CHECK-NEXT: %7 = "quant.scast"(%6) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">> + // CHECK-NEXT: return %7 : tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">> %0 = "fxpmath.real_add_ew"(%arg0, %arg1) { clamp_min:-4.0, clamp_max:4.0 } : (!type_lhs, !type_rhs) -> (!type_result) return %0 : !type_result diff --git a/mlir/test/FxpMathOps/lower-uniform-real-math-mulew.mlir b/mlir/test/FxpMathOps/lower-uniform-real-math-mulew.mlir new file mode 100644 index 0000000..8106612 --- /dev/null +++ b/mlir/test/FxpMathOps/lower-uniform-real-math-mulew.mlir @@ -0,0 +1,94 @@ +// RUN: mlir-opt %s -split-input-file -fxpmath-lower-uniform-real-math -canonicalize -verify | FileCheck %s --dump-input=always + +// ----- +// Verify lowering when operands and result have the same fixedpoint scale. +// CHECK-LABEL: real_mulew_fixedpoint +!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> +!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{3.875e-2}">> +!type_result = type tensor<4x!quant<"uniform[i8:f32]{1.065e-1}">> +func @real_mulew_fixedpoint(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result { + // CHECK: %0 = "quant.scast"(%arg0) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8> + // CHECK-NEXT: %1 = "quant.scast"(%arg1) : (tensor<4x!quant<"uniform[i8:f32]{3.875000e-02}">>) -> tensor<4xi8> + // CHECK-NEXT: %2 = "fxpmath.convertis"(%0) : (tensor<4xi8>) -> tensor<4xi32> + // CHECK-NEXT: %3 = "fxpmath.convertis"(%1) : (tensor<4xi8>) -> tensor<4xi32> + // CHECK-NEXT: %4 = muli %2, %3 : tensor<4xi32> + // CHECK-NEXT: %5 = "fxpmath.vs_saturating_rounding_doubling_high_mulis"(%4) {b: 1562722842 : i32} : (tensor<4xi32>) -> tensor<4xi32> + // CHECK-NEXT: %6 = "fxpmath.rounding_divide_by_potis"(%5) {exponent: 5 : i32} : (tensor<4xi32>) -> tensor<4xi32> + // CHECK-NEXT: %7 = "fxpmath.clampis"(%6) {clamp_max: 127 : i32, clamp_min: -128 : i32} : (tensor<4xi32>) -> tensor<4xi32> + // CHECK-NEXT: %8 = "fxpmath.convertis"(%7) : (tensor<4xi32>) -> tensor<4xi8> + // CHECK-NEXT: %9 = "quant.scast"(%8) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{1.065000e-01}">> + // CHECK-NEXT: return %9 : tensor<4x!quant<"uniform[i8:f32]{1.065000e-01}">> + %0 = "fxpmath.real_mul_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result) + return %0 : !type_result +} + +// ----- +// Verify lowering when operands and result have the same fixedpoint scale +// and non-zero zero points. +// CHECK-LABEL: real_mulew_affine_clamp +!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2:-3}">> +!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2:-5}">> +!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2:-9}">> +func @real_mulew_affine_clamp(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result { + // Just verify that the affine adds/constants and clamps are present. + // CHECK: %cst = constant splat, 3> : tensor<4xi32> + // CHECK: %cst_0 = constant splat, 5> : tensor<4xi32> + // CHECK: %cst_1 = constant splat, -9> : tensor<4xi32> + // CHECK: addi %2, %cst : tensor<4xi32> + // CHECK: addi %3, %cst_0 : tensor<4xi32> + // CHECK: muli %4, %5 : tensor<4xi32> + // CHECK: addi %8, %cst_1 : tensor<4xi32> + // CHECK: {clamp_max: 55 : i32, clamp_min: -73 : i32} + %0 = "fxpmath.real_mul_ew"(%arg0, %arg1) { clamp_min:-4.0, clamp_max:4.0 } : (!type_lhs, !type_rhs) -> (!type_result) + return %0 : !type_result +} + +// ----- +// CHECK-LABEL: real_mulew_unquantized_lhs +// Verifies that leaves as-is for unquantized lhs. +!type_lhs = type tensor<4xf32> +!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> +!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> +func @real_mulew_unquantized_lhs(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result { + // CHECK: %0 = "fxpmath.real_mul_ew"(%arg0, %arg1) + %0 = "fxpmath.real_mul_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result) + return %0 : !type_result +} + +// ----- +// CHECK-LABEL: real_mulew_unquantized_rhs +// Verifies that leaves as-is for unquantized rhs. +!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> +!type_rhs = type tensor<4xf32> +!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> +func @real_mulew_unquantized_rhs(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result { + // CHECK: %0 = "fxpmath.real_mul_ew"(%arg0, %arg1) + %0 = "fxpmath.real_mul_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result) + return %0 : !type_result +} + +// ----- +// CHECK-LABEL: real_mulew_unquantized_result +// Verifies that leaves as-is for unquantized result. +!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> +!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> +!type_result = type tensor<4xf32> +func @real_mulew_unquantized_result(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result { + // CHECK: %0 = "fxpmath.real_mul_ew"(%arg0, %arg1) + %0 = "fxpmath.real_mul_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result) + return %0 : !type_result +} + +// ----- +// Verify lowering when operands and result have the same fixedpoint scale. +// Note that the multiplier = lhs_scale * rhs_scale / result_scale +// = 22.740610328638496 +// CHECK-LABEL: real_mulew_multiplier_gt_1 +!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">> +!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{3.875e-2}">> +!type_result = type tensor<4x!quant<"uniform[i8:f32]{1.065e-4}">> +func @real_mulew_multiplier_gt_1(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result { + // expected-warning@+1 {{unimplemented: cannot multiply with multipler > 1.0}} + %0 = "fxpmath.real_mul_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result) + return %0 : !type_result +} diff --git a/mlir/test/Quantization/canonicalize.mlir b/mlir/test/Quantization/canonicalize.mlir new file mode 100644 index 0000000..5cfd59a --- /dev/null +++ b/mlir/test/Quantization/canonicalize.mlir @@ -0,0 +1,24 @@ +// RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s --dump-input=fail + +// ----- +// CHECK-LABEL: redundant_scast +func @redundant_scast() -> tensor<4xi8> { + // CHECK-NEXT: constant splat, 10> + // CHECK-NEXT: return + %cst = constant splat, 5> : tensor<4xi8> + %1 = "quant.scast"(%cst) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[u8:f32]{7.812500e-03:128}">> + %2 = "quant.scast"(%1) : (tensor<4x!quant<"uniform[u8:f32]{7.812500e-03:128}">>) -> tensor<4xi8> + %3 = addi %2, %2 : tensor<4xi8> + return %3 : tensor<4xi8> +} + +// ----- +// CHECK-LABEL: non_redundant_scast +func @non_redundant_scast() -> tensor<4x!quant<"uniform[u8:f32]{7.812500e-03:128}">> { + // CHECK-NEXT: constant splat, 5> + // CHECK-NEXT: scast + // CHECK-NEXT: return + %cst = constant splat, 5> : tensor<4xi8> + %1 = "quant.scast"(%cst) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[u8:f32]{7.812500e-03:128}">> + return %1 : tensor<4x!quant<"uniform[u8:f32]{7.812500e-03:128}">> +} -- 2.7.4