Upstream the Quantizer tool (part 3).
authorStella Laurenzo <laurenzo@google.com>
Sat, 18 May 2019 00:43:50 +0000 (17:43 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Mon, 20 May 2019 20:46:43 +0000 (13:46 -0700)
    This upstreams the config and constraints for a reference quantization scheme based on the FxpMathOps dialect.

    There are probably two more CLs to get the rest: one with the passes/tests, and one with the tool main() itself.

--

PiperOrigin-RevId: 248817505

mlir/include/mlir/Quantizer/Configurations/FxpMathConfig.h [new file with mode: 0644]
mlir/include/mlir/Quantizer/Support/UniformConstraints.h [new file with mode: 0644]
mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp [new file with mode: 0644]
mlir/lib/Quantizer/Support/UniformConstraints.cpp [new file with mode: 0644]

diff --git a/mlir/include/mlir/Quantizer/Configurations/FxpMathConfig.h b/mlir/include/mlir/Quantizer/Configurations/FxpMathConfig.h
new file mode 100644 (file)
index 0000000..d0efe25
--- /dev/null
@@ -0,0 +1,49 @@
+//===- FxpMathConfig.h - Reference fixed point config -----------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines a TargetConfiguration for reference fixed-point math
+// quantization scheme based on the FxpMathOps (plus a small category of
+// extension ops that can be added from other dialects).
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_QUANTIZER_CONFIGURATIONS_FXPMATHCONFIG_H
+#define MLIR_QUANTIZER_CONFIGURATIONS_FXPMATHCONFIG_H
+
+#include "mlir/Quantizer/Support/Configuration.h"
+#include "mlir/Quantizer/Support/Metadata.h"
+
+namespace mlir {
+namespace quantizer {
+
+/// Target configuration for a reference affine/fixed-point quantization
+/// scheme defined in terms of the FxpMathOps dialect. This can be extended
+/// with select ops from other dialects by way of the following public
+/// methods:
+///   - addValueIdentityOp
+class FxpMathTargetConfig : public TargetConfiguration {
+ public:
+  /// Creates an FxpMathTargetConfig instance which can be further customized.
+  static std::unique_ptr<FxpMathTargetConfig> create(SolverContext &context);
+ protected:
+  FxpMathTargetConfig(SolverContext &context) : TargetConfiguration(context) {}
+};
+
+}  // namespace quantizer
+}  // namespace mlir
+
+#endif  // MLIR_QUANTIZER_CONFIGURATIONS_FXPMATHCONFIG_H
diff --git a/mlir/include/mlir/Quantizer/Support/UniformConstraints.h b/mlir/include/mlir/Quantizer/Support/UniformConstraints.h
new file mode 100644 (file)
index 0000000..beae3ed
--- /dev/null
@@ -0,0 +1,69 @@
+//===- UniformConstraints.h - Constraints for uniform quant -----*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines a builder that lets you attach constraints necessary to
+// perform a variety of uniform quantization conversions to CAG anchors.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_QUANTIZER_SUPPORT_UNIFORMCONSTRAINTS_H
+#define MLIR_QUANTIZER_SUPPORT_UNIFORMCONSTRAINTS_H
+
+#include "mlir/Quantizer/Support/Statistics.h"
+
+namespace mlir {
+namespace quantizer {
+
+class CAGAnchorNode;
+class CAGSlice;
+
+/// Factory methods for adding CAG constraints of various kinds suitable
+/// for solving for uniform quantization.
+class UniformConstraintsBuilder {
+ public:
+  UniformConstraintsBuilder(CAGSlice &slice) : slice(slice) {}
+
+  /// Adds a coupling constraint between two nodes, effectively treating
+  /// them as a hard identity relationship.
+  void coupleAnchors(CAGAnchorNode *a, CAGAnchorNode *b);
+
+  /// Applies statistics constraints to the given anchor, such that the solver
+  /// ensures that the statistics are representable by chosen types.
+  void applyStats(CAGAnchorNode *a, TensorAxisStatistics stats);
+
+  /// Applies a constraint to a node which allows solutions that do not extend
+  /// beyond given min/max bounds (this is a hint that the tensor will not
+  /// take values outside of these bounds). If either minValue or maxValue is
+  /// NAN, then that side is considered open.
+  void clamp(CAGAnchorNode *a, APFloat minValue, APFloat maxValue);
+
+  /// Propagates an explicit scale from an anchor that may have a uniform
+  /// |selectedType| to the |explicitScaleZeroPoint| field of the to node.
+  /// This is typically used with a to node that has a candidate quantized
+  /// type of |UniformExplicitFixedPointScale|, indicating that it can be
+  /// an arbitrary (signed) type that is expected to share the same scale
+  /// as the originating node.
+  void propagateExplicitScale(CAGAnchorNode *from, CAGAnchorNode *to);
+
+ private:
+  CAGSlice &slice;
+};
+
+}  // namespace quantizer
+}  // namespace mlir
+
+#endif  // MLIR_QUANTIZER_SUPPORT_UNIFORMCONSTRAINTS_H
diff --git a/mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp b/mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp
new file mode 100644 (file)
index 0000000..8623df9
--- /dev/null
@@ -0,0 +1,289 @@
+//===- FxpMathConfig.cpp - Reference fixed point config -------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines a TargetConfiguration for reference fixed-point math
+// quantization scheme based on the FxpMathOps (plus a small category of
+// extension ops that can be added from other dialects).
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Quantizer/Configurations/FxpMathConfig.h"
+
+#include "mlir/Dialect/FxpMathOps/FxpMathOps.h"
+#include "mlir/Dialect/QuantOps/QuantOps.h"
+#include "mlir/Dialect/QuantOps/QuantTypes.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h"
+#include "mlir/Quantizer/Support/Metadata.h"
+#include "mlir/Quantizer/Support/Statistics.h"
+#include "mlir/Quantizer/Support/UniformConstraints.h"
+#include "mlir/StandardOps/Ops.h"
+
+using namespace mlir;
+using namespace mlir::quantizer;
+using namespace mlir::fxpmath;
+using namespace mlir::quant;
+using namespace std::placeholders;
+
+namespace {
+
+struct FxpMathTargetConfigImpl : public FxpMathTargetConfig {
+  FxpMathTargetConfigImpl(SolverContext &context)
+      : FxpMathTargetConfig(context) {
+    Builder b(&context.getMlirContext());
+    IntegerType i8Type = b.getIntegerType(8);
+    IntegerType i16Type = b.getIntegerType(16);
+    IntegerType i32Type = b.getIntegerType(32);
+
+    q8 = addCandidateType(
+        AnyQuantizedType::get(QuantizationFlags::Signed, i8Type, nullptr,
+                              std::numeric_limits<int8_t>::min(),
+                              std::numeric_limits<int8_t>::max()),
+        CandidateQuantizedType::Scheme::UniformPerLayer);
+    q16 = addCandidateType(
+        AnyQuantizedType::get(QuantizationFlags::Signed, i16Type, nullptr,
+                              std::numeric_limits<int16_t>::min(),
+                              std::numeric_limits<int16_t>::max()),
+        CandidateQuantizedType::Scheme::UniformPerLayer);
+    q32ExplicitFixedPoint = addCandidateType(
+        AnyQuantizedType::get(QuantizationFlags::Signed, i32Type, nullptr,
+                              std::numeric_limits<int32_t>::min(),
+                              std::numeric_limits<int32_t>::max()),
+        CandidateQuantizedType::Scheme::UniformExplicitFixedPointScale);
+
+    // Op handlers.
+    addOpHandler<ConstantOp>(
+        std::bind(&FxpMathTargetConfigImpl::handleConstant, this, _1, _2));
+    addOpHandler<ReturnOp>(
+        std::bind(&FxpMathTargetConfigImpl::handleTerminal, this, _1, _2));
+    addOpHandler<quant::StatisticsOp>(
+        std::bind(&FxpMathTargetConfigImpl::handleStats, this, _1, _2));
+
+    // FxpMathOps.
+    addOpHandler<RealAddEwOp>(
+        std::bind(&FxpMathTargetConfigImpl::handleAdd, this, _1, _2));
+    addOpHandler<RealMulEwOp>(
+        std::bind(&FxpMathTargetConfigImpl::handleMul, this, _1, _2));
+    addOpHandler<RealMatMulOp>(
+        std::bind(&FxpMathTargetConfigImpl::handleMatMul, this, _1, _2));
+    addOpHandler<RealMatMulBiasOp>(
+        std::bind(&FxpMathTargetConfigImpl::handleMatMulBias, this, _1, _2));
+
+    // Require stats ops.
+    addRequireStatsOp<RealAddEwOp>();
+    addRequireStatsOp<RealSubEwOp>();
+    addRequireStatsOp<RealDivEwOp>();
+    addRequireStatsOp<RealMulEwOp>();
+    addRequireStatsOp<RealMatMulOp>();
+    addRequireStatsOp<RealMatMulBiasOp>();
+  }
+
+  bool isHandledType(Type t) const final {
+    if (t.isa<FloatType>())
+      return true;
+    auto shapedType = t.dyn_cast<ShapedType>();
+    return (shapedType && shapedType.getElementType().isa<FloatType>() &&
+            (t.isa<VectorType>() || t.isa<TensorType>()));
+  }
+
+  void finalizeAnchors(CAGSlice &cag) const override {
+    cag.enumerateImpliedConnections(
+        [&](CAGAnchorNode *from, CAGAnchorNode *to) {
+          UniformConstraintsBuilder(cag).coupleAnchors(from, to);
+        });
+  }
+
+  void addValueIdentityOpByName(StringRef opName) override {
+    addOpHandlerByName(
+        opName,
+        std::bind(&FxpMathTargetConfigImpl::handleValueIdentity, this, _1, _2));
+  }
+
+  void handleValueIdentity(Operation *op, CAGSlice &cag) const {
+    assert(op->getNumResults() == 1);
+    if (!isHandledType(op->getResult(0)->getType()))
+      return;
+
+    auto resultNode = cag.getResultAnchor(op, 0);
+    resultNode->setTypeTransformRule(
+        CAGAnchorNode::TypeTransformRule::DirectStorage);
+
+    for (unsigned opIdx = 0, e = op->getNumOperands(); opIdx < e; ++opIdx) {
+      if (!isHandledType(op->getOperand(opIdx)->getType()))
+        continue;
+      auto operandNode = cag.getOperandAnchor(op, opIdx);
+      operandNode->setTypeTransformRule(
+          CAGAnchorNode::TypeTransformRule::DirectStorage);
+      UniformConstraintsBuilder(cag).coupleAnchors(operandNode, resultNode);
+    }
+  }
+
+  void handleConstant(Operation *op, CAGSlice &cag) const {
+    if (!isHandledType(op->getResult(0)->getType()))
+      return;
+
+    auto resultNode = cag.getResultAnchor(op, 0);
+    resultNode->setTypeTransformRule(
+        CAGAnchorNode::TypeTransformRule::ExpressedOnly);
+    Attribute valueAttr;
+    if (!matchPattern(op, m_Constant(&valueAttr))) {
+      return;
+    }
+
+    AttributeTensorStatistics stats(valueAttr);
+    TensorAxisStatistics layerStats;
+    if (!stats.get(layerStats)) {
+      op->emitOpError("could not compute statistics");
+      return;
+    }
+
+    UniformConstraintsBuilder(cag).applyStats(resultNode, layerStats);
+  }
+
+  void handleTerminal(Operation *op, CAGSlice &cag) const {
+    if (!isHandledType(op->getOperand(0)->getType()))
+      return;
+    auto operandNode = cag.getOperandAnchor(op, 0);
+    operandNode->setTypeTransformRule(
+        CAGAnchorNode::TypeTransformRule::ExpressedOnly);
+  }
+
+  void handleStats(Operation *op, CAGSlice &cag) const {
+    if (!isHandledType(op->getResult(0)->getType()))
+      return;
+
+    auto argNode = cag.getOperandAnchor(op, 0);
+    auto resultNode = cag.getResultAnchor(op, 0);
+    UniformConstraintsBuilder(cag).coupleAnchors(argNode, resultNode);
+
+    TensorAxisStatistics layerStats;
+    auto statsOp = cast<quant::StatisticsOp>(op);
+    auto layerStatsAttr = statsOp.layerStats();
+    layerStats.minValue =
+        layerStatsAttr.getValue({0}).cast<FloatAttr>().getValueAsDouble();
+    layerStats.maxValue =
+        layerStatsAttr.getValue({1}).cast<FloatAttr>().getValueAsDouble();
+    UniformConstraintsBuilder(cag).applyStats(resultNode,
+                                              std::move(layerStats));
+  }
+
+  void handleAdd(Operation *op, CAGSlice &cag) const {
+    if (!isHandledType(op->getResult(0)->getType()))
+      return;
+
+    auto lhs = cag.getOperandAnchor(op, 0);
+    auto rhs = cag.getOperandAnchor(op, 1);
+    auto resultNode = cag.getResultAnchor(op, 0);
+    // Add supports 8/16 bit math.
+    llvm::SmallBitVector disableMask =
+        getCandidateTypeDisabledExceptMask({q8, q16});
+    lhs->getUniformMetadata().disabledCandidateTypes = disableMask;
+    rhs->getUniformMetadata().disabledCandidateTypes = disableMask;
+    resultNode->getUniformMetadata().disabledCandidateTypes = disableMask;
+    // NOTE: We couple the add such that the scale/zeroPoint match between
+    // both args and the result. This is overly constrained in that it is
+    // possible to write efficient add kernels with a bit more freedom (i.e.
+    // zeroPoints can vary, scales can differ by a power of two, etc).
+    // However, fully coupled yields the simples solutions on the fast path.
+    // Further efficiency can be had by constraining the zeroPoint to 0, but
+    // there isn't a constraint for this yet (and there are tradeoffs).
+    UniformConstraintsBuilder(cag).coupleAnchors(lhs, resultNode);
+    UniformConstraintsBuilder(cag).coupleAnchors(rhs, resultNode);
+    addRealMathOptionalConstraints(op, resultNode, cag);
+  }
+
+  void handleMul(Operation *op, CAGSlice &cag) const {
+    if (!isHandledType(op->getResult(0)->getType()))
+      return;
+
+    auto lhs = cag.getOperandAnchor(op, 0);
+    auto rhs = cag.getOperandAnchor(op, 1);
+    auto resultNode = cag.getResultAnchor(op, 0);
+    // Mul supports 8/16 bit math.
+    llvm::SmallBitVector disableMask =
+        getCandidateTypeDisabledExceptMask({q8, q16});
+    lhs->getUniformMetadata().disabledCandidateTypes = disableMask;
+    rhs->getUniformMetadata().disabledCandidateTypes = disableMask;
+    resultNode->getUniformMetadata().disabledCandidateTypes = disableMask;
+    addRealMathOptionalConstraints(op, resultNode, cag);
+  }
+
+  void handleMatMul(Operation *op, CAGSlice &cag) const {
+    if (!isHandledType(op->getResult(0)->getType()))
+      return;
+
+    auto lhs = cag.getOperandAnchor(op, 0);
+    auto rhs = cag.getOperandAnchor(op, 1);
+    auto resultNode = cag.getResultAnchor(op, 0);
+    // Mul supports 8/16 bit math.
+    llvm::SmallBitVector disableMask =
+        getCandidateTypeDisabledExceptMask({q8, q16});
+    lhs->getUniformMetadata().disabledCandidateTypes = disableMask;
+    rhs->getUniformMetadata().disabledCandidateTypes = disableMask;
+    resultNode->getUniformMetadata().disabledCandidateTypes = disableMask;
+    addRealMathOptionalConstraints(op, resultNode, cag);
+  }
+
+  void handleMatMulBias(Operation *op, CAGSlice &cag) const {
+    if (!isHandledType(op->getResult(0)->getType()))
+      return;
+
+    auto lhs = cag.getOperandAnchor(op, 0);
+    auto rhs = cag.getOperandAnchor(op, 1);
+    auto bias = cag.getOperandAnchor(op, 2);
+    bias->getUniformMetadata().disabledCandidateTypes =
+        getCandidateTypeDisabledExceptMask({q32ExplicitFixedPoint});
+
+    auto resultNode = cag.getResultAnchor(op, 0);
+    UniformConstraintsBuilder(cag).propagateExplicitScale(resultNode, bias);
+
+    // Mul supports 8/16 bit math.
+    llvm::SmallBitVector disableMask =
+        getCandidateTypeDisabledExceptMask({q8, q16});
+    lhs->getUniformMetadata().disabledCandidateTypes = disableMask;
+    rhs->getUniformMetadata().disabledCandidateTypes = disableMask;
+    resultNode->getUniformMetadata().disabledCandidateTypes = disableMask;
+    addRealMathOptionalConstraints(op, resultNode, cag);
+  }
+
+  void addRealMathOptionalConstraints(Operation *op, CAGAnchorNode *anchor,
+                                      CAGSlice &cag) const {
+    // TODO: It would be nice if these all extended some base trait instead
+    // of requiring name lookup.
+    auto clampMinAttr = op->getAttrOfType<FloatAttr>("clamp_min");
+    auto clampMaxAttr = op->getAttrOfType<FloatAttr>("clamp_max");
+
+    if (clampMinAttr || clampMaxAttr) {
+      auto nan = APFloat::getQNaN(APFloat::IEEEdouble());
+      auto clampMin = clampMinAttr ? clampMinAttr.getValue() : nan;
+      auto clampMax = clampMaxAttr ? clampMaxAttr.getValue() : nan;
+      UniformConstraintsBuilder(cag).clamp(anchor, clampMin, clampMax);
+    }
+  }
+
+  unsigned q8;
+  unsigned q16;
+  unsigned q32ExplicitFixedPoint;
+};
+
+} // anonymous namespace
+
+std::unique_ptr<FxpMathTargetConfig>
+FxpMathTargetConfig::create(SolverContext &context) {
+  return llvm::make_unique<FxpMathTargetConfigImpl>(context);
+}
diff --git a/mlir/lib/Quantizer/Support/UniformConstraints.cpp b/mlir/lib/Quantizer/Support/UniformConstraints.cpp
new file mode 100644 (file)
index 0000000..ab1ced1
--- /dev/null
@@ -0,0 +1,269 @@
+//===- UniformConstraints.cpp - Constraints for uniform quant -------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Quantizer/Support/UniformConstraints.h"
+
+#include "mlir/Dialect/QuantOps/QuantTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/Quantizer/Support/Configuration.h"
+#include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h"
+#include "mlir/Quantizer/Support/Metadata.h"
+#include "mlir/Quantizer/Support/Rules.h"
+#include "mlir/Quantizer/Support/TypeUtils.h"
+#include "mlir/Quantizer/Support/UniformSolvers.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+using namespace mlir::quantizer;
+using namespace mlir::quant;
+
+namespace {
+
+struct ClusteredFacts {
+  ExpandingMinMaxFact requiredRange;
+  DiscreteScaleZeroPointFact explicitScaleZeroPoint;
+};
+
+} // end anonymous namespace
+
+static QuantizedType solveUniformType(SolverContext &solverContext,
+                                      const ClusteredFacts &clusteredFacts,
+                                      const CandidateQuantizedType &ct,
+                                      Type originalElementType, Location loc) {
+  switch (ct.scheme) {
+  default:
+    solverContext.getMlirContext().emitError(
+        loc, "unsupported scheme for uniform type conversion");
+    return nullptr;
+
+  case CandidateQuantizedType::Scheme::UniformPerLayer: {
+    if (!clusteredFacts.requiredRange.hasValue()) {
+      // TODO: Issue some kind of diagnostic. This is not an error.
+      return nullptr;
+    }
+
+    uint64_t numLevels = ct.quantizedType.getStorageTypeMax() -
+                         ct.quantizedType.getStorageTypeMin();
+    UniformStorageParams params{numLevels,
+                                ct.quantizedType.getStorageTypeMin()};
+    UniformParamsFromMinMaxSolver solver(
+        params, clusteredFacts.requiredRange.getValue().first,
+        clusteredFacts.requiredRange.getValue().second);
+    if (!solver.compute()) {
+      solverContext.getMlirContext().emitWarning(loc)
+          << "unable to solve uniform type with "
+          << "UniformParamsFromMinMaxSolver";
+      return nullptr;
+    }
+
+    return UniformQuantizedType::getChecked(
+        ct.quantizedType.getFlags(), ct.quantizedType.getStorageType(),
+        originalElementType, solver.getScale(), solver.getZp(),
+        ct.quantizedType.getStorageTypeMin(),
+        ct.quantizedType.getStorageTypeMax(), loc);
+  }
+  case CandidateQuantizedType::Scheme::UniformExplicitFixedPointScale: {
+    if (!clusteredFacts.explicitScaleZeroPoint.hasValue()) {
+      solverContext.getMlirContext().emitRemark(loc)
+          << "unable to solve uniform type with UniformExplicitFixedPointScale "
+          << "(no explicitScaleZeroPoint)";
+      return nullptr;
+    }
+
+    const auto &scaleZp = clusteredFacts.explicitScaleZeroPoint.getValue();
+    assert(scaleZp.value && "optional value not set on fact");
+
+    if (scaleZp.conflict) {
+      solverContext.getMlirContext().emitWarning(loc)
+          << "conflicting explicit scale/zeroPoint on node cluster: "
+          << "an arbitrary scale/zeroPoint will be used";
+    }
+
+    return UniformQuantizedType::getChecked(
+        ct.quantizedType.getFlags(), ct.quantizedType.getStorageType(),
+        originalElementType,
+        scaleZp.value->first, // scale
+        0, // zeroPoint (fixed point solutions only for this scheme)
+        ct.quantizedType.getStorageTypeMin(),
+        ct.quantizedType.getStorageTypeMax(), loc);
+
+    return nullptr;
+  }
+  }
+}
+
+namespace {
+
+class PropagateExplicitScale : public CAGConstraintNode {
+public:
+  PropagateExplicitScale()
+      : CAGConstraintNode(Kind::UniformPropagateExplicitScale) {}
+  static bool classof(const CAGNode *n) {
+    return n->getKind() == Kind::Constraint ||
+           n->getKind() == Kind::UniformPropagateExplicitScale;
+  }
+
+private:
+  void printLabel(llvm::raw_ostream &os) const override {
+    os << "PropagateExplicitScale";
+  }
+  void propagate(SolverContext &solverContext,
+                 const TargetConfiguration &config) {
+    DiscreteScaleZeroPointFact scaleZp;
+
+    // Get scale/zp from all parents.
+    for (auto it = incoming_begin(), e = incoming_end(); it != e; ++it) {
+      auto parentAnchor = llvm::cast<CAGAnchorNode>(*it);
+      auto selectedType = parentAnchor->getUniformMetadata().selectedType;
+      if (auto uqType = selectedType.dyn_cast_or_null<UniformQuantizedType>()) {
+        scaleZp.assertValue(
+            CAGUniformMetadata::SalienceRequired,
+            std::make_pair(uqType.getScale(), static_cast<int64_t>(0)));
+      }
+    }
+
+    // Propagate to children.
+    if (scaleZp.hasValue()) {
+      for (auto it = begin(), e = end(); it != e; ++it) {
+        auto childAnchor = llvm::cast<CAGAnchorNode>(*it);
+        if (modified(childAnchor->getUniformMetadata()
+                         .explicitScaleZeroPoint.mergeFrom(scaleZp))) {
+          childAnchor->markDirty();
+        }
+      }
+    }
+  }
+};
+
+/// A constraint node which will solve uniform quantization for all parents
+/// of the constraint, assuming that they are coupled.
+class SolveUniformConstraintNode : public CAGConstraintNode {
+public:
+  SolveUniformConstraintNode()
+      : CAGConstraintNode(Kind::SolveUniformConstraint) {
+    markDirty();
+  }
+  static bool classof(const CAGNode *n) {
+    return n->getKind() == Kind::Constraint ||
+           n->getKind() == Kind::SolveUniformConstraint;
+  }
+
+private:
+  void printLabel(llvm::raw_ostream &os) const override {
+    os << "SolveUniform";
+  }
+
+  void propagate(SolverContext &solverContext,
+                 const TargetConfiguration &config) {
+    // First determine the required min/max range and type constraints.
+    Location fusedLoc = UnknownLoc::get(&solverContext.getMlirContext());
+    llvm::SmallBitVector enabledCandidateTypesMask(
+        config.getAllCandidateTypesMask());
+    ClusteredFacts clusteredFacts;
+    Type originalElementType;
+    for (auto it = incoming_begin(), e = incoming_end(); it != e; ++it) {
+      auto parentAnchor = llvm::cast<CAGAnchorNode>(*it);
+      auto metadata = parentAnchor->getUniformMetadata();
+      // TODO: Possibly use a location that fuses all involved parents.
+      fusedLoc = parentAnchor->getOp()->getLoc();
+
+      // Shared element type.
+      auto parentOriginalElementType =
+          getElementOrPrimitiveType(parentAnchor->getOriginalType());
+      if (!originalElementType) {
+        originalElementType = parentOriginalElementType;
+      } else {
+        if (originalElementType != parentOriginalElementType) {
+          parentAnchor->getOp()->emitError()
+              << "cannot compute uniform type: parent element types mismatch";
+          return;
+        }
+      }
+      // Range.
+      clusteredFacts.requiredRange.mergeFrom(metadata.requiredRange);
+
+      // Explicit scale and zero point.
+      clusteredFacts.explicitScaleZeroPoint.mergeFrom(
+          metadata.explicitScaleZeroPoint);
+
+      // Shared candidate types.
+      enabledCandidateTypesMask.reset(metadata.disabledCandidateTypes);
+    }
+
+    // Find the first enabled candidate type.
+    const CandidateQuantizedType *bestCandidateType = nullptr;
+    for (auto &ct : config.getCandidateTypes()) {
+      if (enabledCandidateTypesMask.test(ct.ordinal)) {
+        bestCandidateType = &ct;
+        break;
+      }
+    }
+
+    if (!bestCandidateType || !originalElementType) {
+      solverContext.getMlirContext().emitRemark(fusedLoc)
+          << "not solving uniform type (no viable candidate type)";
+      return;
+    }
+
+    // Solve for the type.
+    QuantizedType selectedType =
+        solveUniformType(solverContext, clusteredFacts, *bestCandidateType,
+                         originalElementType, fusedLoc);
+
+    // Apply it to all parents.
+    for (auto it = incoming_begin(), e = incoming_end(); it != e; ++it) {
+      auto parentAnchor = llvm::cast<CAGAnchorNode>(*it);
+      auto &metadata = parentAnchor->getUniformMetadata();
+      if (metadata.selectedType != selectedType) {
+        metadata.selectedType = selectedType;
+        // And mark all children of the parent dirty (except us).
+        for (auto child : *parentAnchor) {
+          if (child != this) {
+            child->markDirty();
+          }
+        }
+      }
+    }
+  }
+};
+
+} // end anonymous namespace
+
+void UniformConstraintsBuilder::coupleAnchors(CAGAnchorNode *a,
+                                              CAGAnchorNode *b) {
+  slice.addClusteredConstraint<SolveUniformConstraintNode>({a, b});
+}
+
+void UniformConstraintsBuilder::applyStats(CAGAnchorNode *a,
+                                           TensorAxisStatistics stats) {
+  a->getUniformMetadata().requiredRange.assertValue(
+      CAGUniformMetadata::SalienceDefault, {stats.minValue, stats.maxValue});
+}
+
+void UniformConstraintsBuilder::clamp(CAGAnchorNode *a, APFloat minValue,
+                                      APFloat maxValue) {
+  a->getUniformMetadata().requiredRange.assertValue(
+      CAGUniformMetadata::SalienceDefault,
+      {minValue.convertToDouble(), maxValue.convertToDouble()});
+}
+
+void UniformConstraintsBuilder::propagateExplicitScale(CAGAnchorNode *from,
+                                                       CAGAnchorNode *to) {
+  slice.addUnidirectionalConstraint<PropagateExplicitScale>(from, {to});
+}