--- /dev/null
+# MLIR Quantization
+
+This document outlines the design of the MLIR quantization system. While the
+term "quantization" is highly overloaded, in this case, it refers to a fairly
+narrow scope of techniques in use to enable conversion of floating-point
+computations to corresponding and plausible variants expressed in integer math
+for inference, as has historically been supported by low-bit depth inference
+engines such as TFLite, various accelerator hardware, and many DSPs.
+
+Much of this is inspired by the approach taken
+[in this paper](https://arxiv.org/abs/1712.05877) with many extensions and
+adaptations folded in. It specifically documents the positions that MLIR has
+taken on the topic, and is not a general reference.
+
+[TOC]
+
+## Uniform quantization
+
+The primary quantization mechanism supported by MLIR is a scheme which can
+express fixed point and affine transformations via uniformly spaced point on the
+Real number line.
+
+Further, the scheme can be applied:
+
+* *per-layer* : Applying to every value within the target type.
+* *per-axis* (also called *per-channel*) : Applying individually to each index
+ along a specific axis of a tensor type.
+
+### Fixed point values {#fixed-point}
+
+[Fixed point](https://en.wikipedia.org/wiki/Fixed-point_arithmetic) values are a
+[Real](https://en.wikipedia.org/wiki/Real_number) number divided by a *scale*.
+We will call the result of the divided Real the *scaled value*.
+
+$$ real\_value = scaled\_value * scale $$
+
+The scale can be interpreted as the distance, in Real units, between neighboring
+scaled values. For example, if the scale is $$ \pi $$, then fixed point values
+with this scale can only represent multiples of $$ \pi $$, and nothing in
+between. The maximum rounding error to convert an arbitrary Real to a fixed
+point value with a given $$ scale $$ is $$ \frac{scale}{2} $$. Continuing the
+previous example, when $$ scale = \pi $$, the maximum rounding error will be $$
+\frac{\pi}{2} $$.
+
+Multiplication can be performed on scaled values with different scales, using
+the same algorithm as multiplication of Real values (note that product scaled
+value has $$ scale_{product} = scale_{left \mbox{ } operand} * scale_{right
+\mbox{ } operand} $$). Addition can be performed on scaled values, as long as
+they have the same scale, using the same algorithm as addition of Real values.
+This makes it convenient to represent scaled values on a computer as signed
+integers, and perform arithmetic on those signed integers, because the results
+will be correct scaled values.
+
+### Affine values {#affine}
+
+Mathematically speaking, affine values are the result of
+[adding a Real-valued *zero point*, to a scaled value](https://en.wikipedia.org/wiki/Affine_transformation#Representation).
+Or equivalently, subtracting a zero point from an affine value results in a
+scaled value:
+
+$$ real\_value = scaled\_value * scale = (affine\_value - zero\_point) * scale $$
+
+Essentially, affine values are a shifting of the scaled values by some constant
+amount. Arithmetic (i.e., addition, subtraction, multiplication, division)
+cannot, in general, be directly performed on affine values; you must first
+[convert](#affine-to-fixed-point) them to the equivalent scaled values.
+
+As alluded to above, the motivation for using affine values is to more
+efficiently represent the Real values that will actually be encountered during
+computation. Frequently, the Real values that will be encountered are not
+symmetric around the Real zero. We also make the assumption that the Real zero
+is encountered during computation, and should thus be represented.
+
+In this case, it's inefficient to store scaled values represented by signed
+integers, as some of the signed integers will never be used. The bit patterns
+corresponding to those signed integers are going to waste.
+
+In order to exactly represent the Real zero with an integral-valued affine
+value, the zero point must be an integer between the minimum and maximum affine
+value (inclusive). For example, given an affine value represented by an 8 bit
+unsigned integer, we have: $$ 0 \leq zero\_point \leq 255$$. This is important,
+because in deep neural networks's convolution-like operations, we frequently
+need to zero-pad inputs and outputs, so zero must be exactly representable, or
+the result will be biased.
+
+### Relation
+
+Real values, fixed point values, and affine values relate through the following
+equation, which demonstrates how to convert one type of number to another:
+
+$$ real\_value = scaled\_value * scale = (affine\_value - zero\_point) * scale $$
+
+Note that computers generally store mathematical values using a finite number of
+bits. Thus, while the above conversions are exact, to store the result in a
+finite number of bits, we must, in general, round the result of the conversion
+(this applies to both cases: storing using floating point and storing using
+fixed point). Note that a full discussion of rounding behavior is outside the
+scope of this document, and it is safe to assume unless otherwise stated that
+rounding should be according to the IEEE754 default of RNE (where hardware
+permits).
+
+### Converting between Real and fixed point or affine {#converting-between}
+
+To convert a Real value to a fixed point value, you must know the scale. To
+convert a Real value to an affine value, you must know the scale and zero point.
+
+#### Real to affine
+
+To convert an input tensor of Real-valued elements (usually represented by a
+floating point format, frequently
+[Single precision](https://en.wikipedia.org/wiki/Single-precision_floating-point_format))
+to a tensor of affine elements represented by an integral type (e.g. 8-bit
+unsigned integer), the following conversion can be performed (note that it is
+not required that all representable values of the integral type are used):
+
+$$
+\begin{align*}
+af&fine\_value_{uint8 \, or \, uint16} \\
+ &= clampToTargetSize(roundToNearestInteger( \frac{real\_value_{Single}}{scale_{Single}})_{sint32} + zero\_point_{uint8 \, or \, uint16})
+\end{align*}
+$$
+
+In the above, we assume that $$real\_value$$ is a Single, $$scale$$ is a Single,
+$$roundToNearestInteger$$ returns a signed 32 bit integer, and $$zero\_point$$
+is an unsigned 8 or 16 bit integer. Note that bit depth and number of fixed
+point values is indicative of common types on typical hardware but is not
+constrained to particular bit depths or a requirement that the entire range of
+an N-bit integer is used.
+
+#### Affine to Real {#affine-to-real}
+
+To convert an output tensor of affine elements represented by uint8
+or uint16 to a tensor of Real-valued elements (usually represented with a
+floating point format, frequently Single precision), the following conversion
+can be performed:
+
+$$
+\begin{align*}
+re&al\_value_{Single} \\
+ &= roundToNearestFloat((affine\_value_{uint8 \, or \, uint16} - zero\_point_{uint8 \, or \, uint16})_{sint32})_{Single} * scale_{Single}
+\end{align*}
+$$
+
+In the above, we assume that the result of subtraction is in 32-bit signed
+integer format, and that $$roundToNearestFloat$$ returns a Single.
+
+#### Affine to fixed point {#affine-to-fixed-point}
+
+When the affine and fixed point scales are the same, subtract the zero point
+from the affine value to get the equivalent fixed point value.
+
+$$
+scaled\_value = affine\_value_{non\mbox{-}negative} - zero\_point_{non\mbox{-}negative}
+$$
+
+#### Fixed point to affine {#fixed-point-to-affine}
+
+When the affine and fixed point scales are the same, add the zero point to the
+fixed point value to get the equivalent affine value.
+
+$$
+affine\_value_{non\mbox{-}negative} = scaled\_value + zero\_point_{non\mbox{-}negative}
+$$
+
+## Usage within MLIR {#usage-within-mlir}
+
+There are several components to the quantization system within MLIR:
+
+* *Quantization* dialect containing:
+
+ * A family of [QuantizedTypes](#quantized-type) which represent the
+ mapping between *expressed* values (typically of a floating point
+ computer type) and *storage* values (typically of an integral computer
+ type).
+ * [Type conversion ops](#quantized-type-conversion-ops) for converting
+ between types based on a QuantizedType and its *expressed* and *storage*
+ sub-types.
+ * [Instrumentation ops](#instrumentation-ops) for assigning
+ instrumentation points within the computation where runtime statistics
+ may help guide the quantization process.
+
+* *QuantizedMath* dialect containing:
+
+ * [Real math ops](#real-math-ops) representing common combinations of
+ arithmetic operations that closely match corresponding fixed-point math
+ concepts (as opposed to being spread across multiple ops as is typical
+ in source dialects).
+ * [Fixed-point math ops](#fixed-point-math-ops) that for carrying out
+ computations on integers, as are typically needed by uniform
+ quantization schemes.
+ * Passes to lower from real math ops to fixed-point math ops.
+
+* [Solver tools](#solver-tools) which can generically operate on computations
+ expressed in the *QuantizedMath* dialect in order to convert from floating
+ point types to appropriate *QuantizedTypes*, allowing the computation to be
+ further lowered to integral math ops.
+
+Not every application of quantization will use all facilities. Specifically, the
+TensorFlow to TensorFlow Lite conversion uses the QuantizedTypes but has its own
+ops for type conversion and expression of the backing math.
+
+## Interactions with simulated quantization at training time {#training-time}
+
+TensorFlow has historically used the
+[tf.quantization.fake_quant_\*](https://www.tensorflow.org/api_docs/python/tf/quantization/fake_quant_with_min_max_args)
+family of operations to simulate the effect of quantization at training time.
+
+As originally implemented, TensorFlow Lite was the primary user of such
+operations at inference time. When quantized inference was enabled, if every
+eligible tensor passed through an appropriate fake_quant node (the rules of
+which tensors can have fake_quant applied are somewhat involved), then
+TensorFlow Lite would use the attributes of the fake_quant ops to make a
+judgment about how to convert to use kernels from its quantized ops subset.
+
+In MLIR-based quantization, fake_quant_\* ops are handled by converting them to
+a sequence of *qcast* (quantize) followed by *dcast* (dequantize) with an
+appropriate *UniformQuantizedType* as the target of the qbarrier operation.
+
+This allows subsequent compiler passes to preserve the knowledge that
+quantization was simulated in a certain way while giving the compiler
+flexibility to move the barriers as it simplifies the computation and converts
+it to a form based on integral arithmetic.
+
+This scheme also naturally allows computations that are *partially quantized*
+where the parts which could not be reduced to integral ops are still carried out
+in floating point with appropriate conversions at the boundaries.
+
+## Quantization Dialect
+
+### Quantized type {#quantized-type}
+
+TODO : Flesh this section out.
+
+* QuantizedType base class
+* UniformQuantizedType
+
+### Quantized type conversion ops {#quantized-type-conversion-ops}
+
+* qcast : Convert from an expressed type to QuantizedType
+* dcast : Convert from a QuantizedType to its expressed type
+* scast : Convert between a QuantizedType and its storage type
+
+### Instrumentation and constraint ops {#instrumentation-ops}
+
+TODO : These ops are not defined yet
+
+* instrument_stats : Assigns a unique id and signals that statistics should be
+ collected by the runtime when execution passes through this op.
+* constrain_uniform : Constrains that for uniform quantization, the solver
+ should choose a type with certain characteristics such as the number of
+ fixed-point values, underlying storage type, or whether to constrain to
+ power of two scales.
+
+## QuantizedMath Dialect
+
+### Real math ops {#real-math-ops}
+
+Note that these all support explicit clamps, which allows for simple fusions and
+representation of some common sequences quantization-compatible math. Of
+addition, some support explicit biases, which are often represented as separate
+adds in source dialects.
+
+TODO: This op set is still evolving and needs to be completed.
+
+* RealBinaryOp
+ * RealAddEwOp
+ * RealSubEwOp
+ * RealMulEwOp
+ * RealDivEwOp
+* RealUnaryOp
+ * IDENTITY
+ * TANH
+ * SIGMOID
+ * EXP
+ * LOG
+ * NEG
+ * RSQRT
+ * SIN
+ * SQUARE
+ * SQRT
+ * CMPZ
+ * CMPNZ
+ * CMPLZ
+ * CMPGZ
+
+### Fixed-point math ops {#fixed-point-math-ops}
+
+TODO: This op set only has enough ops to lower a simple power-of-two
+RealAddEwOp.
+
+* RoundingDivideByPotFxpOp
+* SaturatingAddFxpOp
+
+## Solver tools {#solver-tools}
+
+Solver tools exist to analyze an MLIR-computation, expressed in either a
+supported source dialect or in the *real math ops* set and solve for appropriate
+QuantizedTypes that allow the computation to be lowered to integral math.
+
+These tools are an active area of work and may be expanded in the future to
+adjacent areas such as solving for transformations to other kinds of lower
+precision types (i.e. bfloat16 or fp16).
+
+Solver tools are expected to operate in several modes, depending on the
+computation and the manner in which it was trained:
+
+* *Transform* : With all available information in the MLIR computation, infer
+ boundaries where the computation can be carried out with integral math and
+ change types accordingly to appropriate QuantizedTypes:
+
+ * For passthrough ops which do not perform active math, change them to
+ operate directly on the storage type, converting in and out at the edges
+ via scast ops.
+ * For ops that have the *Quantizable* trait, the type can be set directly.
+ This includes ops from the [real math ops set]{#real-math-ops}.
+ * For others, encase them in appropriate dcast/qcast ops, presuming that
+ some follow-on pass will know what to do with them.
+
+* *Instrument* : Most of the time, there are not sufficient implied
+ constraints within a computation to perform many transformations. For this
+ reason, the solver can insert instrumentation ops at points where additional
+ runtime statistics may yield solutions. It is expected that such
+ computations will be lowered as-is for execution, run over an appropriate
+ eval set, and statistics at each instrumentation point made available for a
+ future invocation of the solver.
+
+* *Simplify* : A variety of passes and simplifications are applied once
+ QuantizedTypes are added in order to arrive at a computation that is
+ expressed in as much integral math, with the fewest number of casts as
+ possible.
--- /dev/null
+//===- FakeQuantSupport.h - Support utilities for FakeQuant ops -*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines support utilities for interoperating with FakeQuant* based
+// QAT (Quantized Aware Training) computations, as implemented by TFLite. Note
+// that FakeQuant* operators mix multiple concerns specific to how TFLite
+// originally implemented quantization. As such, utilities here enforce
+// opinions taken by that codebase (vs providing any amount of genericity).
+//
+// Specifically, it combines the following concerns, each of which would be
+// independent variables in a more generic setup:
+// - num_bits implies storage data type (quint8, int16)
+// - num_bits < 8 is promoted to quint8
+// - "narrow_range" narrows the lower bound of the storage type's range by
+// 1
+// - the specified min/max values are "nudged" so that the result has a zero
+// that can be exactly expressed
+// - min=max=0 implies scale=0 and zero_point=0
+//
+// With the above assumptions applied, every conforming specified FakeQuant op
+// can be represented by a UniformQuantizedType. This scheme is not expected to
+// be generalized further in the future and should be considered to be a
+// legacy set of rules.
+//
+// As canonically used in TensorFlow graphs, the presence of a FakeQuant node
+// is a hint that the specific math represented here has been simulated at
+// training time. As such, it is usually not advised to arbitrarily change
+// quantization parameters derived from FakeQuant.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_QUANTIZATION_FAKEQUANTSUPPORT_H_
+#define MLIR_QUANTIZATION_FAKEQUANTSUPPORT_H_
+
+#include "mlir/Quantization/QuantOps.h"
+
+namespace mlir {
+namespace quant {
+
+/// Converts per-layer FakeQuant attributes to the corresponding type.
+/// In the event that the parameters cannot be converted, returns a nullptr
+/// convertible Type and issues an appropriate error.
+/// Note that there are multiple variants of a per-layer FakeQuant op, so
+/// this function takes the attributes discretely vs taking a reference to the
+/// originating op.
+UniformQuantizedType fakeQuantAttrsToType(Location loc, unsigned numBits,
+ double rmin, double rmax,
+ bool narrowRange, Type expressedType);
+
+} // namespace quant
+} // namespace mlir
+
+#endif // MLIR_QUANTIZATION_FAKEQUANTSUPPORT_H_
--- /dev/null
+//===- Passes.h - Quantization Passes ------ --------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines all of the passes owned by the quantization dialect. As
+// things mature, it is expected that passes specific to certain frontend or
+// backend dialects will move to those dialects directly. For now, they are
+// incubated here.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_QUANTIZATION_PASSES_H
+#define MLIR_QUANTIZATION_PASSES_H
+
+namespace mlir {
+class FunctionPassBase;
+
+namespace quant {
+
+/// Creates a pass that lowers quantization related TensorFlow ops into
+/// the quantization dialect so that express and implied constraints expressed
+/// at the TensorFlow source level can be represented to the quantization
+/// system. This will specially handle any TensorFlow op that is useful for
+/// guiding quantization.
+///
+/// Note that if your intent is to compile a TensorFlow graph for floating
+/// point inference, you should probably not use this pass.
+FunctionPassBase *createLowerTFPass();
+
+/// Creates a pass that converts constants followed by a qbarrier to a
+/// constant whose value is quantized. This is typically one of the last
+/// passes done when lowering to express actual quantized arithmetic in a
+/// low level representation. Because it modifies the constant, it is
+/// destructive and cannot be undone.
+FunctionPassBase *createConvertConstPass();
+
+/// Creates a pass that lowers uniform-quantized real math ops to integer
+/// arithmetic. This will leave unrecognized real math ops as-is and is
+/// typically followed by a pass that lowers any unrecognized ops to a pure
+/// floating point form.
+FunctionPassBase *createLowerUniformRealMathPass();
+
+} // namespace quant
+} // namespace mlir
+
+#endif // MLIR_QUANTIZATION_PASSES_H
--- /dev/null
+//===- Quantization/QuantOps.h - Quantization Ops and Types -----*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_QUANTIZATION_QUANTOPS_H_
+#define MLIR_QUANTIZATION_QUANTOPS_H_
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/Types.h"
+#include "llvm/Support/MathExtras.h"
+
+namespace mlir {
+namespace quant {
+
+class QuantizedIntegerType;
+
+namespace detail {
+
+struct QuantizedTypeStorage;
+struct UniformQuantizedTypeStorage;
+struct UniformQuantizedPerAxisTypeStorage;
+
+} // namespace detail
+
+namespace QuantizationTypes {
+enum Kind {
+ UniformQuantized = Type::FIRST_QUANTIZATION_TYPE,
+ UniformQuantizedPerAxis,
+ LAST_USED_QUANTIZATION_TYPE = UniformQuantizedPerAxis,
+};
+} // namespace QuantizationTypes
+
+/// Enumeration of bit-mapped flags related to quantized types.
+namespace QuantizationFlags {
+enum FlagValue {
+ // Indicates that the storage type should be interpreted as a signed
+ // integer. The default is to interpret it as an unsigned value.
+ Signed = 1,
+};
+} // namespace QuantizationFlags
+
+/// Base class for all quantized types known to this dialect.
+/// All quantized types have:
+/// - storageType: The (narrower) numeric type that is being used to
+/// approximate some expressed type.
+/// - expressedType: The type that is being approximated.
+///
+/// The base class provides generic support for manipulating the types based
+/// on these fields.
+class QuantizedType : public Type {
+public:
+ using ImplType = detail::QuantizedTypeStorage;
+ using Type::Type;
+
+ /// The maximum number of bits supported for storage types.
+ static constexpr unsigned MaxStorageBits = 32;
+
+ static LogicalResult
+ verifyConstructionInvariants(llvm::Optional<Location> loc,
+ MLIRContext *context, unsigned flags,
+ Type storageType, Type expressedType,
+ int64_t storageTypeMin, int64_t storageTypeMax);
+
+ /// Support method to enable LLVM-style type casting.
+ static bool kindof(unsigned kind) {
+ return kind == QuantizationTypes::UniformQuantized;
+ }
+
+ /// Gets the minimum possible stored by a storageType. storageTypeMin must
+ /// be greater than or equal to this value.
+ static int64_t getDefaultMininumForInteger(bool isSigned,
+ unsigned integralWidth) {
+ if (isSigned) {
+ return llvm::minIntN(integralWidth);
+ }
+ return 0;
+ }
+
+ /// Gets the maximum possible stored by a storageType. storageTypeMax must
+ /// be less than or equal to this value.
+ static int64_t getDefaultMaxinumForInteger(bool isSigned,
+ unsigned integralWidth) {
+ if (isSigned) {
+ return llvm::maxIntN(integralWidth);
+ }
+ return llvm::maxUIntN(integralWidth);
+ }
+
+ /// Gets the original expressed type that this quantized type approximates.
+ /// Note that this presumes that the quantized type was always derived from
+ /// a floating point type, which in the broadest definition, is not true (i.e.
+ /// it could be some form of integral, fixed type or affine type in its own
+ /// right); however, at the high level, no examples of such usage are
+ /// presently known and the restriction serves some useful purposes (such as
+ /// always being able to reverse a transformation or measure error). In most
+ /// cases, this will be f32.
+ Type getExpressedType() const;
+
+ /// Gets the flags associated with this type. Typically a more specific
+ /// accessor is appropriate.
+ unsigned getFlags() const;
+
+ // Convenience helpers.
+ /// Whether the storage type should be interpreted as a signed quantity
+ /// (true) or an unsigned value (false).
+ bool isSigned() const {
+ return (getFlags() & QuantizationFlags::Signed) ==
+ QuantizationFlags::Signed;
+ }
+
+ /// Gets the underlying type used for to store values. Note that this may
+ /// be signed or unsigned. Use the isSigned() accessor to differentiate.
+ Type getStorageType() const;
+
+ /// The minimum value that storageType can take.
+ int64_t getStorageTypeMin() const;
+
+ /// The maximum value that storageType can take.
+ int64_t getStorageTypeMax() const;
+
+ /// Gets the integral bit width that the underlying storage type can exactly
+ /// represent. For integral storage types, this will just be their width.
+ unsigned getStorageTypeIntegralWidth() const;
+
+ /// Returns whether the candidateExpressedType is a match for this
+ /// QuantizedType. This will be true if the candidate type is either a
+ /// primitive type or a container type whose element type equals this
+ /// QuantizedType's expressed type.
+ /// Examples of compatible candidateExpressedType:
+ /// !quant<"uniform[i8:f32]{1.0}"> =~ f32
+ /// !quant<"uniform[i8:f32]{1.0}"> =~ tensor<4xf32>
+ bool isCompatibleExpressedType(Type candidateExpressedType);
+
+ /// Returns the element type as a QuantizedType or nullptr if it is not
+ /// a quantized type. If the type is primitive, returns that. If it is a
+ /// container (vector/tensor), return the element type.
+ /// Examples:
+ /// !quant<"uniform[i8:f32]{1.0}"> -> !quant<"uniform[i8:f32]{1.0}">
+ /// tensor<4x!quant<"uniform[i8:f32]{1.0}"> -> quant<"uniform[i8:f32]{1.0}">
+ static QuantizedType getQuantizedElementType(Type primitiveOrContainerType);
+
+ /// Casts from a type based on the storageType to a corresponding type based
+ /// on this type (returns nullptr if the cast is not valid).
+ /// Examples:
+ /// i8 -> !quant<"uniform[i8:f32]{1.0}">
+ /// tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
+ /// vector<4xi8> -> vector<4x!quant<"uniform[i8:f32]{1.0}">>
+ Type castFromStorageType(Type candidateType);
+
+ /// Casts from a type based on a QuantizedType to a corresponding type based
+ /// on the storageType (returns nullptr if the cast is not valid).
+ /// This is the inverse of castFromStorageType().
+ static Type castToStorageType(Type quantizedType);
+
+ /// Casts from a type based on the expressedType to a corresponding type based
+ /// on this type (returns nullptr if the cast is not valid).
+ /// Examples:
+ /// f32 -> !quant<"uniform[i8:f32]{1.0}">
+ /// tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
+ /// vector<4xf32> -> vector<4x!quant<"uniform[i8:f32]{1.0}">>
+ Type castFromExpressedType(Type candidateType);
+
+ /// Casts from a type based on QuantizedType to a corresponding type based
+ /// on the expressedType (returns nullptr if the cast is not valid).
+ /// This is the inverse of castFromExpressedType.
+ static Type castToExpressedType(Type quantizedType);
+
+ /// Casts from a type based on the expressedType to the equivalent type
+ /// based on storageType by way of this QuantizedType. Equivalent to:
+ /// QuantizedType::castToStorageType(castFromExpressedType(candidateType))
+ /// (but with validity checks).
+ /// Example (for this = !quant<"uniform[i8:f32]{1.0}">):
+ /// tensor<4xf32> -> tensor<4xi8>
+ Type castExpressedToStorageType(Type candidateType);
+};
+
+/// Represents a family of uniform, quantized types.
+///
+/// Each instance of this type expresses a mapping between real values (most
+/// often expressed in floating point f32) and quantized values (either fixed
+/// point or affine).
+///
+/// The relationship is:
+/// real_value = scale * (quantized_value - zero_point)
+///
+/// It is used as part of high level graph transformations that have the goal
+/// of re-expressing parts of a computation in terms of this common form for
+/// more efficient execution at runtime. In addition, it is designed to be
+/// expressive enough to facilitate lowering to precise types and operations
+/// in target hardware.
+///
+/// As a high-level type, focused on intermediate passes, this type holds
+/// opinions consistent with high-level usage. If lowering math kernels below
+/// the high level arithmetic ops (i.e. to LLVM IR or hardware specific
+/// instruction sets), it is expected that the information expressed here
+/// will be used to drive low level codegen and target specific type selection,
+/// but this type will likely be erased in the process.
+///
+/// Syntax synopsis:
+/// Per-layer, all parameters expressed:
+/// !quant<uniform[StorageType:ExpressedType]{Scale:ZeroPoint}>
+/// Per-layer, optional parameters omitted:
+/// !quant<uniform[StorageType]{Scale}>
+///
+/// StorageType: 'i'|'u' NumBits
+/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
+/// Scale: A legal double value
+/// ZeroPoint: An integer value
+class UniformQuantizedType
+ : public Type::TypeBase<UniformQuantizedType, QuantizedType,
+ detail::UniformQuantizedTypeStorage> {
+public:
+ using Base::Base;
+
+ /// Gets an instance of the type with all parameters specified but not
+ /// checked.
+ static UniformQuantizedType get(unsigned flags, Type storageType,
+ Type expressedType, double scale,
+ int64_t zeroPoint, int64_t storageTypeMin,
+ int64_t storageTypeMax);
+
+ /// Gets an instance of the type with all specified parameters checked.
+ /// Returns a nullptr convertible type on failure.
+ static UniformQuantizedType
+ getChecked(unsigned flags, Type storageType, Type expressedType, double scale,
+ int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax,
+ Location location);
+
+ /// Verifies construction invariants and issues errors/warnings.
+ static LogicalResult verifyConstructionInvariants(
+ llvm::Optional<Location> loc, MLIRContext *context, unsigned flags,
+ Type storageType, Type expressedType, double scale, int64_t zeroPoint,
+ int64_t storageTypeMin, int64_t storageTypeMax);
+
+ /// Support method to enable LLVM-style type casting.
+ static bool kindof(unsigned kind) {
+ return kind == QuantizationTypes::UniformQuantized;
+ }
+
+ /// Gets the scale term. The scale designates the difference between the real
+ /// values corresponding to consecutive quantized values differing by 1.
+ double getScale() const;
+
+ /// Gets the storage value corresponding to the real value 0 in the affine
+ /// equation.
+ int64_t getZeroPoint() const;
+
+ // Fixed point values are real numbers divided by a scale.
+ // Currently, only signed storage types are treated as fixed point.
+ // A fixed point value can be obtained from an affine value by subtracting
+ // the zeroPoint.
+ // In the future, this may be explicit versus implied by type and zeroPoint.
+ bool isFixedPoint() const { return isSigned() && getZeroPoint() == 0; }
+};
+
+/// Represents per-axis (also known as per-channel quantization).
+///
+/// Syntax synopsis:
+/// Per-axis, all parameters expressed:
+/// !quant<uniform[StorageType:ExpressedType:QuantizedDim]{QuantParams}>
+/// Per-axis, optional parameters omitted:
+/// !quant<uniform[StorageType]{Scale}>
+///
+/// StorageType: 'i'|'u' NumBits
+/// ExpressedType: 'f16', 'f32', 'bf16', 'f64'
+/// QuantizedDim: An integer value
+/// QuantParams: (Scale ':' ZeroPoint)+
+/// Scale: A legal double value
+/// ZeroPoint: An integer value
+class UniformQuantizedPerAxisType
+ : public Type::TypeBase<UniformQuantizedPerAxisType, QuantizedType,
+ detail::UniformQuantizedPerAxisTypeStorage> {
+public:
+ using Base::Base;
+
+ /// Gets an instance of the type with all parameters specified but not
+ /// checked.
+ static UniformQuantizedPerAxisType
+ get(unsigned flags, Type storageType, Type expressedType,
+ ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
+ int32_t quantizedDimension, int64_t storageTypeMin,
+ int64_t storageTypeMax);
+
+ /// Gets an instance of the type with all specified parameters checked.
+ /// Returns a nullptr convertible type on failure.
+ static UniformQuantizedPerAxisType
+ getChecked(unsigned flags, Type storageType, Type expressedType,
+ ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
+ int32_t quantizedDimension, int64_t storageTypeMin,
+ int64_t storageTypeMax, Location location);
+
+ /// Verifies construction invariants and issues errors/warnings.
+ static LogicalResult verifyConstructionInvariants(
+ llvm::Optional<Location> loc, MLIRContext *context, unsigned flags,
+ Type storageType, Type expressedType, ArrayRef<double> scales,
+ ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
+ int64_t storageTypeMin, int64_t storageTypeMax);
+
+ /// Support method to enable LLVM-style type casting.
+ static bool kindof(unsigned kind) {
+ return kind == QuantizationTypes::UniformQuantizedPerAxis;
+ }
+
+ /// Gets the quantization scales. The scales designate the difference between
+ /// the real values corresponding to consecutive quantized values differing
+ /// by 1. The ith scale corresponds to the ith slice in the
+ /// quantized_dimension.
+ ArrayRef<double> getScales() const;
+
+ /// Gets the storage values corresponding to the real value 0 in the affine
+ /// equation. The ith zero point corresponds to the ith slice in the
+ /// quantized_dimension.
+ ArrayRef<int64_t> getZeroPoints() const;
+
+ /// Specifies the dimension of the Tensor's shape that the scales and
+ /// zero_points correspond to. For example, a tensor t, with dims=[4, 3, 2, 1]
+ /// with quantization params:
+ /// scales=[1.0, 2.0, 3.0], zeroPoints=[1, 2, 3], quantizedDimension=1
+ /// will be quantized across the second dimension of t.
+ /// t[:, 0, :, :] will have scale[0]=1.0, zero_point[0]=1
+ /// t[:, 1, :, :] will have scale[1]=2.0, zero_point[0]=2
+ /// t[:, 2, :, :] will have scale[2]=3.0, zero_point[0]=3
+ int32_t getQuantizedDimension() const;
+
+ /// Fixed point values are real numbers divided by a scale.
+ /// Currently, only signed storage types are treated as fixed point.
+ /// A fixed point value can be obtained from an affine value by subtracting
+ /// the zeroPoint.
+ /// In the future, this may be explicit versus implied by type and zeroPoint.
+ bool isFixedPoint() const {
+ if (!isSigned())
+ return false;
+ return llvm::all_of(getZeroPoints(),
+ [](int64_t zeroPoint) { return zeroPoint != 0; });
+ }
+};
+
+/// Defines the 'Quantization' dialect
+class QuantizationDialect : public Dialect {
+public:
+ QuantizationDialect(MLIRContext *context);
+
+ /// Parse a type registered to this dialect.
+ Type parseType(StringRef spec, Location loc) const override;
+
+ /// Print a type registered to this dialect.
+ void printType(Type type, raw_ostream &os) const override;
+};
+
+#define GET_OP_CLASSES
+#include "mlir/Quantization/QuantOps.h.inc"
+
+} // namespace quant
+} // namespace mlir
+
+#endif // MLIR_QUANTIZATION_QUANTOPS_H_
--- /dev/null
+//===- QuantOps.td - Quantization operation definition -----*- tablegen -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This is the operation definition file for Quantization.
+//
+//===----------------------------------------------------------------------===//
+
+#ifdef QUANTIZATION_OPS
+#else
+
+#ifdef OP_BASE
+#else
+include "mlir/IR/OpBase.td"
+#endif // OP_BASE
+
+//===----------------------------------------------------------------------===//
+// Quantization type definitions
+//===----------------------------------------------------------------------===//
+
+class quant_TypedPrimitiveOrContainer<Type etype> :
+ Type<AnyOf<[etype.predicate,
+ TypedTensor<etype>.predicate,
+ TypedVector<etype>.predicate]>,
+ "primitive/tensor/vector of " # etype.description>;
+
+// An implementation of QuantizedType.
+def quant_QuantizedType :
+ Type<CPred<"{0}.isa<QuantizedType>()">, "QuantizedType">;
+
+// A primitive type that can represent a real value. This is either a
+// floating point value or a quantized type.
+def quant_RealPrimitiveType :
+ Type<AnyOf<[Float.predicate, quant_QuantizedType.predicate]>,
+ "real valued primitive (float or quantized type)">;
+
+// A primitive type that can represent a storage value. This is either an
+// integer or quantized type.
+def quant_StoragePrimitiveType :
+ Type<AnyOf<[Integer.predicate, quant_QuantizedType.predicate]>,
+ "quantized storage primitive (integer or quantized type)">;
+
+// A primitive or container of RealPrimitiveType.
+def quant_RealValueType :
+ quant_TypedPrimitiveOrContainer<quant_RealPrimitiveType>;
+
+// A primitive or container of StoragePrimitiveType.
+def quant_StorageValueType :
+ quant_TypedPrimitiveOrContainer<quant_StoragePrimitiveType>;
+
+// Either a real valued or storage primitive or container type.
+def quant_RealOrStorageValueType :
+ Type<AnyOf<[quant_RealValueType.predicate,
+ quant_StorageValueType.predicate]>>;
+
+// An implementation of UniformQuantizedType.
+def quant_UniformQuantizedType :
+ Type<CPred<"{0}.isa<UniformQuantizedType>()">, "UniformQuantizedType">;
+
+// Predicate for detecting a container or primitive of UniformQuantizedType.
+def quant_UniformQuantizedValueType :
+ quant_TypedPrimitiveOrContainer<quant_UniformQuantizedType>;
+
+//===----------------------------------------------------------------------===//
+// Attributes
+//===----------------------------------------------------------------------===//
+
+// Real value for an (inclusive) min/max clamp limit.
+def quant_ClampValueAttr : OptionalAttr<F64Attr>;
+
+// Element-wise activation function to apply.
+// Note that RELU activations are not here: they are expressed as clamps.
+def quant_EwUnaryFnAttr :
+ StringBasedAttr<CPred<"true">, "element-wise unary function"> {
+ let returnType = [{ StringRef }];
+ let defaultValue = "IDENTITY";
+}
+
+class quant_ConstEwUnaryFn<string val> : ConstantAttr<quant_EwUnaryFnAttr, val>;
+def quant_EwUnaryFn_Identity: quant_ConstEwUnaryFn<"IDENTITY">;
+def quant_EwUnaryFn_Tanh : quant_ConstEwUnaryFn<"TANH">;
+def quant_EwUnaryFn_Sigmoid : quant_ConstEwUnaryFn<"SIGMOID">;
+def quant_EwUnaryFn_Exp : quant_ConstEwUnaryFn<"EXP">;
+def quant_EwUnaryFn_Log : quant_ConstEwUnaryFn<"LOG">;
+def quant_EwUnaryFn_Neg : quant_ConstEwUnaryFn<"NEG">;
+def quant_EwUnaryFn_Rsqrt : quant_ConstEwUnaryFn<"RSQRT">;
+def quant_EwUnaryFn_Sin : quant_ConstEwUnaryFn<"SIN">;
+def quant_EwUnaryFn_Square : quant_ConstEwUnaryFn<"SQUARE">;
+def quant_EwUnaryFn_Sqrt : quant_ConstEwUnaryFn<"SQRT">;
+def quant_EwUnaryFn_CmpZ : quant_ConstEwUnaryFn<"CMPZ">;
+def quant_EwUnaryFn_CmpNZ : quant_ConstEwUnaryFn<"CMPNZ">;
+def quant_EwUnaryFn_CmpLZ : quant_ConstEwUnaryFn<"CMPLZ">;
+def quant_EwUnaryFn_CmpGZ : quant_ConstEwUnaryFn<"CMPGZ">;
+
+//===----------------------------------------------------------------------===//
+// Base classes
+//===----------------------------------------------------------------------===//
+
+class quant_Op<string mnemonic, list<OpTrait> traits> :
+ Op<!strconcat("quant.", mnemonic), traits>;
+
+//===----------------------------------------------------------------------===//
+// Quantization barriers
+//===----------------------------------------------------------------------===//
+class quant_BarrierOp<string mnemonic, list<OpTrait> traits> :
+ quant_Op<mnemonic, traits>, Arguments<(ins quant_RealValueType:$arg)>,
+ Results<(outs quant_RealValueType)>;
+
+// A QuantizeBarrier (qbarrier) represents a potential type shift from a
+// quantizable type to a quantized type.
+//
+// At runtime, a qbarrier will apply the transformation expressed by its
+// operand and result type. For flexibility during transformation, it is also
+// possible to have a qbarrier that performs no transformation (both its
+// operand and result type are quantizable).
+//
+// A qbarrier will typically originate from either:
+// a) An expressed or implied constraint in the source dialect which signals
+// that a certain level of quantization is possible or required.
+// b) An inference made by a quantization algorithm indicating that a
+// quantized representation may be acceptable.
+//
+// Especially early in transformation, it is common to have pairs of
+// qbarrier/dbarrier at points where a transition to a quantized type is
+// required. In addition, it is also common to have an identity qbarrier
+// (where the operand and result type are not quantized) at all points where
+// it is legal to use a quantized representation (but is not known to be
+// acceptable).
+def quant_QuantizeBarrierOp : quant_BarrierOp<"qbarrier", [NoSideEffect]>;
+
+// A DequantizeBarrier (dbarrier) represents the inverse of a qbarrier,
+// converting back from a quantized to quantizable (expressed) type.
+//
+// Like qbarriers, a dbarrier is allowed to have both its operand and result
+// as non quantized types. This facilitates transformations and marks edges
+// where the computation must be carried out in the expressed type.
+//
+// Especially early in transformation, it is common to have dbarriers on
+// all operands to ops that must operate with the expressed type (typically
+// math ops prior to lowering to target-specific, quantized kernels).
+def quant_DequantizeBarrierOp : quant_BarrierOp<"dbarrier", [NoSideEffect]>;
+
+// A StorageCast (scast) represents a cast from or to a type based on the
+// storage type and a type based on a corresponding quantized type.
+//
+// This op exists to ensure type coherency for between parts of the computation
+// which are operating directly on an underlying storage type and those which
+// operate on quantized values.
+//
+// Examples from storage to quantized type:
+// i8 -> !quant<"uniform[i8:f32]{1.0}">
+// tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
+// vector<4xi8> -> vector<4x!quant<"uniform[i8:f32]{1.0}">>
+def quant_StorageCastOp :
+ quant_Op<"scast", [NoSideEffect]>,
+ Arguments<(ins quant_RealOrStorageValueType:$arg)>,
+ Results<(outs quant_RealOrStorageValueType)>;
+
+//===----------------------------------------------------------------------===//
+// Integral arithmetic ops used by kernels.
+//===----------------------------------------------------------------------===//
+
+def quant_RoundingDivideByPotIOp :
+ quant_Op<"rounding_divide_by_poti", [NoSideEffect, SameValueType]>,
+ Arguments<(ins quant_StorageValueType:$x, I32Attr:$exponent)>,
+ Results<(outs quant_StorageValueType:$y)> {
+ let description = [{
+ Computes integer division by a power-of-two, correctly rounded-to-nearest.
+ Also known as a rounding arithmetic right shift. See
+ gemmlowp::RoundingDivideByPOT for a reference implementation.
+ }];
+
+ let verifier = [{
+ auto verifyExponent = exponent().getSExtValue();
+ if (verifyExponent < 0 || verifyExponent > 31) {
+ return emitOpError("exponent must be in range [0..31]");
+ }
+ return success();
+ }];
+}
+
+def quant_SaturatingAddIOp :
+ quant_Op<"saturating_addi", [NoSideEffect, SameValueType]>,
+ Arguments<(ins quant_StorageValueType:$x,
+ quant_StorageValueType:$y,
+ I32Attr:$clamp_min,
+ I32Attr:$clamp_max)>,
+ Results<(outs quant_StorageValueType:$sum)> {
+ let description = [{
+ Computes saturating addition of two operands, saturating to the given min
+ and max value. The implementation is responsible for choosing an
+ intermediate register size appropriate to carry out the operation without
+ overflow. See gemmlowp::SaturatingAdd for a reference implementation.
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// Real math ops.
+//
+// Math ops on real numbers which may have a representation in quantized
+// arithmetic. It is expected that eligible ops are lowered from a source
+// dialect to this set of ops prior to the process of converting a compuation
+// to a quantized form. It is a non-goal of these ops to preserve enough
+// information to convert back to the higher level, source dialect.
+//
+// These ops support either real/floating point or QuantizedTypes as operands
+// and results. Since not all transformations are supported (globally or
+// sometimes for specific targets), a computation may end up with
+// untransformable RealMathOps, in which case they need to be lowered as is
+// (using floating point math).
+//
+// This op set takes advantage of the fact that it is typically trivial to
+// combine a math function with a compatible bias addition and real-valued
+// clamp (which can be done at a higher accumulation bit depth).
+//
+// In addition, all element-wise unary functions are collapsed into a single
+// quant_RealUnaryEwOp and selected via an enum-like attribute. Especially at
+// low bit depths, this makes matching simpler and allows the construction of
+// generic LUT-based implementations. It also allows specific lowering rules
+// to consolidate runs of chained unary ops and fuse them to preceding math
+// ops, potentially allowing them to operate directly on higher precision
+// intermediates without resorting to lots of custom kernels for common
+// formulas that can suffer from insufficient precision at low bit depths.
+//
+// Comparison operators are modeled as element-wise unary functions (i.e.
+// CMPZ, CMPNZ, CMPLZ, CMPGZ) intended to follow a sub and output a 1bit
+// quantized value. It is expected that lowering rules can fuse them with
+// the preceding sub.
+//===----------------------------------------------------------------------===//
+
+class quant_RealMathOp<string mnemonic, list<OpTrait> traits = [], dag args> :
+ quant_Op<mnemonic, traits>,
+ Arguments<!con(args, (ins
+ quant_ClampValueAttr:$clamp_min, quant_ClampValueAttr:$clamp_max))>;
+
+//===----------------------------------------------------------------------===//
+// Element wise binary real math ops.
+//===----------------------------------------------------------------------===//
+
+class quant_RealBinaryOp<string mnemonic, list<OpTrait> traits = []> :
+ quant_RealMathOp<mnemonic, traits,
+ (ins quant_RealValueType:$x, quant_RealValueType:$y)>,
+ Results<(outs quant_RealValueType:$r)>;
+
+class quant_RealBinaryBiasOp<string mnemonic, list<OpTrait> traits = []> :
+ quant_RealMathOp<mnemonic, traits,
+ (ins quant_RealValueType:$x, quant_RealValueType:$y,
+ quant_RealValueType:$bias)>,
+ Results<(outs quant_RealValueType:$r)>;
+
+def quant_RealAddEwOp :
+ quant_RealBinaryOp<"real_add_ew", [NoSideEffect]>;
+
+def quant_RealSubEwOp :
+ quant_RealBinaryOp<"real_sub_ew", [NoSideEffect]>;
+
+def quant_RealMulEwOp :
+ quant_RealBinaryOp<"real_mul_ew", [NoSideEffect]>;
+
+def quant_RealDivEwOp :
+ quant_RealBinaryOp<"real_div_ew", [NoSideEffect]>;
+
+//===----------------------------------------------------------------------===//
+// Element wise unary real math op.
+//===----------------------------------------------------------------------===//
+
+def quant_RealUnaryEwOp :
+ quant_RealMathOp<"real_unary_ew", [NoSideEffect],
+ (ins quant_RealValueType:$x, quant_EwUnaryFnAttr:$fn)>,
+ Results<(outs quant_RealValueType:$r)>;
+
+#endif // QUANTIZATION_OPS
--- /dev/null
+//===- QuantizeUtils.h - Support utilities for quantization -----*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_QUANTIZATION_QUANTIZEUTILS_H_
+#define MLIR_QUANTIZATION_QUANTIZEUTILS_H_
+
+namespace mlir {
+class Attribute;
+class Type;
+
+namespace quant {
+class QuantizedType;
+class UniformQuantizedType;
+class UniformQuantizedValueConverter;
+
+/// Converts an attribute from a type based on
+/// quantizedElementType.getExpressedType() to one based on
+/// quantizedElementType.getStorageType(), where quantizedElementType is as from
+/// QuantizedType::getQuantizedElementType().
+/// Returns nullptr if the conversion is not supported. On success, stores the
+/// converted type in outConvertedType.
+///
+/// Examples:
+/// 1. realValue is a primitive value attribute:
+/// (realValue: FloatAttr, quantizedElementType: UniformQuantizedType[i8:f32])
+/// -> (IntegerAttr, outConvertedType: i8)
+/// 2. realValue is an elements attribute:
+/// (realValue: DenseElementsAttr[tensor<2x2xf32>],
+/// quantizedElementType: UniformQuantizedType[i8:f32])
+/// -> (DenseElementsAttr[tensor<2x2xi8>], outConvertedType: tensor<2x2xi8>)
+Attribute quantizeAttr(Attribute realValue, QuantizedType quantizedElementType,
+ Type &outConvertedType);
+
+/// Converts an attribute from a type based on
+/// quantizedElementType.getExpressedType() to one based on
+/// quantizedElementType.getStorageType(), where quantizedElementType is as from
+/// QuantizedType::getQuantizedElementType() and casted to an
+/// UniformQuantizedType. Returns nullptr if the conversion is not supported. On
+/// success, stores the converted type in outConvertedType.
+///
+/// Examples:
+/// 1. realValue is a primitive value attribute:
+/// (realValue: FloatAttr, quantizedElementType: UniformQuantizedType[i8:f32])
+/// -> (IntegerAttr, outConvertedType: i8)
+/// 2. realValue is an elements attribute:
+/// (realValue: DenseElementsAttr[tensor<2x2xf32>],
+/// quantizedElementType: UniformQuantizedType[i8:f32])
+/// -> (DenseElementsAttr[tensor<2x2xi8>], outConvertedType: tensor<2x2xi8>)
+Attribute quantizeAttrUniform(Attribute realValue,
+ UniformQuantizedType quantizedElementType,
+ const UniformQuantizedValueConverter &converter,
+ Type &outConvertedType);
+} // namespace quant
+} // namespace mlir
+
+#endif // MLIR_QUANTIZATION_QUANTIZEUTILS_H_
--- /dev/null
+//===- UniformSupport.h - Support utilities for uniform quant ---*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_QUANTIZATION_UNIFORMSUPPORT_H_
+#define MLIR_QUANTIZATION_UNIFORMSUPPORT_H_
+
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Quantization/QuantOps.h"
+#include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/APSInt.h"
+
+namespace mlir {
+namespace quant {
+
+/// Performs type conversion from an arbitrary input type to a type
+/// that is expressed by a UniformQuantizedType.
+///
+/// This handles cases where the inputType is a supported primitive type
+/// (i.e. f32, bf16, etc) or a vector/tensor type based on a supported
+/// elemental type.
+///
+/// Since conversion often involves introspecting some attributes of the
+/// input type in order to determine how to represent it, this is a two step
+/// process.
+struct ExpressedToUniformQuantizedConverter {
+ /// Creates a converter for the given input type.
+ static const ExpressedToUniformQuantizedConverter
+ forInputType(Type inputType);
+
+ /// Converts the inputType to be based on the given elemental type,
+ /// returning the new type (or nullptr and emit an error on failure).
+ Type convert(UniformQuantizedType elementalType) const;
+
+ /// Whether the conversion is legal.
+ explicit operator bool() const { return (bool)expressedType; }
+
+ /// The input type that is being converted from.
+ /// This may be an elemental or composite type.
+ const Type inputType;
+
+ /// Supported, elemental expressed type (i.e. f32).
+ /// Will be nullptr if conversion is not supported.
+ const Type expressedType;
+};
+
+/// Reference implementation of converting between real numbers and values
+/// represented by a UniformQuantizedType.
+/// Note that this is not expected to be speedy and may be superceded eventually
+/// by a more optimal implementation.
+/// Also, the interface assumes that quantization is done per-layer and will
+/// need to be wider for various per-channel schemes. As such, this is a
+/// placeholder.
+class UniformQuantizedValueConverter {
+public:
+ UniformQuantizedValueConverter(UniformQuantizedType uniformType)
+ : scale(uniformType.getScale()),
+ zeroPoint(static_cast<double>(uniformType.getZeroPoint())),
+ clampMin(static_cast<double>(uniformType.getStorageTypeMin())),
+ clampMax(static_cast<double>(uniformType.getStorageTypeMax())),
+ storageBitWidth(uniformType.getStorageTypeIntegralWidth()),
+ isSigned(uniformType.isSigned()) {
+ assert(uniformType.getExpressedType().isa<FloatType>());
+ assert(uniformType.getStorageType().isa<IntegerType>());
+ }
+
+ virtual APInt quantizeFloatToInt(APFloat expressedValue) const {
+ bool lossy;
+ expressedValue.convert(scale.getSemantics(), APFloat::rmNearestTiesToEven,
+ &lossy);
+ // fixedpoint = clamp(clampMin, clampMax, (
+ // roundHalfToEven(expressed / scale) + zeroPoint))
+ APFloat scaled = (expressedValue / scale);
+ scaled.roundToIntegral(APFloat::rmNearestTiesToEven);
+ scaled.add(zeroPoint, APFloat::rmNearestTiesToEven);
+ APFloat fixedpoint = llvm::minimum(scaled, clampMax);
+ fixedpoint = llvm::maximum(fixedpoint, clampMin);
+
+ llvm::APSInt result(storageBitWidth, !isSigned);
+ fixedpoint.convertToInteger(result, APFloat::rmNearestTiesToEven, &lossy);
+
+ return result;
+ }
+
+ int64_t quantizeFloatToInt64(APFloat expressedValue) const {
+ APInt qValue = quantizeFloatToInt(expressedValue);
+ return isSigned ? qValue.getSExtValue() : qValue.getZExtValue();
+ }
+
+ virtual ~UniformQuantizedValueConverter() {}
+
+private:
+ const APFloat scale;
+ const APFloat zeroPoint;
+ const APFloat clampMin;
+ const APFloat clampMax;
+ const uint32_t storageBitWidth;
+ const bool isSigned;
+};
+
+} // namespace quant
+} // namespace mlir
+
+#endif // MLIR_QUANTIZATION_UNIFORMSUPPORT_H_
--- /dev/null
+//===- DialectRegistration.cpp - Register Quantization dialect ------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Quantization/QuantOps.h"
+
+using namespace mlir;
+using namespace mlir::quant;
+
+// Static initialization for Quantization dialect registration.
+static mlir::DialectRegistration<QuantizationDialect> QuantizationOps;
--- /dev/null
+#include "mlir/Quantization/FakeQuantSupport.h"
+#include "mlir/Quantization/QuantOps.h"
+
+using namespace mlir;
+using namespace mlir::quant;
+
+UniformQuantizedType mlir::quant::fakeQuantAttrsToType(Location loc,
+ unsigned numBits,
+ double rmin, double rmax,
+ bool narrowRange,
+ Type expressedType) {
+ MLIRContext *ctx = expressedType.getContext();
+ Type storageType;
+ unsigned flags;
+ int64_t qmin;
+ int64_t qmax;
+
+ // Hard-coded type mapping from TFLite.
+ if (numBits <= 8) {
+ storageType = IntegerType::get(8, ctx);
+ flags = 0;
+ qmin = 0;
+ qmax = 255;
+ } else if (numBits <= 16) {
+ storageType = IntegerType::get(16, ctx);
+ flags = QuantizationFlags::Signed;
+ qmin = -32768;
+ qmax = 32767;
+ } else {
+ ctx->emitError(loc,
+ "unsupported FakeQuant number of bits: " + Twine(numBits));
+ return nullptr;
+ }
+
+ // Handle narrowRange.
+ if (narrowRange) {
+ qmin += 1;
+ }
+
+ // Range must straddle zero.
+ if (rmin > 0.0 || rmax < 0.0) {
+ return (ctx->emitError(loc, "FakeQuant range must straddle zero: [" +
+ Twine(std::to_string(rmin)) + "," +
+ Twine(std::to_string(rmax)) + "]"),
+ nullptr);
+ }
+
+ // Special case where min/max is a point. Must be 0.
+ if (rmin == rmax) {
+ return UniformQuantizedType::getChecked(flags, storageType, expressedType,
+ 0.0, 0, qmin, qmax, loc);
+ }
+
+ // Determine the scale.
+ const double qminDouble = qmin;
+ const double qmaxDouble = qmax;
+ const double scale = (rmax - rmin) / (qmaxDouble - qminDouble);
+
+ // Zero point computation.
+ // In float, solve the affine equation for any known pair
+ // (real value, corresponding quantized value), of which, two such pairs
+ // are known: (rmin, qmin), (rmax, qmax).
+ // The arithmetic error on the zero point computed from either pair will be
+ // roughly machine_epsilon * (sum of absolute values of terms).
+ // Use the variant that adds the smaller error.
+ const double zeroPointFromMin = qminDouble - rmin / scale;
+ const double zeroPointFromMinError =
+ std::abs(qminDouble) + std::abs(rmin / scale);
+ const double zeroPointFromMax = qmaxDouble - rmax / scale;
+ const double zeroPointFromMaxError =
+ std::abs(qmaxDouble) + std::abs(rmax / scale);
+
+ const double zeroPointDouble = (zeroPointFromMinError < zeroPointFromMaxError)
+ ? zeroPointFromMin
+ : zeroPointFromMax;
+
+ // Now nudge the zero point to be an integer.
+ int64_t nudgedZeroPoint = 0;
+ if (zeroPointDouble < qminDouble) {
+ nudgedZeroPoint = qmin;
+ } else if (zeroPointDouble > qmaxDouble) {
+ nudgedZeroPoint = qmax;
+ } else {
+ nudgedZeroPoint = round(zeroPointDouble);
+ }
+
+ // By construction, the nudged zero point should always be in range.
+ assert(nudgedZeroPoint >= qmin);
+ assert(nudgedZeroPoint <= qmax);
+
+ return UniformQuantizedType::getChecked(flags, storageType, expressedType,
+ scale, nudgedZeroPoint, qmin, qmax,
+ loc);
+}
--- /dev/null
+//===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Quantization/QuantOps.h"
+#include "TypeDetail.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/StandardTypes.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/MathExtras.h"
+
+using namespace mlir;
+using namespace mlir::quant;
+using namespace mlir::quant::detail;
+
+unsigned QuantizedType::getFlags() const {
+ return static_cast<ImplType *>(type)->flags;
+}
+
+LogicalResult QuantizedType::verifyConstructionInvariants(
+ llvm::Optional<Location> loc, MLIRContext *context, unsigned flags,
+ Type storageType, Type expressedType, int64_t storageTypeMin,
+ int64_t storageTypeMax) {
+ // Verify that the expressed type is floating point.
+ // If this restriction is ever eliminated, the parser/printer must be
+ // extended.
+ if (!expressedType.isa<FloatType>()) {
+ if (loc) {
+ context->emitError(*loc, "expressed type must be floating point");
+ }
+ return failure();
+ }
+
+ // Verify that the storage type is integral.
+ // This restriction may be lifted at some point in favor of using bf16
+ // or f16 as exact representations on hardware where that is advantageous.
+ auto intStorageType = storageType.dyn_cast<IntegerType>();
+ if (!intStorageType) {
+ if (loc) {
+ context->emitError(*loc, "storage type must be integral");
+ }
+ return failure();
+ }
+ unsigned integralWidth = intStorageType.getWidth();
+
+ // Verify storage width.
+ if (integralWidth == 0 || integralWidth > MaxStorageBits) {
+ if (loc) {
+ context->emitError(*loc,
+ "illegal storage type size: " + Twine(integralWidth));
+ }
+ return failure();
+ }
+
+ // Verify storageTypeMin and storageTypeMax.
+ bool isSigned =
+ (flags & QuantizationFlags::Signed) == QuantizationFlags::Signed;
+ int64_t defaultIntegerMin =
+ getDefaultMininumForInteger(isSigned, integralWidth);
+ int64_t defaultIntegerMax =
+ getDefaultMaxinumForInteger(isSigned, integralWidth);
+ if (storageTypeMax - storageTypeMin <= 0 ||
+ storageTypeMin < defaultIntegerMin ||
+ storageTypeMax > defaultIntegerMax) {
+ if (loc) {
+ context->emitError(*loc, "illegal storage min and storage max: (" +
+ Twine(storageTypeMin) + ":" +
+ Twine(storageTypeMax) + ")");
+ }
+ return failure();
+ }
+ return success();
+}
+
+Type QuantizedType::getStorageType() const {
+ return static_cast<ImplType *>(type)->storageType;
+}
+
+int64_t QuantizedType::getStorageTypeMin() const {
+ return static_cast<ImplType *>(type)->storageTypeMin;
+}
+
+int64_t QuantizedType::getStorageTypeMax() const {
+ return static_cast<ImplType *>(type)->storageTypeMax;
+}
+
+unsigned QuantizedType::getStorageTypeIntegralWidth() const {
+ // NOTE: If ever supporting non-integral storage types, some other scheme
+ // for determining the width will be needed.
+ return static_cast<ImplType *>(type)->storageType.getIntOrFloatBitWidth();
+}
+
+Type QuantizedType::getExpressedType() const {
+ return static_cast<ImplType *>(type)->expressedType;
+}
+
+bool QuantizedType::isCompatibleExpressedType(Type candidateExpressedType) {
+ if (candidateExpressedType.isa<VectorOrTensorType>()) {
+ return candidateExpressedType.cast<VectorOrTensorType>().getElementType() ==
+ getExpressedType();
+ }
+ return candidateExpressedType == getExpressedType();
+}
+
+QuantizedType
+QuantizedType::getQuantizedElementType(Type primitiveOrContainerType) {
+ if (primitiveOrContainerType.isa<VectorOrTensorType>()) {
+ Type elementType =
+ primitiveOrContainerType.cast<VectorOrTensorType>().getElementType();
+ return elementType.dyn_cast<QuantizedType>();
+ }
+ return primitiveOrContainerType.dyn_cast<QuantizedType>();
+}
+
+Type QuantizedType::castFromStorageType(Type candidateType) {
+ if (candidateType == getStorageType()) {
+ // i.e. i32 -> quant<"uniform[i8:f32]{1.0}">
+ return *this;
+ } else if (candidateType.isa<RankedTensorType>()) {
+ // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
+ return RankedTensorType::get(
+ candidateType.cast<RankedTensorType>().getShape(), getStorageType());
+ } else if (candidateType.isa<UnrankedTensorType>()) {
+ // i.e. tensor<i8> -> tensor<!quant<"uniform[i8:f32]{1.0}">>
+ return UnrankedTensorType::get(getStorageType());
+ } else if (candidateType.isa<VectorType>()) {
+ // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
+ return VectorType::get(candidateType.cast<VectorType>().getShape(),
+ getStorageType());
+ }
+
+ return nullptr;
+}
+
+Type QuantizedType::castToStorageType(Type quantizedType) {
+ if (quantizedType.isa<QuantizedType>()) {
+ // i.e. quant<"uniform[i8:f32]{1.0}"> -> i8
+ return quantizedType.cast<QuantizedType>().getStorageType();
+ } else if (quantizedType.isa<VectorOrTensorType>()) {
+ // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
+ VectorOrTensorType vtType = quantizedType.cast<VectorOrTensorType>();
+ if (!vtType.getElementType().isa<QuantizedType>()) {
+ return nullptr;
+ }
+ Type storageType =
+ vtType.getElementType().cast<QuantizedType>().getStorageType();
+ if (quantizedType.isa<RankedTensorType>()) {
+ return RankedTensorType::get(vtType.getShape(), storageType);
+ } else if (quantizedType.isa<UnrankedTensorType>()) {
+ return UnrankedTensorType::get(storageType);
+ } else if (quantizedType.isa<VectorType>()) {
+ return VectorType::get(vtType.getShape(), storageType);
+ }
+ }
+
+ return nullptr;
+}
+
+Type QuantizedType::castFromExpressedType(Type candidateType) {
+ if (candidateType == getExpressedType()) {
+ // i.e. f32 -> quant<"uniform[i8:f32]{1.0}">
+ return *this;
+ } else if (candidateType.isa<VectorOrTensorType>()) {
+ VectorOrTensorType candidateVtType =
+ candidateType.cast<VectorOrTensorType>();
+ if (candidateVtType.getElementType() != getExpressedType()) {
+ return nullptr;
+ }
+
+ if (candidateType.isa<RankedTensorType>()) {
+ // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
+ return RankedTensorType::get(candidateVtType.getShape(), *this);
+ } else if (candidateType.isa<UnrankedTensorType>()) {
+ // i.e. tensor<xf32> -> tensor<x!quant<"uniform[i8:f32]{1.0}">>
+ return UnrankedTensorType::get(*this);
+ } else if (candidateType.isa<VectorType>()) {
+ // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
+ return VectorType::get(candidateVtType.getShape(), *this);
+ }
+ }
+
+ return nullptr;
+}
+
+Type QuantizedType::castToExpressedType(Type quantizedType) {
+ if (quantizedType.isa<QuantizedType>()) {
+ // i.e. quant<"uniform[i8:f32]{1.0}"> -> f32
+ return quantizedType.cast<QuantizedType>().getExpressedType();
+ } else if (quantizedType.isa<VectorOrTensorType>()) {
+ // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
+ VectorOrTensorType vtType = quantizedType.cast<VectorOrTensorType>();
+ if (!vtType.getElementType().isa<QuantizedType>()) {
+ return nullptr;
+ }
+ Type expressedType =
+ vtType.getElementType().cast<QuantizedType>().getExpressedType();
+ if (quantizedType.isa<RankedTensorType>()) {
+ return RankedTensorType::get(vtType.getShape(), expressedType);
+ } else if (quantizedType.isa<UnrankedTensorType>()) {
+ return UnrankedTensorType::get(expressedType);
+ } else if (quantizedType.isa<VectorType>()) {
+ return VectorType::get(vtType.getShape(), expressedType);
+ }
+ }
+
+ return nullptr;
+}
+
+Type QuantizedType::castExpressedToStorageType(Type candidateType) {
+ Type expressedQuantizedType = castFromExpressedType(candidateType);
+ if (!expressedQuantizedType) {
+ return nullptr;
+ }
+ return QuantizedType::castToStorageType(expressedQuantizedType);
+}
+
+UniformQuantizedType UniformQuantizedType::get(unsigned flags, Type storageType,
+ Type expressedType, double scale,
+ int64_t zeroPoint,
+ int64_t storageTypeMin,
+ int64_t storageTypeMax) {
+ return Base::get(storageType.getContext(),
+ QuantizationTypes::UniformQuantized, flags, storageType,
+ expressedType, scale, zeroPoint, storageTypeMin,
+ storageTypeMax);
+}
+
+UniformQuantizedType
+UniformQuantizedType::getChecked(unsigned flags, Type storageType,
+ Type expressedType, double scale,
+ int64_t zeroPoint, int64_t storageTypeMin,
+ int64_t storageTypeMax, Location location) {
+ return Base::getChecked(location, storageType.getContext(),
+ QuantizationTypes::UniformQuantized, flags,
+ storageType, expressedType, scale, zeroPoint,
+ storageTypeMin, storageTypeMax);
+}
+
+LogicalResult UniformQuantizedType::verifyConstructionInvariants(
+ llvm::Optional<Location> loc, MLIRContext *context, unsigned flags,
+ Type storageType, Type expressedType, double scale, int64_t zeroPoint,
+ int64_t storageTypeMin, int64_t storageTypeMax) {
+ if (failed(QuantizedType::verifyConstructionInvariants(
+ loc, context, flags, storageType, expressedType, storageTypeMin,
+ storageTypeMax))) {
+ return failure();
+ }
+
+ // Verify scale.
+ if (scale <= 0.0 || isinf(scale) || isnan(scale)) {
+ if (loc) {
+ context->emitError(*loc,
+ "illegal scale: " + Twine(std::to_string(scale)));
+ }
+ return failure();
+ }
+
+ return success();
+}
+
+double UniformQuantizedType::getScale() const { return getImpl()->scale; }
+
+int64_t UniformQuantizedType::getZeroPoint() const {
+ return getImpl()->zeroPoint;
+}
+
+UniformQuantizedPerAxisType UniformQuantizedPerAxisType::get(
+ unsigned flags, Type storageType, Type expressedType,
+ ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
+ int32_t quantizedDimension, int64_t storageTypeMin,
+ int64_t storageTypeMax) {
+ return Base::get(storageType.getContext(),
+ QuantizationTypes::UniformQuantizedPerAxis, flags,
+ storageType, expressedType, scales, zeroPoints,
+ quantizedDimension, storageTypeMin, storageTypeMax);
+}
+
+UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked(
+ unsigned flags, Type storageType, Type expressedType,
+ ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
+ int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax,
+ Location location) {
+ return Base::getChecked(location, storageType.getContext(),
+ QuantizationTypes::UniformQuantizedPerAxis, flags,
+ storageType, expressedType, scales, zeroPoints,
+ quantizedDimension, storageTypeMin, storageTypeMax);
+}
+
+LogicalResult UniformQuantizedPerAxisType::verifyConstructionInvariants(
+ llvm::Optional<Location> loc, MLIRContext *context, unsigned flags,
+ Type storageType, Type expressedType, ArrayRef<double> scales,
+ ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
+ int64_t storageTypeMin, int64_t storageTypeMax) {
+ if (failed(QuantizedType::verifyConstructionInvariants(
+ loc, context, flags, storageType, expressedType, storageTypeMin,
+ storageTypeMax))) {
+ return failure();
+ }
+
+ // Ensure that the number of scales and zeroPoints match.
+ if (scales.size() != zeroPoints.size()) {
+ if (loc) {
+ context->emitError(*loc, "illegal number of scales and zeroPoints: " +
+ Twine(scales.size()) + ", " +
+ Twine(zeroPoints.size()));
+ }
+ return failure();
+ }
+
+ // Verify scale.
+ for (double scale : scales) {
+ if (scale <= 0.0 || isinf(scale) || isnan(scale)) {
+ if (loc) {
+ context->emitError(*loc,
+ "illegal scale: " + Twine(std::to_string(scale)));
+ }
+ return failure();
+ }
+ }
+
+ return success();
+}
+
+ArrayRef<double> UniformQuantizedPerAxisType::getScales() const {
+ return getImpl()->getScales();
+}
+
+ArrayRef<int64_t> UniformQuantizedPerAxisType::getZeroPoints() const {
+ return getImpl()->getZeroPoints();
+}
+
+int32_t UniformQuantizedPerAxisType::getQuantizedDimension() const {
+ return getImpl()->quantizedDimension;
+}
+
+#define GET_OP_CLASSES
+#include "mlir/Quantization/QuantOps.cpp.inc"
+
+QuantizationDialect::QuantizationDialect(MLIRContext *context)
+ : Dialect(/*name=*/"quant", context) {
+ addTypes<UniformQuantizedType, UniformQuantizedPerAxisType>();
+ addOperations<
+#define GET_OP_LIST
+#include "mlir/Quantization/QuantOps.cpp.inc"
+ >();
+}
--- /dev/null
+//===- Quantization/IR/TypeDetail.h - Type detail ---------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef TYPE_DETAIL_H_
+#define TYPE_DETAIL_H_
+
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/TypeSupport.h"
+#include "mlir/IR/Types.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/Hashing.h"
+#include "llvm/ADT/bit.h"
+
+namespace mlir {
+namespace quant {
+namespace detail {
+
+struct QuantizedTypeStorage : public mlir::TypeStorage {
+ QuantizedTypeStorage(unsigned flags, Type storageType, Type expressedType,
+ int64_t storageTypeMin, int64_t storageTypeMax)
+ : flags(flags), storageType(storageType), expressedType(expressedType),
+ storageTypeMin(storageTypeMin), storageTypeMax(storageTypeMax) {}
+
+ /// Flags corresponding to the bitmapped enum QuantizationFlags::FlagValue.
+ unsigned flags;
+
+ // Integral type for the storage point representation.
+ Type storageType;
+
+ // Floating point type that the quantized type approximates.
+ Type expressedType;
+
+ // The minimum value storageType can take.
+ int64_t storageTypeMin;
+
+ // The maximum value storageType can take.
+ int64_t storageTypeMax;
+};
+
+struct UniformQuantizedTypeStorage : public QuantizedTypeStorage {
+ struct KeyTy {
+ KeyTy(unsigned flags, Type storageType, Type expressedType, double scale,
+ int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)
+ : flags(flags), storageType(storageType), expressedType(expressedType),
+ scale(scale), zeroPoint(zeroPoint), storageTypeMin(storageTypeMin),
+ storageTypeMax(storageTypeMax) {}
+ /// Flags corresponding to the bitmapped enum QuantizationFlags::FlagValue.
+ unsigned flags;
+
+ // Integral type for the storage point representation.
+ Type storageType;
+
+ // Floating point type that the quantized type approximates.
+ Type expressedType;
+
+ double scale;
+ int64_t zeroPoint;
+ int64_t storageTypeMin;
+ int64_t storageTypeMax;
+
+ // Check for equality of two structures that share KeyTy data members
+ // (by name).
+ template <typename T, typename U>
+ static bool genericIsEqual(const T &lhs, const U &rhs) {
+ return lhs.flags == rhs.flags && lhs.storageType == rhs.storageType &&
+ lhs.expressedType == rhs.expressedType && lhs.scale == rhs.scale &&
+ lhs.zeroPoint == rhs.zeroPoint &&
+ lhs.storageTypeMin == rhs.storageTypeMin &&
+ lhs.storageTypeMax == rhs.storageTypeMax;
+ }
+
+ bool operator==(const KeyTy &other) const {
+ return genericIsEqual(*this, other);
+ }
+
+ unsigned getHashValue() const {
+ int64_t scaleBits = llvm::bit_cast<int64_t>(scale);
+ return llvm::hash_combine(flags, storageType, expressedType, scaleBits,
+ zeroPoint, storageTypeMin, storageTypeMax);
+ }
+ };
+
+ UniformQuantizedTypeStorage(const KeyTy &key)
+ : QuantizedTypeStorage(key.flags, key.storageType, key.expressedType,
+ key.storageTypeMin, key.storageTypeMax),
+ scale(key.scale), zeroPoint(key.zeroPoint) {}
+
+ bool operator==(const KeyTy &key) const {
+ return KeyTy::genericIsEqual(*this, key);
+ }
+
+ /// Construction.
+ static UniformQuantizedTypeStorage *construct(TypeStorageAllocator &allocator,
+ const KeyTy &key) {
+ return new (allocator.allocate<UniformQuantizedTypeStorage>())
+ UniformQuantizedTypeStorage(key);
+ }
+
+ static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); }
+
+ double scale;
+ int64_t zeroPoint;
+};
+
+struct UniformQuantizedPerAxisTypeStorage : public QuantizedTypeStorage {
+ struct KeyTy {
+ KeyTy(unsigned flags, Type storageType, Type expressedType,
+ ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
+ int32_t quantizedDimension, int64_t storageTypeMin,
+ int64_t storageTypeMax)
+ : flags(flags), storageType(storageType), expressedType(expressedType),
+ scales(scales), zeroPoints(zeroPoints),
+ quantizedDimension(quantizedDimension),
+ storageTypeMin(storageTypeMin), storageTypeMax(storageTypeMax) {}
+ /// Flags corresponding to the bitmapped enum QuantizationFlags::FlagValue.
+ unsigned flags;
+
+ // Integral type for the storage point representation.
+ Type storageType;
+
+ // Floating point type that the quantized type approximates.
+ Type expressedType;
+
+ ArrayRef<double> scales;
+ ArrayRef<int64_t> zeroPoints;
+ int32_t quantizedDimension;
+ int64_t storageTypeMin;
+ int64_t storageTypeMax;
+
+ ArrayRef<double> getScales() const { return scales; }
+
+ ArrayRef<int64_t> getZeroPoints() const { return zeroPoints; }
+
+ // Check for equality of two structures that share KeyTy data members
+ // (by name).
+ template <typename T, typename U>
+ static bool genericIsEqual(const T &lhs, const U &rhs) {
+ return lhs.flags == rhs.flags && lhs.storageType == rhs.storageType &&
+ lhs.expressedType == rhs.expressedType &&
+ lhs.getScales() == rhs.getScales() &&
+ lhs.getZeroPoints() == rhs.getZeroPoints() &&
+ lhs.quantizedDimension == rhs.quantizedDimension &&
+ lhs.storageTypeMin == rhs.storageTypeMin &&
+ lhs.storageTypeMax == rhs.storageTypeMax;
+ }
+
+ bool operator==(const KeyTy &other) const {
+ return genericIsEqual(*this, other);
+ }
+
+ unsigned getHashValue() const {
+ int64_t *scalesCast = llvm::bit_cast<int64_t *>(scales.data());
+ ArrayRef<int64_t> scalesBits(scalesCast, scales.size());
+ return llvm::hash_combine(
+ flags, storageType, expressedType,
+ llvm::hash_combine_range(scalesBits.begin(), scalesBits.end()),
+ llvm::hash_combine_range(zeroPoints.begin(), zeroPoints.end()),
+ storageTypeMin, storageTypeMax);
+ }
+ };
+
+ // We pass scales and zeroPoints in directly rather than relying on KeyTy
+ // because we have to create new reallocated versions in `constrcut` below.
+ UniformQuantizedPerAxisTypeStorage(const KeyTy &key, ArrayRef<double> scales,
+ ArrayRef<int64_t> zeroPoints)
+ : QuantizedTypeStorage(key.flags, key.storageType, key.expressedType,
+ key.storageTypeMin, key.storageTypeMax),
+ scaleElements(scales.data()), zeroPointElements(zeroPoints.data()),
+ quantParamsSize(scales.size()),
+ quantizedDimension(key.quantizedDimension) {}
+
+ bool operator==(const KeyTy &key) const {
+ return KeyTy::genericIsEqual(*this, key);
+ }
+
+ /// Construction.
+ static UniformQuantizedPerAxisTypeStorage *
+ construct(TypeStorageAllocator &allocator, const KeyTy &key) {
+ ArrayRef<double> scales = allocator.copyInto(key.scales);
+ ArrayRef<int64_t> zeroPoints = allocator.copyInto(key.zeroPoints);
+ return new (allocator.allocate<UniformQuantizedPerAxisTypeStorage>())
+ UniformQuantizedPerAxisTypeStorage(key, scales, zeroPoints);
+ }
+
+ static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); }
+
+ ArrayRef<double> getScales() const {
+ return ArrayRef<double>(scaleElements, quantParamsSize);
+ }
+
+ ArrayRef<int64_t> getZeroPoints() const {
+ return ArrayRef<int64_t>(zeroPointElements, quantParamsSize);
+ }
+
+ const double *scaleElements;
+ const int64_t *zeroPointElements;
+ unsigned quantParamsSize;
+ int32_t quantizedDimension;
+};
+
+} // namespace detail
+} // namespace quant
+} // namespace mlir
+
+#endif // TYPE_DETAIL_H_
--- /dev/null
+//===- Quantization/IR/TypeParser.h - Quantization Type Parser --*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/Location.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Quantization/QuantOps.h"
+#include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/StringSwitch.h"
+#include "llvm/Support/Format.h"
+#include "llvm/Support/MathExtras.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace mlir {
+namespace quant {
+
+/// Print a floating point value in a way that the parser will be able to
+/// round-trip losslessly.
+static void printStabilizedFloat(const APFloat &apValue, raw_ostream &os) {
+ // We would like to output the FP constant value in exponential notation,
+ // but we cannot do this if doing so will lose precision. Check here to
+ // make sure that we only output it in exponential format if we can parse
+ // the value back and get the same value.
+ bool isInf = apValue.isInfinity();
+ bool isNaN = apValue.isNaN();
+ if (!isInf && !isNaN) {
+ SmallString<128> strValue;
+ apValue.toString(strValue, 6, 0, false);
+
+ // Check to make sure that the stringized number is not some string like
+ // "Inf" or NaN, that atof will accept, but the lexer will not. Check
+ // that the string matches the "[-+]?[0-9]" regex.
+ assert(((strValue[0] >= '0' && strValue[0] <= '9') ||
+ ((strValue[0] == '-' || strValue[0] == '+') &&
+ (strValue[1] >= '0' && strValue[1] <= '9'))) &&
+ "[-+]?[0-9] regex does not match!");
+ // Reparse stringized version!
+ if (APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(apValue)) {
+ os << strValue;
+ return;
+ }
+ }
+
+ SmallVector<char, 16> str;
+ apValue.toString(str);
+ os << str;
+}
+
+namespace {
+
+enum class TokenKind {
+ error,
+ eof,
+ l_bracket,
+ r_bracket,
+ l_brace,
+ r_brace,
+ l_paren,
+ r_paren,
+ colon,
+ comma,
+ alpha_ident,
+ integer_literal,
+ float_literal,
+};
+
+struct Token {
+ TokenKind kind;
+ StringRef spelling;
+};
+
+class Lexer {
+public:
+ Lexer(StringRef source) : curBuffer(source), curPtr(curBuffer.begin()) {}
+
+ Token lexToken();
+
+private:
+ Token formToken(TokenKind kind, const char *tokStart) {
+ return Token{kind, StringRef(tokStart, curPtr - tokStart)};
+ }
+
+ Token emitError(const char *loc, const Twine &message) {
+ return formToken(TokenKind::error, loc);
+ }
+
+ bool isEnd() const { return curPtr == curBuffer.end(); }
+
+ // Lexer implementation methods
+ Token lexalpha_ident(const char *tokStart);
+ Token lexNumber(const char *tokStart);
+
+ StringRef curBuffer;
+ const char *curPtr;
+};
+
+} // namespace
+
+Token Lexer::lexToken() {
+ // Ignore whitespace.
+ while (!isEnd()) {
+ switch (*curPtr) {
+ case ' ':
+ case '\t':
+ case '\n':
+ case '\r':
+ ++curPtr;
+ continue;
+ default:
+ break;
+ }
+ break;
+ }
+
+ if (isEnd()) {
+ return Token{TokenKind::eof, ""};
+ }
+
+ const char *tokStart = curPtr;
+ switch (*curPtr++) {
+ default:
+ if (isalpha(*tokStart)) {
+ return lexalpha_ident(tokStart);
+ }
+ if (isdigit(*tokStart)) {
+ return lexNumber(tokStart);
+ }
+
+ return emitError(tokStart, "unexpected character");
+
+ case '[':
+ return formToken(TokenKind::l_bracket, tokStart);
+ case ']':
+ return formToken(TokenKind::r_bracket, tokStart);
+ case '{':
+ return formToken(TokenKind::l_brace, tokStart);
+ case '}':
+ return formToken(TokenKind::r_brace, tokStart);
+ case '(':
+ return formToken(TokenKind::l_paren, tokStart);
+ case ')':
+ return formToken(TokenKind::r_paren, tokStart);
+ case ':':
+ return formToken(TokenKind::colon, tokStart);
+ case ',':
+ return formToken(TokenKind::comma, tokStart);
+ case '-':
+ return lexNumber(tokStart);
+ case '+':
+ return lexNumber(tokStart);
+ }
+}
+
+/// Lex a bare alpha identifier. Since this DSL often contains identifiers with
+/// trailing numeric components, this only matches alphas. It is up to the
+/// parser to handle identifiers that can be mixed alphanum.
+///
+/// alpha-ident ::= (letter)(letter)*
+Token Lexer::lexalpha_ident(const char *tokStart) {
+ while (!isEnd() && isalpha(*curPtr)) {
+ ++curPtr;
+ }
+ return formToken(TokenKind::alpha_ident, tokStart);
+}
+
+/// Lex a number.
+///
+/// integer-literal ::= [-+]?digit+
+/// float-literal ::= [-+]?[0-9]+[.][0-9]*([eE][-+]?[0-9]+)?
+Token Lexer::lexNumber(const char *tokStart) {
+ // Leading '+', '-' or digit has already been consumed.
+ while (!isEnd() && isdigit(*curPtr)) {
+ ++curPtr;
+ }
+ // If not a decimal point, treat as integer.
+ if (isEnd() || *curPtr != '.') {
+ return formToken(TokenKind::integer_literal, tokStart);
+ }
+ ++curPtr;
+
+ // Skip over [0-9]*([eE][-+]?[0-9]+)?
+ // Leading digits.
+ while (!isEnd() && isdigit(*curPtr)) {
+ ++curPtr;
+ }
+
+ // [eE][-+]?[0-9]+
+ if (!isEnd() && (*curPtr == 'e' || *curPtr == 'E')) {
+ auto remaining = curBuffer.end() - curPtr;
+ if (remaining > 2 && isdigit(curPtr[1])) {
+ // Lookahead 2 for digit.
+ curPtr += 2;
+ while (!isEnd() && isdigit(*curPtr)) {
+ ++curPtr;
+ }
+ } else if (remaining > 3 && (curPtr[1] == '-' || curPtr[1] == '+') &&
+ isdigit(curPtr[2])) {
+ // Lookahead 3 for [+-] digit.
+ curPtr += 3;
+ while (!isEnd() && isdigit(*curPtr)) {
+ ++curPtr;
+ }
+ }
+ }
+ return formToken(TokenKind::float_literal, tokStart);
+} // end namespace
+
+// --- TypeParser ---
+namespace {
+
+class TypeParser {
+public:
+ TypeParser(StringRef source, MLIRContext *context, Location location)
+ : context(context), location(location), lexer(source),
+ curToken(lexer.lexToken()) {}
+
+ /// Attempts to parse the source as a type, returning the unknown
+ /// type on error.
+ Type parseType();
+
+private:
+ /// Unconditionally consumes the current token.
+ void consumeToken() {
+ assert(curToken.kind != TokenKind::eof &&
+ "should not advance past EOF or errors");
+ curToken = lexer.lexToken();
+ }
+
+ /// Unconditionally consumes the current token, asserting that it is of the
+ /// specified kind.
+ void consumeToken(TokenKind kind) {
+ assert(curToken.kind == kind && "consumed an unexpected token");
+ consumeToken();
+ }
+
+ /// Conditionally consumes a token if of the specified kind.
+ /// Returns true if consumed.
+ bool consumeIf(TokenKind kind) {
+ if (curToken.kind == kind) {
+ consumeToken();
+ return true;
+ }
+ return false;
+ }
+
+ /// Emits an error at the current location with a message.
+ void emitError(const Twine &message) {
+ // TODO: All errors show up at the beginning of the extended type location.
+ // Figure out how to make this location relative to where the error occurred
+ // in this instance.
+ context->emitError(location, message);
+ }
+
+ // Parsers.
+ Type parseUniformType();
+ IntegerType parseStorageType(bool &isSigned);
+ FloatType parseExpressedType();
+ bool parseQuantParams(double &scale, int64_t &zeroPoint);
+
+ MLIRContext *context;
+ Location location;
+ Lexer lexer;
+
+ // The next token that has not yet been consumed.
+ Token curToken;
+};
+
+} // namespace
+
+Type TypeParser::parseType() {
+ // All types start with an identifier that we switch on.
+ if (curToken.kind == TokenKind::alpha_ident) {
+ StringRef typeNameSpelling = curToken.spelling;
+ consumeToken();
+
+ Type result;
+ if (typeNameSpelling == "uniform") {
+ result = parseUniformType();
+ if (!result) {
+ return nullptr;
+ }
+ } else {
+ return (emitError("unknown quantized type " + typeNameSpelling), nullptr);
+ }
+
+ // Make sure the entire input was consumed.
+ if (curToken.kind != TokenKind::eof) {
+ return (emitError("unrecognized token: " + curToken.spelling), nullptr);
+ }
+
+ return result;
+ } else {
+ return (emitError("unrecognized token: " + curToken.spelling), nullptr);
+ }
+}
+
+/// Parses a UniformQuantizedType.
+///
+/// uniform_type ::= `uniform` type_spec quant_param_spec
+///
+/// type_spec ::= `[` storage-spec `:` expressed-type (quant-dim)? `]`
+/// quant-dim ::= `:` integer-literal
+/// storage-spec ::= storage-type (`(` storage-range `)`)?
+/// storage-range ::= integer-literal `:` integer-literal
+/// storage-type ::= (`i` | `u`) integer-literal
+/// expressed-type ::= (`f16` | `f32` | `f64` | `bf16`)
+///
+/// quant_param_spec ::= `{` scale-zero (`,` scale-zero )* `}`
+/// scale-zero ::= float-literal `:` integer-literal
+Type TypeParser::parseUniformType() {
+ IntegerType storageType;
+ FloatType expressedType;
+ unsigned typeFlags = 0;
+ int64_t storageTypeMin;
+ int64_t storageTypeMax;
+ bool isPerAxis = false;
+ int32_t quantizedDimension;
+ SmallVector<double, 1> scales;
+ SmallVector<int64_t, 1> zeroPoints;
+
+ // Type specification.
+ if (!consumeIf(TokenKind::l_bracket)) {
+ return (emitError("unrecognized token: " + curToken.spelling), nullptr);
+ }
+
+ // Storage type.
+ bool isSigned = false;
+ storageType = parseStorageType(isSigned);
+ if (!storageType) {
+ return nullptr;
+ }
+ if (isSigned) {
+ typeFlags |= QuantizationFlags::Signed;
+ }
+
+ // Storage type range.
+ int64_t defaultIntegerMin = QuantizedType::getDefaultMininumForInteger(
+ isSigned, storageType.getWidth());
+ int64_t defaultIntegerMax = QuantizedType::getDefaultMaxinumForInteger(
+ isSigned, storageType.getWidth());
+ if (consumeIf(TokenKind::l_paren)) {
+ // Explicit storage min and storage max.
+ if (curToken.kind != TokenKind::integer_literal) {
+ return (emitError("expected storage type minimum"), nullptr);
+ }
+ if (curToken.spelling.getAsInteger(10, storageTypeMin) ||
+ storageTypeMin < defaultIntegerMin) {
+ return (emitError("illegal storage type minimum: " + curToken.spelling),
+ nullptr);
+ }
+ consumeToken(TokenKind::integer_literal);
+
+ if (!consumeIf(TokenKind::colon)) {
+ return (emitError("unrecognized token: " + curToken.spelling), nullptr);
+ }
+
+ if (curToken.kind != TokenKind::integer_literal) {
+ return (emitError("expected storage type maximum"), nullptr);
+ }
+ if (curToken.spelling.getAsInteger(10, storageTypeMax) ||
+ storageTypeMax > defaultIntegerMax) {
+ return (emitError("illegal storage type maximum: " + curToken.spelling),
+ nullptr);
+ }
+ consumeToken(TokenKind::integer_literal);
+
+ if (!consumeIf(TokenKind::r_paren)) {
+ return (emitError("unrecognized token: " + curToken.spelling), nullptr);
+ }
+ } else {
+ storageTypeMin = defaultIntegerMin;
+ storageTypeMax = defaultIntegerMax;
+ }
+
+ // Repr type.
+ if (!consumeIf(TokenKind::colon)) {
+ return (emitError("unrecognized token: " + curToken.spelling), nullptr);
+ }
+ expressedType = parseExpressedType();
+ if (!expressedType) {
+ return nullptr;
+ }
+
+ // Optionally parse quantized dimension for per-axis quantization.
+ if (consumeIf(TokenKind::colon)) {
+ if (curToken.kind != TokenKind::integer_literal) {
+ return (emitError("expected quantized dimension"), nullptr);
+ }
+ if (curToken.spelling.getAsInteger(10, quantizedDimension)) {
+ return (emitError("illegal quantized dimension: " + curToken.spelling),
+ nullptr);
+ }
+ consumeToken(TokenKind::integer_literal);
+ isPerAxis = true;
+ }
+
+ if (!consumeIf(TokenKind::r_bracket)) {
+ return (emitError("unrecognized token: " + curToken.spelling), nullptr);
+ }
+
+ // Parameter specification.
+ if (!consumeIf(TokenKind::l_brace)) {
+ return (emitError("unrecognized token: " + curToken.spelling), nullptr);
+ }
+
+ // Parse scales/zeroPoints.
+ do {
+ scales.resize(scales.size() + 1);
+ zeroPoints.resize(zeroPoints.size() + 1);
+ if (parseQuantParams(scales.back(), zeroPoints.back())) {
+ return nullptr;
+ }
+ } while (consumeIf(TokenKind::comma));
+
+ if (!consumeIf(TokenKind::r_brace)) {
+ return (emitError("unrecognized token: " + curToken.spelling), nullptr);
+ }
+
+ if (!isPerAxis && scales.size() > 1) {
+ return (emitError("multiple scales/zeroPoints provided, but "
+ "quantizedDimension wasn't specified"),
+ nullptr);
+ }
+
+ if (isPerAxis) {
+ ArrayRef<double> scalesRef(scales.begin(), scales.end());
+ ArrayRef<int64_t> zeroPointsRef(zeroPoints.begin(), zeroPoints.end());
+ return UniformQuantizedPerAxisType::getChecked(
+ typeFlags, storageType, expressedType, scalesRef, zeroPointsRef,
+ quantizedDimension, storageTypeMin, storageTypeMax, location);
+ }
+
+ return UniformQuantizedType::getChecked(
+ typeFlags, storageType, expressedType, scales.front(), zeroPoints.front(),
+ storageTypeMin, storageTypeMax, location);
+}
+
+IntegerType TypeParser::parseStorageType(bool &isSigned) {
+ // Parse storage type (alpha_ident, integer_literal).
+ StringRef storageTypePrefix = curToken.spelling;
+ unsigned storageTypeWidth;
+ if (curToken.kind != TokenKind::alpha_ident) {
+ return (emitError("expected storage type prefix"), nullptr);
+ }
+ consumeToken();
+ if (curToken.kind != TokenKind::integer_literal) {
+ return (emitError("expected storage type width"), nullptr);
+ }
+ if (curToken.spelling.getAsInteger(10, storageTypeWidth) ||
+ storageTypeWidth == 0 ||
+ storageTypeWidth > QuantizedType::MaxStorageBits) {
+ return (emitError("illegal storage type size: " + Twine(curToken.spelling)),
+ nullptr);
+ }
+ consumeToken();
+
+ if (storageTypePrefix == "i") {
+ isSigned = true;
+ return IntegerType::get(storageTypeWidth, context);
+ } else if (storageTypePrefix == "u") {
+ isSigned = false;
+ return IntegerType::get(storageTypeWidth, context);
+ } else {
+ return (
+ emitError("illegal storage type prefix: " + Twine(storageTypePrefix)),
+ nullptr);
+ }
+}
+
+FloatType TypeParser::parseExpressedType() {
+ // Expect an alpha_ident followed by integer literal that we concat back
+ // together.
+ StringRef prefix = curToken.spelling;
+ if (!consumeIf(TokenKind::alpha_ident)) {
+ return (emitError("expected expressed type"), nullptr);
+ }
+ StringRef suffix = curToken.spelling;
+ if (!consumeIf(TokenKind::integer_literal)) {
+ return (emitError("expected expressed type"), nullptr);
+ }
+
+ SmallVector<char, 4> holder;
+ StringRef typeName = (Twine(prefix) + Twine(suffix)).toStringRef(holder);
+ if (typeName == "f32")
+ return FloatType::getF32(context);
+ if (typeName == "f16")
+ return FloatType::getF16(context);
+ if (typeName == "bf16")
+ return FloatType::getBF16(context);
+ if (typeName == "f64")
+ return FloatType::getF64(context);
+
+ return (emitError("unrecognized expressed type: " + typeName), nullptr);
+}
+
+bool TypeParser::parseQuantParams(double &scale, int64_t &zeroPoint) {
+ // scale[:zeroPoint]?
+ // scale.
+ StringRef scaleSpelling = curToken.spelling;
+ if (!consumeIf(TokenKind::float_literal) ||
+ scaleSpelling.getAsDouble(scale)) {
+ return (
+ emitError("expected valid uniform scale. got: " + Twine(scaleSpelling)),
+ true);
+ }
+
+ // zero point.
+ zeroPoint = 0;
+ if (!consumeIf(TokenKind::colon)) {
+ // Default zero point.
+ return false;
+ }
+ StringRef zeroPointSpelling = curToken.spelling;
+ if (!consumeIf(TokenKind::integer_literal) ||
+ zeroPointSpelling.getAsInteger(10, zeroPoint)) {
+ return (emitError("expected integer uniform zero point. got: " +
+ Twine(zeroPointSpelling)),
+ true);
+ }
+
+ return false;
+}
+
+/// Parse a type registered to this dialect.
+Type QuantizationDialect::parseType(StringRef spec, Location loc) const {
+ TypeParser parser(spec, getContext(), loc);
+ Type parsedType = parser.parseType();
+ if (parsedType == nullptr) {
+ // Error.
+ // TODO(laurenzo): Do something?
+ return parsedType;
+ }
+
+ return parsedType;
+}
+
+static void printStorageType(QuantizedType type, raw_ostream &out) {
+ // storage type
+ unsigned storageWidth = type.getStorageTypeIntegralWidth();
+ bool isSigned = type.isSigned();
+ if (isSigned) {
+ out << "i" << storageWidth;
+ } else {
+ out << "u" << storageWidth;
+ }
+
+ // storageTypeMin and storageTypeMax if not default.
+ int64_t defaultIntegerMin =
+ QuantizedType::getDefaultMininumForInteger(isSigned, storageWidth);
+ int64_t defaultIntegerMax =
+ QuantizedType::getDefaultMaxinumForInteger(isSigned, storageWidth);
+ if (defaultIntegerMin != type.getStorageTypeMin() ||
+ defaultIntegerMax != type.getStorageTypeMax()) {
+ out << "(" << type.getStorageTypeMin() << ":" << type.getStorageTypeMax()
+ << ")";
+ }
+}
+
+static void printExpressedType(QuantizedType type, raw_ostream &out) {
+ // repr type
+ Type expressedType = type.getExpressedType();
+ if (expressedType.isF32()) {
+ out << "f32";
+ } else if (expressedType.isF64()) {
+ out << "f64";
+ } else if (expressedType.isF16()) {
+ out << "f16";
+ } else if (expressedType.isBF16()) {
+ out << "bf16";
+ } else {
+ out << "unknown";
+ }
+}
+
+static void printQuantParams(double scale, int64_t zeroPoint,
+ raw_ostream &out) {
+ printStabilizedFloat(APFloat(scale), out);
+ if (zeroPoint != 0) {
+ out << ":" << zeroPoint;
+ }
+}
+
+/// Helper that prints a UniformQuantizedType.
+static void printUniformQuantizedType(UniformQuantizedType type,
+ raw_ostream &out) {
+ out << "uniform[";
+ printStorageType(type, out);
+ out << ":";
+ printExpressedType(type, out);
+ out << "]";
+
+ // scheme specific parameters
+ out << "{";
+ printQuantParams(type.getScale(), type.getZeroPoint(), out);
+ out << "}";
+}
+
+/// Helper that prints a UniformQuantizedPerAxisType.
+static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type,
+ raw_ostream &out) {
+ out << "uniform[";
+ printStorageType(type, out);
+ out << ":";
+ printExpressedType(type, out);
+ out << ":";
+ out << type.getQuantizedDimension();
+ out << "]";
+
+ // scheme specific parameters
+ ArrayRef<double> scales = type.getScales();
+ ArrayRef<int64_t> zeroPoints = type.getZeroPoints();
+ out << "{";
+ for (unsigned i = 0; i < scales.size(); ++i) {
+ printQuantParams(scales[i], zeroPoints[i], out);
+ if (i != scales.size() - 1) {
+ out << ",";
+ }
+ }
+ out << "}";
+}
+
+/// Print a type registered to this dialect.
+void QuantizationDialect::printType(Type type, raw_ostream &os) const {
+ switch (type.getKind()) {
+ default:
+ llvm_unreachable("Unhandled quantized type");
+ case QuantizationTypes::UniformQuantized:
+ printUniformQuantizedType(type.cast<UniformQuantizedType>(), os);
+ break;
+ case QuantizationTypes::UniformQuantizedPerAxis:
+ printUniformQuantizedPerAxisType(type.cast<UniformQuantizedPerAxisType>(),
+ os);
+ break;
+ }
+}
+
+} // namespace quant
+} // namespace mlir
--- /dev/null
+//===- UniformSupport.cpp - Support utilities for uniform quant -----------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Quantization/UniformSupport.h"
+#include "mlir/IR/StandardTypes.h"
+
+using namespace mlir;
+using namespace mlir::quant;
+
+static bool isQuantizablePrimitiveType(Type inputType) {
+ return inputType.isa<FloatType>();
+}
+
+const ExpressedToUniformQuantizedConverter
+ExpressedToUniformQuantizedConverter::forInputType(Type inputType) {
+ switch (inputType.getKind()) {
+ default:
+ if (isQuantizablePrimitiveType(inputType)) {
+ // Supported primitive type (which just is the expressed type).
+ return ExpressedToUniformQuantizedConverter{inputType, inputType};
+ }
+ // Unsupported.
+ return ExpressedToUniformQuantizedConverter{inputType, nullptr};
+ case StandardTypes::RankedTensor:
+ case StandardTypes::UnrankedTensor:
+ case StandardTypes::Vector: {
+ Type elementType = inputType.cast<VectorOrTensorType>().getElementType();
+ if (!isQuantizablePrimitiveType(elementType)) {
+ // Unsupported.
+ return ExpressedToUniformQuantizedConverter{inputType, nullptr};
+ }
+ return ExpressedToUniformQuantizedConverter{
+ inputType, inputType.cast<VectorOrTensorType>().getElementType()};
+ }
+ }
+}
+
+Type ExpressedToUniformQuantizedConverter::convert(
+ UniformQuantizedType elementalType) const {
+ assert(expressedType && "convert() on unsupported conversion");
+
+ switch (inputType.getKind()) {
+ default:
+ if (isQuantizablePrimitiveType(elementalType)) {
+ // For primitives, just use the new elemental type.
+ return elementalType;
+ }
+ // Unsupported.
+ return nullptr;
+ case StandardTypes::RankedTensor:
+ return RankedTensorType::get(inputType.cast<RankedTensorType>().getShape(),
+ elementalType);
+ case StandardTypes::UnrankedTensor:
+ return UnrankedTensorType::get(elementalType);
+ case StandardTypes::Vector:
+ return VectorType::get(inputType.cast<VectorType>().getShape(),
+ elementalType);
+ }
+}
--- /dev/null
+//===- ConvertConst.cpp - Quantizes constant ops --------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Quantization/Passes.h"
+#include "mlir/Quantization/QuantOps.h"
+#include "mlir/Quantization/QuantizeUtils.h"
+#include "mlir/Quantization/UniformSupport.h"
+#include "mlir/StandardOps/Ops.h"
+
+using namespace mlir;
+using namespace mlir::quant;
+
+namespace {
+
+class ConvertConstPass : public FunctionPass<ConvertConstPass> {
+public:
+ void runOnFunction() override;
+};
+
+class QuantizedConstRewrite : public RewritePattern {
+public:
+ struct State : PatternState {
+ QuantizedType quantizedElementType;
+ Attribute value;
+ };
+
+ QuantizedConstRewrite(MLIRContext *context)
+ : RewritePattern(QuantizeBarrierOp::getOperationName(), 1, context) {}
+
+ PatternMatchResult match(Operation *op) const override;
+ void rewrite(Operation *op, std::unique_ptr<PatternState> baseState,
+ PatternRewriter &rewriter) const override;
+};
+
+} // end anonymous namespace
+
+/// Matches a [constant] -> [qbarrier] where the qbarrier results type is
+/// quantized and the operand type is quantizable.
+PatternMatchResult QuantizedConstRewrite::match(Operation *op) const {
+ State state;
+
+ // Is the operand a constant?
+ auto qbarrier = op->cast<QuantizeBarrierOp>();
+ if (!matchPattern(qbarrier.arg(), m_Constant(&state.value))) {
+ return matchFailure();
+ }
+ // Does the qbarrier convert to a quantized type. This will not be true
+ // if a quantized type has not yet been chosen or if the cast to an equivalent
+ // storage type is not supported.
+ Type qbarrierResultType = qbarrier.getResult()->getType();
+ state.quantizedElementType =
+ QuantizedType::getQuantizedElementType(qbarrierResultType);
+ if (!state.quantizedElementType) {
+ return matchFailure();
+ }
+ if (!QuantizedType::castToStorageType(qbarrierResultType)) {
+ return matchFailure();
+ }
+
+ // Is the operand type compatible with the expressed type of the quantized
+ // type? This will not be true if the qbarrier is superfluous (converts
+ // from and to a quantized type).
+ if (!state.quantizedElementType.isCompatibleExpressedType(
+ qbarrier.arg()->getType())) {
+ return matchFailure();
+ }
+
+ // Is the constant value a type expressed in a way that we support?
+ if (!state.value.isa<FloatAttr>() && !state.value.isa<SplatElementsAttr>() &&
+ !state.value.isa<DenseElementsAttr>() &&
+ !state.value.isa<SparseElementsAttr>()) {
+ return matchFailure();
+ }
+
+ return matchSuccess(llvm::make_unique<State>(std::move(state)));
+}
+
+void QuantizedConstRewrite::rewrite(Operation *op,
+ std::unique_ptr<PatternState> baseState,
+ PatternRewriter &rewriter) const {
+ auto state = static_cast<State *>(baseState.get());
+
+ Type newConstValueType;
+ Attribute newConstValue = quantizeAttr(
+ state->value, state->quantizedElementType, newConstValueType);
+ if (!newConstValue) {
+ return;
+ }
+
+ auto *origConstOp = op->getOperand(0);
+ // When creating the new const op, use a fused location that combines the
+ // original const and the qbarrier that led to the quantization.
+ auto fusedLoc =
+ FusedLoc::get({origConstOp->getDefiningOp()->getLoc(), op->getLoc()},
+ rewriter.getContext());
+ auto newConstOp =
+ rewriter.create<ConstantOp>(fusedLoc, newConstValueType, newConstValue);
+ rewriter.replaceOpWithNewOp<StorageCastOp>(
+ op, {origConstOp}, *op->result_type_begin(), newConstOp);
+}
+
+void ConvertConstPass::runOnFunction() {
+ OwningRewritePatternList patterns;
+ auto &func = getFunction();
+ auto *context = &getContext();
+ patterns.push_back(llvm::make_unique<QuantizedConstRewrite>(context));
+ applyPatternsGreedily(func, std::move(patterns));
+}
+
+FunctionPassBase *createConvertConstPass() { return new ConvertConstPass(); }
+
+static PassRegistration<ConvertConstPass>
+ pass("quant-convert-const",
+ "Converts constants followed by qbarrier to actual quantized values");
--- /dev/null
+//===- LowerTF.cpp - Passes for lowering from TensorFlow ------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Quantization/FakeQuantSupport.h"
+#include "mlir/Quantization/Passes.h"
+#include "mlir/Quantization/QuantOps.h"
+#include "mlir/Quantization/UniformSupport.h"
+#include "mlir/TensorFlow/TFOps.h"
+
+using namespace mlir;
+using namespace mlir::quant;
+
+namespace {
+
+class LowerTFPass : public FunctionPass<LowerTFPass> {
+public:
+ void runOnFunction() override;
+};
+
+} // end anonymous namespace
+
+/// Rewrites TensorFlow FakeQuantWithMinMaxArgs into a qbarrier/dbarrier pair.
+class FakeQuantWithMinMaxArgsRewrite : public RewritePattern {
+public:
+ bool *hadFailure;
+
+ FakeQuantWithMinMaxArgsRewrite(MLIRContext *context, bool *hadFailure)
+ : RewritePattern(TF::FakeQuantWithMinMaxArgsOp::getOperationName(), 1,
+ context),
+ hadFailure(hadFailure) {}
+
+ PatternMatchResult match(Operation *op) const override {
+ return matchSuccess();
+ }
+
+ void rewrite(Operation *op, PatternRewriter &rewriter) const override {
+ // TODO: If this pattern comes up more frequently, consider adding core
+ // support for failable rewrites.
+ if (failableRewrite(op, rewriter)) {
+ *hadFailure = true;
+ }
+ }
+
+ bool failableRewrite(Operation *op, PatternRewriter &rewriter) const {
+ auto fqOp = op->template cast<TF::FakeQuantWithMinMaxArgsOp>();
+
+ auto converter =
+ ExpressedToUniformQuantizedConverter::forInputType(fqOp.getType());
+ if (!converter) {
+ return (op->emitError("unsupported quantized type conversion"), true);
+ }
+
+ UniformQuantizedType uniformElementType = fakeQuantAttrsToType(
+ fqOp.getLoc(), fqOp.num_bits().getSExtValue(),
+ fqOp.min().convertToDouble(), fqOp.max().convertToDouble(),
+ fqOp.narrow_range(), converter.expressedType);
+
+ if (!uniformElementType) {
+ // Note that the fakeQuantAttrsToType will have emitted the error.
+ return true;
+ }
+
+ Type quantizedType = converter.convert(uniformElementType);
+ assert(quantizedType &&
+ "Converter accepted a type that it did not convert");
+
+ // TODO: Map to a qbarrier with an attribute like [Forced] to signal that
+ // this is a forced/hard-coded constraint.
+ auto qbarrier = rewriter.create<QuantizeBarrierOp>(
+ op->getLoc(), quantizedType, fqOp.inputs());
+ rewriter.replaceOpWithNewOp<DequantizeBarrierOp>(op, converter.inputType,
+ qbarrier.getResult());
+
+ return false;
+ }
+};
+
+void LowerTFPass::runOnFunction() {
+ bool hadFailure = false;
+ OwningRewritePatternList patterns;
+ auto &func = getFunction();
+ auto *context = &getContext();
+ patterns.push_back(
+ llvm::make_unique<FakeQuantWithMinMaxArgsRewrite>(context, &hadFailure));
+ applyPatternsGreedily(func, std::move(patterns));
+ if (hadFailure)
+ signalPassFailure();
+}
+
+FunctionPassBase *createLowerTFPass() { return new LowerTFPass(); }
+
+static PassRegistration<LowerTFPass>
+ pass("quant-lower-tf",
+ "Lowers TensorFlow constraint ops to the quantization dialect");
--- /dev/null
+//===- LowerUniformRealMath.cpp ------------------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Quantization/Passes.h"
+#include "mlir/Quantization/QuantOps.h"
+#include "mlir/Quantization/UniformSupport.h"
+
+#include <functional>
+
+using namespace mlir;
+using namespace mlir::quant;
+
+namespace {
+
+struct LowerUniformRealMathPass
+ : public FunctionPass<LowerUniformRealMathPass> {
+ void runOnFunction() override;
+};
+
+UniformQuantizedType getUniformElementType(Type t) {
+ return QuantizedType::getQuantizedElementType(t)
+ .dyn_cast_or_null<UniformQuantizedType>();
+}
+
+/// Computes the log2(x), rounded to an integral value. Returns whether 'x' can
+/// be considered an exact integral value.
+template <typename F> bool integralLog2(F x, int &log2Result) {
+ const F xLog2 = std::log(x) * (1.0 / std::log(2.0));
+ const F xLog2Rounded = std::round(xLog2);
+ const F xLog2Frac = xLog2 - xLog2Rounded;
+ log2Result = static_cast<int>(xLog2Rounded);
+ // Allow small comparison slop below the level that would make a difference
+ // for 2^16 levels.
+ return std::abs(xLog2Frac) < 1e-6;
+}
+
+/// Helper class for operating on binary operations where all operands
+/// and the result are a UniformQuantizedType.
+struct RealBinaryOpInfo {
+ RealBinaryOpInfo(Operation *op, Value *lhs, Value *rhs,
+ Optional<APFloat> clampMin, Optional<APFloat> clampMax)
+ : op(op), lhs(lhs), rhs(rhs), clampMin(clampMin), clampMax(clampMax),
+ lhsType(getUniformElementType(lhs->getType())),
+ rhsType(getUniformElementType(rhs->getType())),
+ resultType(getUniformElementType(*op->result_type_begin())),
+ lhsStorageType(QuantizedType::castToStorageType(lhs->getType())),
+ rhsStorageType(QuantizedType::castToStorageType(rhs->getType())),
+ resultStorageType(
+ QuantizedType::castToStorageType(*op->result_type_begin())) {}
+
+ /// Returns whether this info is valid (all types defined, etc).
+ bool isValid() const {
+ return lhsType && rhsType && resultType && lhsStorageType &&
+ rhsStorageType && resultStorageType;
+ }
+
+ /// Returns whether the storage type of all operands is identical.
+ bool isSameStorageType() const {
+ return lhsType.getStorageType() == rhsType.getStorageType() &&
+ lhsType.getStorageType() == resultType.getStorageType();
+ }
+
+ /// Returns whether all operands and result are considered fixedpoint power
+ /// of two, setting the lhs, rhs, and result log2 scale references.
+ bool isFixedPointPOT(int &lhsLog2Scale, int &rhsLog2Scale,
+ int &resultLog2Scale) const {
+ if (!lhsType.isFixedPoint() || !rhsType.isFixedPoint() ||
+ !resultType.isFixedPoint()) {
+ return false;
+ }
+
+ if (!integralLog2(lhsType.getScale(), lhsLog2Scale) ||
+ !integralLog2(rhsType.getScale(), rhsLog2Scale) ||
+ !integralLog2(resultType.getScale(), resultLog2Scale)) {
+ return false;
+ }
+
+ return true;
+ }
+
+ /// Gets the result integer clamp range given the result quantized type
+ // and any explicit clamp provided as attributes.
+ std::pair<IntegerAttr, IntegerAttr> getClampMinMax() const {
+ int64_t typeMin = resultType.getStorageTypeMin();
+ int64_t typeMax = resultType.getStorageTypeMax();
+
+ if (clampMin || clampMax) {
+ UniformQuantizedValueConverter conv(resultType);
+ if (clampMin) {
+ typeMin = std::max(typeMin, conv.quantizeFloatToInt64(*clampMin));
+ }
+ if (clampMax) {
+ typeMax = std::min(typeMax, conv.quantizeFloatToInt64(*clampMax));
+ }
+ }
+
+ // The quantized, integral ops expect clamps as 32bit ints.
+ return {
+ IntegerAttr::get(IntegerType::get(32, resultType.getContext()),
+ typeMin),
+ IntegerAttr::get(IntegerType::get(32, resultType.getContext()),
+ typeMax),
+ };
+ }
+
+ Operation *op;
+ Value *lhs;
+ Value *rhs;
+ Optional<APFloat> clampMin;
+ Optional<APFloat> clampMax;
+
+ // Element UniformQuantizedType for operands/result.
+ UniformQuantizedType lhsType;
+ UniformQuantizedType rhsType;
+ UniformQuantizedType resultType;
+
+ // Full storage-based types.
+ Type lhsStorageType;
+ Type rhsStorageType;
+ Type resultStorageType;
+};
+
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// Elementwise add
+//===----------------------------------------------------------------------===//
+/// Attempts to rewrite a fixed point power-of-two addition of two integers.
+/// This supports a limited number of cases, but when supported, represents
+/// the simplest computation.
+static LogicalResult tryRewriteFixedPOTAddEw(const RealBinaryOpInfo &constInfo,
+ PatternRewriter &rewriter) {
+ if (!constInfo.isSameStorageType()) {
+ return failure();
+ }
+
+ int lhsLog2Scale;
+ int rhsLog2Scale;
+ int resultLog2Scale;
+ if (!constInfo.isFixedPointPOT(lhsLog2Scale, rhsLog2Scale, resultLog2Scale)) {
+ return failure();
+ }
+
+ // Adjust shifts to be relative to the output.
+ // Left shift of one input scale is supported. The other must match the result
+ // scale.
+ int lhsScaleShift = lhsLog2Scale - resultLog2Scale;
+ int rhsScaleShift = rhsLog2Scale - resultLog2Scale;
+ if (lhsScaleShift != 0 && rhsScaleShift != 0) {
+ return failure();
+ }
+ if (lhsScaleShift > 0 || rhsScaleShift > 0) {
+ return failure();
+ }
+
+ // State accessed by the closure.
+ Operation *mathOp = constInfo.op;
+ const auto clampMinMax = constInfo.getClampMinMax();
+ Value *lhs = constInfo.lhs;
+ Value *rhs = constInfo.rhs;
+ Type lhsStorageType = constInfo.lhsStorageType;
+ Type rhsStorageType = constInfo.rhsStorageType;
+
+ // If the lhs operand is the one requiring a shift, swap it so that the shift
+ // happens the rhs operand.
+ if (lhsScaleShift != 0) {
+ std::swap(lhs, rhs);
+ std::swap(lhsStorageType, rhsStorageType);
+ std::swap(lhsScaleShift, rhsScaleShift);
+ }
+ int rhsRightShift = -rhsScaleShift;
+
+ // Cast operands to storage type.
+ Value *lhsStorageValue =
+ rewriter.create<StorageCastOp>(mathOp->getLoc(), lhsStorageType, lhs)
+ .getResult();
+ Value *rhsStorageValue =
+ rewriter.create<StorageCastOp>(mathOp->getLoc(), rhsStorageType, rhs)
+ .getResult();
+
+ // Rescale the rhs operand if needed.
+ if (rhsRightShift != 0) {
+ rhsStorageValue =
+ rewriter
+ .create<RoundingDivideByPotIOp>(
+ mathOp->getLoc(), rhsStorageValue,
+ IntegerAttr::get(IntegerType::get(32, rewriter.getContext()),
+ rhsRightShift))
+ .getResult();
+ }
+
+ // Add.
+ Value *sumValue = rewriter.create<SaturatingAddIOp>(
+ mathOp->getLoc(), lhsStorageValue, rhsStorageValue, clampMinMax.first,
+ clampMinMax.second);
+
+ // Cast back for new result.
+ rewriter.replaceOpWithNewOp<StorageCastOp>(
+ mathOp, *mathOp->result_type_begin(), sumValue);
+ return success();
+}
+
+namespace {
+
+struct UniformRealAddEwPattern : public RewritePattern {
+ UniformRealAddEwPattern(MLIRContext *context)
+ : RewritePattern(RealAddEwOp::getOperationName(), 1, context) {}
+
+ PatternMatchResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const {
+ auto addOp = op->cast<RealAddEwOp>();
+ const RealBinaryOpInfo info(op, addOp.x(), addOp.y(), addOp.clamp_min(),
+ addOp.clamp_max());
+ if (!info.isValid()) {
+ return matchFailure();
+ }
+
+ // Try all of the permutations we support.
+ if (succeeded(tryRewriteFixedPOTAddEw(info, rewriter))) {
+ return matchSuccess();
+ }
+
+ return matchFailure();
+ }
+};
+
+} // end anonymous namespace
+
+void LowerUniformRealMathPass::runOnFunction() {
+ auto &fn = getFunction();
+ OwningRewritePatternList patterns;
+ auto *context = &getContext();
+ patterns.push_back(llvm::make_unique<UniformRealAddEwPattern>(context));
+ applyPatternsGreedily(fn, std::move(patterns));
+}
+
+FunctionPassBase *createLowerUniformRealMathPass() {
+ return new LowerUniformRealMathPass();
+}
+
+static PassRegistration<LowerUniformRealMathPass>
+ pass("quant-lower-uniform-real-math",
+ "Lowers uniform-quantized real math ops to integer arithmetic.");
--- /dev/null
+//===- QuantizeUtils.cpp - Support utilities for quantization -------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Quantization/QuantizeUtils.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Quantization/UniformSupport.h"
+
+namespace mlir {
+namespace quant {
+/// Converts a possible primitive, real expressed value attribute to a
+/// corresponding storage attribute (typically FloatAttr -> IntegerAttr).
+/// quantizedElementType is the QuantizedType that describes the expressed
+/// origValue.
+/// Returns a converter Attribute or nullptr if conversion is not possible.
+static Attribute convertPrimitiveValueAttr(
+ Attribute origRealValue, QuantizedType quantizedElementType,
+ const UniformQuantizedValueConverter &converter, Type &outConvertedType) {
+ if (origRealValue.isa<FloatAttr>()) {
+ FloatAttr floatAttr = origRealValue.cast<FloatAttr>();
+ outConvertedType = quantizedElementType.getStorageType();
+ return IntegerAttr::get(quantizedElementType.getStorageType(),
+ converter.quantizeFloatToInt(floatAttr.getValue()));
+ }
+
+ return nullptr;
+}
+
+/// Converts a real expressed DenseFPElementsAttr to a corresponding
+/// DenseElementsAttr (typically DenseIntElementsAttr) containing quantized
+/// storage values assuming the given quantizedElementType and converter.
+static DenseElementsAttr
+convertDenseFPElementsAttr(DenseFPElementsAttr realFPElementsAttr,
+ QuantizedType quantizedElementType,
+ const UniformQuantizedValueConverter &converter) {
+ // Read real expressed values.
+ SmallVector<APFloat, 8> realValues;
+ realValues.reserve(realFPElementsAttr.getType().getNumElements());
+ realFPElementsAttr.getValues(realValues);
+
+ // Convert to corresponding quantized value attributes.
+ SmallVector<APInt, 8> quantValues(realValues.size());
+ for (size_t i = 0, e = realValues.size(); i < e; ++i) {
+ quantValues[i] = converter.quantizeFloatToInt(realValues[i]);
+ }
+
+ // Cast from an expressed-type-based type to storage-type-based type,
+ // preserving the dense shape (i.e. tensor<4xf32> -> tensor<4xi8>).
+ VectorOrTensorType newDenseType =
+ quantizedElementType
+ .castExpressedToStorageType(realFPElementsAttr.getType())
+ .dyn_cast_or_null<VectorOrTensorType>();
+ if (!newDenseType) {
+ return nullptr;
+ }
+ return DenseIntElementsAttr::get(newDenseType, quantValues);
+}
+
+/// Converts a real expressed SplatElementsAttr to a corresponding
+/// SplatElementsAttr containing quantized storage values assuming the given
+/// quantizedElementType and converter.
+static SplatElementsAttr
+convertSplatElementsAttr(SplatElementsAttr realSplatAttr,
+ QuantizedType quantizedElementType,
+ const UniformQuantizedValueConverter &converter) {
+ // Since the splat just references a single primitive value, use the
+ // function for converting primitives.
+ // NOTE: When implementing per-channel, we will need to promote the
+ // splat to a dense and handle channels individually.
+ Type unusedPrimitiveType;
+ auto elementAttr =
+ convertPrimitiveValueAttr(realSplatAttr.getValue(), quantizedElementType,
+ converter, unusedPrimitiveType);
+ if (!elementAttr) {
+ return nullptr;
+ }
+
+ // Cast from an expressed-type-based type to storage-type-based type,
+ // preserving the splat shape (i.e. tensor<4xf32> -> tensor<4xi8>).
+ VectorOrTensorType newSplatType =
+ quantizedElementType.castExpressedToStorageType(realSplatAttr.getType())
+ .dyn_cast_or_null<VectorOrTensorType>();
+ if (!newSplatType) {
+ return nullptr;
+ }
+ return SplatElementsAttr::get(newSplatType, elementAttr);
+}
+
+/// Converts a real expressed SplatElementsAttr to a corresponding
+/// SplatElementsAttr containing quantized storage values assuming the given
+/// quantizedElementType and converter.
+static SparseElementsAttr
+convertSparseElementsAttr(SparseElementsAttr realSparseAttr,
+ QuantizedType quantizedElementType,
+ const UniformQuantizedValueConverter &converter) {
+ DenseElementsAttr realDenseAttr = realSparseAttr.getValues();
+ if (!realDenseAttr.isa<DenseFPElementsAttr>()) {
+ return nullptr;
+ }
+ DenseElementsAttr quantDenseAttr =
+ convertDenseFPElementsAttr(realDenseAttr.cast<DenseFPElementsAttr>(),
+ quantizedElementType, converter);
+ if (!quantDenseAttr) {
+ return nullptr;
+ }
+
+ // Cast from an expressed-type-based type to storage-type-based type,
+ // preserving the sparse shape (i.e. tensor<4xf32> -> tensor<4xi8>).
+ VectorOrTensorType newSparseType =
+ quantizedElementType.castExpressedToStorageType(realSparseAttr.getType())
+ .dyn_cast_or_null<VectorOrTensorType>();
+ if (!newSparseType) {
+ return nullptr;
+ }
+ return SparseElementsAttr::get(newSparseType, realSparseAttr.getIndices(),
+ quantDenseAttr);
+}
+
+/// Converts a real expressed Attribute to a corresponding Attribute containing
+/// quantized storage values assuming the given uniform quantizedElementType and
+/// converter.
+Attribute quantizeAttrUniform(Attribute realValue,
+ UniformQuantizedType quantizedElementType,
+ const UniformQuantizedValueConverter &converter,
+ Type &outConvertedType) {
+ // Fork to handle different variants of constants supported.
+ if (realValue.isa<SplatElementsAttr>()) {
+ // Splatted tensor or vector constant.
+ auto converted = convertSplatElementsAttr(
+ realValue.cast<SplatElementsAttr>(), quantizedElementType, converter);
+ outConvertedType = converted.getType();
+ return converted;
+ } else if (realValue.isa<DenseFPElementsAttr>()) {
+ // Dense tensor or vector constant.
+ auto converted = convertDenseFPElementsAttr(
+ realValue.cast<DenseFPElementsAttr>(), quantizedElementType, converter);
+ outConvertedType = converted.getType();
+ return converted;
+ } else if (realValue.isa<SparseElementsAttr>()) {
+ // Sparse tensor or vector constant.
+ auto converted = convertSparseElementsAttr(
+ realValue.cast<SparseElementsAttr>(), quantizedElementType, converter);
+ outConvertedType = converted.getType();
+ return converted;
+ } else {
+ // Nothing else matched: try to convert a primitive.
+ return convertPrimitiveValueAttr(realValue, quantizedElementType, converter,
+ outConvertedType);
+ }
+}
+
+/// Convert an attribute from a type based on
+/// quantizedElementType.getExpressedType() to one based on
+/// quantizedElementType.getStorageType().
+/// Returns nullptr if the conversion is not supported.
+/// On success, stores the converted type in outConvertedType.
+Attribute quantizeAttr(Attribute realValue, QuantizedType quantizedElementType,
+ Type &outConvertedType) {
+ // Hard-coded to just support UniformQuantizedType. This will need to
+ // be generalized when there is more than one.
+ auto uniformQuantizedType =
+ quantizedElementType.dyn_cast<UniformQuantizedType>();
+ if (!uniformQuantizedType) {
+ return nullptr;
+ }
+ UniformQuantizedValueConverter converter(uniformQuantizedType);
+ return quantizeAttrUniform(realValue, uniformQuantizedType, converter,
+ outConvertedType);
+}
+
+} // namespace quant
+} // namespace mlir
--- /dev/null
+// RUN: mlir-opt %s -split-input-file -quant-convert-const | FileCheck %s --dump-input=fail
+
+// Magic numbers:
+// 7.8125e-03 = 1/128 = 2/256 : real range = [-1.0, 0.9921875] (for 8bit, zeroPoint=128)
+// 1.250000e-01 = 1/8 = 2/16 : real range = [-1.0, 0.875] (for 4bit, zeroPoint=8)
+
+// -----
+// Verifies u8 affine quantization on a splat tensor.
+// Note that MLIR prints int attributes as signed, so the constant, when
+// quantized, is the signed printed version of an unsigned quantity
+// (-64 signed == 192 unsigned).
+// CHECK-LABEL: constant_splat_tensor_u8_affine
+func @constant_splat_tensor_u8_affine() -> tensor<4xf32> {
+ // CHECK: %cst = constant splat<tensor<4xi8>, -64> : tensor<4xi8>
+ // CHECK-NEXT: %0 = "quant.scast"(%cst) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[u8:f32]{7.812500e-03:128}">>
+ %cst = constant splat<tensor<4xf32>, 0.5> : tensor<4xf32>
+ %1 = "quant.qbarrier"(%cst) : (tensor<4xf32>) -> tensor<4x!quant<"uniform[u8:f32]{7.812500e-03:128}">>
+ %2 = "quant.dbarrier"(%1) : (tensor<4x!quant<"uniform[u8:f32]{7.812500e-03:128}">>) -> (tensor<4xf32>)
+ return %2 : tensor<4xf32>
+}
+
+// -----
+// Verifies i8 affine quantization on a splat tensor.
+// CHECK-LABEL: constant_splat_tensor_i8_affine
+func @constant_splat_tensor_i8_affine() -> tensor<4xf32> {
+ // CHECK: %cst = constant splat<tensor<4xi8>, 63> : tensor<4xi8>
+ // CHECK-NEXT: %0 = "quant.scast"(%cst) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{7.812500e-03:-1}">>
+ %cst = constant splat<tensor<4xf32>, 0.5> : tensor<4xf32>
+ %1 = "quant.qbarrier"(%cst) : (tensor<4xf32>) -> tensor<4x!quant<"uniform[i8:f32]{7.812500e-03:-1}">>
+ %2 = "quant.dbarrier"(%1) : (tensor<4x!quant<"uniform[i8:f32]{7.812500e-03:-1}">>) -> (tensor<4xf32>)
+ return %2 : tensor<4xf32>
+}
+
+// -----
+// Verifies i8 fixedpoint quantization on a splat tensor.
+// CHECK-LABEL: const_splat_tensor_i8_fixedpoint
+func @const_splat_tensor_i8_fixedpoint() -> tensor<4xf32> {
+ // CHECK: %cst = constant splat<tensor<4xi8>, 64> : tensor<4xi8>
+ // CHECK-NEXT: %0 = "quant.scast"(%cst) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>
+ %cst = constant splat<tensor<4xf32>, 0.5> : tensor<4xf32>
+ %1 = "quant.qbarrier"(%cst) : (tensor<4xf32>) -> tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>
+ %2 = "quant.dbarrier"(%1) : (tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> (tensor<4xf32>)
+ return %2 : tensor<4xf32>
+}
+
+// -----
+// Verifies i8 fixedpoint quantization on a splat tensor resulting in a negative storage value.
+// CHECK-LABEL: const_splat_tensor_i8_fixedpoint_neg
+func @const_splat_tensor_i8_fixedpoint_neg() -> tensor<4xf32> {
+ // CHECK: %cst = constant splat<tensor<4xi8>, -64> : tensor<4xi8>
+ %cst = constant splat<tensor<4xf32>, -0.5> : tensor<4xf32>
+ %1 = "quant.qbarrier"(%cst) : (tensor<4xf32>) -> tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>
+ %2 = "quant.dbarrier"(%1) : (tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> (tensor<4xf32>)
+ return %2 : tensor<4xf32>
+}
+
+// -----
+// Verifies i8 fixedpoint quantization on a dense tensor, sweeping values.
+// CHECK-LABEL: const_dense_tensor_i8_fixedpoint
+func @const_dense_tensor_i8_fixedpoint() -> tensor<7xf32> {
+ // CHECK: %cst = constant dense<tensor<7xi8>, [-128, -128, -64, 0, 64, 127, 127]> : tensor<7xi8>
+ %cst = constant dense<tensor<7xf32>, [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<7xf32>
+ %1 = "quant.qbarrier"(%cst) : (tensor<7xf32>) -> tensor<7x!quant<"uniform[i8:f32]{7.812500e-03}">>
+ %2 = "quant.dbarrier"(%1) : (tensor<7x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> (tensor<7xf32>)
+ return %2 : tensor<7xf32>
+}
+
+// -----
+// Verifies i8 fixedpoint quantization on a sparse tensor, sweeping values.
+// CHECK-LABEL: const_sparse_tensor_i8_fixedpoint
+func @const_sparse_tensor_i8_fixedpoint() -> tensor<7x2xf32> {
+ // NOTE: Ugly regex match pattern for opening "[[" of indices tensor.
+ // CHECK: %cst = constant sparse<tensor<7x2xi8>, {{\[}}[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6]], [-128, -128, -64, 0, 64, 127, 127]> : tensor<7x2xi8>
+ %cst = constant sparse<tensor<7x2xf32>,
+ [[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6]],
+ [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<7x2xf32>
+ %1 = "quant.qbarrier"(%cst) : (tensor<7x2xf32>) -> tensor<7x2x!quant<"uniform[i8:f32]{7.812500e-03}">>
+ %2 = "quant.dbarrier"(%1) : (tensor<7x2x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> (tensor<7x2xf32>)
+ return %2 : tensor<7x2xf32>
+}
+
+// -----
+// Verifies i8 fixedpoint quantization on a primitive const.
+// CHECK-LABEL: const_primitive_float_i8_fixedpoint
+func @const_primitive_float_i8_fixedpoint() -> f32 {
+ // CHECK: %c64_i8 = constant 64 : i8
+ // CHECK-NEXT: %0 = "quant.scast"(%c64_i8) : (i8) -> !quant<"uniform[i8:f32]{7.812500e-03}">
+ %cst = constant 0.5 : f32
+ %1 = "quant.qbarrier"(%cst) : (f32) -> !quant<"uniform[i8:f32]{7.812500e-03}">
+ %2 = "quant.dbarrier"(%1) : (!quant<"uniform[i8:f32]{7.812500e-03}">) -> (f32)
+ return %2 : f32
+}
+
+// -----
+// Verifies u4 affine quantization on a dense tensor, sweeping values.
+// CHECK-LABEL: const_dense_tensor_u4_affine
+func @const_dense_tensor_u4_affine() -> tensor<7xf32> {
+ // NOTE: Unsigned quantities printed by MLIR as signed.
+ // CHECK: %cst = constant dense<tensor<7xi4>, [0, 0, 4, -8, -4, -1, -1]> : tensor<7xi4>
+ %cst = constant dense<tensor<7xf32>, [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<7xf32>
+ %1 = "quant.qbarrier"(%cst) : (tensor<7xf32>) -> tensor<7x!quant<"uniform[u4:f32]{1.250000e-01:8}">>
+ %2 = "quant.dbarrier"(%1) : (tensor<7x!quant<"uniform[u4:f32]{1.250000e-01:8}">>) -> (tensor<7xf32>)
+ return %2 : tensor<7xf32>
+}
+
+// -----
+// Verifies i4 affine quantization on a dense tensor, sweeping values.
+// CHECK-LABEL: const_dense_tensor_i4_affine
+func @const_dense_tensor_i4_affine() -> tensor<7xf32> {
+ // NOTE: Unsigned quantities printed by MLIR as signed.
+ // CHECK: %cst = constant dense<tensor<7xi4>, [-8, -8, -5, -1, 3, 7, 7]> : tensor<7xi4>
+ %cst = constant dense<tensor<7xf32>, [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<7xf32>
+ %1 = "quant.qbarrier"(%cst) : (tensor<7xf32>) -> tensor<7x!quant<"uniform[i4:f32]{1.250000e-01:-1}">>
+ %2 = "quant.dbarrier"(%1) : (tensor<7x!quant<"uniform[i4:f32]{1.250000e-01:-1}">>) -> (tensor<7xf32>)
+ return %2 : tensor<7xf32>
+}
+
+// -----
+// Verifies i4 fixed point quantization on a dense tensor, sweeping values.
+// CHECK-LABEL: const_dense_tensor_i4_fixedpoint
+func @const_dense_tensor_i4_fixedpoint() -> tensor<7xf32> {
+ // CHECK: %cst = constant dense<tensor<7xi4>, [-8, -8, -4, 0, 4, 7, 7]> : tensor<7xi4>
+ %cst = constant dense<tensor<7xf32>, [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<7xf32>
+ %1 = "quant.qbarrier"(%cst) : (tensor<7xf32>) -> tensor<7x!quant<"uniform[i4:f32]{1.250000e-01}">>
+ %2 = "quant.dbarrier"(%1) : (tensor<7x!quant<"uniform[i4:f32]{1.250000e-01}">>) -> (tensor<7xf32>)
+ return %2 : tensor<7xf32>
+}
+
+// -----
+// Verifies i8 fixedpoint quantization on a dense tensor, sweeping values, and
+// custom storage range. (the -128 should be clamped to -100, and the 127 should
+// be clamped to 100).
+// CHECK-LABEL: const_custom_storage_range_i8_fixedpoint
+func @const_custom_storage_range_i8_fixedpoint() -> tensor<7xf32> {
+ // CHECK: %cst = constant dense<tensor<7xi8>, [-100, -100, -64, 0, 64, 100, 100]> : tensor<7xi8>
+ %cst = constant dense<tensor<7xf32>, [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<7xf32>
+ %1 = "quant.qbarrier"(%cst) : (tensor<7xf32>) -> tensor<7x!quant<"uniform[i8(-100:100):f32]{7.812500e-03}">>
+ %2 = "quant.dbarrier"(%1) : (tensor<7x!quant<"uniform[i8(-100:100):f32]{7.812500e-03}">>) -> (tensor<7xf32>)
+ return %2 : tensor<7xf32>
+}
--- /dev/null
+// RUN: mlir-opt %s -split-input-file -quant-lower-uniform-real-math | FileCheck %s --dump-input=fail
+
+// -----
+// Verify lowering when operands and result have the same fixedpoint pot scale.
+// CHECK-LABEL: real_addew_fixedpoint_same_scale
+// CHECK: %0 = "quant.scast"(%arg0) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8>
+// CHECK-NEXT: %1 = "quant.scast"(%arg1) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8>
+// CHECK-NEXT: %2 = "quant.saturating_addi"(%0, %1) {clamp_max: 127 : i32, clamp_min: -128 : i32} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
+// CHECK-NEXT: %3 = "quant.scast"(%2) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
+// CHECK-NEXT: return %3 : tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
+!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
+!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
+!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
+func @real_addew_fixedpoint_same_scale(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
+ %0 = "quant.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
+ return %0 : !type_result
+}
+
+// -----
+// Verify lowering when the rhs is a shifted pot scale compared to lhs and result.
+// CHECK-LABEL: real_addew_fixedpoint_rhs_shift
+// CHECK: %0 = "quant.scast"(%arg0) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8>
+// CHECK-NEXT: %1 = "quant.scast"(%arg1) : (tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> tensor<4xi8>
+// CHECK-NEXT: %2 = "quant.rounding_divide_by_poti"(%1) {exponent: 3 : i32} : (tensor<4xi8>) -> tensor<4xi8>
+// CHECK-NEXT: %3 = "quant.saturating_addi"(%0, %2) {clamp_max: 127 : i32, clamp_min: -128 : i32} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
+// CHECK-NEXT: %4 = "quant.scast"(%3) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
+// CHECK-NEXT: return %4 : tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
+!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
+!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{7.8125e-03}">>
+!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
+func @real_addew_fixedpoint_rhs_shift(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
+ %0 = "quant.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
+ return %0 : !type_result
+}
+
+// -----
+// Verify lowering when the lhs is a shifted pot scale compared to lhs and result.
+// CHECK-LABEL: real_addew_fixedpoint_lhs_shift
+// CHECK: %0 = "quant.scast"(%arg1) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8>
+// CHECK-NEXT: %1 = "quant.scast"(%arg0) : (tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> tensor<4xi8>
+// CHECK-NEXT: %2 = "quant.rounding_divide_by_poti"(%1) {exponent: 3 : i32} : (tensor<4xi8>) -> tensor<4xi8>
+// CHECK-NEXT: %3 = "quant.saturating_addi"(%0, %2) {clamp_max: 127 : i32, clamp_min: -128 : i32} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
+// CHECK-NEXT: %4 = "quant.scast"(%3) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
+// CHECK-NEXT: return %4 : tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
+!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{7.8125e-03}">>
+!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
+!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
+func @real_addew_fixedpoint_lhs_shift(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
+ %0 = "quant.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
+ return %0 : !type_result
+}
+
+// -----
+// The RHS quant parameters proscribe a range of [-8..8) so an explicit clamp
+// of [-4..4] should result in an integral clamp range of [-64..64].
+// CHECK-LABEL: real_addew_fixedpoint_clamp
+// CHECK: %0 = "quant.scast"(%arg1) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8>
+// CHECK-NEXT: %1 = "quant.scast"(%arg0) : (tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> tensor<4xi8>
+// CHECK-NEXT: %2 = "quant.rounding_divide_by_poti"(%1) {exponent: 3 : i32} : (tensor<4xi8>) -> tensor<4xi8>
+// CHECK-NEXT: %3 = "quant.saturating_addi"(%0, %2) {clamp_max: 64 : i32, clamp_min: -64 : i32} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
+// CHECK-NEXT: %4 = "quant.scast"(%3) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
+// CHECK-NEXT: return %4 : tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
+!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{7.8125e-03}">>
+!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
+!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
+func @real_addew_fixedpoint_clamp(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
+ %0 = "quant.real_add_ew"(%arg0, %arg1) { clamp_min:-4.0, clamp_max:4.0 }
+ : (!type_lhs, !type_rhs) -> (!type_result)
+ return %0 : !type_result
+}
+
+// -----
+// CHECK-LABEL: real_addew_unquantized_lhs
+// Verifies that leaves as-is for unquantized lhs.
+!type_lhs = type tensor<4xf32>
+!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
+!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
+func @real_addew_unquantized_lhs(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
+ // CHECK: %0 = "quant.real_add_ew"(%arg0, %arg1)
+ %0 = "quant.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
+ return %0 : !type_result
+}
+
+// -----
+// CHECK-LABEL: real_addew_unquantized_rhs
+// Verifies that leaves as-is for unquantized rhs.
+!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
+!type_rhs = type tensor<4xf32>
+!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
+func @real_addew_unquantized_rhs(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
+ // CHECK: %0 = "quant.real_add_ew"(%arg0, %arg1)
+ %0 = "quant.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
+ return %0 : !type_result
+}
+
+// -----
+// CHECK-LABEL: real_addew_unquantized_result
+// Verifies that leaves as-is for unquantized result.
+!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
+!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
+!type_result = type tensor<4xf32>
+func @real_addew_unquantized_result(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
+ // CHECK: %0 = "quant.real_add_ew"(%arg0, %arg1)
+ %0 = "quant.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
+ return %0 : !type_result
+}
--- /dev/null
+// RUN: mlir-opt %s -split-input-file -verify
+
+// -----
+// Unknown type.
+// expected-error@+1 {{unknown quantized type foobar}}
+!qalias = type !quant<"foobar">
+
+// -----
+// Unrecognized token: illegal token
+// expected-error@+1 {{unrecognized token: %}}
+!qalias = type !quant<"%%">
+
+// -----
+// Unrecognized token: trailing
+// expected-error@+1 {{unrecognized token: 23}}
+!qalias = type !quant<"uniform[i8(-4:3):f32]{0.99872:127} 23">
+
+// -----
+// Unrecognized token: type open
+// expected-error@+1 {{unrecognized token: (}}
+!qalias = type !quant<"uniform(i8(-4:3):f32){0.99872:127}">
+
+// -----
+// Unrecognized token: missing storage type maximum
+// expected-error@+1 {{expected storage type maximum}}
+!qalias = type !quant<"uniform[i8(16:f32]{0.99872:127}">
+
+// -----
+// Unrecognized token: missing closing paren
+// expected-error@+1 {{unrecognized token: :}}
+!qalias = type !quant<"uniform[i8(-4:3:f32]{0.99872:127}">
+
+// -----
+// Unrecognized token: missing type colon
+// expected-error@+1 {{unrecognized token: f}}
+!qalias = type !quant<"uniform[i8(-4:3)f32]{0.99872:127}">
+
+// -----
+// Unrecognized token: missing closing bracket
+// expected-error@+1 {{unrecognized token: {}}
+!qalias = type !quant<"uniform[i8(-4:3):f32{0.99872:127}">
+
+// -----
+// Unrecognized token: wrong opening brace
+// expected-error@+1 {{unrecognized token: (}}
+!qalias = type !quant<"uniform[i8(-4:3):f32](0.99872:127}">
+
+// -----
+// Unrecognized storage type: illegal prefix
+// expected-error@+1 {{illegal storage type prefix: int}}
+!qalias = type !quant<"uniform[int8(-4:3):f32]{0.99872:127}">
+
+// -----
+// Unrecognized storage type: no width
+// expected-error@+1 {{expected storage type width}}
+!qalias = type !quant<"uniform[i(-4:3):f32]{0.99872:127}">
+
+// -----
+// Unrecognized storage type: storage size > 32
+// expected-error@+1 {{illegal storage type size: 33}}
+!qalias = type !quant<"uniform[i33:f32]{0.99872:127}">
+
+// -----
+// Unrecognized storage type: storage size < 0
+// expected-error@+1 {{illegal storage type size: -1}}
+!qalias = type !quant<"uniform[i-1(-4:3):f32]{0.99872:127}">
+
+// -----
+// Unrecognized storage type: storage size == 0
+// expected-error@+1 {{illegal storage type size: 0}}
+!qalias = type !quant<"uniform[i0(-4:3):f32]{0.99872:127}">
+
+// -----
+// Illegal storage min/max: max - min < 0
+// expected-error@+1 {{illegal storage min and storage max: (2:1)}}
+!qalias = type !quant<"uniform[i8(2:1):f32]{0.99872:127}">
+
+// -----
+// Illegal storage min/max: max - min == 0
+// expected-error@+1 {{illegal storage min and storage max: (1:1)}}
+!qalias = type !quant<"uniform[i8(1:1):f32]{0.99872:127}">
+
+// -----
+// Illegal storage min/max: max > defaultMax
+// expected-error@+1 {{illegal storage type maximum: 9}}
+!qalias = type !quant<"uniform[i4(-1:9):f32]{0.99872:127}">
+
+// -----
+// Illegal storage min/max: min < defaultMin
+// expected-error@+1 {{illegal storage type minimum: -9}}
+!qalias = type !quant<"uniform[i4(-9:1):f32]{0.99872:127}">
+
+// -----
+// Illegal uniform params: invalid scale
+// expected-error@+1 {{expected valid uniform scale. got: abc}}
+!qalias = type !quant<"uniform[i8(-4:3):f32]{abc:127}">
+
+// -----
+// Illegal uniform params: invalid zero point separator
+// expected-error@+1 {{unrecognized token: abc}}
+!qalias = type !quant<"uniform[i8(-4:3):f32]{0.1abc}">
+
+// -----
+// Illegal uniform params: missing zero point
+// expected-error@+1 {{expected integer uniform zero point. got: }}}
+!qalias = type !quant<"uniform[i8(-4:3):f32]{0.1:}">
+
+// -----
+// Illegal uniform params: invalid zero point
+// expected-error@+1 {{expected integer uniform zero point. got: abc}}
+!qalias = type !quant<"uniform[i8(-4:3):f32]{0.1:abc}">
+
+// -----
+// Illegal uniform params: missing closing brace
+// expected-error@+1 {{unrecognized token: )}}
+!qalias = type !quant<"uniform[i8(-4:3):f32]{0.1:0)">
+
+// -----
+// Illegal expressed type: f33
+// expected-error@+1 {{unrecognized expressed type: f33}}
+!qalias = type !quant<"uniform[i8(-4:3):f33]{0.99872:127}">
+
+// -----
+// Illegal scale: negative
+// expected-error@+1 {{illegal scale: -1.000000}}
+!qalias = type !quant<"uniform[i8(-4:3):f32]{-1.0:127}">
+
+// -----
+// Illegal uniform params: missing quantized dimension
+// expected-error@+1 {{expected quantized dimension}}
+!qalias = type !quant<"uniform[i8(-4:3):f32:]{2.000000e+02:-19.987200e-01:1}">
+
+// -----
+// Illegal uniform params: unspecified quantized dimension, when multiple scales
+// provided.
+// expected-error@+1 {{multiple scales/zeroPoints provided, but quantizedDimension wasn't specified}}
+!qalias = type !quant<"uniform[i8(-4:3):f32]{2.000000e+02,-19.987200e-01:1}">
--- /dev/null
+// RUN: mlir-opt %s -split-input-file | FileCheck %s
+
+// -----
+// All per-layer params specified:
+// [signed] storageType, storageTypeMin, storageTypeMax, expressedType, scale, zeroPoint
+// CHECK: !quant<"uniform[i8(-8:7):f32]{9.987200e-01:127}">
+!qalias = type !quant<"uniform[i8(-8:7):f32]{0.99872:127}">
+func @parse() -> !qalias {
+ %0 = "foo"() : () -> !qalias
+ return %0 : !qalias
+}
+
+// -----
+// Trailing whitespace.
+// CHECK: !quant<"uniform[i8(-8:7):f32]{9.987200e-01:127}">
+!qalias = type !quant<"uniform[i8(-8:7):f32]{0.99872:127} ">
+func @parse() -> !qalias {
+ %0 = "foo"() : () -> !qalias
+ return %0 : !qalias
+}
+
+// -----
+// Required per-layer params specified:
+// [unsigned] storageType, expressedType, scale
+// CHECK: !quant<"uniform[u8:f32]{9.987200e-01}">
+!qalias = type !quant<"uniform[u8:f32]{0.99872}">
+func @parse() -> !qalias {
+ %0 = "foo"() : () -> !qalias
+ return %0 : !qalias
+}
+
+// -----
+// Exponential scale (-)
+// CHECK: !quant<"uniform[u8:f32]{2.000000e-02}">
+!qalias = type !quant<"uniform[u8:f32]{2.0e-2}">
+func @parse() -> !qalias {
+ %0 = "foo"() : () -> !qalias
+ return %0 : !qalias
+}
+
+// -----
+// Exponential scale (+)
+// CHECK: !quant<"uniform[u8:f32]{2.000000e+02}">
+!qalias = type !quant<"uniform[u8:f32]{2.0e+2}">
+func @parse() -> !qalias {
+ %0 = "foo"() : () -> !qalias
+ return %0 : !qalias
+}
+
+// -----
+// Storage type: i16
+// CHECK: !quant<"uniform[i16:f32]{2.000000e+02}">
+!qalias = type !quant<"uniform[i16:f32]{2.0e+2}">
+func @parse() -> !qalias {
+ %0 = "foo"() : () -> !qalias
+ return %0 : !qalias
+}
+
+// -----
+// Storage type: u16
+// CHECK: !quant<"uniform[u16:f32]{2.000000e+02}">
+!qalias = type !quant<"uniform[u16:f32]{2.0e+2}">
+func @parse() -> !qalias {
+ %0 = "foo"() : () -> !qalias
+ return %0 : !qalias
+}
+
+// -----
+// Storage type: i32
+// CHECK: !quant<"uniform[i32:f32]{2.000000e+02}">
+!qalias = type !quant<"uniform[i32:f32]{2.0e+2}">
+func @parse() -> !qalias {
+ %0 = "foo"() : () -> !qalias
+ return %0 : !qalias
+}
+
+// -----
+// Storage type: u32
+// CHECK: !quant<"uniform[u32:f32]{2.000000e+02}">
+!qalias = type !quant<"uniform[u32:f32]{2.0e+2}">
+func @parse() -> !qalias {
+ %0 = "foo"() : () -> !qalias
+ return %0 : !qalias
+}
+
+// -----
+// Expressed type: f32
+// CHECK: !quant<"uniform[u8:f32]{2.000000e+02}">
+!qalias = type !quant<"uniform[u8:f32]{2.0e+2}">
+func @parse() -> !qalias {
+ %0 = "foo"() : () -> !qalias
+ return %0 : !qalias
+}
+
+// -----
+// Expressed type: f16
+// CHECK: !quant<"uniform[u8:f16]{2.000000e+02}">
+!qalias = type !quant<"uniform[u8:f16]{2.0e+2}">
+func @parse() -> !qalias {
+ %0 = "foo"() : () -> !qalias
+ return %0 : !qalias
+}
+
+// -----
+// Expressed type: f64
+// CHECK: !quant<"uniform[u8:f64]{2.000000e+02}">
+!qalias = type !quant<"uniform[u8:f64]{2.0e+2}">
+func @parse() -> !qalias {
+ %0 = "foo"() : () -> !qalias
+ return %0 : !qalias
+}
+
+// -----
+// Expressed type: bf16
+// CHECK: !quant<"uniform[u8:bf16]{2.000000e+02}">
+!qalias = type !quant<"uniform[u8:bf16]{2.0e+2}">
+func @parse() -> !qalias {
+ %0 = "foo"() : () -> !qalias
+ return %0 : !qalias
+}
+
+// -----
+// Per-axis scales and zero points (affine)
+// CHECK: !quant<"uniform[u8:f32:1]{2.000000e+02:-120,9.987200e-01:127}">
+!qalias = type !quant<"uniform[u8:f32:1]{2.0e+2:-120,0.99872:127}">
+func @parse() -> !qalias {
+ %0 = "foo"() : () -> !qalias
+ return %0 : !qalias
+}
+
+// -----
+// Per-axis scales and no zero points (fixedpoint)
+// CHECK: !quant<"uniform[i8:f32:1]{2.000000e+02,9.987200e-01}">
+!qalias = type !quant<"uniform[i8:f32:1]{2.0e+2,0.99872}">
+func @parse() -> !qalias {
+ %0 = "foo"() : () -> !qalias
+ return %0 : !qalias
+}
+
+// -----
+// Per-axis scales and zero points (mixed affine and fixedpoint)
+// CHECK: !quant<"uniform[i8:f32:1]{2.000000e+02,9.987200e-01:120}">
+!qalias = type !quant<"uniform[i8:f32:1]{2.0e+2,0.99872:120}">
+func @parse() -> !qalias {
+ %0 = "foo"() : () -> !qalias
+ return %0 : !qalias
+}
--- /dev/null
+// RUN: mlir-opt %s -split-input-file -verify -quant-lower-tf
+
+// -----
+// TODO(laurenzo): move this test to the TensorFlow/tf-ops-invalid.mlir
+// Verify that a mismatched range errors.
+func @fakeQuantArgs(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
+^bb0(%arg0: tensor<8x4x3xf32>):
+ // expected-error@+1 {{op range failed to straddle zero: [1.100000,1.500000]}}
+ %0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) {
+ min: 1.1, max: 1.5, num_bits: 8
+ } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
+ return %0 : tensor<8x4x3xf32>
+}
+
+// -----
+// Verify that a valid range errors.
+func @fakeQuantArgs(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
+^bb0(%arg0: tensor<8x4x3xf32>):
+ // expected-error@+1 {{op range is invalid: [1.100000,1.000000}}
+ %0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) {
+ min: 1.1, max: 1.0, num_bits: 8
+ } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
+ return %0 : tensor<8x4x3xf32>
+}
+
+// -----
+// TODO(laurenzo): move this test to the TensorFlow/tf-ops-invalid.mlir
+// Unsupported quantizable type (i1 is currently not a supported element type).
+func @fakeQuantArgs(tensor<8x4x3xi1>) -> tensor<8x4x3xi1> {
+^bb0(%arg0: tensor<8x4x3xi1>):
+ // expected-error@+1 {{op operand #0 must be tensor of 32-bit float values}}
+ %0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) {
+ min: 1.1, max: 1.0, num_bits: 8
+ } : (tensor<8x4x3xi1>) -> tensor<8x4x3xi1>
+ return %0 : tensor<8x4x3xi1>
+}
--- /dev/null
+// RUN: mlir-opt %s -split-input-file -quant-lower-tf | FileCheck %s --dump-input=fail
+
+// -----
+// Verifies a quint8 asymmetric 0..1 range.
+// CHECK-LABEL: fakeQuantArgs_Quint8_0_1
+func @fakeQuantArgs_Quint8_0_1(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
+^bb0(%arg0: tensor<8x4x3xf32>):
+ // CHECK: %0 = "quant.qbarrier"(%arg0) : (tensor<8x4x3xf32>)
+ // CHECK-SAME: -> tensor<8x4x3x!quant<"uniform[u8:f32]{0.0039215686274509803}">>
+ // CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor<8x4x3x!quant<"uniform[u8:f32]{0.0039215686274509803}">>)
+ // CHECK-SAME: -> tensor<8x4x3xf32>
+ %0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) {
+ min: 0.0, max: 1.0, num_bits: 8
+ } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
+ return %0 : tensor<8x4x3xf32>
+}
+
+// -----
+// Verifies a quint8 asymmetric 0..1 range (with narrow_range = true).
+// CHECK_LABEL: fakeQuantArgs_Quint8_NarrowRange
+func @fakeQuantArgs_Quint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
+^bb0(%arg0: tensor<8x4x3xf32>):
+ // CHECK: %0 = "quant.qbarrier"(%arg0) : (tensor<8x4x3xf32>)
+ // CHECK-SAME: -> tensor<8x4x3x!quant<"uniform[u8(1:255):f32]{0.003937007874015748:1}">>
+ // CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor<8x4x3x!quant<"uniform[u8(1:255):f32]{0.003937007874015748:1}">>)
+ // CHECK-SAME: -> tensor<8x4x3xf32>
+ %0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) {
+ min: 0.0, max: 1.0, num_bits: 8, narrow_range: true
+ } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
+ return %0 : tensor<8x4x3xf32>
+}
+
+// -----
+// Verifies a quint8 symmetric range of -1..127/128.
+// CHECK_LABEL: fakeQuantArgs_Quint8_SymmetricRange
+func @fakeQuantArgs_Quint8_SymmetricRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
+^bb0(%arg0: tensor<8x4x3xf32>):
+ // CHECK: %0 = "quant.qbarrier"(%arg0) : (tensor<8x4x3xf32>)
+ // CHECK-SAME: -> tensor<8x4x3x!quant<"uniform[u8:f32]{7.812500e-03:128}">>
+ // CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor<8x4x3x!quant<"uniform[u8:f32]{7.812500e-03:128}">>)
+ // CHECK-SAME: -> tensor<8x4x3xf32>
+ %0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) {
+ min: -1.0, max: 0.9921875, num_bits: 8, narrow_range: false
+ } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
+ return %0 : tensor<8x4x3xf32>
+}
+
+// -----
+// Verifies a commonly used -1..1 symmetric 16bit range with a zero point of
+// 0 and range -1.0 .. 32767/32768.
+// CHECK-LABEL: fakeQuantArgs_Qint16_Symmetric
+func @fakeQuantArgs_Qint16_Symmetric(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
+^bb0(%arg0: tensor<8x4x3xf32>):
+ // CHECK: %0 = "quant.qbarrier"(%arg0) : (tensor<8x4x3xf32>)
+ // CHECK-SAME: -> tensor<8x4x3x!quant<"uniform[i16:f32]{3.05175781185626E-5}">>
+ // CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor<8x4x3x!quant<"uniform[i16:f32]{3.05175781185626E-5}">>)
+ // CHECK-SAME: -> tensor<8x4x3xf32>
+ %0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) {
+ min: -1.0, max: 0.999969482, num_bits: 16
+ } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
+ return %0 : tensor<8x4x3xf32>
+}
+
+// -----
+// Verify that lowering to barriers of unranked tensors functions.
+// CHECK-LABEL: fakeQuantArgs_UnrankedTensor
+func @fakeQuantArgs_UnrankedTensor(tensor<f32>) -> tensor<f32> {
+^bb0(%arg0: tensor<f32>):
+ // CHECK: %0 = "quant.qbarrier"(%arg0) : (tensor<f32>)
+ // CHECK-SAME: -> tensor<!quant<"uniform[u8:f32]{0.0039215686274509803}">>
+ // CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor<!quant<"uniform[u8:f32]{0.0039215686274509803}">>)
+ // CHECK-SAME: -> tensor<f32>
+ %0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) {
+ min: 0.0, max: 1.0, num_bits: 8
+ } : (tensor<f32>) -> tensor<f32>
+ return %0 : tensor<f32>
+}
--- /dev/null
+//===- QuantizationUtilsTest.cpp - unit tests for quantization utils ------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Quantization/QuantizeUtils.h"
+#include "mlir/Quantization/UniformSupport.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+using namespace mlir::quant;
+
+namespace {
+
+// Test UniformQuantizedValueConverter converts all APFloat to a magic number 5.
+class TestUniformQuantizedValueConverter
+ : public UniformQuantizedValueConverter {
+public:
+ TestUniformQuantizedValueConverter(UniformQuantizedType type)
+ : UniformQuantizedValueConverter(type), qtype(type) {}
+ APInt quantizeFloatToInt(APFloat expressedValue) const {
+ return APInt(qtype.getStorageType().cast<IntegerType>().getWidth(), 5L);
+ }
+
+private:
+ UniformQuantizedType qtype;
+};
+
+Attribute getTestFloatAttr(double value, MLIRContext *ctx) {
+ return FloatAttr::get(FloatType::getF32(ctx), value);
+}
+
+template <typename ConcreteAttrClass, typename... Arg>
+ConcreteAttrClass getTestElementsAttr(MLIRContext *ctx, ArrayRef<int64_t> shape,
+ Arg... value) {
+ auto eleType = FloatType::getF32(ctx);
+ VectorOrTensorType tensorType;
+ if (shape.size() == 1 && shape[0] == -1) {
+ tensorType = UnrankedTensorType::get(eleType);
+ } else {
+ tensorType = RankedTensorType::get(shape, eleType);
+ }
+ return ConcreteAttrClass::get(tensorType, value...);
+}
+
+ElementsAttr getTestSparseElementsAttr(MLIRContext *ctx,
+ ArrayRef<int64_t> shape) {
+ auto eleType = FloatType::getF32(ctx);
+ VectorOrTensorType tensorType;
+ if (shape.size() == 1 && shape[0] == -1) {
+ tensorType = UnrankedTensorType::get(eleType);
+ } else {
+ tensorType = RankedTensorType::get(shape, eleType);
+ }
+ auto indicesType = RankedTensorType::get({1, 2}, IntegerType::get(64, ctx));
+ auto indices =
+ DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)});
+ auto valuesType = RankedTensorType::get({1}, eleType);
+ auto values = DenseFPElementsAttr::get(valuesType, {APFloat(0.0f)});
+ return SparseElementsAttr::get(tensorType, indices, values);
+}
+
+UniformQuantizedType getTestQuantizedType(Type storageType, MLIRContext *ctx) {
+ return UniformQuantizedType::get(/*flags=*/false, storageType,
+ FloatType::getF32(ctx), /*scale=*/1.0,
+ /*zeroPoint=*/0, /*storageTypeMin=*/0,
+ /*storageTypeMax=*/255);
+}
+
+TEST(QuantizationUtilsTest, convertFloatAttrUniform) {
+ MLIRContext ctx;
+ IntegerType convertedType = IntegerType::get(8, &ctx);
+ auto quantizedType = getTestQuantizedType(convertedType, &ctx);
+ TestUniformQuantizedValueConverter converter(quantizedType);
+
+ auto realValue = getTestFloatAttr(1.0, &ctx);
+ Type typeResult;
+ auto valueResult =
+ quantizeAttrUniform(realValue, quantizedType, converter, typeResult);
+
+ EXPECT_EQ(valueResult.cast<IntegerAttr>().getInt(), 5);
+ EXPECT_EQ(
+ valueResult.cast<IntegerAttr>().getType().cast<IntegerType>().getWidth(),
+ convertedType.getWidth());
+}
+
+TEST(QuantizationUtilsTest, convertRankedDenseAttrUniform) {
+ MLIRContext ctx;
+ IntegerType convertedType = IntegerType::get(8, &ctx);
+ auto quantizedType = getTestQuantizedType(convertedType, &ctx);
+ TestUniformQuantizedValueConverter converter(quantizedType);
+ auto realValue = getTestElementsAttr<DenseElementsAttr, ArrayRef<Attribute>>(
+ &ctx, {1, 2}, {getTestFloatAttr(1.0, &ctx), getTestFloatAttr(2.0, &ctx)});
+
+ Type returnedType;
+ auto returnedValue =
+ quantizeAttrUniform(realValue, quantizedType, converter, returnedType);
+
+ // Check Elements attribute shape and kind are not changed.
+ auto tensorType = returnedType.cast<TensorType>();
+ auto expectedTensorType = realValue.getType().cast<TensorType>();
+ EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape());
+ EXPECT_EQ(tensorType.getElementType(), convertedType);
+ EXPECT_EQ(returnedValue.getKind(), Attribute::Kind::DenseIntElements);
+
+ // Check Elements attribute element value is expected.
+ auto firstValue = returnedValue.cast<ElementsAttr>().getValue({0, 0});
+ EXPECT_EQ(firstValue.cast<IntegerAttr>().getInt(), 5);
+}
+
+TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) {
+ MLIRContext ctx;
+ IntegerType convertedType = IntegerType::get(8, &ctx);
+ auto quantizedType = getTestQuantizedType(convertedType, &ctx);
+ TestUniformQuantizedValueConverter converter(quantizedType);
+ auto realValue = getTestElementsAttr<SplatElementsAttr, Attribute>(
+ &ctx, {1, 2}, getTestFloatAttr(1.0, &ctx));
+
+ Type returnedType;
+ auto returnedValue =
+ quantizeAttrUniform(realValue, quantizedType, converter, returnedType);
+
+ // Check Elements attribute shape and kind are not changed.
+ auto tensorType = returnedType.cast<TensorType>();
+ auto expectedTensorType = realValue.getType().cast<TensorType>();
+ EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape());
+ EXPECT_EQ(tensorType.getElementType(), convertedType);
+ EXPECT_EQ(returnedValue.getKind(), Attribute::Kind::SplatElements);
+
+ // Check Elements attribute element value is expected.
+ auto firstValue = returnedValue.cast<ElementsAttr>().getValue({0, 0});
+ EXPECT_EQ(firstValue.cast<IntegerAttr>().getInt(), 5);
+}
+
+TEST(QuantizationUtilsTest, convertRankedSparseAttrUniform) {
+ MLIRContext ctx;
+ IntegerType convertedType = IntegerType::get(8, &ctx);
+ auto quantizedType = getTestQuantizedType(convertedType, &ctx);
+ TestUniformQuantizedValueConverter converter(quantizedType);
+ auto realValue = getTestSparseElementsAttr(&ctx, {1, 2});
+
+ Type returnedType;
+ auto returnedValue =
+ quantizeAttrUniform(realValue, quantizedType, converter, returnedType);
+
+ // Check Elements attribute shape and kind are not changed.
+ auto tensorType = returnedType.cast<TensorType>();
+ auto expectedTensorType = realValue.getType().cast<TensorType>();
+ EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape());
+ EXPECT_EQ(tensorType.getElementType(), convertedType);
+ EXPECT_EQ(returnedValue.getKind(), Attribute::Kind::SparseElements);
+
+ // Check Elements attribute element value is expected.
+ auto firstValue = returnedValue.cast<ElementsAttr>().getValue({0, 0});
+ EXPECT_EQ(firstValue.cast<IntegerAttr>().getInt(), 5);
+}
+
+} // end namespace