From 8e5bfb85c44916ef581675ea5a68e062d85624fd Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Fri, 17 May 2019 17:43:50 -0700 Subject: [PATCH] Upstream the Quantizer tool (part 3). This upstreams the config and constraints for a reference quantization scheme based on the FxpMathOps dialect. There are probably two more CLs to get the rest: one with the passes/tests, and one with the tool main() itself. -- PiperOrigin-RevId: 248817505 --- .../mlir/Quantizer/Configurations/FxpMathConfig.h | 49 ++++ .../mlir/Quantizer/Support/UniformConstraints.h | 69 +++++ .../lib/Quantizer/Configurations/FxpMathConfig.cpp | 289 +++++++++++++++++++++ mlir/lib/Quantizer/Support/UniformConstraints.cpp | 269 +++++++++++++++++++ 4 files changed, 676 insertions(+) create mode 100644 mlir/include/mlir/Quantizer/Configurations/FxpMathConfig.h create mode 100644 mlir/include/mlir/Quantizer/Support/UniformConstraints.h create mode 100644 mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp create mode 100644 mlir/lib/Quantizer/Support/UniformConstraints.cpp diff --git a/mlir/include/mlir/Quantizer/Configurations/FxpMathConfig.h b/mlir/include/mlir/Quantizer/Configurations/FxpMathConfig.h new file mode 100644 index 0000000..d0efe25 --- /dev/null +++ b/mlir/include/mlir/Quantizer/Configurations/FxpMathConfig.h @@ -0,0 +1,49 @@ +//===- FxpMathConfig.h - Reference fixed point config -----------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file defines a TargetConfiguration for reference fixed-point math +// quantization scheme based on the FxpMathOps (plus a small category of +// extension ops that can be added from other dialects). +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_QUANTIZER_CONFIGURATIONS_FXPMATHCONFIG_H +#define MLIR_QUANTIZER_CONFIGURATIONS_FXPMATHCONFIG_H + +#include "mlir/Quantizer/Support/Configuration.h" +#include "mlir/Quantizer/Support/Metadata.h" + +namespace mlir { +namespace quantizer { + +/// Target configuration for a reference affine/fixed-point quantization +/// scheme defined in terms of the FxpMathOps dialect. This can be extended +/// with select ops from other dialects by way of the following public +/// methods: +/// - addValueIdentityOp +class FxpMathTargetConfig : public TargetConfiguration { + public: + /// Creates an FxpMathTargetConfig instance which can be further customized. + static std::unique_ptr create(SolverContext &context); + protected: + FxpMathTargetConfig(SolverContext &context) : TargetConfiguration(context) {} +}; + +} // namespace quantizer +} // namespace mlir + +#endif // MLIR_QUANTIZER_CONFIGURATIONS_FXPMATHCONFIG_H diff --git a/mlir/include/mlir/Quantizer/Support/UniformConstraints.h b/mlir/include/mlir/Quantizer/Support/UniformConstraints.h new file mode 100644 index 0000000..beae3ed --- /dev/null +++ b/mlir/include/mlir/Quantizer/Support/UniformConstraints.h @@ -0,0 +1,69 @@ +//===- UniformConstraints.h - Constraints for uniform quant -----*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file defines a builder that lets you attach constraints necessary to +// perform a variety of uniform quantization conversions to CAG anchors. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_QUANTIZER_SUPPORT_UNIFORMCONSTRAINTS_H +#define MLIR_QUANTIZER_SUPPORT_UNIFORMCONSTRAINTS_H + +#include "mlir/Quantizer/Support/Statistics.h" + +namespace mlir { +namespace quantizer { + +class CAGAnchorNode; +class CAGSlice; + +/// Factory methods for adding CAG constraints of various kinds suitable +/// for solving for uniform quantization. +class UniformConstraintsBuilder { + public: + UniformConstraintsBuilder(CAGSlice &slice) : slice(slice) {} + + /// Adds a coupling constraint between two nodes, effectively treating + /// them as a hard identity relationship. + void coupleAnchors(CAGAnchorNode *a, CAGAnchorNode *b); + + /// Applies statistics constraints to the given anchor, such that the solver + /// ensures that the statistics are representable by chosen types. + void applyStats(CAGAnchorNode *a, TensorAxisStatistics stats); + + /// Applies a constraint to a node which allows solutions that do not extend + /// beyond given min/max bounds (this is a hint that the tensor will not + /// take values outside of these bounds). If either minValue or maxValue is + /// NAN, then that side is considered open. + void clamp(CAGAnchorNode *a, APFloat minValue, APFloat maxValue); + + /// Propagates an explicit scale from an anchor that may have a uniform + /// |selectedType| to the |explicitScaleZeroPoint| field of the to node. + /// This is typically used with a to node that has a candidate quantized + /// type of |UniformExplicitFixedPointScale|, indicating that it can be + /// an arbitrary (signed) type that is expected to share the same scale + /// as the originating node. + void propagateExplicitScale(CAGAnchorNode *from, CAGAnchorNode *to); + + private: + CAGSlice &slice; +}; + +} // namespace quantizer +} // namespace mlir + +#endif // MLIR_QUANTIZER_SUPPORT_UNIFORMCONSTRAINTS_H diff --git a/mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp b/mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp new file mode 100644 index 0000000..8623df9 --- /dev/null +++ b/mlir/lib/Quantizer/Configurations/FxpMathConfig.cpp @@ -0,0 +1,289 @@ +//===- FxpMathConfig.cpp - Reference fixed point config -------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file defines a TargetConfiguration for reference fixed-point math +// quantization scheme based on the FxpMathOps (plus a small category of +// extension ops that can be added from other dialects). +// +//===----------------------------------------------------------------------===// + +#include "mlir/Quantizer/Configurations/FxpMathConfig.h" + +#include "mlir/Dialect/FxpMathOps/FxpMathOps.h" +#include "mlir/Dialect/QuantOps/QuantOps.h" +#include "mlir/Dialect/QuantOps/QuantTypes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h" +#include "mlir/Quantizer/Support/Metadata.h" +#include "mlir/Quantizer/Support/Statistics.h" +#include "mlir/Quantizer/Support/UniformConstraints.h" +#include "mlir/StandardOps/Ops.h" + +using namespace mlir; +using namespace mlir::quantizer; +using namespace mlir::fxpmath; +using namespace mlir::quant; +using namespace std::placeholders; + +namespace { + +struct FxpMathTargetConfigImpl : public FxpMathTargetConfig { + FxpMathTargetConfigImpl(SolverContext &context) + : FxpMathTargetConfig(context) { + Builder b(&context.getMlirContext()); + IntegerType i8Type = b.getIntegerType(8); + IntegerType i16Type = b.getIntegerType(16); + IntegerType i32Type = b.getIntegerType(32); + + q8 = addCandidateType( + AnyQuantizedType::get(QuantizationFlags::Signed, i8Type, nullptr, + std::numeric_limits::min(), + std::numeric_limits::max()), + CandidateQuantizedType::Scheme::UniformPerLayer); + q16 = addCandidateType( + AnyQuantizedType::get(QuantizationFlags::Signed, i16Type, nullptr, + std::numeric_limits::min(), + std::numeric_limits::max()), + CandidateQuantizedType::Scheme::UniformPerLayer); + q32ExplicitFixedPoint = addCandidateType( + AnyQuantizedType::get(QuantizationFlags::Signed, i32Type, nullptr, + std::numeric_limits::min(), + std::numeric_limits::max()), + CandidateQuantizedType::Scheme::UniformExplicitFixedPointScale); + + // Op handlers. + addOpHandler( + std::bind(&FxpMathTargetConfigImpl::handleConstant, this, _1, _2)); + addOpHandler( + std::bind(&FxpMathTargetConfigImpl::handleTerminal, this, _1, _2)); + addOpHandler( + std::bind(&FxpMathTargetConfigImpl::handleStats, this, _1, _2)); + + // FxpMathOps. + addOpHandler( + std::bind(&FxpMathTargetConfigImpl::handleAdd, this, _1, _2)); + addOpHandler( + std::bind(&FxpMathTargetConfigImpl::handleMul, this, _1, _2)); + addOpHandler( + std::bind(&FxpMathTargetConfigImpl::handleMatMul, this, _1, _2)); + addOpHandler( + std::bind(&FxpMathTargetConfigImpl::handleMatMulBias, this, _1, _2)); + + // Require stats ops. + addRequireStatsOp(); + addRequireStatsOp(); + addRequireStatsOp(); + addRequireStatsOp(); + addRequireStatsOp(); + addRequireStatsOp(); + } + + bool isHandledType(Type t) const final { + if (t.isa()) + return true; + auto shapedType = t.dyn_cast(); + return (shapedType && shapedType.getElementType().isa() && + (t.isa() || t.isa())); + } + + void finalizeAnchors(CAGSlice &cag) const override { + cag.enumerateImpliedConnections( + [&](CAGAnchorNode *from, CAGAnchorNode *to) { + UniformConstraintsBuilder(cag).coupleAnchors(from, to); + }); + } + + void addValueIdentityOpByName(StringRef opName) override { + addOpHandlerByName( + opName, + std::bind(&FxpMathTargetConfigImpl::handleValueIdentity, this, _1, _2)); + } + + void handleValueIdentity(Operation *op, CAGSlice &cag) const { + assert(op->getNumResults() == 1); + if (!isHandledType(op->getResult(0)->getType())) + return; + + auto resultNode = cag.getResultAnchor(op, 0); + resultNode->setTypeTransformRule( + CAGAnchorNode::TypeTransformRule::DirectStorage); + + for (unsigned opIdx = 0, e = op->getNumOperands(); opIdx < e; ++opIdx) { + if (!isHandledType(op->getOperand(opIdx)->getType())) + continue; + auto operandNode = cag.getOperandAnchor(op, opIdx); + operandNode->setTypeTransformRule( + CAGAnchorNode::TypeTransformRule::DirectStorage); + UniformConstraintsBuilder(cag).coupleAnchors(operandNode, resultNode); + } + } + + void handleConstant(Operation *op, CAGSlice &cag) const { + if (!isHandledType(op->getResult(0)->getType())) + return; + + auto resultNode = cag.getResultAnchor(op, 0); + resultNode->setTypeTransformRule( + CAGAnchorNode::TypeTransformRule::ExpressedOnly); + Attribute valueAttr; + if (!matchPattern(op, m_Constant(&valueAttr))) { + return; + } + + AttributeTensorStatistics stats(valueAttr); + TensorAxisStatistics layerStats; + if (!stats.get(layerStats)) { + op->emitOpError("could not compute statistics"); + return; + } + + UniformConstraintsBuilder(cag).applyStats(resultNode, layerStats); + } + + void handleTerminal(Operation *op, CAGSlice &cag) const { + if (!isHandledType(op->getOperand(0)->getType())) + return; + auto operandNode = cag.getOperandAnchor(op, 0); + operandNode->setTypeTransformRule( + CAGAnchorNode::TypeTransformRule::ExpressedOnly); + } + + void handleStats(Operation *op, CAGSlice &cag) const { + if (!isHandledType(op->getResult(0)->getType())) + return; + + auto argNode = cag.getOperandAnchor(op, 0); + auto resultNode = cag.getResultAnchor(op, 0); + UniformConstraintsBuilder(cag).coupleAnchors(argNode, resultNode); + + TensorAxisStatistics layerStats; + auto statsOp = cast(op); + auto layerStatsAttr = statsOp.layerStats(); + layerStats.minValue = + layerStatsAttr.getValue({0}).cast().getValueAsDouble(); + layerStats.maxValue = + layerStatsAttr.getValue({1}).cast().getValueAsDouble(); + UniformConstraintsBuilder(cag).applyStats(resultNode, + std::move(layerStats)); + } + + void handleAdd(Operation *op, CAGSlice &cag) const { + if (!isHandledType(op->getResult(0)->getType())) + return; + + auto lhs = cag.getOperandAnchor(op, 0); + auto rhs = cag.getOperandAnchor(op, 1); + auto resultNode = cag.getResultAnchor(op, 0); + // Add supports 8/16 bit math. + llvm::SmallBitVector disableMask = + getCandidateTypeDisabledExceptMask({q8, q16}); + lhs->getUniformMetadata().disabledCandidateTypes = disableMask; + rhs->getUniformMetadata().disabledCandidateTypes = disableMask; + resultNode->getUniformMetadata().disabledCandidateTypes = disableMask; + // NOTE: We couple the add such that the scale/zeroPoint match between + // both args and the result. This is overly constrained in that it is + // possible to write efficient add kernels with a bit more freedom (i.e. + // zeroPoints can vary, scales can differ by a power of two, etc). + // However, fully coupled yields the simples solutions on the fast path. + // Further efficiency can be had by constraining the zeroPoint to 0, but + // there isn't a constraint for this yet (and there are tradeoffs). + UniformConstraintsBuilder(cag).coupleAnchors(lhs, resultNode); + UniformConstraintsBuilder(cag).coupleAnchors(rhs, resultNode); + addRealMathOptionalConstraints(op, resultNode, cag); + } + + void handleMul(Operation *op, CAGSlice &cag) const { + if (!isHandledType(op->getResult(0)->getType())) + return; + + auto lhs = cag.getOperandAnchor(op, 0); + auto rhs = cag.getOperandAnchor(op, 1); + auto resultNode = cag.getResultAnchor(op, 0); + // Mul supports 8/16 bit math. + llvm::SmallBitVector disableMask = + getCandidateTypeDisabledExceptMask({q8, q16}); + lhs->getUniformMetadata().disabledCandidateTypes = disableMask; + rhs->getUniformMetadata().disabledCandidateTypes = disableMask; + resultNode->getUniformMetadata().disabledCandidateTypes = disableMask; + addRealMathOptionalConstraints(op, resultNode, cag); + } + + void handleMatMul(Operation *op, CAGSlice &cag) const { + if (!isHandledType(op->getResult(0)->getType())) + return; + + auto lhs = cag.getOperandAnchor(op, 0); + auto rhs = cag.getOperandAnchor(op, 1); + auto resultNode = cag.getResultAnchor(op, 0); + // Mul supports 8/16 bit math. + llvm::SmallBitVector disableMask = + getCandidateTypeDisabledExceptMask({q8, q16}); + lhs->getUniformMetadata().disabledCandidateTypes = disableMask; + rhs->getUniformMetadata().disabledCandidateTypes = disableMask; + resultNode->getUniformMetadata().disabledCandidateTypes = disableMask; + addRealMathOptionalConstraints(op, resultNode, cag); + } + + void handleMatMulBias(Operation *op, CAGSlice &cag) const { + if (!isHandledType(op->getResult(0)->getType())) + return; + + auto lhs = cag.getOperandAnchor(op, 0); + auto rhs = cag.getOperandAnchor(op, 1); + auto bias = cag.getOperandAnchor(op, 2); + bias->getUniformMetadata().disabledCandidateTypes = + getCandidateTypeDisabledExceptMask({q32ExplicitFixedPoint}); + + auto resultNode = cag.getResultAnchor(op, 0); + UniformConstraintsBuilder(cag).propagateExplicitScale(resultNode, bias); + + // Mul supports 8/16 bit math. + llvm::SmallBitVector disableMask = + getCandidateTypeDisabledExceptMask({q8, q16}); + lhs->getUniformMetadata().disabledCandidateTypes = disableMask; + rhs->getUniformMetadata().disabledCandidateTypes = disableMask; + resultNode->getUniformMetadata().disabledCandidateTypes = disableMask; + addRealMathOptionalConstraints(op, resultNode, cag); + } + + void addRealMathOptionalConstraints(Operation *op, CAGAnchorNode *anchor, + CAGSlice &cag) const { + // TODO: It would be nice if these all extended some base trait instead + // of requiring name lookup. + auto clampMinAttr = op->getAttrOfType("clamp_min"); + auto clampMaxAttr = op->getAttrOfType("clamp_max"); + + if (clampMinAttr || clampMaxAttr) { + auto nan = APFloat::getQNaN(APFloat::IEEEdouble()); + auto clampMin = clampMinAttr ? clampMinAttr.getValue() : nan; + auto clampMax = clampMaxAttr ? clampMaxAttr.getValue() : nan; + UniformConstraintsBuilder(cag).clamp(anchor, clampMin, clampMax); + } + } + + unsigned q8; + unsigned q16; + unsigned q32ExplicitFixedPoint; +}; + +} // anonymous namespace + +std::unique_ptr +FxpMathTargetConfig::create(SolverContext &context) { + return llvm::make_unique(context); +} diff --git a/mlir/lib/Quantizer/Support/UniformConstraints.cpp b/mlir/lib/Quantizer/Support/UniformConstraints.cpp new file mode 100644 index 0000000..ab1ced1 --- /dev/null +++ b/mlir/lib/Quantizer/Support/UniformConstraints.cpp @@ -0,0 +1,269 @@ +//===- UniformConstraints.cpp - Constraints for uniform quant -------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/Quantizer/Support/UniformConstraints.h" + +#include "mlir/Dialect/QuantOps/QuantTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Quantizer/Support/Configuration.h" +#include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h" +#include "mlir/Quantizer/Support/Metadata.h" +#include "mlir/Quantizer/Support/Rules.h" +#include "mlir/Quantizer/Support/TypeUtils.h" +#include "mlir/Quantizer/Support/UniformSolvers.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; +using namespace mlir::quantizer; +using namespace mlir::quant; + +namespace { + +struct ClusteredFacts { + ExpandingMinMaxFact requiredRange; + DiscreteScaleZeroPointFact explicitScaleZeroPoint; +}; + +} // end anonymous namespace + +static QuantizedType solveUniformType(SolverContext &solverContext, + const ClusteredFacts &clusteredFacts, + const CandidateQuantizedType &ct, + Type originalElementType, Location loc) { + switch (ct.scheme) { + default: + solverContext.getMlirContext().emitError( + loc, "unsupported scheme for uniform type conversion"); + return nullptr; + + case CandidateQuantizedType::Scheme::UniformPerLayer: { + if (!clusteredFacts.requiredRange.hasValue()) { + // TODO: Issue some kind of diagnostic. This is not an error. + return nullptr; + } + + uint64_t numLevels = ct.quantizedType.getStorageTypeMax() - + ct.quantizedType.getStorageTypeMin(); + UniformStorageParams params{numLevels, + ct.quantizedType.getStorageTypeMin()}; + UniformParamsFromMinMaxSolver solver( + params, clusteredFacts.requiredRange.getValue().first, + clusteredFacts.requiredRange.getValue().second); + if (!solver.compute()) { + solverContext.getMlirContext().emitWarning(loc) + << "unable to solve uniform type with " + << "UniformParamsFromMinMaxSolver"; + return nullptr; + } + + return UniformQuantizedType::getChecked( + ct.quantizedType.getFlags(), ct.quantizedType.getStorageType(), + originalElementType, solver.getScale(), solver.getZp(), + ct.quantizedType.getStorageTypeMin(), + ct.quantizedType.getStorageTypeMax(), loc); + } + case CandidateQuantizedType::Scheme::UniformExplicitFixedPointScale: { + if (!clusteredFacts.explicitScaleZeroPoint.hasValue()) { + solverContext.getMlirContext().emitRemark(loc) + << "unable to solve uniform type with UniformExplicitFixedPointScale " + << "(no explicitScaleZeroPoint)"; + return nullptr; + } + + const auto &scaleZp = clusteredFacts.explicitScaleZeroPoint.getValue(); + assert(scaleZp.value && "optional value not set on fact"); + + if (scaleZp.conflict) { + solverContext.getMlirContext().emitWarning(loc) + << "conflicting explicit scale/zeroPoint on node cluster: " + << "an arbitrary scale/zeroPoint will be used"; + } + + return UniformQuantizedType::getChecked( + ct.quantizedType.getFlags(), ct.quantizedType.getStorageType(), + originalElementType, + scaleZp.value->first, // scale + 0, // zeroPoint (fixed point solutions only for this scheme) + ct.quantizedType.getStorageTypeMin(), + ct.quantizedType.getStorageTypeMax(), loc); + + return nullptr; + } + } +} + +namespace { + +class PropagateExplicitScale : public CAGConstraintNode { +public: + PropagateExplicitScale() + : CAGConstraintNode(Kind::UniformPropagateExplicitScale) {} + static bool classof(const CAGNode *n) { + return n->getKind() == Kind::Constraint || + n->getKind() == Kind::UniformPropagateExplicitScale; + } + +private: + void printLabel(llvm::raw_ostream &os) const override { + os << "PropagateExplicitScale"; + } + void propagate(SolverContext &solverContext, + const TargetConfiguration &config) { + DiscreteScaleZeroPointFact scaleZp; + + // Get scale/zp from all parents. + for (auto it = incoming_begin(), e = incoming_end(); it != e; ++it) { + auto parentAnchor = llvm::cast(*it); + auto selectedType = parentAnchor->getUniformMetadata().selectedType; + if (auto uqType = selectedType.dyn_cast_or_null()) { + scaleZp.assertValue( + CAGUniformMetadata::SalienceRequired, + std::make_pair(uqType.getScale(), static_cast(0))); + } + } + + // Propagate to children. + if (scaleZp.hasValue()) { + for (auto it = begin(), e = end(); it != e; ++it) { + auto childAnchor = llvm::cast(*it); + if (modified(childAnchor->getUniformMetadata() + .explicitScaleZeroPoint.mergeFrom(scaleZp))) { + childAnchor->markDirty(); + } + } + } + } +}; + +/// A constraint node which will solve uniform quantization for all parents +/// of the constraint, assuming that they are coupled. +class SolveUniformConstraintNode : public CAGConstraintNode { +public: + SolveUniformConstraintNode() + : CAGConstraintNode(Kind::SolveUniformConstraint) { + markDirty(); + } + static bool classof(const CAGNode *n) { + return n->getKind() == Kind::Constraint || + n->getKind() == Kind::SolveUniformConstraint; + } + +private: + void printLabel(llvm::raw_ostream &os) const override { + os << "SolveUniform"; + } + + void propagate(SolverContext &solverContext, + const TargetConfiguration &config) { + // First determine the required min/max range and type constraints. + Location fusedLoc = UnknownLoc::get(&solverContext.getMlirContext()); + llvm::SmallBitVector enabledCandidateTypesMask( + config.getAllCandidateTypesMask()); + ClusteredFacts clusteredFacts; + Type originalElementType; + for (auto it = incoming_begin(), e = incoming_end(); it != e; ++it) { + auto parentAnchor = llvm::cast(*it); + auto metadata = parentAnchor->getUniformMetadata(); + // TODO: Possibly use a location that fuses all involved parents. + fusedLoc = parentAnchor->getOp()->getLoc(); + + // Shared element type. + auto parentOriginalElementType = + getElementOrPrimitiveType(parentAnchor->getOriginalType()); + if (!originalElementType) { + originalElementType = parentOriginalElementType; + } else { + if (originalElementType != parentOriginalElementType) { + parentAnchor->getOp()->emitError() + << "cannot compute uniform type: parent element types mismatch"; + return; + } + } + // Range. + clusteredFacts.requiredRange.mergeFrom(metadata.requiredRange); + + // Explicit scale and zero point. + clusteredFacts.explicitScaleZeroPoint.mergeFrom( + metadata.explicitScaleZeroPoint); + + // Shared candidate types. + enabledCandidateTypesMask.reset(metadata.disabledCandidateTypes); + } + + // Find the first enabled candidate type. + const CandidateQuantizedType *bestCandidateType = nullptr; + for (auto &ct : config.getCandidateTypes()) { + if (enabledCandidateTypesMask.test(ct.ordinal)) { + bestCandidateType = &ct; + break; + } + } + + if (!bestCandidateType || !originalElementType) { + solverContext.getMlirContext().emitRemark(fusedLoc) + << "not solving uniform type (no viable candidate type)"; + return; + } + + // Solve for the type. + QuantizedType selectedType = + solveUniformType(solverContext, clusteredFacts, *bestCandidateType, + originalElementType, fusedLoc); + + // Apply it to all parents. + for (auto it = incoming_begin(), e = incoming_end(); it != e; ++it) { + auto parentAnchor = llvm::cast(*it); + auto &metadata = parentAnchor->getUniformMetadata(); + if (metadata.selectedType != selectedType) { + metadata.selectedType = selectedType; + // And mark all children of the parent dirty (except us). + for (auto child : *parentAnchor) { + if (child != this) { + child->markDirty(); + } + } + } + } + } +}; + +} // end anonymous namespace + +void UniformConstraintsBuilder::coupleAnchors(CAGAnchorNode *a, + CAGAnchorNode *b) { + slice.addClusteredConstraint({a, b}); +} + +void UniformConstraintsBuilder::applyStats(CAGAnchorNode *a, + TensorAxisStatistics stats) { + a->getUniformMetadata().requiredRange.assertValue( + CAGUniformMetadata::SalienceDefault, {stats.minValue, stats.maxValue}); +} + +void UniformConstraintsBuilder::clamp(CAGAnchorNode *a, APFloat minValue, + APFloat maxValue) { + a->getUniformMetadata().requiredRange.assertValue( + CAGUniformMetadata::SalienceDefault, + {minValue.convertToDouble(), maxValue.convertToDouble()}); +} + +void UniformConstraintsBuilder::propagateExplicitScale(CAGAnchorNode *from, + CAGAnchorNode *to) { + slice.addUnidirectionalConstraint(from, {to}); +} -- 2.7.4