From 5065839da7d993419ecb297ef9dc43b87b70288b Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Mon, 20 May 2019 18:27:38 -0700 Subject: [PATCH] Upstream the Quantizer tool (part 4). This adds the basic passes needed and ties them into mlir-opt. Also adds two specific unit tests that exercise them. Next step is a standalone quantizer tool and additional cleanup. Tested: ninja check-mlir -- PiperOrigin-RevId: 249167690 --- .../Support/ConstraintAnalysisGraphTraits.h | 4 +- mlir/include/mlir/Quantizer/Support/Statistics.h | 5 + .../mlir/Quantizer/Support/UniformConstraints.h | 10 +- mlir/include/mlir/Quantizer/Transforms/Passes.h | 51 ++++ mlir/lib/CMakeLists.txt | 1 + mlir/lib/Quantizer/CMakeLists.txt | 43 +++ mlir/lib/Quantizer/Support/Statistics.cpp | 2 +- .../Transforms/AddDefaultStatsTestPass.cpp | 128 +++++++++ .../Transforms/InferQuantizedTypesPass.cpp | 303 +++++++++++++++++++++ .../Transforms/RemoveInstrumentationPass.cpp | 79 ++++++ mlir/test/Quantizer/matmul.mlir | 51 ++++ mlir/test/Quantizer/remove-instrumentation.mlir | 15 + mlir/tools/mlir-opt/CMakeLists.txt | 1 + 13 files changed, 685 insertions(+), 8 deletions(-) create mode 100644 mlir/include/mlir/Quantizer/Transforms/Passes.h create mode 100644 mlir/lib/Quantizer/CMakeLists.txt create mode 100644 mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp create mode 100644 mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp create mode 100644 mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp create mode 100644 mlir/test/Quantizer/matmul.mlir create mode 100644 mlir/test/Quantizer/remove-instrumentation.mlir diff --git a/mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraphTraits.h b/mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraphTraits.h index 7a907b0..7e2b61d 100644 --- a/mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraphTraits.h +++ b/mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraphTraits.h @@ -28,7 +28,7 @@ namespace llvm { template <> -struct llvm::GraphTraits { +struct GraphTraits { using NodeRef = const mlir::quantizer::CAGNode *; static NodeRef getEntryNode(NodeRef node) { return node; } @@ -40,7 +40,7 @@ struct llvm::GraphTraits { }; template <> -struct llvm::GraphTraits +struct GraphTraits : public llvm::GraphTraits { using nodes_iterator = mlir::quantizer::CAGSlice::const_iterator; static mlir::quantizer::CAGSlice::const_iterator diff --git a/mlir/include/mlir/Quantizer/Support/Statistics.h b/mlir/include/mlir/Quantizer/Support/Statistics.h index 1c26ad0..d4641d6 100644 --- a/mlir/include/mlir/Quantizer/Support/Statistics.h +++ b/mlir/include/mlir/Quantizer/Support/Statistics.h @@ -36,6 +36,11 @@ struct TensorAxisStatistics { double mean = 0; double variance = 0; + TensorAxisStatistics() {} + TensorAxisStatistics(int64_t sampleSize, double minValue, double maxValue, + double mean, double variance) + : sampleSize(sampleSize), minValue(minValue), maxValue(maxValue), + mean(mean), variance(variance) {} void clear() { *this = TensorAxisStatistics(); } }; diff --git a/mlir/include/mlir/Quantizer/Support/UniformConstraints.h b/mlir/include/mlir/Quantizer/Support/UniformConstraints.h index beae3ed..90b5fe1 100644 --- a/mlir/include/mlir/Quantizer/Support/UniformConstraints.h +++ b/mlir/include/mlir/Quantizer/Support/UniformConstraints.h @@ -34,7 +34,7 @@ class CAGSlice; /// Factory methods for adding CAG constraints of various kinds suitable /// for solving for uniform quantization. class UniformConstraintsBuilder { - public: +public: UniformConstraintsBuilder(CAGSlice &slice) : slice(slice) {} /// Adds a coupling constraint between two nodes, effectively treating @@ -59,11 +59,11 @@ class UniformConstraintsBuilder { /// as the originating node. void propagateExplicitScale(CAGAnchorNode *from, CAGAnchorNode *to); - private: +private: CAGSlice &slice; }; -} // namespace quantizer -} // namespace mlir +} // namespace quantizer +} // namespace mlir -#endif // MLIR_QUANTIZER_SUPPORT_UNIFORMCONSTRAINTS_H +#endif // MLIR_QUANTIZER_SUPPORT_UNIFORMCONSTRAINTS_H diff --git a/mlir/include/mlir/Quantizer/Transforms/Passes.h b/mlir/include/mlir/Quantizer/Transforms/Passes.h new file mode 100644 index 0000000..0d7b4cb --- /dev/null +++ b/mlir/include/mlir/Quantizer/Transforms/Passes.h @@ -0,0 +1,51 @@ +//===- Passes.h - Quantizer passes -----------------------------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file defines entry points to create passes to perform various kinds +// of quantization related transforms. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_QUANTIZER_TRANSFORMS_PASSES_H +#define MLIR_QUANTIZER_TRANSFORMS_PASSES_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace quantizer { + +class SolverContext; +class TargetConfiguration; + +/// Creates a pass that infers quantized types based on metadata discovered +/// in the computation. +ModulePassBase * +createInferQuantizedTypesPass(SolverContext &solverContext, + const TargetConfiguration &config); + +/// Creates a pass which removes any instrumentation and hint ops which have +/// no effect on final runtime. +FunctionPassBase *createRemoveInstrumentationPass(); + +/// Adds default (dummy) statistics to ops that can benefit from runtime stats. +/// Meant for testing. +FunctionPassBase *createAddDefaultStatsPass(); + +} // namespace quantizer +} // namespace mlir + +#endif // MLIR_QUANTIZER_TRANSFORMS_PASSES_H diff --git a/mlir/lib/CMakeLists.txt b/mlir/lib/CMakeLists.txt index 659cba9..995c886 100644 --- a/mlir/lib/CMakeLists.txt +++ b/mlir/lib/CMakeLists.txt @@ -9,6 +9,7 @@ add_subdirectory(LLVMIR) add_subdirectory(Linalg) add_subdirectory(Parser) add_subdirectory(Pass) +add_subdirectory(Quantizer) add_subdirectory(StandardOps) add_subdirectory(Support) add_subdirectory(TableGen) diff --git a/mlir/lib/Quantizer/CMakeLists.txt b/mlir/lib/Quantizer/CMakeLists.txt new file mode 100644 index 0000000..4071263 --- /dev/null +++ b/mlir/lib/Quantizer/CMakeLists.txt @@ -0,0 +1,43 @@ +# Support. +add_llvm_library(MLIRQuantizerSupport + Support/Configuration.cpp + Support/ConstraintAnalysisGraph.cpp + Support/Metadata.cpp + Support/Statistics.cpp + Support/TypeUtils.cpp + Support/UniformConstraints.cpp + Support/UniformSolvers.cpp + + ADDITIONAL_HEADER_DIRS + ) +add_dependencies(MLIRQuantizerSupport + MLIRIR + MLIRQuantOps + MLIRSupport + MLIRStandardOps) + +# Configurations. +add_llvm_library(MLIRQuantizerFxpMathConfig + Configurations/FxpMathConfig.cpp + + ADDITIONAL_HEADER_DIRS + ) +add_dependencies(MLIRQuantizerFxpMathConfig + MLIRQuantizerSupport) + +# Transforms. +add_llvm_library(MLIRQuantizerTransforms + Transforms/AddDefaultStatsTestPass.cpp + Transforms/InferQuantizedTypesPass.cpp + Transforms/RemoveInstrumentationPass.cpp + + ADDITIONAL_HEADER_DIRS + ) +add_dependencies(MLIRQuantizerTransforms + MLIRQuantizerFxpMathConfig + MLIRQuantizerSupport + MLIRPass) +target_link_libraries(MLIRQuantizerTransforms + MLIRQuantizerFxpMathConfig + MLIRQuantizerSupport + MLIRPass) diff --git a/mlir/lib/Quantizer/Support/Statistics.cpp b/mlir/lib/Quantizer/Support/Statistics.cpp index 3ec07e4..058d31f 100644 --- a/mlir/lib/Quantizer/Support/Statistics.cpp +++ b/mlir/lib/Quantizer/Support/Statistics.cpp @@ -88,7 +88,7 @@ static bool getElementsStatistics(ElementsAttr attr, bool AttributeTensorStatistics::get(TensorAxisStatistics &stats) const { if (FloatAttr floatAttr = attr.dyn_cast()) { double value = floatAttr.getValueAsDouble(); - stats = TensorAxisStatistics{1, value, value, value, 0}; + stats = TensorAxisStatistics(1, value, value, value, 0); return true; } else if (auto eltAttr = attr.dyn_cast()) { return getElementsStatistics(eltAttr, stats); diff --git a/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp b/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp new file mode 100644 index 0000000..75c082f --- /dev/null +++ b/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp @@ -0,0 +1,128 @@ +//===- AddDefaultStatsTestPass.cpp - Testing pass to add default stats ----===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file defines a testing pass to add default statistics nodes to every +// quantization eligible op. Useful for unit testing. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/QuantOps/QuantOps.h" +#include "mlir/Dialect/QuantOps/QuantTypes.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/Quantizer/Configurations/FxpMathConfig.h" +#include "mlir/Quantizer/Support/Configuration.h" +#include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h" +#include "mlir/Quantizer/Support/ConstraintAnalysisGraphTraits.h" +#include "mlir/Quantizer/Transforms/Passes.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/Support/GraphWriter.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; +using namespace mlir::quantizer; +using namespace mlir::quant; + +namespace { + +class AddDefaultStatsPass : public FunctionPass { +public: + AddDefaultStatsPass() = default; + AddDefaultStatsPass(SolverContext &solverContext, + const TargetConfiguration &config) + : explicitSolverContext(&solverContext), explicitConfig(&config) {} + + void runOnFunction() override; + void runWithConfig(SolverContext &solverContext, + const TargetConfiguration &config); + +private: + SolverContext *explicitSolverContext = nullptr; + const TargetConfiguration *explicitConfig = nullptr; +}; + +} // end anonymous namespace + +void AddDefaultStatsPass::runOnFunction() { + if (explicitSolverContext && explicitConfig) { + // If explicitly constructed with a config and context. + runWithConfig(*explicitSolverContext, *explicitConfig); + return; + } + // For global pass registration, use defaults. + SolverContext solverContext(*getFunction().getContext()); + auto config = FxpMathTargetConfig::create(solverContext); + runWithConfig(solverContext, *config); +} + +void AddDefaultStatsPass::runWithConfig(SolverContext &solverContext, + const TargetConfiguration &config) { + auto &func = getFunction(); + + // Insert stats for each argument. + for (auto *arg : func.getArguments()) { + if (!config.isHandledType(arg->getType())) + continue; + FuncBuilder b(func); + APFloat minValue(-1.0f); + APFloat maxValue(1.0f); + ElementsAttr layerStats = DenseFPElementsAttr::get( + b.getTensorType({2}, b.getF32Type()), {minValue, maxValue}); + auto statsOp = + b.create(func.getLoc(), arg, layerStats, nullptr); + arg->replaceAllUsesWith(statsOp); + + // StatsOp contained a use to 'arg' so make sure to reset it after replacing + // all of the uses of 'arg'. + statsOp.getOperation()->replaceUsesOfWith(statsOp, arg); + } + + // Walk the ops and insert stats. + func.walk([&](Operation *op) { + if (!config.isRequireStatsOp(op)) { + return; + } + assert(op->getNumResults() == 1); + + auto originalResult = op->getResult(0); + if (!config.isHandledType(originalResult->getType())) + return; + + FuncBuilder b(op->getBlock(), ++op->getIterator()); + + APFloat minValue(-1.0f); + APFloat maxValue(1.0f); + ElementsAttr layerStats = DenseFPElementsAttr::get( + b.getTensorType({2}, b.getF32Type()), {minValue, maxValue}); + auto statsOp = b.create(op->getLoc(), op->getResult(0), + layerStats, nullptr); + originalResult->replaceAllUsesWith(statsOp); + + // StatsOp contained a use to 'op' so make sure to reset it after replacing + // all of the uses of 'op'. + statsOp.getOperation()->replaceUsesOfWith(statsOp, originalResult); + }); +} + +FunctionPassBase *mlir::quantizer::createAddDefaultStatsPass() { + return new AddDefaultStatsPass(); +} + +static PassRegistration pass( + "quantizer-add-default-stats-test", + "Adds default (dummy) statistics to all ops that can benefit from " + "runtime statistics. This is meant to help in early stage bootstrapping."); diff --git a/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp b/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp new file mode 100644 index 0000000..a2cfe72 --- /dev/null +++ b/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp @@ -0,0 +1,303 @@ +//===- InferQuantizedTypesPass.cpp - Infers quantized types ---------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file defines the primary pass for instantiating a CAG, running it to +// convergence on a module to determine eligible quantized type transforms, and +// applying those transforms to the IR. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/QuantOps/QuantOps.h" +#include "mlir/Dialect/QuantOps/QuantTypes.h" +#include "mlir/IR/Builders.h" +#include "mlir/Quantizer/Configurations/FxpMathConfig.h" +#include "mlir/Quantizer/Support/Configuration.h" +#include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h" +#include "mlir/Quantizer/Support/ConstraintAnalysisGraphTraits.h" +#include "mlir/Quantizer/Transforms/Passes.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/Support/DOTGraphTraits.h" +#include "llvm/Support/GraphWriter.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; +using namespace mlir::quantizer; +using namespace mlir::quant; + +namespace llvm { + +template <> +struct DOTGraphTraits + : public DOTGraphTraits { + DOTGraphTraits(bool isSimple = false) + : DOTGraphTraits(isSimple) {} + + std::string getNodeLabel(const CAGNode *node, const CAGSlice *graph) { + std::string s; + llvm::raw_string_ostream out(s); + node->printLabel(out); + return out.str(); + } + + static std::string getGraphProperties(const CAGSlice *) { + return "rankdir=LR;"; + } + + static bool isNodeHidden(const CAGNode *node) { + // Filter constraint nodes with no incoming or outgoing connections. + // These orphans are often created as part of graph merging operations. + return llvm::isa(node) && node->isOrphan(); + } + + std::string getNodeAttributes(const CAGNode *node, const CAGSlice *graph) { + switch (node->getKind()) { + default: + return std::string(); + case CAGNode::Kind::OperandAnchor: + return "shape=record,color=yellow,style=filled"; + case CAGNode::Kind::ResultAnchor: + return "shape=record,color=lightblue,style=filled"; + case CAGNode::Kind::Constraint: + return "shape=record,style=dotted"; + } + } +}; + +} // end namespace llvm + +namespace { + +class InferQuantizedTypesPass : public ModulePass { +public: + InferQuantizedTypesPass() = default; + InferQuantizedTypesPass(SolverContext &solverContext, + const TargetConfiguration &config) + : explicitSolverContext(&solverContext), explicitConfig(&config) {} + void runOnModule() override; + void runWithConfig(SolverContext &solverContext, + const TargetConfiguration &config); + + void transformOperandType(CAGOperandAnchor *anchor, Type newType); + void transformResultType(CAGResultAnchor *anchor, Type newType); + +private: + SolverContext *explicitSolverContext = nullptr; + const TargetConfiguration *explicitConfig = nullptr; +}; + +} // end anonymous namespace + +/// Maximum number of propagation rounds to run to converge the CAG before +/// signalling an error. +static const int kMaximumPropagationRounds = 1000; + +static LogicalResult validateTypeConversion(Type newType, Type origType, + Operation *op) { + if (!newType) { + return op->emitOpError() << "unsupported type conversion from " << newType; + } + return success(); +} + +void InferQuantizedTypesPass::runOnModule() { + if (explicitSolverContext && explicitConfig) { + // If explicitly constructed with a config and context. + runWithConfig(*explicitSolverContext, *explicitConfig); + return; + } + + // For global pass registration, use defaults. + SolverContext solverContext(*getModule().getContext()); + auto config = FxpMathTargetConfig::create(solverContext); + runWithConfig(solverContext, *config); +} + +void InferQuantizedTypesPass::runWithConfig(SolverContext &solverContext, + const TargetConfiguration &config) { + CAGSlice cag(solverContext); + for (auto &f : getModule()) { + f.walk([&cag, &config](Operation *op) { config.handleOp(op, cag); }); + } + config.finalizeAnchors(cag); + + // Propagate. + int propRound; + for (propRound = kMaximumPropagationRounds; propRound > 0; --propRound) { + auto propCount = cag.propagate(config); + if (propCount == 0) + break; + } + if (propRound == 0) { + getContext().emitError( + UnknownLoc::get(&getContext()), + "exceeded maximum number of solver iterations (infinite loop?)"); + return; + } + + // TODO: Only dump the GraphViz if a flag is set and move to a utility. + // GraphViz. + if (!solverContext.getDebugCAGDotPath().empty()) { + auto actFileName = + llvm::WriteGraph(const_cast(&cag), "CAG", + /*ShortNames=*/false, + /*Title=*/"CAG", + /*Filename=*/solverContext.getDebugCAGDotPath()); + llvm::errs() << "Wrote graphviz file: " << actFileName << "\n"; + } + + // Start transforming the types in order of anchor type (results, then + // operands). + // Apply result types. + for (auto *node : cag) { + auto anchorNode = llvm::dyn_cast(node); + if (!anchorNode) + continue; + if (Type newType = anchorNode->getTransformedType()) + transformResultType(anchorNode, newType); + } + + // Apply operand types. + for (auto *node : cag) { + auto anchorNode = llvm::dyn_cast(node); + if (!anchorNode) + continue; + if (Type newType = anchorNode->getTransformedType()) + transformOperandType(anchorNode, newType); + } +} + +void InferQuantizedTypesPass::transformOperandType(CAGOperandAnchor *anchor, + Type newType) { + Value *inputValue = anchor->getValue(); + Operation *op = anchor->getOp(); + FuncBuilder b(op->getBlock(), Block::iterator(op)); + + SmallVector removeValuesIfDead; + + // Because we've already run the result transforms at this phase, it is + // very likely that inputValue points to a dcast op whose input matches + // our type. We detect that situation and route around just to save some + // bulk in the IR. + Value *newTypedInputValue = inputValue; + auto inputDcastOp = + dyn_cast_or_null(inputValue->getDefiningOp()); + if (inputDcastOp && inputDcastOp.arg()->getType() == newType) { + // Can just use the dcast's input value. + newTypedInputValue = inputDcastOp.arg(); + removeValuesIfDead.push_back(inputDcastOp); + } else { + // Need to synthesize a qcast. + newTypedInputValue = + b.create(op->getLoc(), newType, inputValue); + } + + switch (anchor->getTypeTransformRule()) { + default: + op->emitOpError("unsupported type transform rule"); + break; + case CAGAnchorNode::TypeTransformRule::Direct: + anchor->getOp()->setOperand(anchor->getOperandIdx(), newTypedInputValue); + break; + + case CAGAnchorNode::TypeTransformRule::DirectStorage: { + Type storageType = QuantizedType::castToStorageType(newType); + if (failed(validateTypeConversion(storageType, newType, op))) + return; + anchor->getOp()->setOperand( + anchor->getOperandIdx(), + b.create(op->getLoc(), storageType, newTypedInputValue)); + break; + } + + case CAGAnchorNode::TypeTransformRule::ExpressedOnly: + // Leave the anchor as-is and just cast in/out after it. + anchor->getOp()->setOperand( + anchor->getOperandIdx(), + b.create(op->getLoc(), anchor->getOriginalType(), + newTypedInputValue)); + break; + } + + for (Value *removeValueIfDead : removeValuesIfDead) { + if (removeValueIfDead->use_empty()) { + removeValueIfDead->getDefiningOp()->erase(); + } + } +} + +void InferQuantizedTypesPass::transformResultType(CAGResultAnchor *anchor, + Type newType) { + Value *origResultValue = anchor->getValue(); + Operation *op = origResultValue->getDefiningOp(); + FuncBuilder b(op->getBlock(), ++Block::iterator(op)); + + Value *replacedResultValue = nullptr; + Value *newResultValue = nullptr; + switch (anchor->getTypeTransformRule()) { + default: + op->emitOpError("unsupported type transform rule"); + return; + case CAGAnchorNode::TypeTransformRule::Direct: + origResultValue->setType(newType); + replacedResultValue = newResultValue = b.create( + op->getLoc(), anchor->getOriginalType(), origResultValue); + break; + + case CAGAnchorNode::TypeTransformRule::DirectStorage: { + Type storageType = QuantizedType::castToStorageType(newType); + if (failed(validateTypeConversion(storageType, newType, op))) + return; + origResultValue->setType(storageType); + replacedResultValue = + b.create(op->getLoc(), newType, origResultValue); + newResultValue = b.create( + op->getLoc(), anchor->getOriginalType(), replacedResultValue); + break; + } + + case CAGAnchorNode::TypeTransformRule::ExpressedOnly: + // Leave the anchor as-is and just cast in/out after it. + replacedResultValue = + b.create(op->getLoc(), newType, origResultValue); + newResultValue = b.create( + op->getLoc(), anchor->getOriginalType(), replacedResultValue); + break; + } + + if (replacedResultValue) { + // Transform: + // origResultValue --> replaceResultValue -> newResultValue + // \-> [original uses] + // To: + // origResultValue -> replaceResultValue -> + // newResultValue -> [original uses] + // Note that replaceResultValue may equal newResultValue or there may + // be operands between the two. + origResultValue->replaceAllUsesWith(newResultValue); + replacedResultValue->getDefiningOp()->replaceUsesOfWith(newResultValue, + origResultValue); + } +} + +ModulePassBase *mlir::quantizer::createInferQuantizedTypesPass( + SolverContext &solverContext, const TargetConfiguration &config) { + return new InferQuantizedTypesPass(solverContext, config); +} + +static PassRegistration + pass("quantizer-infer-quantized-types", + "Infers quantized types for a module"); diff --git a/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp b/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp new file mode 100644 index 0000000..ed3b095 --- /dev/null +++ b/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp @@ -0,0 +1,79 @@ +//===- RemoveInstrumentationPass.cpp - Removes instrumentation ------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file defines a pass to remove any instrumentation ops. It is often one +// of the final steps when performing quantization and is run after any +// decisions requiring instrumentation have been made. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/QuantOps/QuantOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Quantizer/Transforms/Passes.h" + +using namespace mlir; +using namespace mlir::quantizer; +using namespace mlir::quant; + +namespace { + +class RemoveInstrumentationPass + : public FunctionPass { + void runOnFunction() override; +}; + +template +class RemoveIdentityOpRewrite : public RewritePattern { +public: + RemoveIdentityOpRewrite(MLIRContext *context) + : RewritePattern(OpTy::getOperationName(), 1, context) {} + + PatternMatchResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + assert(op->getNumOperands() == 1); + assert(op->getNumResults() == 1); + + rewriter.replaceOp(op, op->getOperand(0)); + return matchSuccess(); + } +}; + +} // end anonymous namespace + +void RemoveInstrumentationPass::runOnFunction() { + OwningRewritePatternList patterns; + auto &func = getFunction(); + auto *context = &getContext(); + patterns.push_back( + llvm::make_unique>(context)); + patterns.push_back( + llvm::make_unique>(context)); + patterns.push_back( + llvm::make_unique>(context)); + applyPatternsGreedily(func, std::move(patterns)); +} + +FunctionPassBase *mlir::quantizer::createRemoveInstrumentationPass() { + return new RemoveInstrumentationPass(); +} + +static PassRegistration + pass("quantizer-remove-instrumentation", + "Removes instrumentation and hints which have no effect on final " + "execution"); diff --git a/mlir/test/Quantizer/matmul.mlir b/mlir/test/Quantizer/matmul.mlir new file mode 100644 index 0000000..c207162 --- /dev/null +++ b/mlir/test/Quantizer/matmul.mlir @@ -0,0 +1,51 @@ +// RUN: mlir-opt %s -quantizer-infer-quantized-types -quant-convert-const -quantizer-remove-instrumentation -canonicalize -split-input-file | FileCheck %s + +// ---- +// A matmul without fused clamp or bias. +// CHECK-LABEL: @matmul +// CHECK: %cst = constant dense +// CHECK-NEXT: %0 = "quant.qcast"(%arg0) : (tensor<300x3xf32>) -> tensor<300x3x!quant.uniform> +// CHECK-NEXT: %1 = "quant.scast"(%cst) : (tensor<3x5xi8>) -> tensor<3x5x!quant.uniform> +// CHECK-NEXT: %2 = "fxpmath.real_matmul"(%0, %1) : (tensor<300x3x!quant.uniform>, tensor<3x5x!quant.uniform>) -> tensor<300x5x!quant.uniform> +// CHECK-NEXT: %3 = "quant.dcast"(%2) : (tensor<300x5x!quant.uniform>) -> tensor<300x5xf32> +func @matmul(%arg0: tensor<300x3xf32>) -> tensor<300x5xf32> { + %0 = "quant.stats"(%arg0) {layerStats: dense, [-6.123e+00, 3.45e+00]>} : (tensor<300x3xf32>) -> tensor<300x3xf32> + %cst = constant {name: "constant.35"} dense, [[-1.060230e-01, 1.215050e-01, 8.002390e-01, -7.688850e-01, 0.0966112986], [6.890140e-01, -4.070560e-01, -0.797852993, 3.789250e-03, -2.088810e-01], [-6.085290e-01, 2.766170e-02, 2.685570e-01, 5.774010e-01, -4.284370e-01]]> + %1 = "fxpmath.real_matmul"(%0, %cst) : (tensor<300x3xf32>, tensor<3x5xf32>) -> tensor<300x5xf32> + %2 = "quant.stats"(%1) {layerStats: dense, [-8.000000e+00, 8.000000e+00]>} : (tensor<300x5xf32>) -> tensor<300x5xf32> + return %2 : tensor<300x5xf32> +} + +// ---- +// A matmul with fused clamp which serves as statistics for the result. +// CHECK-LABEL: @matmul_clamp +// CHECK: %cst = constant dense +// CHECK-NEXT: %0 = "quant.qcast"(%arg0) : (tensor<300x3xf32>) -> tensor<300x3x!quant.uniform> +// CHECK-NEXT: %1 = "quant.scast"(%cst) : (tensor<3x5xi8>) -> tensor<3x5x!quant.uniform> +// CHECK-NEXT: %2 = "fxpmath.real_matmul"(%0, %1) {clamp_max: 6.100000e+00 : f64, clamp_min: -1.225000e+01 : f64} : (tensor<300x3x!quant.uniform>, tensor<3x5x!quant.uniform>) -> tensor<300x5x!quant.uniform> +// CHECK-NEXT: %3 = "quant.dcast"(%2) : (tensor<300x5x!quant.uniform>) -> tensor<300x5xf32> +func @matmul_clamp(%arg0: tensor<300x3xf32>) -> tensor<300x5xf32> { + %0 = "quant.stats"(%arg0) {layerStats: dense, [-6.123e+00, 3.45e+00]>} : (tensor<300x3xf32>) -> tensor<300x3xf32> + %cst = constant {name: "constant.35"} dense, [[-1.060230e-01, 1.215050e-01, 8.002390e-01, -7.688850e-01, 0.0966112986], [6.890140e-01, -4.070560e-01, -0.797852993, 3.789250e-03, -2.088810e-01], [-6.085290e-01, 2.766170e-02, 2.685570e-01, 5.774010e-01, -4.284370e-01]]> + %1 = "fxpmath.real_matmul"(%0, %cst) {clamp_max: 6.10, clamp_min: -12.25} : (tensor<300x3xf32>, tensor<3x5xf32>) -> tensor<300x5xf32> + return %1 : tensor<300x5xf32> +} + +// ---- +// A matmul with bias and clamp. +// CHECK-LABEL: @matmul_add_clamp +// CHECK: %cst = constant dense +// CHECK-NEXT: %cst_0 = constant dense, [14, 28, 42, 56, 69]> +// CHECK-NEXT: %0 = "quant.qcast"(%arg0) : (tensor<300x3xf32>) -> tensor<300x3x!quant.uniform> +// CHECK-NEXT: %1 = "quant.scast"(%cst) : (tensor<3x5xi8>) -> tensor<3x5x!quant.uniform> +// CHECK-NEXT: %2 = "quant.scast"(%cst_0) : (tensor<5xi32>) -> tensor<5x!quant.uniform> +// CHECK-NEXT: %3 = "fxpmath.real_matmul_bias"(%0, %1, %2) {clamp_max: 6.100000e+00 : f64, clamp_min: -1.225000e+01 : f64} : (tensor<300x3x!quant.uniform>, tensor<3x5x!quant.uniform>, tensor<5x!quant.uniform>) -> tensor<300x5x!quant.uniform> +// CHECK-NEXT: %4 = "quant.dcast"(%3) : (tensor<300x5x!quant.uniform>) -> tensor<300x5xf32> +func @matmul_add_clamp(%arg0: tensor<300x3xf32>) -> tensor<300x5xf32> { + %0 = "quant.stats"(%arg0) {layerStats: dense, [-6.123e+00, 3.45e+00]>} : (tensor<300x3xf32>) -> tensor<300x3xf32> + %cst = constant {name: "constant.35"} dense, [[-1.060230e-01, 1.215050e-01, 8.002390e-01, -7.688850e-01, 0.0966112986], [6.890140e-01, -4.070560e-01, -0.797852993, 3.789250e-03, -2.088810e-01], [-6.085290e-01, 2.766170e-02, 2.685570e-01, 5.774010e-01, -4.284370e-01]]> + %cst_0 = constant {name: "constant.37"} dense, [1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00]> + %1 = "fxpmath.real_matmul_bias"(%0, %cst, %cst_0) {clamp_max: 6.10, clamp_min: -12.25} : (tensor<300x3xf32>, tensor<3x5xf32>, tensor<5xf32>) -> tensor<300x5xf32> + return %1 : tensor<300x5xf32> +} + diff --git a/mlir/test/Quantizer/remove-instrumentation.mlir b/mlir/test/Quantizer/remove-instrumentation.mlir new file mode 100644 index 0000000..659104b --- /dev/null +++ b/mlir/test/Quantizer/remove-instrumentation.mlir @@ -0,0 +1,15 @@ +// RUN: mlir-opt %s -quantizer-remove-instrumentation -verify -split-input-file | FileCheck %s + +// ----- +// CHECK-LABEL: remove_ops +func @remove_ops(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { + %0 = "quant.stats"(%arg0) { + layerStats: dense, [-1.0, 1.0]> + } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> + %1 = "quant.coupled_ref"(%0) { coupledKey: "foobar" } : + (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> + %2 = "quant.stats_ref"(%1) { statsKey: "foobar" } : + (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> + // CHECK: return %arg0 : tensor<8x4x3xf32> + return %2 : tensor<8x4x3xf32> +} diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt index 0bc5d4f..636d186 100644 --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -26,6 +26,7 @@ set(LIBS MLIRNVVMIR MLIRParser MLIRPass + MLIRQuantizerTransforms MLIRQuantOps MLIRStandardOps MLIRTransforms -- 2.7.4