Initial release of the Quantization dialect
authorStella Laurenzo <laurenzo@google.com>
Wed, 3 Apr 2019 18:16:32 +0000 (11:16 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Thu, 4 Apr 2019 02:20:12 +0000 (19:20 -0700)
Includes a draft of documentation for the quantization setup.

Given how many comments such docs have garnered in the past, I've biased towards a lightly edited first-draft so that people can argue about terminology, approach and structure without having spent too much time on it.

Note that the sections under "Uniform quantization" were cribbed nearly verbatim from internal documentation that Daniel wrote.

PiperOrigin-RevId: 241768668

24 files changed:
mlir/g3doc/Quantization.md [new file with mode: 0644]
mlir/include/mlir/Quantization/FakeQuantSupport.h [new file with mode: 0644]
mlir/include/mlir/Quantization/Passes.h [new file with mode: 0644]
mlir/include/mlir/Quantization/QuantOps.h [new file with mode: 0644]
mlir/include/mlir/Quantization/QuantOps.td [new file with mode: 0644]
mlir/include/mlir/Quantization/QuantizeUtils.h [new file with mode: 0644]
mlir/include/mlir/Quantization/UniformSupport.h [new file with mode: 0644]
mlir/lib/Quantization/IR/DialectRegistration.cpp [new file with mode: 0644]
mlir/lib/Quantization/IR/FakeQuantSupport.cpp [new file with mode: 0644]
mlir/lib/Quantization/IR/QuantOps.cpp [new file with mode: 0644]
mlir/lib/Quantization/IR/TypeDetail.h [new file with mode: 0644]
mlir/lib/Quantization/IR/TypeParser.cpp [new file with mode: 0644]
mlir/lib/Quantization/IR/UniformSupport.cpp [new file with mode: 0644]
mlir/lib/Quantization/Transforms/ConvertConst.cpp [new file with mode: 0644]
mlir/lib/Quantization/Transforms/LowerTF.cpp [new file with mode: 0644]
mlir/lib/Quantization/Transforms/LowerUniformRealMath.cpp [new file with mode: 0644]
mlir/lib/Quantization/Utils/QuantizeUtils.cpp [new file with mode: 0644]
mlir/test/Quantization/convert-const.mlir [new file with mode: 0644]
mlir/test/Quantization/lower-uniform-real-math-addew.mlir [new file with mode: 0644]
mlir/test/Quantization/parse-uniform-invalid.mlir [new file with mode: 0644]
mlir/test/Quantization/parse-uniform.mlir [new file with mode: 0644]
mlir/test/Quantization/tf-lower-fakequant-invalid.mlir [new file with mode: 0644]
mlir/test/Quantization/tf-lower-fakequant.mlir [new file with mode: 0644]
mlir/unittests/Quantization/QuantizationUtilsTest.cpp [new file with mode: 0644]

diff --git a/mlir/g3doc/Quantization.md b/mlir/g3doc/Quantization.md
new file mode 100644 (file)
index 0000000..ad1775c
--- /dev/null
@@ -0,0 +1,330 @@
+# MLIR Quantization
+
+This document outlines the design of the MLIR quantization system. While the
+term "quantization" is highly overloaded, in this case, it refers to a fairly
+narrow scope of techniques in use to enable conversion of floating-point
+computations to corresponding and plausible variants expressed in integer math
+for inference, as has historically been supported by low-bit depth inference
+engines such as TFLite, various accelerator hardware, and many DSPs.
+
+Much of this is inspired by the approach taken
+[in this paper](https://arxiv.org/abs/1712.05877) with many extensions and
+adaptations folded in. It specifically documents the positions that MLIR has
+taken on the topic, and is not a general reference.
+
+[TOC]
+
+## Uniform quantization
+
+The primary quantization mechanism supported by MLIR is a scheme which can
+express fixed point and affine transformations via uniformly spaced point on the
+Real number line.
+
+Further, the scheme can be applied:
+
+*   *per-layer* : Applying to every value within the target type.
+*   *per-axis* (also called *per-channel*) : Applying individually to each index
+    along a specific axis of a tensor type.
+
+### Fixed point values {#fixed-point}
+
+[Fixed point](https://en.wikipedia.org/wiki/Fixed-point_arithmetic) values are a
+[Real](https://en.wikipedia.org/wiki/Real_number) number divided by a *scale*.
+We will call the result of the divided Real the *scaled value*.
+
+$$ real\_value = scaled\_value * scale $$
+
+The scale can be interpreted as the distance, in Real units, between neighboring
+scaled values. For example, if the scale is $$ \pi $$, then fixed point values
+with this scale can only represent multiples of $$ \pi $$, and nothing in
+between. The maximum rounding error to convert an arbitrary Real to a fixed
+point value with a given $$ scale $$ is $$ \frac{scale}{2} $$. Continuing the
+previous example, when $$ scale = \pi $$, the maximum rounding error will be $$
+\frac{\pi}{2} $$.
+
+Multiplication can be performed on scaled values with different scales, using
+the same algorithm as multiplication of Real values (note that product scaled
+value has $$ scale_{product} = scale_{left \mbox{ } operand} * scale_{right
+\mbox{ } operand} $$). Addition can be performed on scaled values, as long as
+they have the same scale, using the same algorithm as addition of Real values.
+This makes it convenient to represent scaled values on a computer as signed
+integers, and perform arithmetic on those signed integers, because the results
+will be correct scaled values.
+
+### Affine values {#affine}
+
+Mathematically speaking, affine values are the result of
+[adding a Real-valued *zero point*, to a scaled value](https://en.wikipedia.org/wiki/Affine_transformation#Representation).
+Or equivalently, subtracting a zero point from an affine value results in a
+scaled value:
+
+$$ real\_value = scaled\_value * scale = (affine\_value - zero\_point) * scale $$
+
+Essentially, affine values are a shifting of the scaled values by some constant
+amount. Arithmetic (i.e., addition, subtraction, multiplication, division)
+cannot, in general, be directly performed on affine values; you must first
+[convert](#affine-to-fixed-point) them to the equivalent scaled values.
+
+As alluded to above, the motivation for using affine values is to more
+efficiently represent the Real values that will actually be encountered during
+computation. Frequently, the Real values that will be encountered are not
+symmetric around the Real zero. We also make the assumption that the Real zero
+is encountered during computation, and should thus be represented.
+
+In this case, it's inefficient to store scaled values represented by signed
+integers, as some of the signed integers will never be used. The bit patterns
+corresponding to those signed integers are going to waste.
+
+In order to exactly represent the Real zero with an integral-valued affine
+value, the zero point must be an integer between the minimum and maximum affine
+value (inclusive). For example, given an affine value represented by an 8 bit
+unsigned integer, we have: $$ 0 \leq zero\_point \leq 255$$. This is important,
+because in deep neural networks's convolution-like operations, we frequently
+need to zero-pad inputs and outputs, so zero must be exactly representable, or
+the result will be biased.
+
+### Relation
+
+Real values, fixed point values, and affine values relate through the following
+equation, which demonstrates how to convert one type of number to another:
+
+$$ real\_value = scaled\_value * scale = (affine\_value - zero\_point) * scale $$
+
+Note that computers generally store mathematical values using a finite number of
+bits. Thus, while the above conversions are exact, to store the result in a
+finite number of bits, we must, in general, round the result of the conversion
+(this applies to both cases: storing using floating point and storing using
+fixed point). Note that a full discussion of rounding behavior is outside the
+scope of this document, and it is safe to assume unless otherwise stated that
+rounding should be according to the IEEE754 default of RNE (where hardware
+permits).
+
+### Converting between Real and fixed point or affine {#converting-between}
+
+To convert a Real value to a fixed point value, you must know the scale. To
+convert a Real value to an affine value, you must know the scale and zero point.
+
+#### Real to affine
+
+To convert an input tensor of Real-valued elements (usually represented by a
+floating point format, frequently
+[Single precision](https://en.wikipedia.org/wiki/Single-precision_floating-point_format))
+to a tensor of affine elements represented by an integral type (e.g. 8-bit
+unsigned integer), the following conversion can be performed (note that it is
+not required that all representable values of the integral type are used):
+
+$$
+\begin{align*}
+af&fine\_value_{uint8 \, or \, uint16} \\
+      &= clampToTargetSize(roundToNearestInteger( \frac{real\_value_{Single}}{scale_{Single}})_{sint32} + zero\_point_{uint8 \, or \, uint16})
+\end{align*}
+$$
+
+In the above, we assume that $$real\_value$$ is a Single, $$scale$$ is a Single,
+$$roundToNearestInteger$$ returns a signed 32 bit integer, and $$zero\_point$$
+is an unsigned 8 or 16 bit integer. Note that bit depth and number of fixed
+point values is indicative of common types on typical hardware but is not
+constrained to particular bit depths or a requirement that the entire range of
+an N-bit integer is used.
+
+#### Affine to Real {#affine-to-real}
+
+To convert an output tensor of affine elements represented by uint8
+or uint16 to a tensor of Real-valued elements (usually represented with a
+floating point format, frequently Single precision), the following conversion
+can be performed:
+
+$$
+\begin{align*}
+re&al\_value_{Single} \\
+      &= roundToNearestFloat((affine\_value_{uint8 \, or \, uint16} - zero\_point_{uint8 \, or \, uint16})_{sint32})_{Single} * scale_{Single}
+\end{align*}
+$$
+
+In the above, we assume that the result of subtraction is in 32-bit signed
+integer format, and that $$roundToNearestFloat$$ returns a Single.
+
+#### Affine to fixed point {#affine-to-fixed-point}
+
+When the affine and fixed point scales are the same, subtract the zero point
+from the affine value to get the equivalent fixed point value.
+
+$$
+scaled\_value = affine\_value_{non\mbox{-}negative} - zero\_point_{non\mbox{-}negative}
+$$
+
+#### Fixed point to affine {#fixed-point-to-affine}
+
+When the affine and fixed point scales are the same, add the zero point to the
+fixed point value to get the equivalent affine value.
+
+$$
+affine\_value_{non\mbox{-}negative} = scaled\_value + zero\_point_{non\mbox{-}negative}
+$$
+
+## Usage within MLIR {#usage-within-mlir}
+
+There are several components to the quantization system within MLIR:
+
+*   *Quantization* dialect containing:
+
+    *   A family of [QuantizedTypes](#quantized-type) which represent the
+        mapping between *expressed* values (typically of a floating point
+        computer type) and *storage* values (typically of an integral computer
+        type).
+    *   [Type conversion ops](#quantized-type-conversion-ops) for converting
+        between types based on a QuantizedType and its *expressed* and *storage*
+        sub-types.
+    *   [Instrumentation ops](#instrumentation-ops) for assigning
+        instrumentation points within the computation where runtime statistics
+        may help guide the quantization process.
+
+*   *QuantizedMath* dialect containing:
+
+    *   [Real math ops](#real-math-ops) representing common combinations of
+        arithmetic operations that closely match corresponding fixed-point math
+        concepts (as opposed to being spread across multiple ops as is typical
+        in source dialects).
+    *   [Fixed-point math ops](#fixed-point-math-ops) that for carrying out
+        computations on integers, as are typically needed by uniform
+        quantization schemes.
+    *   Passes to lower from real math ops to fixed-point math ops.
+
+*   [Solver tools](#solver-tools) which can generically operate on computations
+    expressed in the *QuantizedMath* dialect in order to convert from floating
+    point types to appropriate *QuantizedTypes*, allowing the computation to be
+    further lowered to integral math ops.
+
+Not every application of quantization will use all facilities. Specifically, the
+TensorFlow to TensorFlow Lite conversion uses the QuantizedTypes but has its own
+ops for type conversion and expression of the backing math.
+
+## Interactions with simulated quantization at training time {#training-time}
+
+TensorFlow has historically used the
+[tf.quantization.fake_quant_\*](https://www.tensorflow.org/api_docs/python/tf/quantization/fake_quant_with_min_max_args)
+family of operations to simulate the effect of quantization at training time.
+
+As originally implemented, TensorFlow Lite was the primary user of such
+operations at inference time. When quantized inference was enabled, if every
+eligible tensor passed through an appropriate fake_quant node (the rules of
+which tensors can have fake_quant applied are somewhat involved), then
+TensorFlow Lite would use the attributes of the fake_quant ops to make a
+judgment about how to convert to use kernels from its quantized ops subset.
+
+In MLIR-based quantization, fake_quant_\* ops are handled by converting them to
+a sequence of *qcast* (quantize) followed by *dcast* (dequantize) with an
+appropriate *UniformQuantizedType* as the target of the qbarrier operation.
+
+This allows subsequent compiler passes to preserve the knowledge that
+quantization was simulated in a certain way while giving the compiler
+flexibility to move the barriers as it simplifies the computation and converts
+it to a form based on integral arithmetic.
+
+This scheme also naturally allows computations that are *partially quantized*
+where the parts which could not be reduced to integral ops are still carried out
+in floating point with appropriate conversions at the boundaries.
+
+## Quantization Dialect
+
+### Quantized type {#quantized-type}
+
+TODO : Flesh this section out.
+
+*   QuantizedType base class
+*   UniformQuantizedType
+
+### Quantized type conversion ops {#quantized-type-conversion-ops}
+
+*   qcast : Convert from an expressed type to QuantizedType
+*   dcast : Convert from a QuantizedType to its expressed type
+*   scast : Convert between a QuantizedType and its storage type
+
+### Instrumentation and constraint ops {#instrumentation-ops}
+
+TODO : These ops are not defined yet
+
+*   instrument_stats : Assigns a unique id and signals that statistics should be
+    collected by the runtime when execution passes through this op.
+*   constrain_uniform : Constrains that for uniform quantization, the solver
+    should choose a type with certain characteristics such as the number of
+    fixed-point values, underlying storage type, or whether to constrain to
+    power of two scales.
+
+## QuantizedMath Dialect
+
+### Real math ops {#real-math-ops}
+
+Note that these all support explicit clamps, which allows for simple fusions and
+representation of some common sequences quantization-compatible math. Of
+addition, some support explicit biases, which are often represented as separate
+adds in source dialects.
+
+TODO: This op set is still evolving and needs to be completed.
+
+*   RealBinaryOp
+    *   RealAddEwOp
+    *   RealSubEwOp
+    *   RealMulEwOp
+    *   RealDivEwOp
+*   RealUnaryOp
+    *   IDENTITY
+    *   TANH
+    *   SIGMOID
+    *   EXP
+    *   LOG
+    *   NEG
+    *   RSQRT
+    *   SIN
+    *   SQUARE
+    *   SQRT
+    *   CMPZ
+    *   CMPNZ
+    *   CMPLZ
+    *   CMPGZ
+
+### Fixed-point math ops {#fixed-point-math-ops}
+
+TODO: This op set only has enough ops to lower a simple power-of-two
+RealAddEwOp.
+
+*   RoundingDivideByPotFxpOp
+*   SaturatingAddFxpOp
+
+## Solver tools {#solver-tools}
+
+Solver tools exist to analyze an MLIR-computation, expressed in either a
+supported source dialect or in the *real math ops* set and solve for appropriate
+QuantizedTypes that allow the computation to be lowered to integral math.
+
+These tools are an active area of work and may be expanded in the future to
+adjacent areas such as solving for transformations to other kinds of lower
+precision types (i.e. bfloat16 or fp16).
+
+Solver tools are expected to operate in several modes, depending on the
+computation and the manner in which it was trained:
+
+*   *Transform* : With all available information in the MLIR computation, infer
+    boundaries where the computation can be carried out with integral math and
+    change types accordingly to appropriate QuantizedTypes:
+
+    *   For passthrough ops which do not perform active math, change them to
+        operate directly on the storage type, converting in and out at the edges
+        via scast ops.
+    *   For ops that have the *Quantizable* trait, the type can be set directly.
+        This includes ops from the [real math ops set]{#real-math-ops}.
+    *   For others, encase them in appropriate dcast/qcast ops, presuming that
+        some follow-on pass will know what to do with them.
+
+*   *Instrument* : Most of the time, there are not sufficient implied
+    constraints within a computation to perform many transformations. For this
+    reason, the solver can insert instrumentation ops at points where additional
+    runtime statistics may yield solutions. It is expected that such
+    computations will be lowered as-is for execution, run over an appropriate
+    eval set, and statistics at each instrumentation point made available for a
+    future invocation of the solver.
+
+*   *Simplify* : A variety of passes and simplifications are applied once
+    QuantizedTypes are added in order to arrive at a computation that is
+    expressed in as much integral math, with the fewest number of casts as
+    possible.
diff --git a/mlir/include/mlir/Quantization/FakeQuantSupport.h b/mlir/include/mlir/Quantization/FakeQuantSupport.h
new file mode 100644 (file)
index 0000000..aa3b7b4
--- /dev/null
@@ -0,0 +1,67 @@
+//===- FakeQuantSupport.h - Support utilities for FakeQuant ops -*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines support utilities for interoperating with FakeQuant* based
+// QAT (Quantized Aware Training) computations, as implemented by TFLite. Note
+// that FakeQuant* operators mix multiple concerns specific to how TFLite
+// originally implemented quantization. As such, utilities here enforce
+// opinions taken by that codebase (vs providing any amount of genericity).
+//
+// Specifically, it combines the following concerns, each of which would be
+// independent variables in a more generic setup:
+//   - num_bits implies storage data type (quint8, int16)
+//   - num_bits < 8 is promoted to quint8
+//   - "narrow_range" narrows the lower bound of the storage type's range by
+//     1
+//   - the specified min/max values are "nudged" so that the result has a zero
+//     that can be exactly expressed
+//   - min=max=0 implies scale=0 and zero_point=0
+//
+// With the above assumptions applied, every conforming specified FakeQuant op
+// can be represented by a UniformQuantizedType. This scheme is not expected to
+// be generalized further in the future and should be considered to be a
+// legacy set of rules.
+//
+// As canonically used in TensorFlow graphs, the presence of a FakeQuant node
+// is a hint that the specific math represented here has been simulated at
+// training time. As such, it is usually not advised to arbitrarily change
+// quantization parameters derived from FakeQuant.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_QUANTIZATION_FAKEQUANTSUPPORT_H_
+#define MLIR_QUANTIZATION_FAKEQUANTSUPPORT_H_
+
+#include "mlir/Quantization/QuantOps.h"
+
+namespace mlir {
+namespace quant {
+
+/// Converts per-layer FakeQuant attributes to the corresponding type.
+/// In the event that the parameters cannot be converted, returns a nullptr
+/// convertible Type and issues an appropriate error.
+/// Note that there are multiple variants of a per-layer FakeQuant op, so
+/// this function takes the attributes discretely vs taking a reference to the
+/// originating op.
+UniformQuantizedType fakeQuantAttrsToType(Location loc, unsigned numBits,
+                                          double rmin, double rmax,
+                                          bool narrowRange, Type expressedType);
+
+} // namespace quant
+} // namespace mlir
+
+#endif // MLIR_QUANTIZATION_FAKEQUANTSUPPORT_H_
diff --git a/mlir/include/mlir/Quantization/Passes.h b/mlir/include/mlir/Quantization/Passes.h
new file mode 100644 (file)
index 0000000..090d21c
--- /dev/null
@@ -0,0 +1,59 @@
+//===- Passes.h - Quantization Passes ------ --------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines all of the passes owned by the quantization dialect. As
+// things mature, it is expected that passes specific to certain frontend or
+// backend dialects will move to those dialects directly. For now, they are
+// incubated here.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_QUANTIZATION_PASSES_H
+#define MLIR_QUANTIZATION_PASSES_H
+
+namespace mlir {
+class FunctionPassBase;
+
+namespace quant {
+
+/// Creates a pass that lowers quantization related TensorFlow ops into
+/// the quantization dialect so that express and implied constraints expressed
+/// at the TensorFlow source level can be represented to the quantization
+/// system. This will specially handle any TensorFlow op that is useful for
+/// guiding quantization.
+///
+/// Note that if your intent is to compile a TensorFlow graph for floating
+/// point inference, you should probably not use this pass.
+FunctionPassBase *createLowerTFPass();
+
+/// Creates a pass that converts constants followed by a qbarrier to a
+/// constant whose value is quantized. This is typically one of the last
+/// passes done when lowering to express actual quantized arithmetic in a
+/// low level representation. Because it modifies the constant, it is
+/// destructive and cannot be undone.
+FunctionPassBase *createConvertConstPass();
+
+/// Creates a pass that lowers uniform-quantized real math ops to integer
+/// arithmetic. This will leave unrecognized real math ops as-is and is
+/// typically followed by a pass that lowers any unrecognized ops to a pure
+/// floating point form.
+FunctionPassBase *createLowerUniformRealMathPass();
+
+} // namespace quant
+} // namespace mlir
+
+#endif // MLIR_QUANTIZATION_PASSES_H
diff --git a/mlir/include/mlir/Quantization/QuantOps.h b/mlir/include/mlir/Quantization/QuantOps.h
new file mode 100644 (file)
index 0000000..dceb2d0
--- /dev/null
@@ -0,0 +1,373 @@
+//===- Quantization/QuantOps.h - Quantization Ops and Types -----*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_QUANTIZATION_QUANTOPS_H_
+#define MLIR_QUANTIZATION_QUANTOPS_H_
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/Types.h"
+#include "llvm/Support/MathExtras.h"
+
+namespace mlir {
+namespace quant {
+
+class QuantizedIntegerType;
+
+namespace detail {
+
+struct QuantizedTypeStorage;
+struct UniformQuantizedTypeStorage;
+struct UniformQuantizedPerAxisTypeStorage;
+
+} // namespace detail
+
+namespace QuantizationTypes {
+enum Kind {
+  UniformQuantized = Type::FIRST_QUANTIZATION_TYPE,
+  UniformQuantizedPerAxis,
+  LAST_USED_QUANTIZATION_TYPE = UniformQuantizedPerAxis,
+};
+} // namespace QuantizationTypes
+
+/// Enumeration of bit-mapped flags related to quantized types.
+namespace QuantizationFlags {
+enum FlagValue {
+  // Indicates that the storage type should be interpreted as a signed
+  // integer. The default is to interpret it as an unsigned value.
+  Signed = 1,
+};
+} // namespace QuantizationFlags
+
+/// Base class for all quantized types known to this dialect.
+/// All quantized types have:
+///   - storageType: The (narrower) numeric type that is being used to
+///     approximate some expressed type.
+///   - expressedType: The type that is being approximated.
+///
+/// The base class provides generic support for manipulating the types based
+/// on these fields.
+class QuantizedType : public Type {
+public:
+  using ImplType = detail::QuantizedTypeStorage;
+  using Type::Type;
+
+  /// The maximum number of bits supported for storage types.
+  static constexpr unsigned MaxStorageBits = 32;
+
+  static LogicalResult
+  verifyConstructionInvariants(llvm::Optional<Location> loc,
+                               MLIRContext *context, unsigned flags,
+                               Type storageType, Type expressedType,
+                               int64_t storageTypeMin, int64_t storageTypeMax);
+
+  /// Support method to enable LLVM-style type casting.
+  static bool kindof(unsigned kind) {
+    return kind == QuantizationTypes::UniformQuantized;
+  }
+
+  /// Gets the minimum possible stored by a storageType. storageTypeMin must
+  /// be greater than or equal to this value.
+  static int64_t getDefaultMininumForInteger(bool isSigned,
+                                             unsigned integralWidth) {
+    if (isSigned) {
+      return llvm::minIntN(integralWidth);
+    }
+    return 0;
+  }
+
+  /// Gets the maximum possible stored by a storageType. storageTypeMax must
+  /// be less than or equal to this value.
+  static int64_t getDefaultMaxinumForInteger(bool isSigned,
+                                             unsigned integralWidth) {
+    if (isSigned) {
+      return llvm::maxIntN(integralWidth);
+    }
+    return llvm::maxUIntN(integralWidth);
+  }
+
+  /// Gets the original expressed type that this quantized type approximates.
+  /// Note that this presumes that the quantized type was always derived from
+  /// a floating point type, which in the broadest definition, is not true (i.e.
+  /// it could be some form of integral, fixed type or affine type in its own
+  /// right); however, at the high level, no examples of such usage are
+  /// presently known and the restriction serves some useful purposes (such as
+  /// always being able to reverse a transformation or measure error). In most
+  /// cases, this will be f32.
+  Type getExpressedType() const;
+
+  /// Gets the flags associated with this type. Typically a more specific
+  /// accessor is appropriate.
+  unsigned getFlags() const;
+
+  // Convenience helpers.
+  /// Whether the storage type should be interpreted as a signed quantity
+  /// (true) or an unsigned value (false).
+  bool isSigned() const {
+    return (getFlags() & QuantizationFlags::Signed) ==
+           QuantizationFlags::Signed;
+  }
+
+  /// Gets the underlying type used for to store values. Note that this may
+  /// be signed or unsigned. Use the isSigned() accessor to differentiate.
+  Type getStorageType() const;
+
+  /// The minimum value that storageType can take.
+  int64_t getStorageTypeMin() const;
+
+  /// The maximum value that storageType can take.
+  int64_t getStorageTypeMax() const;
+
+  /// Gets the integral bit width that the underlying storage type can exactly
+  /// represent. For integral storage types, this will just be their width.
+  unsigned getStorageTypeIntegralWidth() const;
+
+  /// Returns whether the candidateExpressedType is a match for this
+  /// QuantizedType. This will be true if the candidate type is either a
+  /// primitive type or a container type whose element type equals this
+  /// QuantizedType's expressed type.
+  /// Examples of compatible candidateExpressedType:
+  ///   !quant<"uniform[i8:f32]{1.0}"> =~ f32
+  ///   !quant<"uniform[i8:f32]{1.0}"> =~ tensor<4xf32>
+  bool isCompatibleExpressedType(Type candidateExpressedType);
+
+  /// Returns the element type as a QuantizedType or nullptr if it is not
+  /// a quantized type. If the type is primitive, returns that. If it is a
+  /// container (vector/tensor), return the element type.
+  /// Examples:
+  ///   !quant<"uniform[i8:f32]{1.0}"> -> !quant<"uniform[i8:f32]{1.0}">
+  ///   tensor<4x!quant<"uniform[i8:f32]{1.0}"> -> quant<"uniform[i8:f32]{1.0}">
+  static QuantizedType getQuantizedElementType(Type primitiveOrContainerType);
+
+  /// Casts from a type based on the storageType to a corresponding type based
+  /// on this type (returns nullptr if the cast is not valid).
+  /// Examples:
+  ///   i8 -> !quant<"uniform[i8:f32]{1.0}">
+  ///   tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
+  ///   vector<4xi8> -> vector<4x!quant<"uniform[i8:f32]{1.0}">>
+  Type castFromStorageType(Type candidateType);
+
+  /// Casts from a type based on a QuantizedType to a corresponding type based
+  /// on the storageType (returns nullptr if the cast is not valid).
+  /// This is the inverse of castFromStorageType().
+  static Type castToStorageType(Type quantizedType);
+
+  /// Casts from a type based on the expressedType to a corresponding type based
+  /// on this type (returns nullptr if the cast is not valid).
+  /// Examples:
+  ///   f32 -> !quant<"uniform[i8:f32]{1.0}">
+  ///   tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
+  ///   vector<4xf32> -> vector<4x!quant<"uniform[i8:f32]{1.0}">>
+  Type castFromExpressedType(Type candidateType);
+
+  /// Casts from a type based on QuantizedType to a corresponding type based
+  /// on the expressedType (returns nullptr if the cast is not valid).
+  /// This is the inverse of castFromExpressedType.
+  static Type castToExpressedType(Type quantizedType);
+
+  /// Casts from a type based on the expressedType to the equivalent type
+  /// based on storageType by way of this QuantizedType. Equivalent to:
+  ///   QuantizedType::castToStorageType(castFromExpressedType(candidateType))
+  /// (but with validity checks).
+  /// Example (for this = !quant<"uniform[i8:f32]{1.0}">):
+  ///   tensor<4xf32> -> tensor<4xi8>
+  Type castExpressedToStorageType(Type candidateType);
+};
+
+/// Represents a family of uniform, quantized types.
+///
+/// Each instance of this type expresses a mapping between real values (most
+/// often expressed in floating point f32) and quantized values (either fixed
+/// point or affine).
+///
+/// The relationship is:
+///     real_value = scale * (quantized_value - zero_point)
+///
+/// It is used as part of high level graph transformations that have the goal
+/// of re-expressing parts of a computation in terms of this common form for
+/// more efficient execution at runtime. In addition, it is designed to be
+/// expressive enough to facilitate lowering to precise types and operations
+/// in target hardware.
+///
+/// As a high-level type, focused on intermediate passes, this type holds
+/// opinions consistent with high-level usage. If lowering math kernels below
+/// the high level arithmetic ops (i.e. to LLVM IR or hardware specific
+/// instruction sets), it is expected that the information expressed here
+/// will be used to drive low level codegen and target specific type selection,
+/// but this type will likely be erased in the process.
+///
+/// Syntax synopsis:
+///   Per-layer, all parameters expressed:
+///     !quant<uniform[StorageType:ExpressedType]{Scale:ZeroPoint}>
+///   Per-layer, optional parameters omitted:
+///     !quant<uniform[StorageType]{Scale}>
+///
+///   StorageType: 'i'|'u' NumBits
+///   ExpressedType: 'f16', 'f32', 'bf16', 'f64'
+///   Scale: A legal double value
+///   ZeroPoint: An integer value
+class UniformQuantizedType
+    : public Type::TypeBase<UniformQuantizedType, QuantizedType,
+                            detail::UniformQuantizedTypeStorage> {
+public:
+  using Base::Base;
+
+  /// Gets an instance of the type with all parameters specified but not
+  /// checked.
+  static UniformQuantizedType get(unsigned flags, Type storageType,
+                                  Type expressedType, double scale,
+                                  int64_t zeroPoint, int64_t storageTypeMin,
+                                  int64_t storageTypeMax);
+
+  /// Gets an instance of the type with all specified parameters checked.
+  /// Returns a nullptr convertible type on failure.
+  static UniformQuantizedType
+  getChecked(unsigned flags, Type storageType, Type expressedType, double scale,
+             int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax,
+             Location location);
+
+  /// Verifies construction invariants and issues errors/warnings.
+  static LogicalResult verifyConstructionInvariants(
+      llvm::Optional<Location> loc, MLIRContext *context, unsigned flags,
+      Type storageType, Type expressedType, double scale, int64_t zeroPoint,
+      int64_t storageTypeMin, int64_t storageTypeMax);
+
+  /// Support method to enable LLVM-style type casting.
+  static bool kindof(unsigned kind) {
+    return kind == QuantizationTypes::UniformQuantized;
+  }
+
+  /// Gets the scale term. The scale designates the difference between the real
+  /// values corresponding to consecutive quantized values differing by 1.
+  double getScale() const;
+
+  /// Gets the storage value corresponding to the real value 0 in the affine
+  /// equation.
+  int64_t getZeroPoint() const;
+
+  // Fixed point values are real numbers divided by a scale.
+  // Currently, only signed storage types are treated as fixed point.
+  // A fixed point value can be obtained from an affine value by subtracting
+  // the zeroPoint.
+  // In the future, this may be explicit versus implied by type and zeroPoint.
+  bool isFixedPoint() const { return isSigned() && getZeroPoint() == 0; }
+};
+
+/// Represents per-axis (also known as per-channel quantization).
+///
+/// Syntax synopsis:
+///   Per-axis, all parameters expressed:
+///     !quant<uniform[StorageType:ExpressedType:QuantizedDim]{QuantParams}>
+///   Per-axis, optional parameters omitted:
+///     !quant<uniform[StorageType]{Scale}>
+///
+///   StorageType: 'i'|'u' NumBits
+///   ExpressedType: 'f16', 'f32', 'bf16', 'f64'
+///   QuantizedDim: An integer value
+///   QuantParams: (Scale ':' ZeroPoint)+
+///   Scale: A legal double value
+///   ZeroPoint: An integer value
+class UniformQuantizedPerAxisType
+    : public Type::TypeBase<UniformQuantizedPerAxisType, QuantizedType,
+                            detail::UniformQuantizedPerAxisTypeStorage> {
+public:
+  using Base::Base;
+
+  /// Gets an instance of the type with all parameters specified but not
+  /// checked.
+  static UniformQuantizedPerAxisType
+  get(unsigned flags, Type storageType, Type expressedType,
+      ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
+      int32_t quantizedDimension, int64_t storageTypeMin,
+      int64_t storageTypeMax);
+
+  /// Gets an instance of the type with all specified parameters checked.
+  /// Returns a nullptr convertible type on failure.
+  static UniformQuantizedPerAxisType
+  getChecked(unsigned flags, Type storageType, Type expressedType,
+             ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
+             int32_t quantizedDimension, int64_t storageTypeMin,
+             int64_t storageTypeMax, Location location);
+
+  /// Verifies construction invariants and issues errors/warnings.
+  static LogicalResult verifyConstructionInvariants(
+      llvm::Optional<Location> loc, MLIRContext *context, unsigned flags,
+      Type storageType, Type expressedType, ArrayRef<double> scales,
+      ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
+      int64_t storageTypeMin, int64_t storageTypeMax);
+
+  /// Support method to enable LLVM-style type casting.
+  static bool kindof(unsigned kind) {
+    return kind == QuantizationTypes::UniformQuantizedPerAxis;
+  }
+
+  /// Gets the quantization scales. The scales designate the difference between
+  /// the real values corresponding to consecutive quantized values differing
+  /// by 1. The ith scale corresponds to the ith slice in the
+  /// quantized_dimension.
+  ArrayRef<double> getScales() const;
+
+  /// Gets the storage values corresponding to the real value 0 in the affine
+  /// equation. The ith zero point corresponds to the ith slice in the
+  /// quantized_dimension.
+  ArrayRef<int64_t> getZeroPoints() const;
+
+  /// Specifies the dimension of the Tensor's shape that the scales and
+  /// zero_points correspond to. For example, a tensor t, with dims=[4, 3, 2, 1]
+  /// with quantization params:
+  ///   scales=[1.0, 2.0, 3.0], zeroPoints=[1, 2, 3], quantizedDimension=1
+  /// will be quantized across the second dimension of t.
+  ///   t[:, 0, :, :] will have scale[0]=1.0, zero_point[0]=1
+  ///   t[:, 1, :, :] will have scale[1]=2.0, zero_point[0]=2
+  ///   t[:, 2, :, :] will have scale[2]=3.0, zero_point[0]=3
+  int32_t getQuantizedDimension() const;
+
+  /// Fixed point values are real numbers divided by a scale.
+  /// Currently, only signed storage types are treated as fixed point.
+  /// A fixed point value can be obtained from an affine value by subtracting
+  /// the zeroPoint.
+  /// In the future, this may be explicit versus implied by type and zeroPoint.
+  bool isFixedPoint() const {
+    if (!isSigned())
+      return false;
+    return llvm::all_of(getZeroPoints(),
+                        [](int64_t zeroPoint) { return zeroPoint != 0; });
+  }
+};
+
+/// Defines the 'Quantization' dialect
+class QuantizationDialect : public Dialect {
+public:
+  QuantizationDialect(MLIRContext *context);
+
+  /// Parse a type registered to this dialect.
+  Type parseType(StringRef spec, Location loc) const override;
+
+  /// Print a type registered to this dialect.
+  void printType(Type type, raw_ostream &os) const override;
+};
+
+#define GET_OP_CLASSES
+#include "mlir/Quantization/QuantOps.h.inc"
+
+} // namespace quant
+} // namespace mlir
+
+#endif // MLIR_QUANTIZATION_QUANTOPS_H_
diff --git a/mlir/include/mlir/Quantization/QuantOps.td b/mlir/include/mlir/Quantization/QuantOps.td
new file mode 100644 (file)
index 0000000..8c247a3
--- /dev/null
@@ -0,0 +1,285 @@
+//===- QuantOps.td - Quantization operation definition -----*- tablegen -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This is the operation definition file for Quantization.
+//
+//===----------------------------------------------------------------------===//
+
+#ifdef QUANTIZATION_OPS
+#else
+
+#ifdef OP_BASE
+#else
+include "mlir/IR/OpBase.td"
+#endif // OP_BASE
+
+//===----------------------------------------------------------------------===//
+// Quantization type definitions
+//===----------------------------------------------------------------------===//
+
+class quant_TypedPrimitiveOrContainer<Type etype> :
+    Type<AnyOf<[etype.predicate,
+                TypedTensor<etype>.predicate,
+                TypedVector<etype>.predicate]>,
+         "primitive/tensor/vector of " # etype.description>;
+
+// An implementation of QuantizedType.
+def quant_QuantizedType :
+    Type<CPred<"{0}.isa<QuantizedType>()">, "QuantizedType">;
+
+// A primitive type that can represent a real value. This is either a
+// floating point value or a quantized type.
+def quant_RealPrimitiveType :
+    Type<AnyOf<[Float.predicate, quant_QuantizedType.predicate]>,
+    "real valued primitive (float or quantized type)">;
+
+// A primitive type that can represent a storage value. This is either an
+// integer or quantized type.
+def quant_StoragePrimitiveType :
+    Type<AnyOf<[Integer.predicate, quant_QuantizedType.predicate]>,
+    "quantized storage primitive (integer or quantized type)">;
+
+// A primitive or container of RealPrimitiveType.
+def quant_RealValueType :
+    quant_TypedPrimitiveOrContainer<quant_RealPrimitiveType>;
+
+// A primitive or container of StoragePrimitiveType.
+def quant_StorageValueType :
+    quant_TypedPrimitiveOrContainer<quant_StoragePrimitiveType>;
+
+// Either a real valued or storage primitive or container type.
+def quant_RealOrStorageValueType :
+    Type<AnyOf<[quant_RealValueType.predicate,
+                quant_StorageValueType.predicate]>>;
+
+// An implementation of UniformQuantizedType.
+def quant_UniformQuantizedType :
+    Type<CPred<"{0}.isa<UniformQuantizedType>()">, "UniformQuantizedType">;
+
+// Predicate for detecting a container or primitive of UniformQuantizedType.
+def quant_UniformQuantizedValueType :
+    quant_TypedPrimitiveOrContainer<quant_UniformQuantizedType>;
+
+//===----------------------------------------------------------------------===//
+// Attributes
+//===----------------------------------------------------------------------===//
+
+// Real value for an (inclusive) min/max clamp limit.
+def quant_ClampValueAttr : OptionalAttr<F64Attr>;
+
+// Element-wise activation function to apply.
+// Note that RELU activations are not here: they are expressed as clamps.
+def quant_EwUnaryFnAttr :
+    StringBasedAttr<CPred<"true">, "element-wise unary function"> {
+  let returnType = [{ StringRef }];
+  let defaultValue = "IDENTITY";
+}
+
+class quant_ConstEwUnaryFn<string val> : ConstantAttr<quant_EwUnaryFnAttr, val>;
+def quant_EwUnaryFn_Identity: quant_ConstEwUnaryFn<"IDENTITY">;
+def quant_EwUnaryFn_Tanh    : quant_ConstEwUnaryFn<"TANH">;
+def quant_EwUnaryFn_Sigmoid : quant_ConstEwUnaryFn<"SIGMOID">;
+def quant_EwUnaryFn_Exp     : quant_ConstEwUnaryFn<"EXP">;
+def quant_EwUnaryFn_Log     : quant_ConstEwUnaryFn<"LOG">;
+def quant_EwUnaryFn_Neg     : quant_ConstEwUnaryFn<"NEG">;
+def quant_EwUnaryFn_Rsqrt   : quant_ConstEwUnaryFn<"RSQRT">;
+def quant_EwUnaryFn_Sin     : quant_ConstEwUnaryFn<"SIN">;
+def quant_EwUnaryFn_Square  : quant_ConstEwUnaryFn<"SQUARE">;
+def quant_EwUnaryFn_Sqrt    : quant_ConstEwUnaryFn<"SQRT">;
+def quant_EwUnaryFn_CmpZ    : quant_ConstEwUnaryFn<"CMPZ">;
+def quant_EwUnaryFn_CmpNZ   : quant_ConstEwUnaryFn<"CMPNZ">;
+def quant_EwUnaryFn_CmpLZ   : quant_ConstEwUnaryFn<"CMPLZ">;
+def quant_EwUnaryFn_CmpGZ   : quant_ConstEwUnaryFn<"CMPGZ">;
+
+//===----------------------------------------------------------------------===//
+// Base classes
+//===----------------------------------------------------------------------===//
+
+class quant_Op<string mnemonic, list<OpTrait> traits> :
+    Op<!strconcat("quant.", mnemonic), traits>;
+
+//===----------------------------------------------------------------------===//
+// Quantization barriers
+//===----------------------------------------------------------------------===//
+class quant_BarrierOp<string mnemonic, list<OpTrait> traits> :
+    quant_Op<mnemonic, traits>, Arguments<(ins quant_RealValueType:$arg)>,
+    Results<(outs quant_RealValueType)>;
+
+// A QuantizeBarrier (qbarrier) represents a potential type shift from a
+// quantizable type to a quantized type.
+//
+// At runtime, a qbarrier will apply the transformation expressed by its
+// operand and result type. For flexibility during transformation, it is also
+// possible to have a qbarrier that performs no transformation (both its
+// operand and result type are quantizable).
+//
+// A qbarrier will typically originate from either:
+//   a) An expressed or implied constraint in the source dialect which signals
+//      that a certain level of quantization is possible or required.
+//   b) An inference made by a quantization algorithm indicating that a
+//      quantized representation may be acceptable.
+//
+// Especially early in transformation, it is common to have pairs of
+// qbarrier/dbarrier at points where a transition to a quantized type is
+// required. In addition, it is also common to have an identity qbarrier
+// (where the operand and result type are not quantized) at all points where
+// it is legal to use a quantized representation (but is not known to be
+// acceptable).
+def quant_QuantizeBarrierOp : quant_BarrierOp<"qbarrier", [NoSideEffect]>;
+
+// A DequantizeBarrier (dbarrier) represents the inverse of a qbarrier,
+// converting back from a quantized to quantizable (expressed) type.
+//
+// Like qbarriers, a dbarrier is allowed to have both its operand and result
+// as non quantized types. This facilitates transformations and marks edges
+// where the computation must be carried out in the expressed type.
+//
+// Especially early in transformation, it is common to have dbarriers on
+// all operands to ops that must operate with the expressed type (typically
+// math ops prior to lowering to target-specific, quantized kernels).
+def quant_DequantizeBarrierOp : quant_BarrierOp<"dbarrier", [NoSideEffect]>;
+
+// A StorageCast (scast) represents a cast from or to a type based on the
+// storage type and a type based on a corresponding quantized type.
+//
+// This op exists to ensure type coherency for between parts of the computation
+// which are operating directly on an underlying storage type and those which
+// operate on quantized values.
+//
+// Examples from storage to quantized type:
+//   i8 -> !quant<"uniform[i8:f32]{1.0}">
+//   tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
+//   vector<4xi8> -> vector<4x!quant<"uniform[i8:f32]{1.0}">>
+def quant_StorageCastOp :
+    quant_Op<"scast", [NoSideEffect]>,
+    Arguments<(ins quant_RealOrStorageValueType:$arg)>,
+    Results<(outs quant_RealOrStorageValueType)>;
+
+//===----------------------------------------------------------------------===//
+// Integral arithmetic ops used by kernels.
+//===----------------------------------------------------------------------===//
+
+def quant_RoundingDivideByPotIOp :
+    quant_Op<"rounding_divide_by_poti", [NoSideEffect, SameValueType]>,
+    Arguments<(ins quant_StorageValueType:$x, I32Attr:$exponent)>,
+    Results<(outs quant_StorageValueType:$y)> {
+  let description = [{
+    Computes integer division by a power-of-two, correctly rounded-to-nearest.
+    Also known as a rounding arithmetic right shift. See
+    gemmlowp::RoundingDivideByPOT for a reference implementation.
+  }];
+
+  let verifier = [{
+    auto verifyExponent = exponent().getSExtValue();
+    if (verifyExponent < 0 || verifyExponent > 31) {
+      return emitOpError("exponent must be in range [0..31]");
+    }
+    return success();
+  }];
+}
+
+def quant_SaturatingAddIOp :
+    quant_Op<"saturating_addi", [NoSideEffect, SameValueType]>,
+    Arguments<(ins quant_StorageValueType:$x,
+                   quant_StorageValueType:$y,
+                   I32Attr:$clamp_min,
+                   I32Attr:$clamp_max)>,
+    Results<(outs quant_StorageValueType:$sum)> {
+  let description = [{
+    Computes saturating addition of two operands, saturating to the given min
+    and max value. The implementation is responsible for choosing an
+    intermediate register size appropriate to carry out the operation without
+    overflow. See gemmlowp::SaturatingAdd for a reference implementation.
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// Real math ops.
+//
+// Math ops on real numbers which may have a representation in quantized
+// arithmetic. It is expected that eligible ops are lowered from a source
+// dialect to this set of ops prior to the process of converting a compuation
+// to a quantized form. It is a non-goal of these ops to preserve enough
+// information to convert back to the higher level, source dialect.
+//
+// These ops support either real/floating point or QuantizedTypes as operands
+// and results. Since not all transformations are supported (globally or
+// sometimes for specific targets), a computation may end up with
+// untransformable RealMathOps, in which case they need to be lowered as is
+// (using floating point math).
+//
+// This op set takes advantage of the fact that it is typically trivial to
+// combine a math function with a compatible bias addition and real-valued
+// clamp (which can be done at a higher accumulation bit depth).
+//
+// In addition, all element-wise unary functions are collapsed into a single
+// quant_RealUnaryEwOp and selected via an enum-like attribute. Especially at
+// low bit depths, this makes matching simpler and allows the construction of
+// generic LUT-based implementations. It also allows specific lowering rules
+// to consolidate runs of chained unary ops and fuse them to preceding math
+// ops, potentially allowing them to operate directly on higher precision
+// intermediates without resorting to lots of custom kernels for common
+// formulas that can suffer from insufficient precision at low bit depths.
+//
+// Comparison operators are modeled as element-wise unary functions (i.e.
+// CMPZ, CMPNZ, CMPLZ, CMPGZ) intended to follow a sub and output a 1bit
+// quantized value. It is expected that lowering rules can fuse them with
+// the preceding sub.
+//===----------------------------------------------------------------------===//
+
+class quant_RealMathOp<string mnemonic, list<OpTrait> traits = [], dag args> :
+    quant_Op<mnemonic, traits>,
+    Arguments<!con(args, (ins
+        quant_ClampValueAttr:$clamp_min, quant_ClampValueAttr:$clamp_max))>;
+
+//===----------------------------------------------------------------------===//
+// Element wise binary real math ops.
+//===----------------------------------------------------------------------===//
+
+class quant_RealBinaryOp<string mnemonic, list<OpTrait> traits = []> :
+    quant_RealMathOp<mnemonic, traits,
+                     (ins quant_RealValueType:$x, quant_RealValueType:$y)>,
+    Results<(outs quant_RealValueType:$r)>;
+
+class quant_RealBinaryBiasOp<string mnemonic, list<OpTrait> traits = []> :
+    quant_RealMathOp<mnemonic, traits,
+                     (ins quant_RealValueType:$x, quant_RealValueType:$y,
+                          quant_RealValueType:$bias)>,
+    Results<(outs quant_RealValueType:$r)>;
+
+def quant_RealAddEwOp :
+    quant_RealBinaryOp<"real_add_ew", [NoSideEffect]>;
+
+def quant_RealSubEwOp :
+    quant_RealBinaryOp<"real_sub_ew", [NoSideEffect]>;
+
+def quant_RealMulEwOp :
+    quant_RealBinaryOp<"real_mul_ew", [NoSideEffect]>;
+
+def quant_RealDivEwOp :
+    quant_RealBinaryOp<"real_div_ew", [NoSideEffect]>;
+
+//===----------------------------------------------------------------------===//
+// Element wise unary real math op.
+//===----------------------------------------------------------------------===//
+
+def quant_RealUnaryEwOp :
+    quant_RealMathOp<"real_unary_ew", [NoSideEffect],
+        (ins quant_RealValueType:$x, quant_EwUnaryFnAttr:$fn)>,
+    Results<(outs quant_RealValueType:$r)>;
+
+#endif  // QUANTIZATION_OPS
diff --git a/mlir/include/mlir/Quantization/QuantizeUtils.h b/mlir/include/mlir/Quantization/QuantizeUtils.h
new file mode 100644 (file)
index 0000000..0e4d04a
--- /dev/null
@@ -0,0 +1,70 @@
+//===- QuantizeUtils.h - Support utilities for quantization -----*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_QUANTIZATION_QUANTIZEUTILS_H_
+#define MLIR_QUANTIZATION_QUANTIZEUTILS_H_
+
+namespace mlir {
+class Attribute;
+class Type;
+
+namespace quant {
+class QuantizedType;
+class UniformQuantizedType;
+class UniformQuantizedValueConverter;
+
+/// Converts an attribute from a type based on
+/// quantizedElementType.getExpressedType() to one based on
+/// quantizedElementType.getStorageType(), where quantizedElementType is as from
+/// QuantizedType::getQuantizedElementType().
+/// Returns nullptr if the conversion is not supported. On success, stores the
+/// converted type in outConvertedType.
+///
+/// Examples:
+/// 1. realValue is a primitive value attribute:
+/// (realValue: FloatAttr, quantizedElementType: UniformQuantizedType[i8:f32])
+///   -> (IntegerAttr, outConvertedType: i8)
+/// 2. realValue is an elements attribute:
+/// (realValue: DenseElementsAttr[tensor<2x2xf32>],
+///  quantizedElementType: UniformQuantizedType[i8:f32])
+///   -> (DenseElementsAttr[tensor<2x2xi8>], outConvertedType: tensor<2x2xi8>)
+Attribute quantizeAttr(Attribute realValue, QuantizedType quantizedElementType,
+                       Type &outConvertedType);
+
+/// Converts an attribute from a type based on
+/// quantizedElementType.getExpressedType() to one based on
+/// quantizedElementType.getStorageType(), where quantizedElementType is as from
+/// QuantizedType::getQuantizedElementType() and casted to an
+/// UniformQuantizedType. Returns nullptr if the conversion is not supported. On
+/// success, stores the converted type in outConvertedType.
+///
+/// Examples:
+/// 1. realValue is a primitive value attribute:
+/// (realValue: FloatAttr, quantizedElementType: UniformQuantizedType[i8:f32])
+///   -> (IntegerAttr, outConvertedType: i8)
+/// 2. realValue is an elements attribute:
+/// (realValue: DenseElementsAttr[tensor<2x2xf32>],
+///  quantizedElementType: UniformQuantizedType[i8:f32])
+///   -> (DenseElementsAttr[tensor<2x2xi8>], outConvertedType: tensor<2x2xi8>)
+Attribute quantizeAttrUniform(Attribute realValue,
+                              UniformQuantizedType quantizedElementType,
+                              const UniformQuantizedValueConverter &converter,
+                              Type &outConvertedType);
+} // namespace quant
+} // namespace mlir
+
+#endif // MLIR_QUANTIZATION_QUANTIZEUTILS_H_
diff --git a/mlir/include/mlir/Quantization/UniformSupport.h b/mlir/include/mlir/Quantization/UniformSupport.h
new file mode 100644 (file)
index 0000000..a2055ee
--- /dev/null
@@ -0,0 +1,119 @@
+//===- UniformSupport.h - Support utilities for uniform quant ---*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef MLIR_QUANTIZATION_UNIFORMSUPPORT_H_
+#define MLIR_QUANTIZATION_UNIFORMSUPPORT_H_
+
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Quantization/QuantOps.h"
+#include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/APSInt.h"
+
+namespace mlir {
+namespace quant {
+
+/// Performs type conversion from an arbitrary input type to a type
+/// that is expressed by a UniformQuantizedType.
+///
+/// This handles cases where the inputType is a supported primitive type
+/// (i.e. f32, bf16, etc) or a vector/tensor type based on a supported
+/// elemental type.
+///
+/// Since conversion often involves introspecting some attributes of the
+/// input type in order to determine how to represent it, this is a two step
+/// process.
+struct ExpressedToUniformQuantizedConverter {
+  /// Creates a converter for the given input type.
+  static const ExpressedToUniformQuantizedConverter
+  forInputType(Type inputType);
+
+  /// Converts the inputType to be based on the given elemental type,
+  /// returning the new type (or nullptr and emit an error on failure).
+  Type convert(UniformQuantizedType elementalType) const;
+
+  /// Whether the conversion is legal.
+  explicit operator bool() const { return (bool)expressedType; }
+
+  /// The input type that is being converted from.
+  /// This may be an elemental or composite type.
+  const Type inputType;
+
+  /// Supported, elemental expressed type (i.e. f32).
+  /// Will be nullptr if conversion is not supported.
+  const Type expressedType;
+};
+
+/// Reference implementation of converting between real numbers and values
+/// represented by a UniformQuantizedType.
+/// Note that this is not expected to be speedy and may be superceded eventually
+/// by a more optimal implementation.
+/// Also, the interface assumes that quantization is done per-layer and will
+/// need to be wider for various per-channel schemes. As such, this is a
+/// placeholder.
+class UniformQuantizedValueConverter {
+public:
+  UniformQuantizedValueConverter(UniformQuantizedType uniformType)
+      : scale(uniformType.getScale()),
+        zeroPoint(static_cast<double>(uniformType.getZeroPoint())),
+        clampMin(static_cast<double>(uniformType.getStorageTypeMin())),
+        clampMax(static_cast<double>(uniformType.getStorageTypeMax())),
+        storageBitWidth(uniformType.getStorageTypeIntegralWidth()),
+        isSigned(uniformType.isSigned()) {
+    assert(uniformType.getExpressedType().isa<FloatType>());
+    assert(uniformType.getStorageType().isa<IntegerType>());
+  }
+
+  virtual APInt quantizeFloatToInt(APFloat expressedValue) const {
+    bool lossy;
+    expressedValue.convert(scale.getSemantics(), APFloat::rmNearestTiesToEven,
+                           &lossy);
+    // fixedpoint = clamp(clampMin, clampMax, (
+    //   roundHalfToEven(expressed / scale) + zeroPoint))
+    APFloat scaled = (expressedValue / scale);
+    scaled.roundToIntegral(APFloat::rmNearestTiesToEven);
+    scaled.add(zeroPoint, APFloat::rmNearestTiesToEven);
+    APFloat fixedpoint = llvm::minimum(scaled, clampMax);
+    fixedpoint = llvm::maximum(fixedpoint, clampMin);
+
+    llvm::APSInt result(storageBitWidth, !isSigned);
+    fixedpoint.convertToInteger(result, APFloat::rmNearestTiesToEven, &lossy);
+
+    return result;
+  }
+
+  int64_t quantizeFloatToInt64(APFloat expressedValue) const {
+    APInt qValue = quantizeFloatToInt(expressedValue);
+    return isSigned ? qValue.getSExtValue() : qValue.getZExtValue();
+  }
+
+  virtual ~UniformQuantizedValueConverter() {}
+
+private:
+  const APFloat scale;
+  const APFloat zeroPoint;
+  const APFloat clampMin;
+  const APFloat clampMax;
+  const uint32_t storageBitWidth;
+  const bool isSigned;
+};
+
+} // namespace quant
+} // namespace mlir
+
+#endif // MLIR_QUANTIZATION_UNIFORMSUPPORT_H_
diff --git a/mlir/lib/Quantization/IR/DialectRegistration.cpp b/mlir/lib/Quantization/IR/DialectRegistration.cpp
new file mode 100644 (file)
index 0000000..6beb193
--- /dev/null
@@ -0,0 +1,24 @@
+//===- DialectRegistration.cpp - Register Quantization dialect ------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Quantization/QuantOps.h"
+
+using namespace mlir;
+using namespace mlir::quant;
+
+// Static initialization for Quantization dialect registration.
+static mlir::DialectRegistration<QuantizationDialect> QuantizationOps;
diff --git a/mlir/lib/Quantization/IR/FakeQuantSupport.cpp b/mlir/lib/Quantization/IR/FakeQuantSupport.cpp
new file mode 100644 (file)
index 0000000..34457d9
--- /dev/null
@@ -0,0 +1,94 @@
+#include "mlir/Quantization/FakeQuantSupport.h"
+#include "mlir/Quantization/QuantOps.h"
+
+using namespace mlir;
+using namespace mlir::quant;
+
+UniformQuantizedType mlir::quant::fakeQuantAttrsToType(Location loc,
+                                                       unsigned numBits,
+                                                       double rmin, double rmax,
+                                                       bool narrowRange,
+                                                       Type expressedType) {
+  MLIRContext *ctx = expressedType.getContext();
+  Type storageType;
+  unsigned flags;
+  int64_t qmin;
+  int64_t qmax;
+
+  // Hard-coded type mapping from TFLite.
+  if (numBits <= 8) {
+    storageType = IntegerType::get(8, ctx);
+    flags = 0;
+    qmin = 0;
+    qmax = 255;
+  } else if (numBits <= 16) {
+    storageType = IntegerType::get(16, ctx);
+    flags = QuantizationFlags::Signed;
+    qmin = -32768;
+    qmax = 32767;
+  } else {
+    ctx->emitError(loc,
+                   "unsupported FakeQuant number of bits: " + Twine(numBits));
+    return nullptr;
+  }
+
+  // Handle narrowRange.
+  if (narrowRange) {
+    qmin += 1;
+  }
+
+  // Range must straddle zero.
+  if (rmin > 0.0 || rmax < 0.0) {
+    return (ctx->emitError(loc, "FakeQuant range must straddle zero: [" +
+                                    Twine(std::to_string(rmin)) + "," +
+                                    Twine(std::to_string(rmax)) + "]"),
+            nullptr);
+  }
+
+  // Special case where min/max is a point. Must be 0.
+  if (rmin == rmax) {
+    return UniformQuantizedType::getChecked(flags, storageType, expressedType,
+                                            0.0, 0, qmin, qmax, loc);
+  }
+
+  // Determine the scale.
+  const double qminDouble = qmin;
+  const double qmaxDouble = qmax;
+  const double scale = (rmax - rmin) / (qmaxDouble - qminDouble);
+
+  // Zero point computation.
+  // In float, solve the affine equation for any known pair
+  // (real value, corresponding quantized value), of which, two such pairs
+  // are known: (rmin, qmin), (rmax, qmax).
+  // The arithmetic error on the zero point computed from either pair will be
+  // roughly machine_epsilon * (sum of absolute values of terms).
+  // Use the variant that adds the smaller error.
+  const double zeroPointFromMin = qminDouble - rmin / scale;
+  const double zeroPointFromMinError =
+      std::abs(qminDouble) + std::abs(rmin / scale);
+  const double zeroPointFromMax = qmaxDouble - rmax / scale;
+  const double zeroPointFromMaxError =
+      std::abs(qmaxDouble) + std::abs(rmax / scale);
+
+  const double zeroPointDouble = (zeroPointFromMinError < zeroPointFromMaxError)
+                                     ? zeroPointFromMin
+                                     : zeroPointFromMax;
+
+  // Now nudge the zero point to be an integer.
+  int64_t nudgedZeroPoint = 0;
+  if (zeroPointDouble < qminDouble) {
+    nudgedZeroPoint = qmin;
+  } else if (zeroPointDouble > qmaxDouble) {
+    nudgedZeroPoint = qmax;
+  } else {
+    nudgedZeroPoint = round(zeroPointDouble);
+  }
+
+  // By construction, the nudged zero point should always be in range.
+  assert(nudgedZeroPoint >= qmin);
+  assert(nudgedZeroPoint <= qmax);
+
+  return UniformQuantizedType::getChecked(flags, storageType, expressedType,
+                                          scale, nudgedZeroPoint, qmin, qmax,
+                                          loc);
+}
diff --git a/mlir/lib/Quantization/IR/QuantOps.cpp b/mlir/lib/Quantization/IR/QuantOps.cpp
new file mode 100644 (file)
index 0000000..05a5162
--- /dev/null
@@ -0,0 +1,360 @@
+//===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Quantization/QuantOps.h"
+#include "TypeDetail.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/StandardTypes.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/MathExtras.h"
+
+using namespace mlir;
+using namespace mlir::quant;
+using namespace mlir::quant::detail;
+
+unsigned QuantizedType::getFlags() const {
+  return static_cast<ImplType *>(type)->flags;
+}
+
+LogicalResult QuantizedType::verifyConstructionInvariants(
+    llvm::Optional<Location> loc, MLIRContext *context, unsigned flags,
+    Type storageType, Type expressedType, int64_t storageTypeMin,
+    int64_t storageTypeMax) {
+  // Verify that the expressed type is floating point.
+  // If this restriction is ever eliminated, the parser/printer must be
+  // extended.
+  if (!expressedType.isa<FloatType>()) {
+    if (loc) {
+      context->emitError(*loc, "expressed type must be floating point");
+    }
+    return failure();
+  }
+
+  // Verify that the storage type is integral.
+  // This restriction may be lifted at some point in favor of using bf16
+  // or f16 as exact representations on hardware where that is advantageous.
+  auto intStorageType = storageType.dyn_cast<IntegerType>();
+  if (!intStorageType) {
+    if (loc) {
+      context->emitError(*loc, "storage type must be integral");
+    }
+    return failure();
+  }
+  unsigned integralWidth = intStorageType.getWidth();
+
+  // Verify storage width.
+  if (integralWidth == 0 || integralWidth > MaxStorageBits) {
+    if (loc) {
+      context->emitError(*loc,
+                         "illegal storage type size: " + Twine(integralWidth));
+    }
+    return failure();
+  }
+
+  // Verify storageTypeMin and storageTypeMax.
+  bool isSigned =
+      (flags & QuantizationFlags::Signed) == QuantizationFlags::Signed;
+  int64_t defaultIntegerMin =
+      getDefaultMininumForInteger(isSigned, integralWidth);
+  int64_t defaultIntegerMax =
+      getDefaultMaxinumForInteger(isSigned, integralWidth);
+  if (storageTypeMax - storageTypeMin <= 0 ||
+      storageTypeMin < defaultIntegerMin ||
+      storageTypeMax > defaultIntegerMax) {
+    if (loc) {
+      context->emitError(*loc, "illegal storage min and storage max: (" +
+                                   Twine(storageTypeMin) + ":" +
+                                   Twine(storageTypeMax) + ")");
+    }
+    return failure();
+  }
+  return success();
+}
+
+Type QuantizedType::getStorageType() const {
+  return static_cast<ImplType *>(type)->storageType;
+}
+
+int64_t QuantizedType::getStorageTypeMin() const {
+  return static_cast<ImplType *>(type)->storageTypeMin;
+}
+
+int64_t QuantizedType::getStorageTypeMax() const {
+  return static_cast<ImplType *>(type)->storageTypeMax;
+}
+
+unsigned QuantizedType::getStorageTypeIntegralWidth() const {
+  // NOTE: If ever supporting non-integral storage types, some other scheme
+  // for determining the width will be needed.
+  return static_cast<ImplType *>(type)->storageType.getIntOrFloatBitWidth();
+}
+
+Type QuantizedType::getExpressedType() const {
+  return static_cast<ImplType *>(type)->expressedType;
+}
+
+bool QuantizedType::isCompatibleExpressedType(Type candidateExpressedType) {
+  if (candidateExpressedType.isa<VectorOrTensorType>()) {
+    return candidateExpressedType.cast<VectorOrTensorType>().getElementType() ==
+           getExpressedType();
+  }
+  return candidateExpressedType == getExpressedType();
+}
+
+QuantizedType
+QuantizedType::getQuantizedElementType(Type primitiveOrContainerType) {
+  if (primitiveOrContainerType.isa<VectorOrTensorType>()) {
+    Type elementType =
+        primitiveOrContainerType.cast<VectorOrTensorType>().getElementType();
+    return elementType.dyn_cast<QuantizedType>();
+  }
+  return primitiveOrContainerType.dyn_cast<QuantizedType>();
+}
+
+Type QuantizedType::castFromStorageType(Type candidateType) {
+  if (candidateType == getStorageType()) {
+    // i.e. i32 -> quant<"uniform[i8:f32]{1.0}">
+    return *this;
+  } else if (candidateType.isa<RankedTensorType>()) {
+    // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
+    return RankedTensorType::get(
+        candidateType.cast<RankedTensorType>().getShape(), getStorageType());
+  } else if (candidateType.isa<UnrankedTensorType>()) {
+    // i.e. tensor<i8> -> tensor<!quant<"uniform[i8:f32]{1.0}">>
+    return UnrankedTensorType::get(getStorageType());
+  } else if (candidateType.isa<VectorType>()) {
+    // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
+    return VectorType::get(candidateType.cast<VectorType>().getShape(),
+                           getStorageType());
+  }
+
+  return nullptr;
+}
+
+Type QuantizedType::castToStorageType(Type quantizedType) {
+  if (quantizedType.isa<QuantizedType>()) {
+    // i.e. quant<"uniform[i8:f32]{1.0}"> -> i8
+    return quantizedType.cast<QuantizedType>().getStorageType();
+  } else if (quantizedType.isa<VectorOrTensorType>()) {
+    // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
+    VectorOrTensorType vtType = quantizedType.cast<VectorOrTensorType>();
+    if (!vtType.getElementType().isa<QuantizedType>()) {
+      return nullptr;
+    }
+    Type storageType =
+        vtType.getElementType().cast<QuantizedType>().getStorageType();
+    if (quantizedType.isa<RankedTensorType>()) {
+      return RankedTensorType::get(vtType.getShape(), storageType);
+    } else if (quantizedType.isa<UnrankedTensorType>()) {
+      return UnrankedTensorType::get(storageType);
+    } else if (quantizedType.isa<VectorType>()) {
+      return VectorType::get(vtType.getShape(), storageType);
+    }
+  }
+
+  return nullptr;
+}
+
+Type QuantizedType::castFromExpressedType(Type candidateType) {
+  if (candidateType == getExpressedType()) {
+    // i.e. f32 -> quant<"uniform[i8:f32]{1.0}">
+    return *this;
+  } else if (candidateType.isa<VectorOrTensorType>()) {
+    VectorOrTensorType candidateVtType =
+        candidateType.cast<VectorOrTensorType>();
+    if (candidateVtType.getElementType() != getExpressedType()) {
+      return nullptr;
+    }
+
+    if (candidateType.isa<RankedTensorType>()) {
+      // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
+      return RankedTensorType::get(candidateVtType.getShape(), *this);
+    } else if (candidateType.isa<UnrankedTensorType>()) {
+      // i.e. tensor<xf32> -> tensor<x!quant<"uniform[i8:f32]{1.0}">>
+      return UnrankedTensorType::get(*this);
+    } else if (candidateType.isa<VectorType>()) {
+      // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
+      return VectorType::get(candidateVtType.getShape(), *this);
+    }
+  }
+
+  return nullptr;
+}
+
+Type QuantizedType::castToExpressedType(Type quantizedType) {
+  if (quantizedType.isa<QuantizedType>()) {
+    // i.e. quant<"uniform[i8:f32]{1.0}"> -> f32
+    return quantizedType.cast<QuantizedType>().getExpressedType();
+  } else if (quantizedType.isa<VectorOrTensorType>()) {
+    // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
+    VectorOrTensorType vtType = quantizedType.cast<VectorOrTensorType>();
+    if (!vtType.getElementType().isa<QuantizedType>()) {
+      return nullptr;
+    }
+    Type expressedType =
+        vtType.getElementType().cast<QuantizedType>().getExpressedType();
+    if (quantizedType.isa<RankedTensorType>()) {
+      return RankedTensorType::get(vtType.getShape(), expressedType);
+    } else if (quantizedType.isa<UnrankedTensorType>()) {
+      return UnrankedTensorType::get(expressedType);
+    } else if (quantizedType.isa<VectorType>()) {
+      return VectorType::get(vtType.getShape(), expressedType);
+    }
+  }
+
+  return nullptr;
+}
+
+Type QuantizedType::castExpressedToStorageType(Type candidateType) {
+  Type expressedQuantizedType = castFromExpressedType(candidateType);
+  if (!expressedQuantizedType) {
+    return nullptr;
+  }
+  return QuantizedType::castToStorageType(expressedQuantizedType);
+}
+
+UniformQuantizedType UniformQuantizedType::get(unsigned flags, Type storageType,
+                                               Type expressedType, double scale,
+                                               int64_t zeroPoint,
+                                               int64_t storageTypeMin,
+                                               int64_t storageTypeMax) {
+  return Base::get(storageType.getContext(),
+                   QuantizationTypes::UniformQuantized, flags, storageType,
+                   expressedType, scale, zeroPoint, storageTypeMin,
+                   storageTypeMax);
+}
+
+UniformQuantizedType
+UniformQuantizedType::getChecked(unsigned flags, Type storageType,
+                                 Type expressedType, double scale,
+                                 int64_t zeroPoint, int64_t storageTypeMin,
+                                 int64_t storageTypeMax, Location location) {
+  return Base::getChecked(location, storageType.getContext(),
+                          QuantizationTypes::UniformQuantized, flags,
+                          storageType, expressedType, scale, zeroPoint,
+                          storageTypeMin, storageTypeMax);
+}
+
+LogicalResult UniformQuantizedType::verifyConstructionInvariants(
+    llvm::Optional<Location> loc, MLIRContext *context, unsigned flags,
+    Type storageType, Type expressedType, double scale, int64_t zeroPoint,
+    int64_t storageTypeMin, int64_t storageTypeMax) {
+  if (failed(QuantizedType::verifyConstructionInvariants(
+          loc, context, flags, storageType, expressedType, storageTypeMin,
+          storageTypeMax))) {
+    return failure();
+  }
+
+  // Verify scale.
+  if (scale <= 0.0 || isinf(scale) || isnan(scale)) {
+    if (loc) {
+      context->emitError(*loc,
+                         "illegal scale: " + Twine(std::to_string(scale)));
+    }
+    return failure();
+  }
+
+  return success();
+}
+
+double UniformQuantizedType::getScale() const { return getImpl()->scale; }
+
+int64_t UniformQuantizedType::getZeroPoint() const {
+  return getImpl()->zeroPoint;
+}
+
+UniformQuantizedPerAxisType UniformQuantizedPerAxisType::get(
+    unsigned flags, Type storageType, Type expressedType,
+    ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
+    int32_t quantizedDimension, int64_t storageTypeMin,
+    int64_t storageTypeMax) {
+  return Base::get(storageType.getContext(),
+                   QuantizationTypes::UniformQuantizedPerAxis, flags,
+                   storageType, expressedType, scales, zeroPoints,
+                   quantizedDimension, storageTypeMin, storageTypeMax);
+}
+
+UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked(
+    unsigned flags, Type storageType, Type expressedType,
+    ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
+    int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax,
+    Location location) {
+  return Base::getChecked(location, storageType.getContext(),
+                          QuantizationTypes::UniformQuantizedPerAxis, flags,
+                          storageType, expressedType, scales, zeroPoints,
+                          quantizedDimension, storageTypeMin, storageTypeMax);
+}
+
+LogicalResult UniformQuantizedPerAxisType::verifyConstructionInvariants(
+    llvm::Optional<Location> loc, MLIRContext *context, unsigned flags,
+    Type storageType, Type expressedType, ArrayRef<double> scales,
+    ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
+    int64_t storageTypeMin, int64_t storageTypeMax) {
+  if (failed(QuantizedType::verifyConstructionInvariants(
+          loc, context, flags, storageType, expressedType, storageTypeMin,
+          storageTypeMax))) {
+    return failure();
+  }
+
+  // Ensure that the number of scales and zeroPoints match.
+  if (scales.size() != zeroPoints.size()) {
+    if (loc) {
+      context->emitError(*loc, "illegal number of scales and zeroPoints: " +
+                                   Twine(scales.size()) + ", " +
+                                   Twine(zeroPoints.size()));
+    }
+    return failure();
+  }
+
+  // Verify scale.
+  for (double scale : scales) {
+    if (scale <= 0.0 || isinf(scale) || isnan(scale)) {
+      if (loc) {
+        context->emitError(*loc,
+                           "illegal scale: " + Twine(std::to_string(scale)));
+      }
+      return failure();
+    }
+  }
+
+  return success();
+}
+
+ArrayRef<double> UniformQuantizedPerAxisType::getScales() const {
+  return getImpl()->getScales();
+}
+
+ArrayRef<int64_t> UniformQuantizedPerAxisType::getZeroPoints() const {
+  return getImpl()->getZeroPoints();
+}
+
+int32_t UniformQuantizedPerAxisType::getQuantizedDimension() const {
+  return getImpl()->quantizedDimension;
+}
+
+#define GET_OP_CLASSES
+#include "mlir/Quantization/QuantOps.cpp.inc"
+
+QuantizationDialect::QuantizationDialect(MLIRContext *context)
+    : Dialect(/*name=*/"quant", context) {
+  addTypes<UniformQuantizedType, UniformQuantizedPerAxisType>();
+  addOperations<
+#define GET_OP_LIST
+#include "mlir/Quantization/QuantOps.cpp.inc"
+      >();
+}
diff --git a/mlir/lib/Quantization/IR/TypeDetail.h b/mlir/lib/Quantization/IR/TypeDetail.h
new file mode 100644 (file)
index 0000000..d3db91e
--- /dev/null
@@ -0,0 +1,219 @@
+//===- Quantization/IR/TypeDetail.h - Type detail ---------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef TYPE_DETAIL_H_
+#define TYPE_DETAIL_H_
+
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/TypeSupport.h"
+#include "mlir/IR/Types.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/Hashing.h"
+#include "llvm/ADT/bit.h"
+
+namespace mlir {
+namespace quant {
+namespace detail {
+
+struct QuantizedTypeStorage : public mlir::TypeStorage {
+  QuantizedTypeStorage(unsigned flags, Type storageType, Type expressedType,
+                       int64_t storageTypeMin, int64_t storageTypeMax)
+      : flags(flags), storageType(storageType), expressedType(expressedType),
+        storageTypeMin(storageTypeMin), storageTypeMax(storageTypeMax) {}
+
+  /// Flags corresponding to the bitmapped enum QuantizationFlags::FlagValue.
+  unsigned flags;
+
+  // Integral type for the storage point representation.
+  Type storageType;
+
+  // Floating point type that the quantized type approximates.
+  Type expressedType;
+
+  // The minimum value storageType can take.
+  int64_t storageTypeMin;
+
+  // The maximum value storageType can take.
+  int64_t storageTypeMax;
+};
+
+struct UniformQuantizedTypeStorage : public QuantizedTypeStorage {
+  struct KeyTy {
+    KeyTy(unsigned flags, Type storageType, Type expressedType, double scale,
+          int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax)
+        : flags(flags), storageType(storageType), expressedType(expressedType),
+          scale(scale), zeroPoint(zeroPoint), storageTypeMin(storageTypeMin),
+          storageTypeMax(storageTypeMax) {}
+    /// Flags corresponding to the bitmapped enum QuantizationFlags::FlagValue.
+    unsigned flags;
+
+    // Integral type for the storage point representation.
+    Type storageType;
+
+    // Floating point type that the quantized type approximates.
+    Type expressedType;
+
+    double scale;
+    int64_t zeroPoint;
+    int64_t storageTypeMin;
+    int64_t storageTypeMax;
+
+    // Check for equality of two structures that share KeyTy data members
+    // (by name).
+    template <typename T, typename U>
+    static bool genericIsEqual(const T &lhs, const U &rhs) {
+      return lhs.flags == rhs.flags && lhs.storageType == rhs.storageType &&
+             lhs.expressedType == rhs.expressedType && lhs.scale == rhs.scale &&
+             lhs.zeroPoint == rhs.zeroPoint &&
+             lhs.storageTypeMin == rhs.storageTypeMin &&
+             lhs.storageTypeMax == rhs.storageTypeMax;
+    }
+
+    bool operator==(const KeyTy &other) const {
+      return genericIsEqual(*this, other);
+    }
+
+    unsigned getHashValue() const {
+      int64_t scaleBits = llvm::bit_cast<int64_t>(scale);
+      return llvm::hash_combine(flags, storageType, expressedType, scaleBits,
+                                zeroPoint, storageTypeMin, storageTypeMax);
+    }
+  };
+
+  UniformQuantizedTypeStorage(const KeyTy &key)
+      : QuantizedTypeStorage(key.flags, key.storageType, key.expressedType,
+                             key.storageTypeMin, key.storageTypeMax),
+        scale(key.scale), zeroPoint(key.zeroPoint) {}
+
+  bool operator==(const KeyTy &key) const {
+    return KeyTy::genericIsEqual(*this, key);
+  }
+
+  /// Construction.
+  static UniformQuantizedTypeStorage *construct(TypeStorageAllocator &allocator,
+                                                const KeyTy &key) {
+    return new (allocator.allocate<UniformQuantizedTypeStorage>())
+        UniformQuantizedTypeStorage(key);
+  }
+
+  static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); }
+
+  double scale;
+  int64_t zeroPoint;
+};
+
+struct UniformQuantizedPerAxisTypeStorage : public QuantizedTypeStorage {
+  struct KeyTy {
+    KeyTy(unsigned flags, Type storageType, Type expressedType,
+          ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
+          int32_t quantizedDimension, int64_t storageTypeMin,
+          int64_t storageTypeMax)
+        : flags(flags), storageType(storageType), expressedType(expressedType),
+          scales(scales), zeroPoints(zeroPoints),
+          quantizedDimension(quantizedDimension),
+          storageTypeMin(storageTypeMin), storageTypeMax(storageTypeMax) {}
+    /// Flags corresponding to the bitmapped enum QuantizationFlags::FlagValue.
+    unsigned flags;
+
+    // Integral type for the storage point representation.
+    Type storageType;
+
+    // Floating point type that the quantized type approximates.
+    Type expressedType;
+
+    ArrayRef<double> scales;
+    ArrayRef<int64_t> zeroPoints;
+    int32_t quantizedDimension;
+    int64_t storageTypeMin;
+    int64_t storageTypeMax;
+
+    ArrayRef<double> getScales() const { return scales; }
+
+    ArrayRef<int64_t> getZeroPoints() const { return zeroPoints; }
+
+    // Check for equality of two structures that share KeyTy data members
+    // (by name).
+    template <typename T, typename U>
+    static bool genericIsEqual(const T &lhs, const U &rhs) {
+      return lhs.flags == rhs.flags && lhs.storageType == rhs.storageType &&
+             lhs.expressedType == rhs.expressedType &&
+             lhs.getScales() == rhs.getScales() &&
+             lhs.getZeroPoints() == rhs.getZeroPoints() &&
+             lhs.quantizedDimension == rhs.quantizedDimension &&
+             lhs.storageTypeMin == rhs.storageTypeMin &&
+             lhs.storageTypeMax == rhs.storageTypeMax;
+    }
+
+    bool operator==(const KeyTy &other) const {
+      return genericIsEqual(*this, other);
+    }
+
+    unsigned getHashValue() const {
+      int64_t *scalesCast = llvm::bit_cast<int64_t *>(scales.data());
+      ArrayRef<int64_t> scalesBits(scalesCast, scales.size());
+      return llvm::hash_combine(
+          flags, storageType, expressedType,
+          llvm::hash_combine_range(scalesBits.begin(), scalesBits.end()),
+          llvm::hash_combine_range(zeroPoints.begin(), zeroPoints.end()),
+          storageTypeMin, storageTypeMax);
+    }
+  };
+
+  // We pass scales and zeroPoints in directly rather than relying on KeyTy
+  // because we have to create new reallocated versions in `constrcut` below.
+  UniformQuantizedPerAxisTypeStorage(const KeyTy &key, ArrayRef<double> scales,
+                                     ArrayRef<int64_t> zeroPoints)
+      : QuantizedTypeStorage(key.flags, key.storageType, key.expressedType,
+                             key.storageTypeMin, key.storageTypeMax),
+        scaleElements(scales.data()), zeroPointElements(zeroPoints.data()),
+        quantParamsSize(scales.size()),
+        quantizedDimension(key.quantizedDimension) {}
+
+  bool operator==(const KeyTy &key) const {
+    return KeyTy::genericIsEqual(*this, key);
+  }
+
+  /// Construction.
+  static UniformQuantizedPerAxisTypeStorage *
+  construct(TypeStorageAllocator &allocator, const KeyTy &key) {
+    ArrayRef<double> scales = allocator.copyInto(key.scales);
+    ArrayRef<int64_t> zeroPoints = allocator.copyInto(key.zeroPoints);
+    return new (allocator.allocate<UniformQuantizedPerAxisTypeStorage>())
+        UniformQuantizedPerAxisTypeStorage(key, scales, zeroPoints);
+  }
+
+  static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); }
+
+  ArrayRef<double> getScales() const {
+    return ArrayRef<double>(scaleElements, quantParamsSize);
+  }
+
+  ArrayRef<int64_t> getZeroPoints() const {
+    return ArrayRef<int64_t>(zeroPointElements, quantParamsSize);
+  }
+
+  const double *scaleElements;
+  const int64_t *zeroPointElements;
+  unsigned quantParamsSize;
+  int32_t quantizedDimension;
+};
+
+} // namespace detail
+} // namespace quant
+} // namespace mlir
+
+#endif // TYPE_DETAIL_H_
diff --git a/mlir/lib/Quantization/IR/TypeParser.cpp b/mlir/lib/Quantization/IR/TypeParser.cpp
new file mode 100644 (file)
index 0000000..352e952
--- /dev/null
@@ -0,0 +1,653 @@
+//===- Quantization/IR/TypeParser.h - Quantization Type Parser --*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/Location.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Quantization/QuantOps.h"
+#include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/StringSwitch.h"
+#include "llvm/Support/Format.h"
+#include "llvm/Support/MathExtras.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace mlir {
+namespace quant {
+
+/// Print a floating point value in a way that the parser will be able to
+/// round-trip losslessly.
+static void printStabilizedFloat(const APFloat &apValue, raw_ostream &os) {
+  // We would like to output the FP constant value in exponential notation,
+  // but we cannot do this if doing so will lose precision.  Check here to
+  // make sure that we only output it in exponential format if we can parse
+  // the value back and get the same value.
+  bool isInf = apValue.isInfinity();
+  bool isNaN = apValue.isNaN();
+  if (!isInf && !isNaN) {
+    SmallString<128> strValue;
+    apValue.toString(strValue, 6, 0, false);
+
+    // Check to make sure that the stringized number is not some string like
+    // "Inf" or NaN, that atof will accept, but the lexer will not.  Check
+    // that the string matches the "[-+]?[0-9]" regex.
+    assert(((strValue[0] >= '0' && strValue[0] <= '9') ||
+            ((strValue[0] == '-' || strValue[0] == '+') &&
+             (strValue[1] >= '0' && strValue[1] <= '9'))) &&
+           "[-+]?[0-9] regex does not match!");
+    // Reparse stringized version!
+    if (APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(apValue)) {
+      os << strValue;
+      return;
+    }
+  }
+
+  SmallVector<char, 16> str;
+  apValue.toString(str);
+  os << str;
+}
+
+namespace {
+
+enum class TokenKind {
+  error,
+  eof,
+  l_bracket,
+  r_bracket,
+  l_brace,
+  r_brace,
+  l_paren,
+  r_paren,
+  colon,
+  comma,
+  alpha_ident,
+  integer_literal,
+  float_literal,
+};
+
+struct Token {
+  TokenKind kind;
+  StringRef spelling;
+};
+
+class Lexer {
+public:
+  Lexer(StringRef source) : curBuffer(source), curPtr(curBuffer.begin()) {}
+
+  Token lexToken();
+
+private:
+  Token formToken(TokenKind kind, const char *tokStart) {
+    return Token{kind, StringRef(tokStart, curPtr - tokStart)};
+  }
+
+  Token emitError(const char *loc, const Twine &message) {
+    return formToken(TokenKind::error, loc);
+  }
+
+  bool isEnd() const { return curPtr == curBuffer.end(); }
+
+  // Lexer implementation methods
+  Token lexalpha_ident(const char *tokStart);
+  Token lexNumber(const char *tokStart);
+
+  StringRef curBuffer;
+  const char *curPtr;
+};
+
+} // namespace
+
+Token Lexer::lexToken() {
+  // Ignore whitespace.
+  while (!isEnd()) {
+    switch (*curPtr) {
+    case ' ':
+    case '\t':
+    case '\n':
+    case '\r':
+      ++curPtr;
+      continue;
+    default:
+      break;
+    }
+    break;
+  }
+
+  if (isEnd()) {
+    return Token{TokenKind::eof, ""};
+  }
+
+  const char *tokStart = curPtr;
+  switch (*curPtr++) {
+  default:
+    if (isalpha(*tokStart)) {
+      return lexalpha_ident(tokStart);
+    }
+    if (isdigit(*tokStart)) {
+      return lexNumber(tokStart);
+    }
+
+    return emitError(tokStart, "unexpected character");
+
+  case '[':
+    return formToken(TokenKind::l_bracket, tokStart);
+  case ']':
+    return formToken(TokenKind::r_bracket, tokStart);
+  case '{':
+    return formToken(TokenKind::l_brace, tokStart);
+  case '}':
+    return formToken(TokenKind::r_brace, tokStart);
+  case '(':
+    return formToken(TokenKind::l_paren, tokStart);
+  case ')':
+    return formToken(TokenKind::r_paren, tokStart);
+  case ':':
+    return formToken(TokenKind::colon, tokStart);
+  case ',':
+    return formToken(TokenKind::comma, tokStart);
+  case '-':
+    return lexNumber(tokStart);
+  case '+':
+    return lexNumber(tokStart);
+  }
+}
+
+/// Lex a bare alpha identifier. Since this DSL often contains identifiers with
+/// trailing numeric components, this only matches alphas. It is up to the
+/// parser to handle identifiers that can be mixed alphanum.
+///
+///   alpha-ident ::= (letter)(letter)*
+Token Lexer::lexalpha_ident(const char *tokStart) {
+  while (!isEnd() && isalpha(*curPtr)) {
+    ++curPtr;
+  }
+  return formToken(TokenKind::alpha_ident, tokStart);
+}
+
+/// Lex a number.
+///
+///   integer-literal ::= [-+]?digit+
+///   float-literal ::= [-+]?[0-9]+[.][0-9]*([eE][-+]?[0-9]+)?
+Token Lexer::lexNumber(const char *tokStart) {
+  // Leading '+', '-' or digit has already been consumed.
+  while (!isEnd() && isdigit(*curPtr)) {
+    ++curPtr;
+  }
+  // If not a decimal point, treat as integer.
+  if (isEnd() || *curPtr != '.') {
+    return formToken(TokenKind::integer_literal, tokStart);
+  }
+  ++curPtr;
+
+  // Skip over [0-9]*([eE][-+]?[0-9]+)?
+  // Leading digits.
+  while (!isEnd() && isdigit(*curPtr)) {
+    ++curPtr;
+  }
+
+  // [eE][-+]?[0-9]+
+  if (!isEnd() && (*curPtr == 'e' || *curPtr == 'E')) {
+    auto remaining = curBuffer.end() - curPtr;
+    if (remaining > 2 && isdigit(curPtr[1])) {
+      // Lookahead 2 for digit.
+      curPtr += 2;
+      while (!isEnd() && isdigit(*curPtr)) {
+        ++curPtr;
+      }
+    } else if (remaining > 3 && (curPtr[1] == '-' || curPtr[1] == '+') &&
+               isdigit(curPtr[2])) {
+      // Lookahead 3 for [+-] digit.
+      curPtr += 3;
+      while (!isEnd() && isdigit(*curPtr)) {
+        ++curPtr;
+      }
+    }
+  }
+  return formToken(TokenKind::float_literal, tokStart);
+} // end namespace
+
+// --- TypeParser ---
+namespace {
+
+class TypeParser {
+public:
+  TypeParser(StringRef source, MLIRContext *context, Location location)
+      : context(context), location(location), lexer(source),
+        curToken(lexer.lexToken()) {}
+
+  /// Attempts to parse the source as a type, returning the unknown
+  /// type on error.
+  Type parseType();
+
+private:
+  /// Unconditionally consumes the current token.
+  void consumeToken() {
+    assert(curToken.kind != TokenKind::eof &&
+           "should not advance past EOF or errors");
+    curToken = lexer.lexToken();
+  }
+
+  /// Unconditionally consumes the current token, asserting that it is of the
+  /// specified kind.
+  void consumeToken(TokenKind kind) {
+    assert(curToken.kind == kind && "consumed an unexpected token");
+    consumeToken();
+  }
+
+  /// Conditionally consumes a token if of the specified kind.
+  /// Returns true if consumed.
+  bool consumeIf(TokenKind kind) {
+    if (curToken.kind == kind) {
+      consumeToken();
+      return true;
+    }
+    return false;
+  }
+
+  /// Emits an error at the current location with a message.
+  void emitError(const Twine &message) {
+    // TODO: All errors show up at the beginning of the extended type location.
+    // Figure out how to make this location relative to where the error occurred
+    // in this instance.
+    context->emitError(location, message);
+  }
+
+  // Parsers.
+  Type parseUniformType();
+  IntegerType parseStorageType(bool &isSigned);
+  FloatType parseExpressedType();
+  bool parseQuantParams(double &scale, int64_t &zeroPoint);
+
+  MLIRContext *context;
+  Location location;
+  Lexer lexer;
+
+  // The next token that has not yet been consumed.
+  Token curToken;
+};
+
+} // namespace
+
+Type TypeParser::parseType() {
+  // All types start with an identifier that we switch on.
+  if (curToken.kind == TokenKind::alpha_ident) {
+    StringRef typeNameSpelling = curToken.spelling;
+    consumeToken();
+
+    Type result;
+    if (typeNameSpelling == "uniform") {
+      result = parseUniformType();
+      if (!result) {
+        return nullptr;
+      }
+    } else {
+      return (emitError("unknown quantized type " + typeNameSpelling), nullptr);
+    }
+
+    // Make sure the entire input was consumed.
+    if (curToken.kind != TokenKind::eof) {
+      return (emitError("unrecognized token: " + curToken.spelling), nullptr);
+    }
+
+    return result;
+  } else {
+    return (emitError("unrecognized token: " + curToken.spelling), nullptr);
+  }
+}
+
+/// Parses a UniformQuantizedType.
+///
+///   uniform_type ::= `uniform` type_spec quant_param_spec
+///
+///   type_spec ::= `[` storage-spec `:` expressed-type (quant-dim)? `]`
+///   quant-dim ::= `:` integer-literal
+///   storage-spec ::= storage-type (`(` storage-range `)`)?
+///   storage-range ::= integer-literal `:` integer-literal
+///   storage-type ::= (`i` | `u`) integer-literal
+///   expressed-type ::= (`f16` | `f32` | `f64` | `bf16`)
+///
+///   quant_param_spec ::= `{` scale-zero (`,` scale-zero )* `}`
+///   scale-zero ::= float-literal `:` integer-literal
+Type TypeParser::parseUniformType() {
+  IntegerType storageType;
+  FloatType expressedType;
+  unsigned typeFlags = 0;
+  int64_t storageTypeMin;
+  int64_t storageTypeMax;
+  bool isPerAxis = false;
+  int32_t quantizedDimension;
+  SmallVector<double, 1> scales;
+  SmallVector<int64_t, 1> zeroPoints;
+
+  // Type specification.
+  if (!consumeIf(TokenKind::l_bracket)) {
+    return (emitError("unrecognized token: " + curToken.spelling), nullptr);
+  }
+
+  // Storage type.
+  bool isSigned = false;
+  storageType = parseStorageType(isSigned);
+  if (!storageType) {
+    return nullptr;
+  }
+  if (isSigned) {
+    typeFlags |= QuantizationFlags::Signed;
+  }
+
+  // Storage type range.
+  int64_t defaultIntegerMin = QuantizedType::getDefaultMininumForInteger(
+      isSigned, storageType.getWidth());
+  int64_t defaultIntegerMax = QuantizedType::getDefaultMaxinumForInteger(
+      isSigned, storageType.getWidth());
+  if (consumeIf(TokenKind::l_paren)) {
+    // Explicit storage min and storage max.
+    if (curToken.kind != TokenKind::integer_literal) {
+      return (emitError("expected storage type minimum"), nullptr);
+    }
+    if (curToken.spelling.getAsInteger(10, storageTypeMin) ||
+        storageTypeMin < defaultIntegerMin) {
+      return (emitError("illegal storage type minimum: " + curToken.spelling),
+              nullptr);
+    }
+    consumeToken(TokenKind::integer_literal);
+
+    if (!consumeIf(TokenKind::colon)) {
+      return (emitError("unrecognized token: " + curToken.spelling), nullptr);
+    }
+
+    if (curToken.kind != TokenKind::integer_literal) {
+      return (emitError("expected storage type maximum"), nullptr);
+    }
+    if (curToken.spelling.getAsInteger(10, storageTypeMax) ||
+        storageTypeMax > defaultIntegerMax) {
+      return (emitError("illegal storage type maximum: " + curToken.spelling),
+              nullptr);
+    }
+    consumeToken(TokenKind::integer_literal);
+
+    if (!consumeIf(TokenKind::r_paren)) {
+      return (emitError("unrecognized token: " + curToken.spelling), nullptr);
+    }
+  } else {
+    storageTypeMin = defaultIntegerMin;
+    storageTypeMax = defaultIntegerMax;
+  }
+
+  // Repr type.
+  if (!consumeIf(TokenKind::colon)) {
+    return (emitError("unrecognized token: " + curToken.spelling), nullptr);
+  }
+  expressedType = parseExpressedType();
+  if (!expressedType) {
+    return nullptr;
+  }
+
+  // Optionally parse quantized dimension for per-axis quantization.
+  if (consumeIf(TokenKind::colon)) {
+    if (curToken.kind != TokenKind::integer_literal) {
+      return (emitError("expected quantized dimension"), nullptr);
+    }
+    if (curToken.spelling.getAsInteger(10, quantizedDimension)) {
+      return (emitError("illegal quantized dimension: " + curToken.spelling),
+              nullptr);
+    }
+    consumeToken(TokenKind::integer_literal);
+    isPerAxis = true;
+  }
+
+  if (!consumeIf(TokenKind::r_bracket)) {
+    return (emitError("unrecognized token: " + curToken.spelling), nullptr);
+  }
+
+  // Parameter specification.
+  if (!consumeIf(TokenKind::l_brace)) {
+    return (emitError("unrecognized token: " + curToken.spelling), nullptr);
+  }
+
+  // Parse scales/zeroPoints.
+  do {
+    scales.resize(scales.size() + 1);
+    zeroPoints.resize(zeroPoints.size() + 1);
+    if (parseQuantParams(scales.back(), zeroPoints.back())) {
+      return nullptr;
+    }
+  } while (consumeIf(TokenKind::comma));
+
+  if (!consumeIf(TokenKind::r_brace)) {
+    return (emitError("unrecognized token: " + curToken.spelling), nullptr);
+  }
+
+  if (!isPerAxis && scales.size() > 1) {
+    return (emitError("multiple scales/zeroPoints provided, but "
+                      "quantizedDimension wasn't specified"),
+            nullptr);
+  }
+
+  if (isPerAxis) {
+    ArrayRef<double> scalesRef(scales.begin(), scales.end());
+    ArrayRef<int64_t> zeroPointsRef(zeroPoints.begin(), zeroPoints.end());
+    return UniformQuantizedPerAxisType::getChecked(
+        typeFlags, storageType, expressedType, scalesRef, zeroPointsRef,
+        quantizedDimension, storageTypeMin, storageTypeMax, location);
+  }
+
+  return UniformQuantizedType::getChecked(
+      typeFlags, storageType, expressedType, scales.front(), zeroPoints.front(),
+      storageTypeMin, storageTypeMax, location);
+}
+
+IntegerType TypeParser::parseStorageType(bool &isSigned) {
+  // Parse storage type (alpha_ident, integer_literal).
+  StringRef storageTypePrefix = curToken.spelling;
+  unsigned storageTypeWidth;
+  if (curToken.kind != TokenKind::alpha_ident) {
+    return (emitError("expected storage type prefix"), nullptr);
+  }
+  consumeToken();
+  if (curToken.kind != TokenKind::integer_literal) {
+    return (emitError("expected storage type width"), nullptr);
+  }
+  if (curToken.spelling.getAsInteger(10, storageTypeWidth) ||
+      storageTypeWidth == 0 ||
+      storageTypeWidth > QuantizedType::MaxStorageBits) {
+    return (emitError("illegal storage type size: " + Twine(curToken.spelling)),
+            nullptr);
+  }
+  consumeToken();
+
+  if (storageTypePrefix == "i") {
+    isSigned = true;
+    return IntegerType::get(storageTypeWidth, context);
+  } else if (storageTypePrefix == "u") {
+    isSigned = false;
+    return IntegerType::get(storageTypeWidth, context);
+  } else {
+    return (
+        emitError("illegal storage type prefix: " + Twine(storageTypePrefix)),
+        nullptr);
+  }
+}
+
+FloatType TypeParser::parseExpressedType() {
+  // Expect an alpha_ident followed by integer literal that we concat back
+  // together.
+  StringRef prefix = curToken.spelling;
+  if (!consumeIf(TokenKind::alpha_ident)) {
+    return (emitError("expected expressed type"), nullptr);
+  }
+  StringRef suffix = curToken.spelling;
+  if (!consumeIf(TokenKind::integer_literal)) {
+    return (emitError("expected expressed type"), nullptr);
+  }
+
+  SmallVector<char, 4> holder;
+  StringRef typeName = (Twine(prefix) + Twine(suffix)).toStringRef(holder);
+  if (typeName == "f32")
+    return FloatType::getF32(context);
+  if (typeName == "f16")
+    return FloatType::getF16(context);
+  if (typeName == "bf16")
+    return FloatType::getBF16(context);
+  if (typeName == "f64")
+    return FloatType::getF64(context);
+
+  return (emitError("unrecognized expressed type: " + typeName), nullptr);
+}
+
+bool TypeParser::parseQuantParams(double &scale, int64_t &zeroPoint) {
+  // scale[:zeroPoint]?
+  // scale.
+  StringRef scaleSpelling = curToken.spelling;
+  if (!consumeIf(TokenKind::float_literal) ||
+      scaleSpelling.getAsDouble(scale)) {
+    return (
+        emitError("expected valid uniform scale. got: " + Twine(scaleSpelling)),
+        true);
+  }
+
+  // zero point.
+  zeroPoint = 0;
+  if (!consumeIf(TokenKind::colon)) {
+    // Default zero point.
+    return false;
+  }
+  StringRef zeroPointSpelling = curToken.spelling;
+  if (!consumeIf(TokenKind::integer_literal) ||
+      zeroPointSpelling.getAsInteger(10, zeroPoint)) {
+    return (emitError("expected integer uniform zero point. got: " +
+                      Twine(zeroPointSpelling)),
+            true);
+  }
+
+  return false;
+}
+
+/// Parse a type registered to this dialect.
+Type QuantizationDialect::parseType(StringRef spec, Location loc) const {
+  TypeParser parser(spec, getContext(), loc);
+  Type parsedType = parser.parseType();
+  if (parsedType == nullptr) {
+    // Error.
+    // TODO(laurenzo): Do something?
+    return parsedType;
+  }
+
+  return parsedType;
+}
+
+static void printStorageType(QuantizedType type, raw_ostream &out) {
+  // storage type
+  unsigned storageWidth = type.getStorageTypeIntegralWidth();
+  bool isSigned = type.isSigned();
+  if (isSigned) {
+    out << "i" << storageWidth;
+  } else {
+    out << "u" << storageWidth;
+  }
+
+  // storageTypeMin and storageTypeMax if not default.
+  int64_t defaultIntegerMin =
+      QuantizedType::getDefaultMininumForInteger(isSigned, storageWidth);
+  int64_t defaultIntegerMax =
+      QuantizedType::getDefaultMaxinumForInteger(isSigned, storageWidth);
+  if (defaultIntegerMin != type.getStorageTypeMin() ||
+      defaultIntegerMax != type.getStorageTypeMax()) {
+    out << "(" << type.getStorageTypeMin() << ":" << type.getStorageTypeMax()
+        << ")";
+  }
+}
+
+static void printExpressedType(QuantizedType type, raw_ostream &out) {
+  // repr type
+  Type expressedType = type.getExpressedType();
+  if (expressedType.isF32()) {
+    out << "f32";
+  } else if (expressedType.isF64()) {
+    out << "f64";
+  } else if (expressedType.isF16()) {
+    out << "f16";
+  } else if (expressedType.isBF16()) {
+    out << "bf16";
+  } else {
+    out << "unknown";
+  }
+}
+
+static void printQuantParams(double scale, int64_t zeroPoint,
+                             raw_ostream &out) {
+  printStabilizedFloat(APFloat(scale), out);
+  if (zeroPoint != 0) {
+    out << ":" << zeroPoint;
+  }
+}
+
+/// Helper that prints a UniformQuantizedType.
+static void printUniformQuantizedType(UniformQuantizedType type,
+                                      raw_ostream &out) {
+  out << "uniform[";
+  printStorageType(type, out);
+  out << ":";
+  printExpressedType(type, out);
+  out << "]";
+
+  // scheme specific parameters
+  out << "{";
+  printQuantParams(type.getScale(), type.getZeroPoint(), out);
+  out << "}";
+}
+
+/// Helper that prints a UniformQuantizedPerAxisType.
+static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type,
+                                             raw_ostream &out) {
+  out << "uniform[";
+  printStorageType(type, out);
+  out << ":";
+  printExpressedType(type, out);
+  out << ":";
+  out << type.getQuantizedDimension();
+  out << "]";
+
+  // scheme specific parameters
+  ArrayRef<double> scales = type.getScales();
+  ArrayRef<int64_t> zeroPoints = type.getZeroPoints();
+  out << "{";
+  for (unsigned i = 0; i < scales.size(); ++i) {
+    printQuantParams(scales[i], zeroPoints[i], out);
+    if (i != scales.size() - 1) {
+      out << ",";
+    }
+  }
+  out << "}";
+}
+
+/// Print a type registered to this dialect.
+void QuantizationDialect::printType(Type type, raw_ostream &os) const {
+  switch (type.getKind()) {
+  default:
+    llvm_unreachable("Unhandled quantized type");
+  case QuantizationTypes::UniformQuantized:
+    printUniformQuantizedType(type.cast<UniformQuantizedType>(), os);
+    break;
+  case QuantizationTypes::UniformQuantizedPerAxis:
+    printUniformQuantizedPerAxisType(type.cast<UniformQuantizedPerAxisType>(),
+                                     os);
+    break;
+  }
+}
+
+} // namespace quant
+} // namespace mlir
diff --git a/mlir/lib/Quantization/IR/UniformSupport.cpp b/mlir/lib/Quantization/IR/UniformSupport.cpp
new file mode 100644 (file)
index 0000000..d9549bb
--- /dev/null
@@ -0,0 +1,73 @@
+//===- UniformSupport.cpp - Support utilities for uniform quant -----------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Quantization/UniformSupport.h"
+#include "mlir/IR/StandardTypes.h"
+
+using namespace mlir;
+using namespace mlir::quant;
+
+static bool isQuantizablePrimitiveType(Type inputType) {
+  return inputType.isa<FloatType>();
+}
+
+const ExpressedToUniformQuantizedConverter
+ExpressedToUniformQuantizedConverter::forInputType(Type inputType) {
+  switch (inputType.getKind()) {
+  default:
+    if (isQuantizablePrimitiveType(inputType)) {
+      // Supported primitive type (which just is the expressed type).
+      return ExpressedToUniformQuantizedConverter{inputType, inputType};
+    }
+    // Unsupported.
+    return ExpressedToUniformQuantizedConverter{inputType, nullptr};
+  case StandardTypes::RankedTensor:
+  case StandardTypes::UnrankedTensor:
+  case StandardTypes::Vector: {
+    Type elementType = inputType.cast<VectorOrTensorType>().getElementType();
+    if (!isQuantizablePrimitiveType(elementType)) {
+      // Unsupported.
+      return ExpressedToUniformQuantizedConverter{inputType, nullptr};
+    }
+    return ExpressedToUniformQuantizedConverter{
+        inputType, inputType.cast<VectorOrTensorType>().getElementType()};
+  }
+  }
+}
+
+Type ExpressedToUniformQuantizedConverter::convert(
+    UniformQuantizedType elementalType) const {
+  assert(expressedType && "convert() on unsupported conversion");
+
+  switch (inputType.getKind()) {
+  default:
+    if (isQuantizablePrimitiveType(elementalType)) {
+      // For primitives, just use the new elemental type.
+      return elementalType;
+    }
+    // Unsupported.
+    return nullptr;
+  case StandardTypes::RankedTensor:
+    return RankedTensorType::get(inputType.cast<RankedTensorType>().getShape(),
+                                 elementalType);
+  case StandardTypes::UnrankedTensor:
+    return UnrankedTensorType::get(elementalType);
+  case StandardTypes::Vector:
+    return VectorType::get(inputType.cast<VectorType>().getShape(),
+                           elementalType);
+  }
+}
diff --git a/mlir/lib/Quantization/Transforms/ConvertConst.cpp b/mlir/lib/Quantization/Transforms/ConvertConst.cpp
new file mode 100644 (file)
index 0000000..ec947f2
--- /dev/null
@@ -0,0 +1,133 @@
+//===- ConvertConst.cpp - Quantizes constant ops --------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Quantization/Passes.h"
+#include "mlir/Quantization/QuantOps.h"
+#include "mlir/Quantization/QuantizeUtils.h"
+#include "mlir/Quantization/UniformSupport.h"
+#include "mlir/StandardOps/Ops.h"
+
+using namespace mlir;
+using namespace mlir::quant;
+
+namespace {
+
+class ConvertConstPass : public FunctionPass<ConvertConstPass> {
+public:
+  void runOnFunction() override;
+};
+
+class QuantizedConstRewrite : public RewritePattern {
+public:
+  struct State : PatternState {
+    QuantizedType quantizedElementType;
+    Attribute value;
+  };
+
+  QuantizedConstRewrite(MLIRContext *context)
+      : RewritePattern(QuantizeBarrierOp::getOperationName(), 1, context) {}
+
+  PatternMatchResult match(Operation *op) const override;
+  void rewrite(Operation *op, std::unique_ptr<PatternState> baseState,
+               PatternRewriter &rewriter) const override;
+};
+
+} // end anonymous namespace
+
+/// Matches a [constant] -> [qbarrier] where the qbarrier results type is
+/// quantized and the operand type is quantizable.
+PatternMatchResult QuantizedConstRewrite::match(Operation *op) const {
+  State state;
+
+  // Is the operand a constant?
+  auto qbarrier = op->cast<QuantizeBarrierOp>();
+  if (!matchPattern(qbarrier.arg(), m_Constant(&state.value))) {
+    return matchFailure();
+  }
+  // Does the qbarrier convert to a quantized type. This will not be true
+  // if a quantized type has not yet been chosen or if the cast to an equivalent
+  // storage type is not supported.
+  Type qbarrierResultType = qbarrier.getResult()->getType();
+  state.quantizedElementType =
+      QuantizedType::getQuantizedElementType(qbarrierResultType);
+  if (!state.quantizedElementType) {
+    return matchFailure();
+  }
+  if (!QuantizedType::castToStorageType(qbarrierResultType)) {
+    return matchFailure();
+  }
+
+  // Is the operand type compatible with the expressed type of the quantized
+  // type? This will not be true if the qbarrier is superfluous (converts
+  // from and to a quantized type).
+  if (!state.quantizedElementType.isCompatibleExpressedType(
+          qbarrier.arg()->getType())) {
+    return matchFailure();
+  }
+
+  // Is the constant value a type expressed in a way that we support?
+  if (!state.value.isa<FloatAttr>() && !state.value.isa<SplatElementsAttr>() &&
+      !state.value.isa<DenseElementsAttr>() &&
+      !state.value.isa<SparseElementsAttr>()) {
+    return matchFailure();
+  }
+
+  return matchSuccess(llvm::make_unique<State>(std::move(state)));
+}
+
+void QuantizedConstRewrite::rewrite(Operation *op,
+                                    std::unique_ptr<PatternState> baseState,
+                                    PatternRewriter &rewriter) const {
+  auto state = static_cast<State *>(baseState.get());
+
+  Type newConstValueType;
+  Attribute newConstValue = quantizeAttr(
+      state->value, state->quantizedElementType, newConstValueType);
+  if (!newConstValue) {
+    return;
+  }
+
+  auto *origConstOp = op->getOperand(0);
+  // When creating the new const op, use a fused location that combines the
+  // original const and the qbarrier that led to the quantization.
+  auto fusedLoc =
+      FusedLoc::get({origConstOp->getDefiningOp()->getLoc(), op->getLoc()},
+                    rewriter.getContext());
+  auto newConstOp =
+      rewriter.create<ConstantOp>(fusedLoc, newConstValueType, newConstValue);
+  rewriter.replaceOpWithNewOp<StorageCastOp>(
+      op, {origConstOp}, *op->result_type_begin(), newConstOp);
+}
+
+void ConvertConstPass::runOnFunction() {
+  OwningRewritePatternList patterns;
+  auto &func = getFunction();
+  auto *context = &getContext();
+  patterns.push_back(llvm::make_unique<QuantizedConstRewrite>(context));
+  applyPatternsGreedily(func, std::move(patterns));
+}
+
+FunctionPassBase *createConvertConstPass() { return new ConvertConstPass(); }
+
+static PassRegistration<ConvertConstPass>
+    pass("quant-convert-const",
+         "Converts constants followed by qbarrier to actual quantized values");
diff --git a/mlir/lib/Quantization/Transforms/LowerTF.cpp b/mlir/lib/Quantization/Transforms/LowerTF.cpp
new file mode 100644 (file)
index 0000000..24a35c9
--- /dev/null
@@ -0,0 +1,112 @@
+//===- LowerTF.cpp - Passes for lowering from TensorFlow ------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Quantization/FakeQuantSupport.h"
+#include "mlir/Quantization/Passes.h"
+#include "mlir/Quantization/QuantOps.h"
+#include "mlir/Quantization/UniformSupport.h"
+#include "mlir/TensorFlow/TFOps.h"
+
+using namespace mlir;
+using namespace mlir::quant;
+
+namespace {
+
+class LowerTFPass : public FunctionPass<LowerTFPass> {
+public:
+  void runOnFunction() override;
+};
+
+} // end anonymous namespace
+
+/// Rewrites TensorFlow FakeQuantWithMinMaxArgs into a qbarrier/dbarrier pair.
+class FakeQuantWithMinMaxArgsRewrite : public RewritePattern {
+public:
+  bool *hadFailure;
+
+  FakeQuantWithMinMaxArgsRewrite(MLIRContext *context, bool *hadFailure)
+      : RewritePattern(TF::FakeQuantWithMinMaxArgsOp::getOperationName(), 1,
+                       context),
+        hadFailure(hadFailure) {}
+
+  PatternMatchResult match(Operation *op) const override {
+    return matchSuccess();
+  }
+
+  void rewrite(Operation *op, PatternRewriter &rewriter) const override {
+    // TODO: If this pattern comes up more frequently, consider adding core
+    // support for failable rewrites.
+    if (failableRewrite(op, rewriter)) {
+      *hadFailure = true;
+    }
+  }
+
+  bool failableRewrite(Operation *op, PatternRewriter &rewriter) const {
+    auto fqOp = op->template cast<TF::FakeQuantWithMinMaxArgsOp>();
+
+    auto converter =
+        ExpressedToUniformQuantizedConverter::forInputType(fqOp.getType());
+    if (!converter) {
+      return (op->emitError("unsupported quantized type conversion"), true);
+    }
+
+    UniformQuantizedType uniformElementType = fakeQuantAttrsToType(
+        fqOp.getLoc(), fqOp.num_bits().getSExtValue(),
+        fqOp.min().convertToDouble(), fqOp.max().convertToDouble(),
+        fqOp.narrow_range(), converter.expressedType);
+
+    if (!uniformElementType) {
+      // Note that the fakeQuantAttrsToType will have emitted the error.
+      return true;
+    }
+
+    Type quantizedType = converter.convert(uniformElementType);
+    assert(quantizedType &&
+           "Converter accepted a type that it did not convert");
+
+    // TODO: Map to a qbarrier with an attribute like [Forced] to signal that
+    // this is a forced/hard-coded constraint.
+    auto qbarrier = rewriter.create<QuantizeBarrierOp>(
+        op->getLoc(), quantizedType, fqOp.inputs());
+    rewriter.replaceOpWithNewOp<DequantizeBarrierOp>(op, converter.inputType,
+                                                     qbarrier.getResult());
+
+    return false;
+  }
+};
+
+void LowerTFPass::runOnFunction() {
+  bool hadFailure = false;
+  OwningRewritePatternList patterns;
+  auto &func = getFunction();
+  auto *context = &getContext();
+  patterns.push_back(
+      llvm::make_unique<FakeQuantWithMinMaxArgsRewrite>(context, &hadFailure));
+  applyPatternsGreedily(func, std::move(patterns));
+  if (hadFailure)
+    signalPassFailure();
+}
+
+FunctionPassBase *createLowerTFPass() { return new LowerTFPass(); }
+
+static PassRegistration<LowerTFPass>
+    pass("quant-lower-tf",
+         "Lowers TensorFlow constraint ops to the quantization dialect");
diff --git a/mlir/lib/Quantization/Transforms/LowerUniformRealMath.cpp b/mlir/lib/Quantization/Transforms/LowerUniformRealMath.cpp
new file mode 100644 (file)
index 0000000..9ce9264
--- /dev/null
@@ -0,0 +1,259 @@
+//===- LowerUniformRealMath.cpp  ------------------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Quantization/Passes.h"
+#include "mlir/Quantization/QuantOps.h"
+#include "mlir/Quantization/UniformSupport.h"
+
+#include <functional>
+
+using namespace mlir;
+using namespace mlir::quant;
+
+namespace {
+
+struct LowerUniformRealMathPass
+    : public FunctionPass<LowerUniformRealMathPass> {
+  void runOnFunction() override;
+};
+
+UniformQuantizedType getUniformElementType(Type t) {
+  return QuantizedType::getQuantizedElementType(t)
+      .dyn_cast_or_null<UniformQuantizedType>();
+}
+
+/// Computes the log2(x), rounded to an integral value. Returns whether 'x' can
+/// be considered an exact integral value.
+template <typename F> bool integralLog2(F x, int &log2Result) {
+  const F xLog2 = std::log(x) * (1.0 / std::log(2.0));
+  const F xLog2Rounded = std::round(xLog2);
+  const F xLog2Frac = xLog2 - xLog2Rounded;
+  log2Result = static_cast<int>(xLog2Rounded);
+  // Allow small comparison slop below the level that would make a difference
+  // for 2^16 levels.
+  return std::abs(xLog2Frac) < 1e-6;
+}
+
+/// Helper class for operating on binary operations where all operands
+/// and the result are a UniformQuantizedType.
+struct RealBinaryOpInfo {
+  RealBinaryOpInfo(Operation *op, Value *lhs, Value *rhs,
+                   Optional<APFloat> clampMin, Optional<APFloat> clampMax)
+      : op(op), lhs(lhs), rhs(rhs), clampMin(clampMin), clampMax(clampMax),
+        lhsType(getUniformElementType(lhs->getType())),
+        rhsType(getUniformElementType(rhs->getType())),
+        resultType(getUniformElementType(*op->result_type_begin())),
+        lhsStorageType(QuantizedType::castToStorageType(lhs->getType())),
+        rhsStorageType(QuantizedType::castToStorageType(rhs->getType())),
+        resultStorageType(
+            QuantizedType::castToStorageType(*op->result_type_begin())) {}
+
+  /// Returns whether this info is valid (all types defined, etc).
+  bool isValid() const {
+    return lhsType && rhsType && resultType && lhsStorageType &&
+           rhsStorageType && resultStorageType;
+  }
+
+  /// Returns whether the storage type of all operands is identical.
+  bool isSameStorageType() const {
+    return lhsType.getStorageType() == rhsType.getStorageType() &&
+           lhsType.getStorageType() == resultType.getStorageType();
+  }
+
+  /// Returns whether all operands and result are considered fixedpoint power
+  /// of two, setting the lhs, rhs, and result log2 scale references.
+  bool isFixedPointPOT(int &lhsLog2Scale, int &rhsLog2Scale,
+                       int &resultLog2Scale) const {
+    if (!lhsType.isFixedPoint() || !rhsType.isFixedPoint() ||
+        !resultType.isFixedPoint()) {
+      return false;
+    }
+
+    if (!integralLog2(lhsType.getScale(), lhsLog2Scale) ||
+        !integralLog2(rhsType.getScale(), rhsLog2Scale) ||
+        !integralLog2(resultType.getScale(), resultLog2Scale)) {
+      return false;
+    }
+
+    return true;
+  }
+
+  /// Gets the result integer clamp range given the result quantized type
+  // and any explicit clamp provided as attributes.
+  std::pair<IntegerAttr, IntegerAttr> getClampMinMax() const {
+    int64_t typeMin = resultType.getStorageTypeMin();
+    int64_t typeMax = resultType.getStorageTypeMax();
+
+    if (clampMin || clampMax) {
+      UniformQuantizedValueConverter conv(resultType);
+      if (clampMin) {
+        typeMin = std::max(typeMin, conv.quantizeFloatToInt64(*clampMin));
+      }
+      if (clampMax) {
+        typeMax = std::min(typeMax, conv.quantizeFloatToInt64(*clampMax));
+      }
+    }
+
+    // The quantized, integral ops expect clamps as 32bit ints.
+    return {
+        IntegerAttr::get(IntegerType::get(32, resultType.getContext()),
+                         typeMin),
+        IntegerAttr::get(IntegerType::get(32, resultType.getContext()),
+                         typeMax),
+    };
+  }
+
+  Operation *op;
+  Value *lhs;
+  Value *rhs;
+  Optional<APFloat> clampMin;
+  Optional<APFloat> clampMax;
+
+  // Element UniformQuantizedType for operands/result.
+  UniformQuantizedType lhsType;
+  UniformQuantizedType rhsType;
+  UniformQuantizedType resultType;
+
+  // Full storage-based types.
+  Type lhsStorageType;
+  Type rhsStorageType;
+  Type resultStorageType;
+};
+
+} // end anonymous namespace
+
+//===----------------------------------------------------------------------===//
+// Elementwise add
+//===----------------------------------------------------------------------===//
+/// Attempts to rewrite a fixed point power-of-two addition of two integers.
+/// This supports a limited number of cases, but when supported, represents
+/// the simplest computation.
+static LogicalResult tryRewriteFixedPOTAddEw(const RealBinaryOpInfo &constInfo,
+                                             PatternRewriter &rewriter) {
+  if (!constInfo.isSameStorageType()) {
+    return failure();
+  }
+
+  int lhsLog2Scale;
+  int rhsLog2Scale;
+  int resultLog2Scale;
+  if (!constInfo.isFixedPointPOT(lhsLog2Scale, rhsLog2Scale, resultLog2Scale)) {
+    return failure();
+  }
+
+  // Adjust shifts to be relative to the output.
+  // Left shift of one input scale is supported. The other must match the result
+  // scale.
+  int lhsScaleShift = lhsLog2Scale - resultLog2Scale;
+  int rhsScaleShift = rhsLog2Scale - resultLog2Scale;
+  if (lhsScaleShift != 0 && rhsScaleShift != 0) {
+    return failure();
+  }
+  if (lhsScaleShift > 0 || rhsScaleShift > 0) {
+    return failure();
+  }
+
+  // State accessed by the closure.
+  Operation *mathOp = constInfo.op;
+  const auto clampMinMax = constInfo.getClampMinMax();
+  Value *lhs = constInfo.lhs;
+  Value *rhs = constInfo.rhs;
+  Type lhsStorageType = constInfo.lhsStorageType;
+  Type rhsStorageType = constInfo.rhsStorageType;
+
+  // If the lhs operand is the one requiring a shift, swap it so that the shift
+  // happens the rhs operand.
+  if (lhsScaleShift != 0) {
+    std::swap(lhs, rhs);
+    std::swap(lhsStorageType, rhsStorageType);
+    std::swap(lhsScaleShift, rhsScaleShift);
+  }
+  int rhsRightShift = -rhsScaleShift;
+
+  // Cast operands to storage type.
+  Value *lhsStorageValue =
+      rewriter.create<StorageCastOp>(mathOp->getLoc(), lhsStorageType, lhs)
+          .getResult();
+  Value *rhsStorageValue =
+      rewriter.create<StorageCastOp>(mathOp->getLoc(), rhsStorageType, rhs)
+          .getResult();
+
+  // Rescale the rhs operand if needed.
+  if (rhsRightShift != 0) {
+    rhsStorageValue =
+        rewriter
+            .create<RoundingDivideByPotIOp>(
+                mathOp->getLoc(), rhsStorageValue,
+                IntegerAttr::get(IntegerType::get(32, rewriter.getContext()),
+                                 rhsRightShift))
+            .getResult();
+  }
+
+  // Add.
+  Value *sumValue = rewriter.create<SaturatingAddIOp>(
+      mathOp->getLoc(), lhsStorageValue, rhsStorageValue, clampMinMax.first,
+      clampMinMax.second);
+
+  // Cast back for new result.
+  rewriter.replaceOpWithNewOp<StorageCastOp>(
+      mathOp, *mathOp->result_type_begin(), sumValue);
+  return success();
+}
+
+namespace {
+
+struct UniformRealAddEwPattern : public RewritePattern {
+  UniformRealAddEwPattern(MLIRContext *context)
+      : RewritePattern(RealAddEwOp::getOperationName(), 1, context) {}
+
+  PatternMatchResult matchAndRewrite(Operation *op,
+                                     PatternRewriter &rewriter) const {
+    auto addOp = op->cast<RealAddEwOp>();
+    const RealBinaryOpInfo info(op, addOp.x(), addOp.y(), addOp.clamp_min(),
+                                addOp.clamp_max());
+    if (!info.isValid()) {
+      return matchFailure();
+    }
+
+    // Try all of the permutations we support.
+    if (succeeded(tryRewriteFixedPOTAddEw(info, rewriter))) {
+      return matchSuccess();
+    }
+
+    return matchFailure();
+  }
+};
+
+} // end anonymous namespace
+
+void LowerUniformRealMathPass::runOnFunction() {
+  auto &fn = getFunction();
+  OwningRewritePatternList patterns;
+  auto *context = &getContext();
+  patterns.push_back(llvm::make_unique<UniformRealAddEwPattern>(context));
+  applyPatternsGreedily(fn, std::move(patterns));
+}
+
+FunctionPassBase *createLowerUniformRealMathPass() {
+  return new LowerUniformRealMathPass();
+}
+
+static PassRegistration<LowerUniformRealMathPass>
+    pass("quant-lower-uniform-real-math",
+         "Lowers uniform-quantized real math ops to integer arithmetic.");
diff --git a/mlir/lib/Quantization/Utils/QuantizeUtils.cpp b/mlir/lib/Quantization/Utils/QuantizeUtils.cpp
new file mode 100644 (file)
index 0000000..159d6eb
--- /dev/null
@@ -0,0 +1,186 @@
+//===- QuantizeUtils.cpp - Support utilities for quantization -------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Quantization/QuantizeUtils.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Quantization/UniformSupport.h"
+
+namespace mlir {
+namespace quant {
+/// Converts a possible primitive, real expressed value attribute to a
+/// corresponding storage attribute (typically FloatAttr -> IntegerAttr).
+/// quantizedElementType is the QuantizedType that describes the expressed
+/// origValue.
+/// Returns a converter Attribute or nullptr if conversion is not possible.
+static Attribute convertPrimitiveValueAttr(
+    Attribute origRealValue, QuantizedType quantizedElementType,
+    const UniformQuantizedValueConverter &converter, Type &outConvertedType) {
+  if (origRealValue.isa<FloatAttr>()) {
+    FloatAttr floatAttr = origRealValue.cast<FloatAttr>();
+    outConvertedType = quantizedElementType.getStorageType();
+    return IntegerAttr::get(quantizedElementType.getStorageType(),
+                            converter.quantizeFloatToInt(floatAttr.getValue()));
+  }
+
+  return nullptr;
+}
+
+/// Converts a real expressed DenseFPElementsAttr to a corresponding
+/// DenseElementsAttr (typically DenseIntElementsAttr) containing quantized
+/// storage values assuming the given quantizedElementType and converter.
+static DenseElementsAttr
+convertDenseFPElementsAttr(DenseFPElementsAttr realFPElementsAttr,
+                           QuantizedType quantizedElementType,
+                           const UniformQuantizedValueConverter &converter) {
+  // Read real expressed values.
+  SmallVector<APFloat, 8> realValues;
+  realValues.reserve(realFPElementsAttr.getType().getNumElements());
+  realFPElementsAttr.getValues(realValues);
+
+  // Convert to corresponding quantized value attributes.
+  SmallVector<APInt, 8> quantValues(realValues.size());
+  for (size_t i = 0, e = realValues.size(); i < e; ++i) {
+    quantValues[i] = converter.quantizeFloatToInt(realValues[i]);
+  }
+
+  // Cast from an expressed-type-based type to storage-type-based type,
+  // preserving the dense shape (i.e. tensor<4xf32> -> tensor<4xi8>).
+  VectorOrTensorType newDenseType =
+      quantizedElementType
+          .castExpressedToStorageType(realFPElementsAttr.getType())
+          .dyn_cast_or_null<VectorOrTensorType>();
+  if (!newDenseType) {
+    return nullptr;
+  }
+  return DenseIntElementsAttr::get(newDenseType, quantValues);
+}
+
+/// Converts a real expressed SplatElementsAttr to a corresponding
+/// SplatElementsAttr containing quantized storage values assuming the given
+/// quantizedElementType and converter.
+static SplatElementsAttr
+convertSplatElementsAttr(SplatElementsAttr realSplatAttr,
+                         QuantizedType quantizedElementType,
+                         const UniformQuantizedValueConverter &converter) {
+  // Since the splat just references a single primitive value, use the
+  // function for converting primitives.
+  // NOTE: When implementing per-channel, we will need to promote the
+  // splat to a dense and handle channels individually.
+  Type unusedPrimitiveType;
+  auto elementAttr =
+      convertPrimitiveValueAttr(realSplatAttr.getValue(), quantizedElementType,
+                                converter, unusedPrimitiveType);
+  if (!elementAttr) {
+    return nullptr;
+  }
+
+  // Cast from an expressed-type-based type to storage-type-based type,
+  // preserving the splat shape (i.e. tensor<4xf32> -> tensor<4xi8>).
+  VectorOrTensorType newSplatType =
+      quantizedElementType.castExpressedToStorageType(realSplatAttr.getType())
+          .dyn_cast_or_null<VectorOrTensorType>();
+  if (!newSplatType) {
+    return nullptr;
+  }
+  return SplatElementsAttr::get(newSplatType, elementAttr);
+}
+
+/// Converts a real expressed SplatElementsAttr to a corresponding
+/// SplatElementsAttr containing quantized storage values assuming the given
+/// quantizedElementType and converter.
+static SparseElementsAttr
+convertSparseElementsAttr(SparseElementsAttr realSparseAttr,
+                          QuantizedType quantizedElementType,
+                          const UniformQuantizedValueConverter &converter) {
+  DenseElementsAttr realDenseAttr = realSparseAttr.getValues();
+  if (!realDenseAttr.isa<DenseFPElementsAttr>()) {
+    return nullptr;
+  }
+  DenseElementsAttr quantDenseAttr =
+      convertDenseFPElementsAttr(realDenseAttr.cast<DenseFPElementsAttr>(),
+                                 quantizedElementType, converter);
+  if (!quantDenseAttr) {
+    return nullptr;
+  }
+
+  // Cast from an expressed-type-based type to storage-type-based type,
+  // preserving the sparse shape (i.e. tensor<4xf32> -> tensor<4xi8>).
+  VectorOrTensorType newSparseType =
+      quantizedElementType.castExpressedToStorageType(realSparseAttr.getType())
+          .dyn_cast_or_null<VectorOrTensorType>();
+  if (!newSparseType) {
+    return nullptr;
+  }
+  return SparseElementsAttr::get(newSparseType, realSparseAttr.getIndices(),
+                                 quantDenseAttr);
+}
+
+/// Converts a real expressed Attribute to a corresponding Attribute containing
+/// quantized storage values assuming the given uniform quantizedElementType and
+/// converter.
+Attribute quantizeAttrUniform(Attribute realValue,
+                              UniformQuantizedType quantizedElementType,
+                              const UniformQuantizedValueConverter &converter,
+                              Type &outConvertedType) {
+  // Fork to handle different variants of constants supported.
+  if (realValue.isa<SplatElementsAttr>()) {
+    // Splatted tensor or vector constant.
+    auto converted = convertSplatElementsAttr(
+        realValue.cast<SplatElementsAttr>(), quantizedElementType, converter);
+    outConvertedType = converted.getType();
+    return converted;
+  } else if (realValue.isa<DenseFPElementsAttr>()) {
+    // Dense tensor or vector constant.
+    auto converted = convertDenseFPElementsAttr(
+        realValue.cast<DenseFPElementsAttr>(), quantizedElementType, converter);
+    outConvertedType = converted.getType();
+    return converted;
+  } else if (realValue.isa<SparseElementsAttr>()) {
+    // Sparse tensor or vector constant.
+    auto converted = convertSparseElementsAttr(
+        realValue.cast<SparseElementsAttr>(), quantizedElementType, converter);
+    outConvertedType = converted.getType();
+    return converted;
+  } else {
+    // Nothing else matched: try to convert a primitive.
+    return convertPrimitiveValueAttr(realValue, quantizedElementType, converter,
+                                     outConvertedType);
+  }
+}
+
+/// Convert an attribute from a type based on
+/// quantizedElementType.getExpressedType() to one based on
+/// quantizedElementType.getStorageType().
+/// Returns nullptr if the conversion is not supported.
+/// On success, stores the converted type in outConvertedType.
+Attribute quantizeAttr(Attribute realValue, QuantizedType quantizedElementType,
+                       Type &outConvertedType) {
+  // Hard-coded to just support UniformQuantizedType. This will need to
+  // be generalized when there is more than one.
+  auto uniformQuantizedType =
+      quantizedElementType.dyn_cast<UniformQuantizedType>();
+  if (!uniformQuantizedType) {
+    return nullptr;
+  }
+  UniformQuantizedValueConverter converter(uniformQuantizedType);
+  return quantizeAttrUniform(realValue, uniformQuantizedType, converter,
+                             outConvertedType);
+}
+
+} // namespace quant
+} // namespace mlir
diff --git a/mlir/test/Quantization/convert-const.mlir b/mlir/test/Quantization/convert-const.mlir
new file mode 100644 (file)
index 0000000..d0ac5d7
--- /dev/null
@@ -0,0 +1,140 @@
+// RUN: mlir-opt %s -split-input-file -quant-convert-const | FileCheck %s --dump-input=fail
+
+// Magic numbers:
+//   7.8125e-03 = 1/128 = 2/256 : real range = [-1.0, 0.9921875] (for 8bit, zeroPoint=128)
+//   1.250000e-01 = 1/8 = 2/16  : real range = [-1.0, 0.875] (for 4bit, zeroPoint=8)
+
+// -----
+// Verifies u8 affine quantization on a splat tensor.
+// Note that MLIR prints int attributes as signed, so the constant, when
+// quantized, is the signed printed version of an unsigned quantity
+// (-64 signed == 192 unsigned).
+// CHECK-LABEL: constant_splat_tensor_u8_affine
+func @constant_splat_tensor_u8_affine() -> tensor<4xf32> {
+  // CHECK: %cst = constant splat<tensor<4xi8>, -64> : tensor<4xi8>
+  // CHECK-NEXT: %0 = "quant.scast"(%cst) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[u8:f32]{7.812500e-03:128}">>
+  %cst = constant splat<tensor<4xf32>, 0.5> : tensor<4xf32>
+  %1 = "quant.qbarrier"(%cst) : (tensor<4xf32>) -> tensor<4x!quant<"uniform[u8:f32]{7.812500e-03:128}">>
+  %2 = "quant.dbarrier"(%1) : (tensor<4x!quant<"uniform[u8:f32]{7.812500e-03:128}">>) -> (tensor<4xf32>)
+  return %2 : tensor<4xf32>
+}
+
+// -----
+// Verifies i8 affine quantization on a splat tensor.
+// CHECK-LABEL: constant_splat_tensor_i8_affine
+func @constant_splat_tensor_i8_affine() -> tensor<4xf32> {
+  // CHECK: %cst = constant splat<tensor<4xi8>, 63> : tensor<4xi8>
+  // CHECK-NEXT: %0 = "quant.scast"(%cst) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{7.812500e-03:-1}">>
+  %cst = constant splat<tensor<4xf32>, 0.5> : tensor<4xf32>
+  %1 = "quant.qbarrier"(%cst) : (tensor<4xf32>) -> tensor<4x!quant<"uniform[i8:f32]{7.812500e-03:-1}">>
+  %2 = "quant.dbarrier"(%1) : (tensor<4x!quant<"uniform[i8:f32]{7.812500e-03:-1}">>) -> (tensor<4xf32>)
+  return %2 : tensor<4xf32>
+}
+
+// -----
+// Verifies i8 fixedpoint quantization on a splat tensor.
+// CHECK-LABEL: const_splat_tensor_i8_fixedpoint
+func @const_splat_tensor_i8_fixedpoint() -> tensor<4xf32> {
+  // CHECK: %cst = constant splat<tensor<4xi8>, 64> : tensor<4xi8>
+  // CHECK-NEXT: %0 = "quant.scast"(%cst) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>
+  %cst = constant splat<tensor<4xf32>, 0.5> : tensor<4xf32>
+  %1 = "quant.qbarrier"(%cst) : (tensor<4xf32>) -> tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>
+  %2 = "quant.dbarrier"(%1) : (tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> (tensor<4xf32>)
+  return %2 : tensor<4xf32>
+}
+
+// -----
+// Verifies i8 fixedpoint quantization on a splat tensor resulting in a negative storage value.
+// CHECK-LABEL: const_splat_tensor_i8_fixedpoint_neg
+func @const_splat_tensor_i8_fixedpoint_neg() -> tensor<4xf32> {
+  // CHECK: %cst = constant splat<tensor<4xi8>, -64> : tensor<4xi8>
+  %cst = constant splat<tensor<4xf32>, -0.5> : tensor<4xf32>
+  %1 = "quant.qbarrier"(%cst) : (tensor<4xf32>) -> tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>
+  %2 = "quant.dbarrier"(%1) : (tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> (tensor<4xf32>)
+  return %2 : tensor<4xf32>
+}
+
+// -----
+// Verifies i8 fixedpoint quantization on a dense tensor, sweeping values.
+// CHECK-LABEL: const_dense_tensor_i8_fixedpoint
+func @const_dense_tensor_i8_fixedpoint() -> tensor<7xf32> {
+  // CHECK: %cst = constant dense<tensor<7xi8>, [-128, -128, -64, 0, 64, 127, 127]> : tensor<7xi8>
+  %cst = constant dense<tensor<7xf32>, [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<7xf32>
+  %1 = "quant.qbarrier"(%cst) : (tensor<7xf32>) -> tensor<7x!quant<"uniform[i8:f32]{7.812500e-03}">>
+  %2 = "quant.dbarrier"(%1) : (tensor<7x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> (tensor<7xf32>)
+  return %2 : tensor<7xf32>
+}
+
+// -----
+// Verifies i8 fixedpoint quantization on a sparse tensor, sweeping values.
+// CHECK-LABEL: const_sparse_tensor_i8_fixedpoint
+func @const_sparse_tensor_i8_fixedpoint() -> tensor<7x2xf32> {
+  // NOTE: Ugly regex match pattern for opening "[[" of indices tensor.
+  // CHECK: %cst = constant sparse<tensor<7x2xi8>, {{\[}}[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6]], [-128, -128, -64, 0, 64, 127, 127]> : tensor<7x2xi8>
+  %cst = constant sparse<tensor<7x2xf32>,
+      [[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6]],
+      [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<7x2xf32>
+  %1 = "quant.qbarrier"(%cst) : (tensor<7x2xf32>) -> tensor<7x2x!quant<"uniform[i8:f32]{7.812500e-03}">>
+  %2 = "quant.dbarrier"(%1) : (tensor<7x2x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> (tensor<7x2xf32>)
+  return %2 : tensor<7x2xf32>
+}
+
+// -----
+// Verifies i8 fixedpoint quantization on a primitive const.
+// CHECK-LABEL: const_primitive_float_i8_fixedpoint
+func @const_primitive_float_i8_fixedpoint() -> f32 {
+  // CHECK: %c64_i8 = constant 64 : i8
+  // CHECK-NEXT: %0 = "quant.scast"(%c64_i8) : (i8) -> !quant<"uniform[i8:f32]{7.812500e-03}">
+  %cst = constant 0.5 : f32
+  %1 = "quant.qbarrier"(%cst) : (f32) -> !quant<"uniform[i8:f32]{7.812500e-03}">
+  %2 = "quant.dbarrier"(%1) : (!quant<"uniform[i8:f32]{7.812500e-03}">) -> (f32)
+  return %2 : f32
+}
+
+// -----
+// Verifies u4 affine quantization on a dense tensor, sweeping values.
+// CHECK-LABEL: const_dense_tensor_u4_affine
+func @const_dense_tensor_u4_affine() -> tensor<7xf32> {
+  // NOTE: Unsigned quantities printed by MLIR as signed.
+  // CHECK: %cst = constant dense<tensor<7xi4>, [0, 0, 4, -8, -4, -1, -1]> : tensor<7xi4>
+  %cst = constant dense<tensor<7xf32>, [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<7xf32>
+  %1 = "quant.qbarrier"(%cst) : (tensor<7xf32>) -> tensor<7x!quant<"uniform[u4:f32]{1.250000e-01:8}">>
+  %2 = "quant.dbarrier"(%1) : (tensor<7x!quant<"uniform[u4:f32]{1.250000e-01:8}">>) -> (tensor<7xf32>)
+  return %2 : tensor<7xf32>
+}
+
+// -----
+// Verifies i4 affine quantization on a dense tensor, sweeping values.
+// CHECK-LABEL: const_dense_tensor_i4_affine
+func @const_dense_tensor_i4_affine() -> tensor<7xf32> {
+  // NOTE: Unsigned quantities printed by MLIR as signed.
+  // CHECK: %cst = constant dense<tensor<7xi4>, [-8, -8, -5, -1, 3, 7, 7]> : tensor<7xi4>
+  %cst = constant dense<tensor<7xf32>, [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<7xf32>
+  %1 = "quant.qbarrier"(%cst) : (tensor<7xf32>) -> tensor<7x!quant<"uniform[i4:f32]{1.250000e-01:-1}">>
+  %2 = "quant.dbarrier"(%1) : (tensor<7x!quant<"uniform[i4:f32]{1.250000e-01:-1}">>) -> (tensor<7xf32>)
+  return %2 : tensor<7xf32>
+}
+
+// -----
+// Verifies i4 fixed point quantization on a dense tensor, sweeping values.
+// CHECK-LABEL: const_dense_tensor_i4_fixedpoint
+func @const_dense_tensor_i4_fixedpoint() -> tensor<7xf32> {
+  // CHECK: %cst = constant dense<tensor<7xi4>, [-8, -8, -4, 0, 4, 7, 7]> : tensor<7xi4>
+  %cst = constant dense<tensor<7xf32>, [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<7xf32>
+  %1 = "quant.qbarrier"(%cst) : (tensor<7xf32>) -> tensor<7x!quant<"uniform[i4:f32]{1.250000e-01}">>
+  %2 = "quant.dbarrier"(%1) : (tensor<7x!quant<"uniform[i4:f32]{1.250000e-01}">>) -> (tensor<7xf32>)
+  return %2 : tensor<7xf32>
+}
+
+// -----
+// Verifies i8 fixedpoint quantization on a dense tensor, sweeping values, and
+// custom storage range. (the -128 should be clamped to -100, and the 127 should
+// be clamped to 100).
+// CHECK-LABEL: const_custom_storage_range_i8_fixedpoint
+func @const_custom_storage_range_i8_fixedpoint() -> tensor<7xf32> {
+  // CHECK: %cst = constant dense<tensor<7xi8>, [-100, -100, -64, 0, 64, 100, 100]> : tensor<7xi8>
+  %cst = constant dense<tensor<7xf32>, [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<7xf32>
+  %1 = "quant.qbarrier"(%cst) : (tensor<7xf32>) -> tensor<7x!quant<"uniform[i8(-100:100):f32]{7.812500e-03}">>
+  %2 = "quant.dbarrier"(%1) : (tensor<7x!quant<"uniform[i8(-100:100):f32]{7.812500e-03}">>) -> (tensor<7xf32>)
+  return %2 : tensor<7xf32>
+}
diff --git a/mlir/test/Quantization/lower-uniform-real-math-addew.mlir b/mlir/test/Quantization/lower-uniform-real-math-addew.mlir
new file mode 100644 (file)
index 0000000..96c0886
--- /dev/null
@@ -0,0 +1,106 @@
+// RUN: mlir-opt %s -split-input-file -quant-lower-uniform-real-math | FileCheck %s --dump-input=fail
+
+// -----
+// Verify lowering when operands and result have the same fixedpoint pot scale.
+// CHECK-LABEL: real_addew_fixedpoint_same_scale
+//      CHECK: %0 = "quant.scast"(%arg0) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8>
+// CHECK-NEXT: %1 = "quant.scast"(%arg1) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8>
+// CHECK-NEXT: %2 = "quant.saturating_addi"(%0, %1) {clamp_max: 127 : i32, clamp_min: -128 : i32} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
+// CHECK-NEXT: %3 = "quant.scast"(%2) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
+// CHECK-NEXT: return %3 : tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
+!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
+!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
+!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
+func @real_addew_fixedpoint_same_scale(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
+  %0 = "quant.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
+  return %0 : !type_result
+}
+
+// -----
+// Verify lowering when the rhs is a shifted pot scale compared to lhs and result.
+// CHECK-LABEL: real_addew_fixedpoint_rhs_shift
+//      CHECK: %0 = "quant.scast"(%arg0) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8>
+// CHECK-NEXT: %1 = "quant.scast"(%arg1) : (tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> tensor<4xi8>
+// CHECK-NEXT: %2 = "quant.rounding_divide_by_poti"(%1) {exponent: 3 : i32} : (tensor<4xi8>) -> tensor<4xi8>
+// CHECK-NEXT: %3 = "quant.saturating_addi"(%0, %2) {clamp_max: 127 : i32, clamp_min: -128 : i32} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
+// CHECK-NEXT: %4 = "quant.scast"(%3) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
+// CHECK-NEXT: return %4 : tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
+!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
+!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{7.8125e-03}">>
+!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
+func @real_addew_fixedpoint_rhs_shift(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
+  %0 = "quant.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
+  return %0 : !type_result
+}
+
+// -----
+// Verify lowering when the lhs is a shifted pot scale compared to lhs and result.
+// CHECK-LABEL: real_addew_fixedpoint_lhs_shift
+//      CHECK: %0 = "quant.scast"(%arg1) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8>
+// CHECK-NEXT: %1 = "quant.scast"(%arg0) : (tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> tensor<4xi8>
+// CHECK-NEXT: %2 = "quant.rounding_divide_by_poti"(%1) {exponent: 3 : i32} : (tensor<4xi8>) -> tensor<4xi8>
+// CHECK-NEXT: %3 = "quant.saturating_addi"(%0, %2) {clamp_max: 127 : i32, clamp_min: -128 : i32} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
+// CHECK-NEXT: %4 = "quant.scast"(%3) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
+// CHECK-NEXT: return %4 : tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
+!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{7.8125e-03}">>
+!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
+!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
+func @real_addew_fixedpoint_lhs_shift(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
+  %0 = "quant.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
+  return %0 : !type_result
+}
+
+// -----
+// The RHS quant parameters proscribe a range of [-8..8) so an explicit clamp
+// of [-4..4] should result in an integral clamp range of [-64..64].
+// CHECK-LABEL: real_addew_fixedpoint_clamp
+//      CHECK: %0 = "quant.scast"(%arg1) : (tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>) -> tensor<4xi8>
+// CHECK-NEXT: %1 = "quant.scast"(%arg0) : (tensor<4x!quant<"uniform[i8:f32]{7.812500e-03}">>) -> tensor<4xi8>
+// CHECK-NEXT: %2 = "quant.rounding_divide_by_poti"(%1) {exponent: 3 : i32} : (tensor<4xi8>) -> tensor<4xi8>
+// CHECK-NEXT: %3 = "quant.saturating_addi"(%0, %2) {clamp_max: 64 : i32, clamp_min: -64 : i32} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi8>
+// CHECK-NEXT: %4 = "quant.scast"(%3) : (tensor<4xi8>) -> tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
+// CHECK-NEXT: return %4 : tensor<4x!quant<"uniform[i8:f32]{6.250000e-02}">>
+!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{7.8125e-03}">>
+!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
+!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
+func @real_addew_fixedpoint_clamp(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
+  %0 = "quant.real_add_ew"(%arg0, %arg1) { clamp_min:-4.0, clamp_max:4.0 }
+      : (!type_lhs, !type_rhs) -> (!type_result)
+  return %0 : !type_result
+}
+
+// -----
+// CHECK-LABEL: real_addew_unquantized_lhs
+// Verifies that leaves as-is for unquantized lhs.
+!type_lhs = type tensor<4xf32>
+!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
+!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
+func @real_addew_unquantized_lhs(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
+  // CHECK: %0 = "quant.real_add_ew"(%arg0, %arg1)
+  %0 = "quant.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
+  return %0 : !type_result
+}
+
+// -----
+// CHECK-LABEL: real_addew_unquantized_rhs
+// Verifies that leaves as-is for unquantized rhs.
+!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
+!type_rhs = type tensor<4xf32>
+!type_result = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
+func @real_addew_unquantized_rhs(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
+  // CHECK: %0 = "quant.real_add_ew"(%arg0, %arg1)
+  %0 = "quant.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
+  return %0 : !type_result
+}
+
+// -----
+// CHECK-LABEL: real_addew_unquantized_result
+// Verifies that leaves as-is for unquantized result.
+!type_lhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
+!type_rhs = type tensor<4x!quant<"uniform[i8:f32]{6.25e-2}">>
+!type_result = type tensor<4xf32>
+func @real_addew_unquantized_result(%arg0 : !type_lhs, %arg1: !type_rhs) -> !type_result {
+  // CHECK: %0 = "quant.real_add_ew"(%arg0, %arg1)
+  %0 = "quant.real_add_ew"(%arg0, %arg1) : (!type_lhs, !type_rhs) -> (!type_result)
+  return %0 : !type_result
+}
diff --git a/mlir/test/Quantization/parse-uniform-invalid.mlir b/mlir/test/Quantization/parse-uniform-invalid.mlir
new file mode 100644 (file)
index 0000000..9168566
--- /dev/null
@@ -0,0 +1,137 @@
+// RUN: mlir-opt %s -split-input-file -verify
+
+// -----
+// Unknown type.
+// expected-error@+1 {{unknown quantized type foobar}}
+!qalias = type !quant<"foobar">
+
+// -----
+// Unrecognized token: illegal token
+// expected-error@+1 {{unrecognized token: %}}
+!qalias = type !quant<"%%">
+
+// -----
+// Unrecognized token: trailing
+// expected-error@+1 {{unrecognized token: 23}}
+!qalias = type !quant<"uniform[i8(-4:3):f32]{0.99872:127} 23">
+
+// -----
+// Unrecognized token: type open
+// expected-error@+1 {{unrecognized token: (}}
+!qalias = type !quant<"uniform(i8(-4:3):f32){0.99872:127}">
+
+// -----
+// Unrecognized token: missing storage type maximum
+// expected-error@+1 {{expected storage type maximum}}
+!qalias = type !quant<"uniform[i8(16:f32]{0.99872:127}">
+
+// -----
+// Unrecognized token: missing closing paren
+// expected-error@+1 {{unrecognized token: :}}
+!qalias = type !quant<"uniform[i8(-4:3:f32]{0.99872:127}">
+
+// -----
+// Unrecognized token: missing type colon
+// expected-error@+1 {{unrecognized token: f}}
+!qalias = type !quant<"uniform[i8(-4:3)f32]{0.99872:127}">
+
+// -----
+// Unrecognized token: missing closing bracket
+// expected-error@+1 {{unrecognized token: {}}
+!qalias = type !quant<"uniform[i8(-4:3):f32{0.99872:127}">
+
+// -----
+// Unrecognized token: wrong opening brace
+// expected-error@+1 {{unrecognized token: (}}
+!qalias = type !quant<"uniform[i8(-4:3):f32](0.99872:127}">
+
+// -----
+// Unrecognized storage type: illegal prefix
+// expected-error@+1 {{illegal storage type prefix: int}}
+!qalias = type !quant<"uniform[int8(-4:3):f32]{0.99872:127}">
+
+// -----
+// Unrecognized storage type: no width
+// expected-error@+1 {{expected storage type width}}
+!qalias = type !quant<"uniform[i(-4:3):f32]{0.99872:127}">
+
+// -----
+// Unrecognized storage type: storage size > 32
+// expected-error@+1 {{illegal storage type size: 33}}
+!qalias = type !quant<"uniform[i33:f32]{0.99872:127}">
+
+// -----
+// Unrecognized storage type: storage size < 0
+// expected-error@+1 {{illegal storage type size: -1}}
+!qalias = type !quant<"uniform[i-1(-4:3):f32]{0.99872:127}">
+
+// -----
+// Unrecognized storage type: storage size == 0
+// expected-error@+1 {{illegal storage type size: 0}}
+!qalias = type !quant<"uniform[i0(-4:3):f32]{0.99872:127}">
+
+// -----
+// Illegal storage min/max: max - min < 0
+// expected-error@+1 {{illegal storage min and storage max: (2:1)}}
+!qalias = type !quant<"uniform[i8(2:1):f32]{0.99872:127}">
+
+// -----
+// Illegal storage min/max: max - min == 0
+// expected-error@+1 {{illegal storage min and storage max: (1:1)}}
+!qalias = type !quant<"uniform[i8(1:1):f32]{0.99872:127}">
+
+// -----
+// Illegal storage min/max: max > defaultMax
+// expected-error@+1 {{illegal storage type maximum: 9}}
+!qalias = type !quant<"uniform[i4(-1:9):f32]{0.99872:127}">
+
+// -----
+// Illegal storage min/max: min < defaultMin
+// expected-error@+1 {{illegal storage type minimum: -9}}
+!qalias = type !quant<"uniform[i4(-9:1):f32]{0.99872:127}">
+
+// -----
+// Illegal uniform params: invalid scale
+// expected-error@+1 {{expected valid uniform scale. got: abc}}
+!qalias = type !quant<"uniform[i8(-4:3):f32]{abc:127}">
+
+// -----
+// Illegal uniform params: invalid zero point separator
+// expected-error@+1 {{unrecognized token: abc}}
+!qalias = type !quant<"uniform[i8(-4:3):f32]{0.1abc}">
+
+// -----
+// Illegal uniform params: missing zero point
+// expected-error@+1 {{expected integer uniform zero point. got: }}}
+!qalias = type !quant<"uniform[i8(-4:3):f32]{0.1:}">
+
+// -----
+// Illegal uniform params: invalid zero point
+// expected-error@+1 {{expected integer uniform zero point. got: abc}}
+!qalias = type !quant<"uniform[i8(-4:3):f32]{0.1:abc}">
+
+// -----
+// Illegal uniform params: missing closing brace
+// expected-error@+1 {{unrecognized token: )}}
+!qalias = type !quant<"uniform[i8(-4:3):f32]{0.1:0)">
+
+// -----
+// Illegal expressed type: f33
+// expected-error@+1 {{unrecognized expressed type: f33}}
+!qalias = type !quant<"uniform[i8(-4:3):f33]{0.99872:127}">
+
+// -----
+// Illegal scale: negative
+// expected-error@+1 {{illegal scale: -1.000000}}
+!qalias = type !quant<"uniform[i8(-4:3):f32]{-1.0:127}">
+
+// -----
+// Illegal uniform params: missing quantized dimension
+// expected-error@+1 {{expected quantized dimension}}
+!qalias = type !quant<"uniform[i8(-4:3):f32:]{2.000000e+02:-19.987200e-01:1}">
+
+// -----
+// Illegal uniform params: unspecified quantized dimension, when multiple scales
+// provided.
+// expected-error@+1 {{multiple scales/zeroPoints provided, but quantizedDimension wasn't specified}}
+!qalias = type !quant<"uniform[i8(-4:3):f32]{2.000000e+02,-19.987200e-01:1}">
diff --git a/mlir/test/Quantization/parse-uniform.mlir b/mlir/test/Quantization/parse-uniform.mlir
new file mode 100644 (file)
index 0000000..f29a93d
--- /dev/null
@@ -0,0 +1,147 @@
+// RUN: mlir-opt %s -split-input-file | FileCheck %s
+
+// -----
+// All per-layer params specified:
+//   [signed] storageType, storageTypeMin, storageTypeMax, expressedType, scale, zeroPoint
+// CHECK: !quant<"uniform[i8(-8:7):f32]{9.987200e-01:127}">
+!qalias = type !quant<"uniform[i8(-8:7):f32]{0.99872:127}">
+func @parse() -> !qalias {
+  %0 = "foo"() : () -> !qalias
+  return %0 : !qalias
+}
+
+// -----
+// Trailing whitespace.
+// CHECK: !quant<"uniform[i8(-8:7):f32]{9.987200e-01:127}">
+!qalias = type !quant<"uniform[i8(-8:7):f32]{0.99872:127}  ">
+func @parse() -> !qalias {
+  %0 = "foo"() : () -> !qalias
+  return %0 : !qalias
+}
+
+// -----
+// Required per-layer params specified:
+//   [unsigned] storageType, expressedType, scale
+// CHECK: !quant<"uniform[u8:f32]{9.987200e-01}">
+!qalias = type !quant<"uniform[u8:f32]{0.99872}">
+func @parse() -> !qalias {
+  %0 = "foo"() : () -> !qalias
+  return %0 : !qalias
+}
+
+// -----
+// Exponential scale (-)
+// CHECK: !quant<"uniform[u8:f32]{2.000000e-02}">
+!qalias = type !quant<"uniform[u8:f32]{2.0e-2}">
+func @parse() -> !qalias {
+  %0 = "foo"() : () -> !qalias
+  return %0 : !qalias
+}
+
+// -----
+// Exponential scale (+)
+// CHECK: !quant<"uniform[u8:f32]{2.000000e+02}">
+!qalias = type !quant<"uniform[u8:f32]{2.0e+2}">
+func @parse() -> !qalias {
+  %0 = "foo"() : () -> !qalias
+  return %0 : !qalias
+}
+
+// -----
+// Storage type: i16
+// CHECK: !quant<"uniform[i16:f32]{2.000000e+02}">
+!qalias = type !quant<"uniform[i16:f32]{2.0e+2}">
+func @parse() -> !qalias {
+  %0 = "foo"() : () -> !qalias
+  return %0 : !qalias
+}
+
+// -----
+// Storage type: u16
+// CHECK: !quant<"uniform[u16:f32]{2.000000e+02}">
+!qalias = type !quant<"uniform[u16:f32]{2.0e+2}">
+func @parse() -> !qalias {
+  %0 = "foo"() : () -> !qalias
+  return %0 : !qalias
+}
+
+// -----
+// Storage type: i32
+// CHECK: !quant<"uniform[i32:f32]{2.000000e+02}">
+!qalias = type !quant<"uniform[i32:f32]{2.0e+2}">
+func @parse() -> !qalias {
+  %0 = "foo"() : () -> !qalias
+  return %0 : !qalias
+}
+
+// -----
+// Storage type: u32
+// CHECK: !quant<"uniform[u32:f32]{2.000000e+02}">
+!qalias = type !quant<"uniform[u32:f32]{2.0e+2}">
+func @parse() -> !qalias {
+  %0 = "foo"() : () -> !qalias
+  return %0 : !qalias
+}
+
+// -----
+// Expressed type: f32
+// CHECK: !quant<"uniform[u8:f32]{2.000000e+02}">
+!qalias = type !quant<"uniform[u8:f32]{2.0e+2}">
+func @parse() -> !qalias {
+  %0 = "foo"() : () -> !qalias
+  return %0 : !qalias
+}
+
+// -----
+// Expressed type: f16
+// CHECK: !quant<"uniform[u8:f16]{2.000000e+02}">
+!qalias = type !quant<"uniform[u8:f16]{2.0e+2}">
+func @parse() -> !qalias {
+  %0 = "foo"() : () -> !qalias
+  return %0 : !qalias
+}
+
+// -----
+// Expressed type: f64
+// CHECK: !quant<"uniform[u8:f64]{2.000000e+02}">
+!qalias = type !quant<"uniform[u8:f64]{2.0e+2}">
+func @parse() -> !qalias {
+  %0 = "foo"() : () -> !qalias
+  return %0 : !qalias
+}
+
+// -----
+// Expressed type: bf16
+// CHECK: !quant<"uniform[u8:bf16]{2.000000e+02}">
+!qalias = type !quant<"uniform[u8:bf16]{2.0e+2}">
+func @parse() -> !qalias {
+  %0 = "foo"() : () -> !qalias
+  return %0 : !qalias
+}
+
+// -----
+// Per-axis scales and zero points (affine)
+// CHECK: !quant<"uniform[u8:f32:1]{2.000000e+02:-120,9.987200e-01:127}">
+!qalias = type !quant<"uniform[u8:f32:1]{2.0e+2:-120,0.99872:127}">
+func @parse() -> !qalias {
+  %0 = "foo"() : () -> !qalias
+  return %0 : !qalias
+}
+
+// -----
+// Per-axis scales and no zero points (fixedpoint)
+// CHECK: !quant<"uniform[i8:f32:1]{2.000000e+02,9.987200e-01}">
+!qalias = type !quant<"uniform[i8:f32:1]{2.0e+2,0.99872}">
+func @parse() -> !qalias {
+  %0 = "foo"() : () -> !qalias
+  return %0 : !qalias
+}
+
+// -----
+// Per-axis scales and zero points (mixed affine and fixedpoint)
+// CHECK: !quant<"uniform[i8:f32:1]{2.000000e+02,9.987200e-01:120}">
+!qalias = type !quant<"uniform[i8:f32:1]{2.0e+2,0.99872:120}">
+func @parse() -> !qalias {
+  %0 = "foo"() : () -> !qalias
+  return %0 : !qalias
+}
diff --git a/mlir/test/Quantization/tf-lower-fakequant-invalid.mlir b/mlir/test/Quantization/tf-lower-fakequant-invalid.mlir
new file mode 100644 (file)
index 0000000..193522a
--- /dev/null
@@ -0,0 +1,36 @@
+// RUN: mlir-opt %s -split-input-file -verify -quant-lower-tf
+
+// -----
+// TODO(laurenzo): move this test to the TensorFlow/tf-ops-invalid.mlir
+// Verify that a mismatched range errors.
+func @fakeQuantArgs(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
+^bb0(%arg0: tensor<8x4x3xf32>):
+  // expected-error@+1 {{op range failed to straddle zero: [1.100000,1.500000]}}
+  %0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) {
+    min: 1.1, max: 1.5, num_bits: 8
+  } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
+  return %0 : tensor<8x4x3xf32>
+}
+
+// -----
+// Verify that a valid range errors.
+func @fakeQuantArgs(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
+^bb0(%arg0: tensor<8x4x3xf32>):
+  // expected-error@+1 {{op range is invalid: [1.100000,1.000000}}
+  %0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) {
+    min: 1.1, max: 1.0, num_bits: 8
+  } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
+  return %0 : tensor<8x4x3xf32>
+}
+
+// -----
+// TODO(laurenzo): move this test to the TensorFlow/tf-ops-invalid.mlir
+// Unsupported quantizable type (i1 is currently not a supported element type).
+func @fakeQuantArgs(tensor<8x4x3xi1>) -> tensor<8x4x3xi1> {
+^bb0(%arg0: tensor<8x4x3xi1>):
+  // expected-error@+1 {{op operand #0 must be tensor of 32-bit float values}}
+  %0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) {
+    min: 1.1, max: 1.0, num_bits: 8
+  } : (tensor<8x4x3xi1>) -> tensor<8x4x3xi1>
+  return %0 : tensor<8x4x3xi1>
+}
diff --git a/mlir/test/Quantization/tf-lower-fakequant.mlir b/mlir/test/Quantization/tf-lower-fakequant.mlir
new file mode 100644 (file)
index 0000000..a6c572e
--- /dev/null
@@ -0,0 +1,77 @@
+// RUN: mlir-opt %s -split-input-file -quant-lower-tf | FileCheck %s --dump-input=fail
+
+// -----
+// Verifies a quint8 asymmetric 0..1 range.
+// CHECK-LABEL: fakeQuantArgs_Quint8_0_1
+func @fakeQuantArgs_Quint8_0_1(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
+^bb0(%arg0: tensor<8x4x3xf32>):
+  // CHECK: %0 = "quant.qbarrier"(%arg0) : (tensor<8x4x3xf32>)
+  // CHECK-SAME: -> tensor<8x4x3x!quant<"uniform[u8:f32]{0.0039215686274509803}">>
+  // CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor<8x4x3x!quant<"uniform[u8:f32]{0.0039215686274509803}">>)
+  // CHECK-SAME: -> tensor<8x4x3xf32>
+  %0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) {
+    min: 0.0, max: 1.0, num_bits: 8
+  } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
+  return %0 : tensor<8x4x3xf32>
+}
+
+// -----
+// Verifies a quint8 asymmetric 0..1 range (with narrow_range = true).
+// CHECK_LABEL: fakeQuantArgs_Quint8_NarrowRange
+func @fakeQuantArgs_Quint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
+^bb0(%arg0: tensor<8x4x3xf32>):
+  // CHECK: %0 = "quant.qbarrier"(%arg0) : (tensor<8x4x3xf32>)
+  // CHECK-SAME: -> tensor<8x4x3x!quant<"uniform[u8(1:255):f32]{0.003937007874015748:1}">>
+  // CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor<8x4x3x!quant<"uniform[u8(1:255):f32]{0.003937007874015748:1}">>)
+  // CHECK-SAME: -> tensor<8x4x3xf32>
+  %0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) {
+    min: 0.0, max: 1.0, num_bits: 8, narrow_range: true
+  } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
+  return %0 : tensor<8x4x3xf32>
+}
+
+// -----
+// Verifies a quint8 symmetric range of -1..127/128.
+// CHECK_LABEL: fakeQuantArgs_Quint8_SymmetricRange
+func @fakeQuantArgs_Quint8_SymmetricRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
+^bb0(%arg0: tensor<8x4x3xf32>):
+  // CHECK: %0 = "quant.qbarrier"(%arg0) : (tensor<8x4x3xf32>)
+  // CHECK-SAME: -> tensor<8x4x3x!quant<"uniform[u8:f32]{7.812500e-03:128}">>
+  // CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor<8x4x3x!quant<"uniform[u8:f32]{7.812500e-03:128}">>)
+  // CHECK-SAME: -> tensor<8x4x3xf32>
+  %0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) {
+    min: -1.0, max: 0.9921875, num_bits: 8, narrow_range: false
+  } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
+  return %0 : tensor<8x4x3xf32>
+}
+
+// -----
+// Verifies a commonly used -1..1 symmetric 16bit range with a zero point of
+// 0 and range -1.0 .. 32767/32768.
+// CHECK-LABEL: fakeQuantArgs_Qint16_Symmetric
+func @fakeQuantArgs_Qint16_Symmetric(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
+^bb0(%arg0: tensor<8x4x3xf32>):
+  // CHECK: %0 = "quant.qbarrier"(%arg0) : (tensor<8x4x3xf32>)
+  // CHECK-SAME: -> tensor<8x4x3x!quant<"uniform[i16:f32]{3.05175781185626E-5}">>
+  // CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor<8x4x3x!quant<"uniform[i16:f32]{3.05175781185626E-5}">>)
+  // CHECK-SAME: -> tensor<8x4x3xf32>
+  %0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) {
+    min: -1.0, max: 0.999969482, num_bits: 16
+  } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
+  return %0 : tensor<8x4x3xf32>
+}
+
+// -----
+// Verify that lowering to barriers of unranked tensors functions.
+// CHECK-LABEL: fakeQuantArgs_UnrankedTensor
+func @fakeQuantArgs_UnrankedTensor(tensor<f32>) -> tensor<f32> {
+^bb0(%arg0: tensor<f32>):
+  // CHECK: %0 = "quant.qbarrier"(%arg0) : (tensor<f32>)
+  // CHECK-SAME: -> tensor<!quant<"uniform[u8:f32]{0.0039215686274509803}">>
+  // CHECK-NEXT: %1 = "quant.dbarrier"(%0) : (tensor<!quant<"uniform[u8:f32]{0.0039215686274509803}">>)
+  // CHECK-SAME: -> tensor<f32>
+  %0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) {
+    min: 0.0, max: 1.0, num_bits: 8
+  } : (tensor<f32>) -> tensor<f32>
+  return %0 : tensor<f32>
+}
diff --git a/mlir/unittests/Quantization/QuantizationUtilsTest.cpp b/mlir/unittests/Quantization/QuantizationUtilsTest.cpp
new file mode 100644 (file)
index 0000000..9d30d28
--- /dev/null
@@ -0,0 +1,173 @@
+//===- QuantizationUtilsTest.cpp - unit tests for quantization utils ------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Quantization/QuantizeUtils.h"
+#include "mlir/Quantization/UniformSupport.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+using namespace mlir::quant;
+
+namespace {
+
+// Test UniformQuantizedValueConverter converts all APFloat to a magic number 5.
+class TestUniformQuantizedValueConverter
+    : public UniformQuantizedValueConverter {
+public:
+  TestUniformQuantizedValueConverter(UniformQuantizedType type)
+      : UniformQuantizedValueConverter(type), qtype(type) {}
+  APInt quantizeFloatToInt(APFloat expressedValue) const {
+    return APInt(qtype.getStorageType().cast<IntegerType>().getWidth(), 5L);
+  }
+
+private:
+  UniformQuantizedType qtype;
+};
+
+Attribute getTestFloatAttr(double value, MLIRContext *ctx) {
+  return FloatAttr::get(FloatType::getF32(ctx), value);
+}
+
+template <typename ConcreteAttrClass, typename... Arg>
+ConcreteAttrClass getTestElementsAttr(MLIRContext *ctx, ArrayRef<int64_t> shape,
+                                      Arg... value) {
+  auto eleType = FloatType::getF32(ctx);
+  VectorOrTensorType tensorType;
+  if (shape.size() == 1 && shape[0] == -1) {
+    tensorType = UnrankedTensorType::get(eleType);
+  } else {
+    tensorType = RankedTensorType::get(shape, eleType);
+  }
+  return ConcreteAttrClass::get(tensorType, value...);
+}
+
+ElementsAttr getTestSparseElementsAttr(MLIRContext *ctx,
+                                       ArrayRef<int64_t> shape) {
+  auto eleType = FloatType::getF32(ctx);
+  VectorOrTensorType tensorType;
+  if (shape.size() == 1 && shape[0] == -1) {
+    tensorType = UnrankedTensorType::get(eleType);
+  } else {
+    tensorType = RankedTensorType::get(shape, eleType);
+  }
+  auto indicesType = RankedTensorType::get({1, 2}, IntegerType::get(64, ctx));
+  auto indices =
+      DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)});
+  auto valuesType = RankedTensorType::get({1}, eleType);
+  auto values = DenseFPElementsAttr::get(valuesType, {APFloat(0.0f)});
+  return SparseElementsAttr::get(tensorType, indices, values);
+}
+
+UniformQuantizedType getTestQuantizedType(Type storageType, MLIRContext *ctx) {
+  return UniformQuantizedType::get(/*flags=*/false, storageType,
+                                   FloatType::getF32(ctx), /*scale=*/1.0,
+                                   /*zeroPoint=*/0, /*storageTypeMin=*/0,
+                                   /*storageTypeMax=*/255);
+}
+
+TEST(QuantizationUtilsTest, convertFloatAttrUniform) {
+  MLIRContext ctx;
+  IntegerType convertedType = IntegerType::get(8, &ctx);
+  auto quantizedType = getTestQuantizedType(convertedType, &ctx);
+  TestUniformQuantizedValueConverter converter(quantizedType);
+
+  auto realValue = getTestFloatAttr(1.0, &ctx);
+  Type typeResult;
+  auto valueResult =
+      quantizeAttrUniform(realValue, quantizedType, converter, typeResult);
+
+  EXPECT_EQ(valueResult.cast<IntegerAttr>().getInt(), 5);
+  EXPECT_EQ(
+      valueResult.cast<IntegerAttr>().getType().cast<IntegerType>().getWidth(),
+      convertedType.getWidth());
+}
+
+TEST(QuantizationUtilsTest, convertRankedDenseAttrUniform) {
+  MLIRContext ctx;
+  IntegerType convertedType = IntegerType::get(8, &ctx);
+  auto quantizedType = getTestQuantizedType(convertedType, &ctx);
+  TestUniformQuantizedValueConverter converter(quantizedType);
+  auto realValue = getTestElementsAttr<DenseElementsAttr, ArrayRef<Attribute>>(
+      &ctx, {1, 2}, {getTestFloatAttr(1.0, &ctx), getTestFloatAttr(2.0, &ctx)});
+
+  Type returnedType;
+  auto returnedValue =
+      quantizeAttrUniform(realValue, quantizedType, converter, returnedType);
+
+  // Check Elements attribute shape and kind are not changed.
+  auto tensorType = returnedType.cast<TensorType>();
+  auto expectedTensorType = realValue.getType().cast<TensorType>();
+  EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape());
+  EXPECT_EQ(tensorType.getElementType(), convertedType);
+  EXPECT_EQ(returnedValue.getKind(), Attribute::Kind::DenseIntElements);
+
+  // Check Elements attribute element value is expected.
+  auto firstValue = returnedValue.cast<ElementsAttr>().getValue({0, 0});
+  EXPECT_EQ(firstValue.cast<IntegerAttr>().getInt(), 5);
+}
+
+TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) {
+  MLIRContext ctx;
+  IntegerType convertedType = IntegerType::get(8, &ctx);
+  auto quantizedType = getTestQuantizedType(convertedType, &ctx);
+  TestUniformQuantizedValueConverter converter(quantizedType);
+  auto realValue = getTestElementsAttr<SplatElementsAttr, Attribute>(
+      &ctx, {1, 2}, getTestFloatAttr(1.0, &ctx));
+
+  Type returnedType;
+  auto returnedValue =
+      quantizeAttrUniform(realValue, quantizedType, converter, returnedType);
+
+  // Check Elements attribute shape and kind are not changed.
+  auto tensorType = returnedType.cast<TensorType>();
+  auto expectedTensorType = realValue.getType().cast<TensorType>();
+  EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape());
+  EXPECT_EQ(tensorType.getElementType(), convertedType);
+  EXPECT_EQ(returnedValue.getKind(), Attribute::Kind::SplatElements);
+
+  // Check Elements attribute element value is expected.
+  auto firstValue = returnedValue.cast<ElementsAttr>().getValue({0, 0});
+  EXPECT_EQ(firstValue.cast<IntegerAttr>().getInt(), 5);
+}
+
+TEST(QuantizationUtilsTest, convertRankedSparseAttrUniform) {
+  MLIRContext ctx;
+  IntegerType convertedType = IntegerType::get(8, &ctx);
+  auto quantizedType = getTestQuantizedType(convertedType, &ctx);
+  TestUniformQuantizedValueConverter converter(quantizedType);
+  auto realValue = getTestSparseElementsAttr(&ctx, {1, 2});
+
+  Type returnedType;
+  auto returnedValue =
+      quantizeAttrUniform(realValue, quantizedType, converter, returnedType);
+
+  // Check Elements attribute shape and kind are not changed.
+  auto tensorType = returnedType.cast<TensorType>();
+  auto expectedTensorType = realValue.getType().cast<TensorType>();
+  EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape());
+  EXPECT_EQ(tensorType.getElementType(), convertedType);
+  EXPECT_EQ(returnedValue.getKind(), Attribute::Kind::SparseElements);
+
+  // Check Elements attribute element value is expected.
+  auto firstValue = returnedValue.cast<ElementsAttr>().getValue({0, 0});
+  EXPECT_EQ(firstValue.cast<IntegerAttr>().getInt(), 5);
+}
+
+} // end namespace