From 79265887ff1c109396e0fa8156b8f4129487ffc1 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Wed, 15 May 2019 15:04:20 -0700 Subject: [PATCH] Upstreaming Quantizer tool (part 2). This adds some additional core types and utilities, notably the constraint analysis graph (CAG) structures, associated metadata and configuration policy object base class. The CAG is not particularly memory efficient as it stands now. I had started some work to turn it into a form that could be better managed by a bump pointer allocator but abandoned that for now in favor of having something that does semantically what I was going for as a starting point. -- PiperOrigin-RevId: 248413133 --- .../include/mlir/Quantizer/Support/Configuration.h | 157 +++++++++ .../Quantizer/Support/ConstraintAnalysisGraph.h | 374 +++++++++++++++++++++ .../Support/ConstraintAnalysisGraphTraits.h | 58 ++++ mlir/include/mlir/Quantizer/Support/Metadata.h | 110 ++++++ mlir/include/mlir/Quantizer/Support/TypeUtils.h | 40 +++ mlir/lib/Quantizer/Support/Configuration.cpp | 49 +++ .../Quantizer/Support/ConstraintAnalysisGraph.cpp | 181 ++++++++++ mlir/lib/Quantizer/Support/Metadata.cpp | 42 +++ mlir/lib/Quantizer/Support/TypeUtils.cpp | 31 ++ 9 files changed, 1042 insertions(+) create mode 100644 mlir/include/mlir/Quantizer/Support/Configuration.h create mode 100644 mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraph.h create mode 100644 mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraphTraits.h create mode 100644 mlir/include/mlir/Quantizer/Support/Metadata.h create mode 100644 mlir/include/mlir/Quantizer/Support/TypeUtils.h create mode 100644 mlir/lib/Quantizer/Support/Configuration.cpp create mode 100644 mlir/lib/Quantizer/Support/ConstraintAnalysisGraph.cpp create mode 100644 mlir/lib/Quantizer/Support/Metadata.cpp create mode 100644 mlir/lib/Quantizer/Support/TypeUtils.cpp diff --git a/mlir/include/mlir/Quantizer/Support/Configuration.h b/mlir/include/mlir/Quantizer/Support/Configuration.h new file mode 100644 index 0000000..2b67712 --- /dev/null +++ b/mlir/include/mlir/Quantizer/Support/Configuration.h @@ -0,0 +1,157 @@ +//===- Configuration.h - Configuration object base classes ------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// The quantizer is relatively agnostic to source and target dialects, with +// the specific represented by configuration policy objects derived from +// classes in this file. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_QUANTIZER_SUPPORT_CONFIGURATION_H +#define MLIR_QUANTIZER_SUPPORT_CONFIGURATION_H + +#include + +#include "mlir/Dialect/QuantOps/QuantTypes.h" +#include "mlir/IR/Identifier.h" +#include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h" +#include "mlir/Quantizer/Support/Metadata.h" +#include "mlir/Quantizer/Support/Rules.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallBitVector.h" +#include "llvm/ADT/StringSet.h" + +namespace mlir { +class Operation; + +namespace quantizer { + +class CAGSlice; + +/// Defines quantization configuration for the target. +/// The settings here depend on a variety of details about the deployment +/// environment, although, where we have control over such things, we do +/// try to standardize as possible. +/// +/// Non-const methods are used to setup the configuration. It is expected that +/// const instances/references are used post-build. +class TargetConfiguration { +public: + static constexpr size_t MaxSchemeIndex = 31; + using OpHandlerFn = std::function; + + TargetConfiguration(SolverContext &context); + virtual ~TargetConfiguration() = default; + + /// Adds a candidate type, returning its ordinal. + unsigned addCandidateType(quant::AnyQuantizedType quantizedType, + CandidateQuantizedType::Scheme scheme) { + unsigned ordinal = candidateTypes.size(); + assert(allCandidateTypesMask.size() == ordinal); + CandidateQuantizedType ct{ordinal, quantizedType, scheme}; + candidateTypes.push_back(ct); + allCandidateTypesMask.push_back(true); + return ordinal; + } + + /// Gets a prototype scheme by index. + const CandidateQuantizedType &getCandidateType(unsigned index) const { + assert(index < candidateTypes.size()); + return candidateTypes[index]; + } + + llvm::ArrayRef getCandidateTypes() const { + return candidateTypes; + } + + /// Gets a mask of all enabled candidate types by ordinal. + llvm::SmallBitVector getAllCandidateTypesMask() const { + return allCandidateTypesMask; + } + + /// Gets a mask with every candidate type except those in the given mask. + llvm::SmallBitVector getCandidateTypeDisabledExceptMask( + llvm::ArrayRef exceptOrdinals) const { + llvm::SmallBitVector disabled(allCandidateTypesMask); + for (unsigned ordinal : exceptOrdinals) { + disabled.reset(ordinal); + } + return disabled; + } + + /// Adds an op handler. + template + void addOpHandler(OpHandlerFn fn) { + addOpHandlerByName(OpTy::getOperationName(), fn); + } + + /// Adds an operation which requires statistics at its result nodes for + /// best quantization performance. Note that the opName StringRef is + /// expected to come from getOperationName() and be static. + template + void addRequireStatsOp() { + addRequireStatsOpByName(OpTy::getOperationName()); + } + + /// Returns whether opName is a RequireStatsOp. + bool isRequireStatsOp(Operation *op) const; + + /// Adds an op which does not mutate its values but may mutate its shape + /// or combine its operands in an arbitrary way. + /// Such ops are expected to have the same types for operands and results + /// and must be capable of operating on storage types. + template + void addValueIdentityOp() { + addValueIdentityOpByName(OpTy::getOperationName()); + } + + /// Handles the operation if a handler is defined for it. + void handleOp(Operation *op, CAGSlice &cag) const; + + /// Finalizes the CAG after all anchors have been added. + virtual void finalizeAnchors(CAGSlice &cag) const {} + + /// Whether an operand or result type is subject to analysis by this config. + virtual bool isHandledType(Type t) const = 0; + +protected: + virtual void addValueIdentityOpByName(StringRef opName) = 0; + void addOpHandlerByName(StringRef name, OpHandlerFn fn); + +private: + void addRequireStatsOpByName(StringRef opName); + + SolverContext &context; + + /// Vector of all candidate type constraints, indexed by ordinal. + std::vector candidateTypes; + + // A SmallBoolVector with bits set for all known candidate types. + llvm::SmallBitVector allCandidateTypesMask; + + /// Map of all op handlers. + llvm::StringMap opHandlers; + + /// Names of operations which should have their results annotated with + /// statistics. + llvm::StringSet<> requireStatsOpNames; +}; + +} // namespace quantizer +} // namespace mlir + +#endif // MLIR_QUANTIZER_SUPPORT_CONFIGURATION_H diff --git a/mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraph.h b/mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraph.h new file mode 100644 index 0000000..8f2a0e5 --- /dev/null +++ b/mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraph.h @@ -0,0 +1,374 @@ +//===- ConstraintAnalysisGraph.h - Graphs type for constraints --*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file provides graph-based data structures for representing anchors +// and constraints between them. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_QUANTIZER_SUPPORT_CONSTRAINTANALYSISGRAPH_H +#define MLIR_QUANTIZER_SUPPORT_CONSTRAINTANALYSISGRAPH_H + +#include +#include + +#include "mlir/IR/Function.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Types.h" +#include "mlir/Quantizer/Support/Metadata.h" +#include "llvm/ADT/DenseMap.h" + +namespace mlir { +namespace quantizer { + +class CAGNode; +class CAGSlice; +class TargetConfiguration; + +/// A node in the Constraint Analysis Graph. +/// Nodes are either anchors (representing results and operands) or constraints. +/// Anchor nodes are connected to other anchor nodes via constraints. +/// Nodes exist within graph slices, which are typically analyses attached to +/// the function or module. Slices can contain other slices, which mirrors +/// the nesting of analyses. +/// +/// Nodes have directed relationships which propagate successor-ward when dirty. +/// Relationships can be bi-directional, in which case, the constraint's +/// propagation mechanism must ensure convergence. +class CAGNode { +public: + enum class Kind { + /// Anchors. + Anchor, + OperandAnchor, + ResultAnchor, + LastAnchor = ResultAnchor, + + /// Constraints. + Constraint, + SolveUniformConstraint, + UniformPropagateExplicitScale, + LastConstraint = UniformPropagateExplicitScale, + }; + + // Vector and iterator over nodes. + using node_vector = llvm::SmallVector; + using iterator = node_vector::iterator; + using const_iterator = node_vector::const_iterator; + + virtual ~CAGNode() = default; + + Kind getKind() const { return kind; } + + /// Unique id of the node within the slice. + int getNodeId() const { return nodeId; } + + /// Whether the node is dirty, requiring one or more calls to propagate(). + bool isDirty() const { return dirty; } + void markDirty() { dirty = true; } + void clearDirty() { dirty = false; } + + /// Iterator over this node's children (outgoing) nodes. + const_iterator begin() const { return outgoing.begin(); } + const_iterator end() const { return outgoing.end(); } + iterator begin() { return outgoing.begin(); } + iterator end() { return outgoing.end(); } + + /// Iterator over this parents (incoming) nodes. + const_iterator incoming_begin() const { return incoming.begin(); } + const_iterator incoming_end() const { return incoming.end(); } + iterator incoming_begin() { return incoming.begin(); } + iterator incoming_end() { return incoming.end(); } + + virtual void propagate(SolverContext &solverContext, + const TargetConfiguration &config) {} + + /// Prints the node label, suitable for one-line display. + virtual void printLabel(llvm::raw_ostream &os) const; + + template + void findChildrenOfKind(llvm::SmallVectorImpl &found) { + for (CAGNode *child : *this) { + T *ofKind = llvm::dyn_cast(child); + if (ofKind) { + found.push_back(ofKind); + } + } + } + + /// Replaces this node by rerouting any parent nodes to have otherNode + /// as a child. + void replaceIncoming(CAGNode *otherNode); + + /// Adds an outgoing connection to this node (and corresponding back + /// incoming connection). + void addOutgoing(CAGNode *toNode); + + /// Whether this node is an orphan (has no incoming or outgoing connections). + bool isOrphan() const { return incoming.empty() && outgoing.empty(); } + +protected: + CAGNode(Kind kind) : kind(kind) {} + +private: + Kind kind; + int nodeId = -1; + node_vector outgoing; + node_vector incoming; + bool dirty = false; + + friend class CAGSlice; +}; + +/// Anchor nodes represent points in the source IR where we may choose to +/// introduce a type transition. These include operands, results, arguments +/// returns, etc. +class CAGAnchorNode : public CAGNode { +public: + enum class TypeTransformRule { + /// The owning op directly supports all transformed types. In practice, + /// this means that the op supports QuantizedType for this anchor. + Direct, + + /// The type of this anchor should be set to the QuantizedType storage + /// type. This will only be valid if constraints are such that all + /// inputs/outputs converge to the same storage type (i.e. coupled). + DirectStorage, + + /// The anchor must only be typed based on the expressed type. This is + /// used for ops that do not natively support quantization, and suitable + /// casts will be inserted. + ExpressedOnly, + }; + + /// Metadata for solving uniform quantization params. + CAGUniformMetadata &getUniformMetadata() { return uniformMetadata; } + const CAGUniformMetadata &getUniformMetadata() const { + return uniformMetadata; + } + + virtual Operation *getOp() const = 0; + virtual Value *getValue() const = 0; + + static bool classof(const CAGNode *n) { + return n->getKind() >= Kind::Anchor && n->getKind() <= Kind::LastAnchor; + } + + void propagate(SolverContext &solverContext, + const TargetConfiguration &config) override; + + void printLabel(llvm::raw_ostream &os) const override; + + /// Given the anchor metadata and resolved solutions, chooses the most + /// salient and returns an appropriate type to represent it. + Type getTransformedType(); + + TypeTransformRule getTypeTransformRule() const { return typeTransformRule; } + + void setTypeTransformRule(TypeTransformRule r) { typeTransformRule = r; } + + /// Gets the Type that was defined for this anchor at the time of + /// construction. + Type getOriginalType() const { return originalType; } + +protected: + CAGAnchorNode(Kind kind, Type originalType) + : CAGNode(kind), originalType(originalType) {} + +private: + CAGUniformMetadata uniformMetadata; + Type originalType; + TypeTransformRule typeTransformRule = TypeTransformRule::Direct; +}; + +/// An anchor tied to a specific operand. +/// Since operand anchors can be rewritten so that the operand refers to +/// a new result, they are maintained by reference (to the op and index). +class CAGOperandAnchor : public CAGAnchorNode { +public: + CAGOperandAnchor(Operation *op, unsigned operandIdx); + + Operation *getOp() const final { return op; } + unsigned getOperandIdx() const { return operandIdx; } + + static bool classof(const CAGNode *n) { + return n->getKind() == Kind::Anchor || n->getKind() == Kind::OperandAnchor; + } + + Value *getValue() const final { return op->getOperand(operandIdx); } + + void printLabel(llvm::raw_ostream &os) const override; + +private: + Operation *op; + unsigned operandIdx; +}; + +/// An anchor tied to a specific result. +/// Since a result is already anchored to its defining op, result anchors refer +/// directly to the underlying Value*. +class CAGResultAnchor : public CAGAnchorNode { +public: + CAGResultAnchor(Operation *op, unsigned resultIdx); + + static bool classof(const CAGNode *n) { + return n->getKind() == Kind::Anchor || n->getKind() == Kind::ResultAnchor; + } + + Operation *getOp() const final { return resultValue->getDefiningOp(); } + Value *getValue() const final { return resultValue; } + + void printLabel(llvm::raw_ostream &os) const override; + +private: + Value *resultValue; +}; + +/// Base class for constraint nodes. +class CAGConstraintNode : public CAGNode { +public: + CAGConstraintNode(Kind kind) : CAGNode(kind) {} + + static bool classof(const CAGNode *n) { + return n->getKind() >= Kind::Constraint && + n->getKind() <= Kind::LastConstraint; + } +}; + +/// A slice of a CAG (which may be the whole graph). +class CAGSlice { +public: + CAGSlice(SolverContext &context); + ~CAGSlice(); + + using node_vector = std::vector; + using iterator = node_vector::iterator; + using const_iterator = node_vector::const_iterator; + + iterator begin() { return allNodes.begin(); } + iterator end() { return allNodes.end(); } + const_iterator begin() const { return allNodes.begin(); } + const_iterator end() const { return allNodes.end(); } + + /// Gets an operand anchor node. + CAGOperandAnchor *getOperandAnchor(Operation *op, unsigned operandIdx); + + /// Gets a result anchor node. + CAGResultAnchor *getResultAnchor(Operation *op, unsigned resultIdx); + + /// Adds a relation constraint with incoming 'from' anchors and outgoing 'to' + /// anchors. + template + T *addUniqueConstraint(llvm::ArrayRef anchors, + Args... args) { + static_assert(std::is_convertible(), + "T must be a CAGConstraingNode"); + T *constraintNode = addNode(llvm::make_unique(args...)); + for (auto *anchor : anchors) + anchor->addOutgoing(constraintNode); + return constraintNode; + } + + /// Adds a unidirectional constraint from a node to an array of target nodes. + template + T *addUnidirectionalConstraint(CAGAnchorNode *fromAnchor, + llvm::ArrayRef toAnchors, + Args... args) { + static_assert(std::is_convertible(), + "T must be a CAGConstraingNode"); + T *constraintNode = addNode(llvm::make_unique(args...)); + fromAnchor->addOutgoing(constraintNode); + for (auto *toAnchor : toAnchors) { + constraintNode->addOutgoing(toAnchor); + } + return constraintNode; + } + + template + T *addClusteredConstraint(llvm::ArrayRef anchors) { + static_assert(std::is_convertible(), + "T must be a CAGConstraingNode"); + llvm::SmallVector cluster; + for (auto *anchor : anchors) { + anchor->findChildrenOfKind(cluster); + } + + T *constraintNode; + if (cluster.empty()) { + // Create new. + constraintNode = addNode(llvm::make_unique()); + } else { + // Merge existing. + constraintNode = cluster[0]; + for (size_t i = 1, e = cluster.size(); i < e; ++i) { + cluster[i]->replaceIncoming(constraintNode); + } + } + for (auto *anchor : anchors) { + anchor->addOutgoing(constraintNode); + } + return constraintNode; + } + + /// Enumerates all implied connections in the slice. + /// An implied connection is any two nodes that physically refer to the + /// same value in the IR, such as result->operand. + /// Typically this will be modeled with some kind of strong or weak + /// identity constraint such that types propagate. + /// This is usually called when the slice has been fully constructed in + /// order to add final constraints. + /// It is legal for the callback to modify the graph by adding constraints. + void enumerateImpliedConnections( + std::function callback); + + /// Performs one round of propagation, returning the number of nodes + /// propagates. If returns > 0, then additional propagate() rounds are + /// required. + unsigned propagate(const TargetConfiguration &config); + +private: + /// Adds a node to the graph. + /// The node should be a subclass of TransformNode. + /// Returns the raw pointer to the node. + template + T *addNode(std::unique_ptr node) { + node->nodeId = allNodes.size(); + T *unownedNode = node.release(); + allNodes.push_back(unownedNode); + return unownedNode; + } + + SolverContext &context; + std::vector allNodes; + llvm::DenseMap, CAGOperandAnchor *> + operandAnchors; + llvm::DenseMap, CAGResultAnchor *> + resultAnchors; +}; + +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const CAGNode &node) { + node.printLabel(os); + return os; +} + +} // namespace quantizer +} // namespace mlir + +#endif // MLIR_QUANTIZER_SUPPORT_CONSTRAINTANALYSISGRAPH_H diff --git a/mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraphTraits.h b/mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraphTraits.h new file mode 100644 index 0000000..7a907b0 --- /dev/null +++ b/mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraphTraits.h @@ -0,0 +1,58 @@ +//===- ConstraintAnalysisGraphTraits.h - Traits for CAGs --------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// Provides graph traits for constraint analysis graphs. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_QUANTIZER_SUPPORT_CONSTRAINTANALYSISGRAPHTRAITS_H +#define MLIR_QUANTIZER_SUPPORT_CONSTRAINTANALYSISGRAPHTRAITS_H + +#include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h" +#include "llvm/ADT/GraphTraits.h" + +namespace llvm { + +template <> +struct llvm::GraphTraits { + using NodeRef = const mlir::quantizer::CAGNode *; + + static NodeRef getEntryNode(NodeRef node) { return node; } + + // Successors. + using ChildIteratorType = mlir::quantizer::CAGNode::const_iterator; + static ChildIteratorType child_begin(NodeRef node) { return node->begin(); } + static ChildIteratorType child_end(NodeRef node) { return node->end(); } +}; + +template <> +struct llvm::GraphTraits + : public llvm::GraphTraits { + using nodes_iterator = mlir::quantizer::CAGSlice::const_iterator; + static mlir::quantizer::CAGSlice::const_iterator + nodes_begin(const mlir::quantizer::CAGSlice *G) { + return G->begin(); + } + static mlir::quantizer::CAGSlice::const_iterator + nodes_end(const mlir::quantizer::CAGSlice *G) { + return G->end(); + } +}; + +} // end namespace llvm + +#endif // MLIR_QUANTIZER_SUPPORT_CONSTRAINTANALYSISGRAPHTRAITS_H diff --git a/mlir/include/mlir/Quantizer/Support/Metadata.h b/mlir/include/mlir/Quantizer/Support/Metadata.h new file mode 100644 index 0000000..a2ed681 --- /dev/null +++ b/mlir/include/mlir/Quantizer/Support/Metadata.h @@ -0,0 +1,110 @@ +//===- Metadata.h - Top level types and metadata ----------------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file contains top level types needed to construct constraint graphs, +// including context/allocator support and concrete metadata structs for +// different quantization schemes (which must be attached to anchor nodes). +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_QUANTIZER_SUPPORT_METADATA_H +#define MLIR_QUANTIZER_SUPPORT_METADATA_H + +#include + +#include "mlir/Dialect/QuantOps/QuantTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Quantizer/Support/Rules.h" +#include "llvm/ADT/SmallBitVector.h" + +namespace mlir { +namespace quantizer { + +class SolverContext { +public: + SolverContext(MLIRContext &mlirContext) : mlirContext(mlirContext) {} + + MLIRContext &getMlirContext() { return mlirContext; } + + llvm::BumpPtrAllocator &getAllocator() { return allocator; } + + // Optional path to write a debug DOT file for the CAG. + StringRef getDebugCAGDotPath() const { return debugCAGDotPath; } + void setDebugCAGDotPath(StringRef p) { debugCAGDotPath = p; } + +private: + MLIRContext &mlirContext; + llvm::BumpPtrAllocator allocator; + std::string debugCAGDotPath; +}; + +/// Candidate for a quantized type conversion. +struct CandidateQuantizedType { + // Note that scheme encodes more than just the target type: it also encodes + // additional constraints. + enum class Scheme { + // Uses aggregate range information for all nodes in the cluster to + // solve for uniform scale and zero point. + UniformPerLayer, + // Uses aggregate per-axis range information for all nodes in the cluster + // to solve for per-axis uniform scale and zero point. + UniformPerAxisFixedPoint, + // Uses the |explicitScaleZeroPoint| to set the scale (and zero point = 0) + // for the uniform type. This typically overrides all other constraints + // and is used for wide accumulator types (i.e. i32 bias vectors). + UniformExplicitFixedPointScale, + }; + unsigned ordinal; + quant::AnyQuantizedType quantizedType; + Scheme scheme; +}; + +struct CAGUniformMetadata { + /// Default salience for facts that are derived from data either statically + /// discovered in the computation or observed from an outside source. + static constexpr int SalienceDefault = 0; + + /// Highest salience level for facts derived from overrides provided + /// explicitly. + static constexpr int SalienceForced = 100; + + /// Salience for facts derived from constraints in how the math is + /// expressed which must be satisfied. + static constexpr int SalienceRequired = 200; + + /// The range that the scheme must represent in order to accomadate the + /// underlying data. + ExpandingMinMaxFact requiredRange; + + /// Bool vector of scheme ordinals that are disabled. + llvm::SmallBitVector disabledCandidateTypes; + + /// If set, then a solution has converged for the given per-layer scheme. + quant::QuantizedType selectedType; + + /// Optional scale and zero point to be used by types which solve via the + /// UniformExplicitFixedPointScale scheme. + DiscreteScaleZeroPointFact explicitScaleZeroPoint; + + /// Prints a summary of the metadata suitable for display in a graph label. + void printSummary(llvm::raw_ostream &os) const; +}; + +} // end namespace quantizer +} // end namespace mlir + +#endif // MLIR_QUANTIZER_SUPPORT_METADATA_H diff --git a/mlir/include/mlir/Quantizer/Support/TypeUtils.h b/mlir/include/mlir/Quantizer/Support/TypeUtils.h new file mode 100644 index 0000000..074f8b9 --- /dev/null +++ b/mlir/include/mlir/Quantizer/Support/TypeUtils.h @@ -0,0 +1,40 @@ +//===- TypeUtils.h - Helper function for manipulating types -----*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file defines various helper functions for manipulating types. The +// process of quantizing typically involves a number of type manipulations +// that are not very common elsewhere, and it is best to name them and define +// them here versus inline in the rest of the tool. +// +//===----------------------------------------------------------------------===// + +#ifndef THIRD_PARTY_MLIR_EDGE_FXPSOLVER_SUPPORT_TYPEUTILS_H_ +#define THIRD_PARTY_MLIR_EDGE_FXPSOLVER_SUPPORT_TYPEUTILS_H_ + +#include "mlir/IR/Types.h" + +namespace mlir { +namespace quantizer { + +/// Given an arbitrary container or primitive type, returns the element type, +/// where the element type is just the type for non-containers. +Type getElementOrPrimitiveType(Type t); + +} // namespace quantizer +} // namespace mlir + +#endif // THIRD_PARTY_MLIR_EDGE_FXPSOLVER_SUPPORT_TYPEUTILS_H_ diff --git a/mlir/lib/Quantizer/Support/Configuration.cpp b/mlir/lib/Quantizer/Support/Configuration.cpp new file mode 100644 index 0000000..0efded0 --- /dev/null +++ b/mlir/lib/Quantizer/Support/Configuration.cpp @@ -0,0 +1,49 @@ +//===- Configuration.cpp - Configuration object base classes --------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/Quantizer/Support/Configuration.h" + +#include + +#include "mlir/IR/Builders.h" +#include "mlir/IR/Identifier.h" +#include "mlir/IR/MLIRContext.h" + +using namespace mlir; +using namespace mlir::quantizer; + +TargetConfiguration::TargetConfiguration(SolverContext &context) + : context(context) {} + +void TargetConfiguration::addOpHandlerByName(StringRef name, OpHandlerFn fn) { + opHandlers[name] = fn; +} + +void TargetConfiguration::addRequireStatsOpByName(StringRef opName) { + requireStatsOpNames.insert(opName); +} + +bool TargetConfiguration::isRequireStatsOp(Operation *op) const { + return requireStatsOpNames.find(op->getName().getStringRef()) != + requireStatsOpNames.end(); +} + +void TargetConfiguration::handleOp(Operation *op, CAGSlice &cag) const { + auto found_it = opHandlers.find(op->getName().getStringRef()); + if (found_it != opHandlers.end()) + found_it->second(op, cag); +} diff --git a/mlir/lib/Quantizer/Support/ConstraintAnalysisGraph.cpp b/mlir/lib/Quantizer/Support/ConstraintAnalysisGraph.cpp new file mode 100644 index 0000000..b4d48b7 --- /dev/null +++ b/mlir/lib/Quantizer/Support/ConstraintAnalysisGraph.cpp @@ -0,0 +1,181 @@ +//===- ConstraintAnalysisGraph.cpp - Graphs type for constraints ----------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h" + +#include "mlir/IR/MLIRContext.h" +#include "mlir/Quantizer/Support/Configuration.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; +using namespace mlir::quantizer; + +void CAGNode::replaceIncoming(CAGNode *otherNode) { + if (this == otherNode) + return; + for (CAGNode *parentNode : incoming) { + for (CAGNode *&it : parentNode->outgoing) { + if (it == this) { + it = otherNode; + otherNode->incoming.push_back(parentNode); + } + } + } + incoming.clear(); +} + +void CAGNode::addOutgoing(CAGNode *toNode) { + if (!llvm::is_contained(outgoing, toNode)) { + outgoing.push_back(toNode); + toNode->incoming.push_back(this); + } +} + +CAGOperandAnchor::CAGOperandAnchor(Operation *op, unsigned operandIdx) + : CAGAnchorNode(Kind::OperandAnchor, op->getOperand(operandIdx)->getType()), + op(op), operandIdx(operandIdx) {} + +CAGResultAnchor::CAGResultAnchor(Operation *op, unsigned resultIdx) + : CAGAnchorNode(Kind::ResultAnchor, op->getResult(resultIdx)->getType()), + resultValue(op->getResult(resultIdx)) {} + +CAGSlice::CAGSlice(SolverContext &context) : context(context) {} +CAGSlice::~CAGSlice() { llvm::DeleteContainerPointers(allNodes); } + +CAGOperandAnchor *CAGSlice::getOperandAnchor(Operation *op, + unsigned operandIdx) { + assert(operandIdx < op->getNumOperands() && "illegal operand index"); + + // Dedup. + auto key = std::make_pair(op, operandIdx); + auto foundIt = operandAnchors.find(key); + if (foundIt != operandAnchors.end()) { + return foundIt->second; + } + + // Create. + auto anchor = llvm::make_unique(op, operandIdx); + auto *unowned = anchor.release(); + unowned->nodeId = allNodes.size(); + allNodes.push_back(unowned); + operandAnchors.insert(std::make_pair(key, unowned)); + return unowned; +} + +CAGResultAnchor *CAGSlice::getResultAnchor(Operation *op, unsigned resultIdx) { + assert(resultIdx < op->getNumResults() && "illegal result index"); + + // Dedup. + auto key = std::make_pair(op, resultIdx); + auto foundIt = resultAnchors.find(key); + if (foundIt != resultAnchors.end()) { + return foundIt->second; + } + + // Create. + auto anchor = llvm::make_unique(op, resultIdx); + auto *unowned = anchor.release(); + unowned->nodeId = allNodes.size(); + allNodes.push_back(unowned); + resultAnchors.insert(std::make_pair(key, unowned)); + return unowned; +} + +void CAGSlice::enumerateImpliedConnections( + std::function callback) { + // Discover peer identity pairs (i.e. implied edges from Result->Operand and + // Arg->Call). Use an intermediate vector so that the callback can modify. + std::vector> impliedPairs; + for (auto &resultAnchorPair : resultAnchors) { + CAGResultAnchor *resultAnchor = resultAnchorPair.second; + Value *resultValue = resultAnchor->getValue(); + for (auto &use : resultValue->getUses()) { + Operation *operandOp = use.getOwner(); + unsigned operandIdx = use.getOperandNumber(); + auto foundIt = operandAnchors.find(std::make_pair(operandOp, operandIdx)); + if (foundIt != operandAnchors.end()) { + impliedPairs.push_back(std::make_pair(resultAnchor, foundIt->second)); + } + } + } + + // Callback for each pair. + for (auto &impliedPair : impliedPairs) { + callback(impliedPair.first, impliedPair.second); + } +} + +unsigned CAGSlice::propagate(const TargetConfiguration &config) { + std::vector dirtyNodes; + dirtyNodes.reserve(allNodes.size()); + // Note that because iteration happens in nodeId order, there is no need + // to sort in order to make deterministic. If the selection method changes, + // a sort should be explicitly done. + for (CAGNode *child : *this) { + if (child->isDirty()) { + dirtyNodes.push_back(child); + } + } + + if (dirtyNodes.empty()) { + return 0; + } + for (auto dirtyNode : dirtyNodes) { + dirtyNode->clearDirty(); + dirtyNode->propagate(context, config); + } + + return dirtyNodes.size(); +} + +void CAGAnchorNode::propagate(SolverContext &solverContext, + const TargetConfiguration &config) { + for (CAGNode *child : *this) { + child->markDirty(); + } +} + +Type CAGAnchorNode::getTransformedType() { + if (!getUniformMetadata().selectedType) { + return nullptr; + } + return getUniformMetadata().selectedType.castFromExpressedType( + getOriginalType()); +} + +void CAGNode::printLabel(llvm::raw_ostream &os) const { + os << "Node<" << static_cast(this) << ">"; +} + +void CAGAnchorNode::printLabel(llvm::raw_ostream &os) const { + getUniformMetadata().printSummary(os); +} + +void CAGOperandAnchor::printLabel(llvm::raw_ostream &os) const { + os << "Operand<"; + op->getName().print(os); + os << "," << operandIdx; + os << ">"; + CAGAnchorNode::printLabel(os); +} + +void CAGResultAnchor::printLabel(llvm::raw_ostream &os) const { + os << "Result<"; + getOp()->getName().print(os); + os << ">"; + CAGAnchorNode::printLabel(os); +} diff --git a/mlir/lib/Quantizer/Support/Metadata.cpp b/mlir/lib/Quantizer/Support/Metadata.cpp new file mode 100644 index 0000000..3661f52 --- /dev/null +++ b/mlir/lib/Quantizer/Support/Metadata.cpp @@ -0,0 +1,42 @@ +//===- Metadata.cpp - Top level types and metadata ------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/Quantizer/Support/Metadata.h" + +#include "mlir/IR/MLIRContext.h" +#include "llvm/Support/Allocator.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; +using namespace mlir::quantizer; + +void CAGUniformMetadata::printSummary(llvm::raw_ostream &os) const { + if (requiredRange.hasValue()) { + os << "\n[" << requiredRange.getValue().first << "," + << requiredRange.getValue().second << "]"; + } + + if (disabledCandidateTypes.any()) { + os << "\n!["; + mlir::interleaveComma(disabledCandidateTypes.set_bits(), os); + os << "]"; + } + + if (selectedType) { + os << "\n" << selectedType; + } +} diff --git a/mlir/lib/Quantizer/Support/TypeUtils.cpp b/mlir/lib/Quantizer/Support/TypeUtils.cpp new file mode 100644 index 0000000..444322e --- /dev/null +++ b/mlir/lib/Quantizer/Support/TypeUtils.cpp @@ -0,0 +1,31 @@ +//===- TypeUtils.cpp - Helper function for manipulating types -------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/Quantizer/Support/TypeUtils.h" + +#include "mlir/IR/StandardTypes.h" + +using namespace mlir; +using namespace mlir::quantizer; + +Type mlir::quantizer::getElementOrPrimitiveType(Type t) { + if (auto vtType = t.dyn_cast()) { + return vtType.getElementType(); + } else { + return t; + } +} -- 2.7.4