namespace llvm {
template <>
-struct llvm::GraphTraits<const mlir::quantizer::CAGNode *> {
+struct GraphTraits<const mlir::quantizer::CAGNode *> {
using NodeRef = const mlir::quantizer::CAGNode *;
static NodeRef getEntryNode(NodeRef node) { return node; }
};
template <>
-struct llvm::GraphTraits<const mlir::quantizer::CAGSlice *>
+struct GraphTraits<const mlir::quantizer::CAGSlice *>
: public llvm::GraphTraits<const mlir::quantizer::CAGNode *> {
using nodes_iterator = mlir::quantizer::CAGSlice::const_iterator;
static mlir::quantizer::CAGSlice::const_iterator
double mean = 0;
double variance = 0;
+ TensorAxisStatistics() {}
+ TensorAxisStatistics(int64_t sampleSize, double minValue, double maxValue,
+ double mean, double variance)
+ : sampleSize(sampleSize), minValue(minValue), maxValue(maxValue),
+ mean(mean), variance(variance) {}
void clear() { *this = TensorAxisStatistics(); }
};
/// Factory methods for adding CAG constraints of various kinds suitable
/// for solving for uniform quantization.
class UniformConstraintsBuilder {
- public:
+public:
UniformConstraintsBuilder(CAGSlice &slice) : slice(slice) {}
/// Adds a coupling constraint between two nodes, effectively treating
/// as the originating node.
void propagateExplicitScale(CAGAnchorNode *from, CAGAnchorNode *to);
- private:
+private:
CAGSlice &slice;
};
-} // namespace quantizer
-} // namespace mlir
+} // namespace quantizer
+} // namespace mlir
-#endif // MLIR_QUANTIZER_SUPPORT_UNIFORMCONSTRAINTS_H
+#endif // MLIR_QUANTIZER_SUPPORT_UNIFORMCONSTRAINTS_H
--- /dev/null
+//===- Passes.h - Quantizer passes -----------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines entry points to create passes to perform various kinds
+// of quantization related transforms.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_QUANTIZER_TRANSFORMS_PASSES_H
+#define MLIR_QUANTIZER_TRANSFORMS_PASSES_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace quantizer {
+
+class SolverContext;
+class TargetConfiguration;
+
+/// Creates a pass that infers quantized types based on metadata discovered
+/// in the computation.
+ModulePassBase *
+createInferQuantizedTypesPass(SolverContext &solverContext,
+ const TargetConfiguration &config);
+
+/// Creates a pass which removes any instrumentation and hint ops which have
+/// no effect on final runtime.
+FunctionPassBase *createRemoveInstrumentationPass();
+
+/// Adds default (dummy) statistics to ops that can benefit from runtime stats.
+/// Meant for testing.
+FunctionPassBase *createAddDefaultStatsPass();
+
+} // namespace quantizer
+} // namespace mlir
+
+#endif // MLIR_QUANTIZER_TRANSFORMS_PASSES_H
add_subdirectory(Linalg)
add_subdirectory(Parser)
add_subdirectory(Pass)
+add_subdirectory(Quantizer)
add_subdirectory(StandardOps)
add_subdirectory(Support)
add_subdirectory(TableGen)
--- /dev/null
+# Support.
+add_llvm_library(MLIRQuantizerSupport
+ Support/Configuration.cpp
+ Support/ConstraintAnalysisGraph.cpp
+ Support/Metadata.cpp
+ Support/Statistics.cpp
+ Support/TypeUtils.cpp
+ Support/UniformConstraints.cpp
+ Support/UniformSolvers.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ )
+add_dependencies(MLIRQuantizerSupport
+ MLIRIR
+ MLIRQuantOps
+ MLIRSupport
+ MLIRStandardOps)
+
+# Configurations.
+add_llvm_library(MLIRQuantizerFxpMathConfig
+ Configurations/FxpMathConfig.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ )
+add_dependencies(MLIRQuantizerFxpMathConfig
+ MLIRQuantizerSupport)
+
+# Transforms.
+add_llvm_library(MLIRQuantizerTransforms
+ Transforms/AddDefaultStatsTestPass.cpp
+ Transforms/InferQuantizedTypesPass.cpp
+ Transforms/RemoveInstrumentationPass.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ )
+add_dependencies(MLIRQuantizerTransforms
+ MLIRQuantizerFxpMathConfig
+ MLIRQuantizerSupport
+ MLIRPass)
+target_link_libraries(MLIRQuantizerTransforms
+ MLIRQuantizerFxpMathConfig
+ MLIRQuantizerSupport
+ MLIRPass)
bool AttributeTensorStatistics::get(TensorAxisStatistics &stats) const {
if (FloatAttr floatAttr = attr.dyn_cast<FloatAttr>()) {
double value = floatAttr.getValueAsDouble();
- stats = TensorAxisStatistics{1, value, value, value, 0};
+ stats = TensorAxisStatistics(1, value, value, value, 0);
return true;
} else if (auto eltAttr = attr.dyn_cast<ElementsAttr>()) {
return getElementsStatistics(eltAttr, stats);
--- /dev/null
+//===- AddDefaultStatsTestPass.cpp - Testing pass to add default stats ----===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines a testing pass to add default statistics nodes to every
+// quantization eligible op. Useful for unit testing.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/QuantOps/QuantOps.h"
+#include "mlir/Dialect/QuantOps/QuantTypes.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Quantizer/Configurations/FxpMathConfig.h"
+#include "mlir/Quantizer/Support/Configuration.h"
+#include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h"
+#include "mlir/Quantizer/Support/ConstraintAnalysisGraphTraits.h"
+#include "mlir/Quantizer/Transforms/Passes.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/Support/GraphWriter.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+using namespace mlir::quantizer;
+using namespace mlir::quant;
+
+namespace {
+
+class AddDefaultStatsPass : public FunctionPass<AddDefaultStatsPass> {
+public:
+ AddDefaultStatsPass() = default;
+ AddDefaultStatsPass(SolverContext &solverContext,
+ const TargetConfiguration &config)
+ : explicitSolverContext(&solverContext), explicitConfig(&config) {}
+
+ void runOnFunction() override;
+ void runWithConfig(SolverContext &solverContext,
+ const TargetConfiguration &config);
+
+private:
+ SolverContext *explicitSolverContext = nullptr;
+ const TargetConfiguration *explicitConfig = nullptr;
+};
+
+} // end anonymous namespace
+
+void AddDefaultStatsPass::runOnFunction() {
+ if (explicitSolverContext && explicitConfig) {
+ // If explicitly constructed with a config and context.
+ runWithConfig(*explicitSolverContext, *explicitConfig);
+ return;
+ }
+ // For global pass registration, use defaults.
+ SolverContext solverContext(*getFunction().getContext());
+ auto config = FxpMathTargetConfig::create(solverContext);
+ runWithConfig(solverContext, *config);
+}
+
+void AddDefaultStatsPass::runWithConfig(SolverContext &solverContext,
+ const TargetConfiguration &config) {
+ auto &func = getFunction();
+
+ // Insert stats for each argument.
+ for (auto *arg : func.getArguments()) {
+ if (!config.isHandledType(arg->getType()))
+ continue;
+ FuncBuilder b(func);
+ APFloat minValue(-1.0f);
+ APFloat maxValue(1.0f);
+ ElementsAttr layerStats = DenseFPElementsAttr::get(
+ b.getTensorType({2}, b.getF32Type()), {minValue, maxValue});
+ auto statsOp =
+ b.create<StatisticsOp>(func.getLoc(), arg, layerStats, nullptr);
+ arg->replaceAllUsesWith(statsOp);
+
+ // StatsOp contained a use to 'arg' so make sure to reset it after replacing
+ // all of the uses of 'arg'.
+ statsOp.getOperation()->replaceUsesOfWith(statsOp, arg);
+ }
+
+ // Walk the ops and insert stats.
+ func.walk([&](Operation *op) {
+ if (!config.isRequireStatsOp(op)) {
+ return;
+ }
+ assert(op->getNumResults() == 1);
+
+ auto originalResult = op->getResult(0);
+ if (!config.isHandledType(originalResult->getType()))
+ return;
+
+ FuncBuilder b(op->getBlock(), ++op->getIterator());
+
+ APFloat minValue(-1.0f);
+ APFloat maxValue(1.0f);
+ ElementsAttr layerStats = DenseFPElementsAttr::get(
+ b.getTensorType({2}, b.getF32Type()), {minValue, maxValue});
+ auto statsOp = b.create<StatisticsOp>(op->getLoc(), op->getResult(0),
+ layerStats, nullptr);
+ originalResult->replaceAllUsesWith(statsOp);
+
+ // StatsOp contained a use to 'op' so make sure to reset it after replacing
+ // all of the uses of 'op'.
+ statsOp.getOperation()->replaceUsesOfWith(statsOp, originalResult);
+ });
+}
+
+FunctionPassBase *mlir::quantizer::createAddDefaultStatsPass() {
+ return new AddDefaultStatsPass();
+}
+
+static PassRegistration<AddDefaultStatsPass> pass(
+ "quantizer-add-default-stats-test",
+ "Adds default (dummy) statistics to all ops that can benefit from "
+ "runtime statistics. This is meant to help in early stage bootstrapping.");
--- /dev/null
+//===- InferQuantizedTypesPass.cpp - Infers quantized types ---------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines the primary pass for instantiating a CAG, running it to
+// convergence on a module to determine eligible quantized type transforms, and
+// applying those transforms to the IR.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/QuantOps/QuantOps.h"
+#include "mlir/Dialect/QuantOps/QuantTypes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Quantizer/Configurations/FxpMathConfig.h"
+#include "mlir/Quantizer/Support/Configuration.h"
+#include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h"
+#include "mlir/Quantizer/Support/ConstraintAnalysisGraphTraits.h"
+#include "mlir/Quantizer/Transforms/Passes.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/Support/DOTGraphTraits.h"
+#include "llvm/Support/GraphWriter.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+using namespace mlir::quantizer;
+using namespace mlir::quant;
+
+namespace llvm {
+
+template <>
+struct DOTGraphTraits<const CAGSlice *>
+ : public DOTGraphTraits<const CAGNode *> {
+ DOTGraphTraits(bool isSimple = false)
+ : DOTGraphTraits<const CAGNode *>(isSimple) {}
+
+ std::string getNodeLabel(const CAGNode *node, const CAGSlice *graph) {
+ std::string s;
+ llvm::raw_string_ostream out(s);
+ node->printLabel(out);
+ return out.str();
+ }
+
+ static std::string getGraphProperties(const CAGSlice *) {
+ return "rankdir=LR;";
+ }
+
+ static bool isNodeHidden(const CAGNode *node) {
+ // Filter constraint nodes with no incoming or outgoing connections.
+ // These orphans are often created as part of graph merging operations.
+ return llvm::isa<CAGConstraintNode>(node) && node->isOrphan();
+ }
+
+ std::string getNodeAttributes(const CAGNode *node, const CAGSlice *graph) {
+ switch (node->getKind()) {
+ default:
+ return std::string();
+ case CAGNode::Kind::OperandAnchor:
+ return "shape=record,color=yellow,style=filled";
+ case CAGNode::Kind::ResultAnchor:
+ return "shape=record,color=lightblue,style=filled";
+ case CAGNode::Kind::Constraint:
+ return "shape=record,style=dotted";
+ }
+ }
+};
+
+} // end namespace llvm
+
+namespace {
+
+class InferQuantizedTypesPass : public ModulePass<InferQuantizedTypesPass> {
+public:
+ InferQuantizedTypesPass() = default;
+ InferQuantizedTypesPass(SolverContext &solverContext,
+ const TargetConfiguration &config)
+ : explicitSolverContext(&solverContext), explicitConfig(&config) {}
+ void runOnModule() override;
+ void runWithConfig(SolverContext &solverContext,
+ const TargetConfiguration &config);
+
+ void transformOperandType(CAGOperandAnchor *anchor, Type newType);
+ void transformResultType(CAGResultAnchor *anchor, Type newType);
+
+private:
+ SolverContext *explicitSolverContext = nullptr;
+ const TargetConfiguration *explicitConfig = nullptr;
+};
+
+} // end anonymous namespace
+
+/// Maximum number of propagation rounds to run to converge the CAG before
+/// signalling an error.
+static const int kMaximumPropagationRounds = 1000;
+
+static LogicalResult validateTypeConversion(Type newType, Type origType,
+ Operation *op) {
+ if (!newType) {
+ return op->emitOpError() << "unsupported type conversion from " << newType;
+ }
+ return success();
+}
+
+void InferQuantizedTypesPass::runOnModule() {
+ if (explicitSolverContext && explicitConfig) {
+ // If explicitly constructed with a config and context.
+ runWithConfig(*explicitSolverContext, *explicitConfig);
+ return;
+ }
+
+ // For global pass registration, use defaults.
+ SolverContext solverContext(*getModule().getContext());
+ auto config = FxpMathTargetConfig::create(solverContext);
+ runWithConfig(solverContext, *config);
+}
+
+void InferQuantizedTypesPass::runWithConfig(SolverContext &solverContext,
+ const TargetConfiguration &config) {
+ CAGSlice cag(solverContext);
+ for (auto &f : getModule()) {
+ f.walk([&cag, &config](Operation *op) { config.handleOp(op, cag); });
+ }
+ config.finalizeAnchors(cag);
+
+ // Propagate.
+ int propRound;
+ for (propRound = kMaximumPropagationRounds; propRound > 0; --propRound) {
+ auto propCount = cag.propagate(config);
+ if (propCount == 0)
+ break;
+ }
+ if (propRound == 0) {
+ getContext().emitError(
+ UnknownLoc::get(&getContext()),
+ "exceeded maximum number of solver iterations (infinite loop?)");
+ return;
+ }
+
+ // TODO: Only dump the GraphViz if a flag is set and move to a utility.
+ // GraphViz.
+ if (!solverContext.getDebugCAGDotPath().empty()) {
+ auto actFileName =
+ llvm::WriteGraph(const_cast<const CAGSlice *>(&cag), "CAG",
+ /*ShortNames=*/false,
+ /*Title=*/"CAG",
+ /*Filename=*/solverContext.getDebugCAGDotPath());
+ llvm::errs() << "Wrote graphviz file: " << actFileName << "\n";
+ }
+
+ // Start transforming the types in order of anchor type (results, then
+ // operands).
+ // Apply result types.
+ for (auto *node : cag) {
+ auto anchorNode = llvm::dyn_cast<CAGResultAnchor>(node);
+ if (!anchorNode)
+ continue;
+ if (Type newType = anchorNode->getTransformedType())
+ transformResultType(anchorNode, newType);
+ }
+
+ // Apply operand types.
+ for (auto *node : cag) {
+ auto anchorNode = llvm::dyn_cast<CAGOperandAnchor>(node);
+ if (!anchorNode)
+ continue;
+ if (Type newType = anchorNode->getTransformedType())
+ transformOperandType(anchorNode, newType);
+ }
+}
+
+void InferQuantizedTypesPass::transformOperandType(CAGOperandAnchor *anchor,
+ Type newType) {
+ Value *inputValue = anchor->getValue();
+ Operation *op = anchor->getOp();
+ FuncBuilder b(op->getBlock(), Block::iterator(op));
+
+ SmallVector<Value *, 1> removeValuesIfDead;
+
+ // Because we've already run the result transforms at this phase, it is
+ // very likely that inputValue points to a dcast op whose input matches
+ // our type. We detect that situation and route around just to save some
+ // bulk in the IR.
+ Value *newTypedInputValue = inputValue;
+ auto inputDcastOp =
+ dyn_cast_or_null<DequantizeCastOp>(inputValue->getDefiningOp());
+ if (inputDcastOp && inputDcastOp.arg()->getType() == newType) {
+ // Can just use the dcast's input value.
+ newTypedInputValue = inputDcastOp.arg();
+ removeValuesIfDead.push_back(inputDcastOp);
+ } else {
+ // Need to synthesize a qcast.
+ newTypedInputValue =
+ b.create<QuantizeCastOp>(op->getLoc(), newType, inputValue);
+ }
+
+ switch (anchor->getTypeTransformRule()) {
+ default:
+ op->emitOpError("unsupported type transform rule");
+ break;
+ case CAGAnchorNode::TypeTransformRule::Direct:
+ anchor->getOp()->setOperand(anchor->getOperandIdx(), newTypedInputValue);
+ break;
+
+ case CAGAnchorNode::TypeTransformRule::DirectStorage: {
+ Type storageType = QuantizedType::castToStorageType(newType);
+ if (failed(validateTypeConversion(storageType, newType, op)))
+ return;
+ anchor->getOp()->setOperand(
+ anchor->getOperandIdx(),
+ b.create<StorageCastOp>(op->getLoc(), storageType, newTypedInputValue));
+ break;
+ }
+
+ case CAGAnchorNode::TypeTransformRule::ExpressedOnly:
+ // Leave the anchor as-is and just cast in/out after it.
+ anchor->getOp()->setOperand(
+ anchor->getOperandIdx(),
+ b.create<DequantizeCastOp>(op->getLoc(), anchor->getOriginalType(),
+ newTypedInputValue));
+ break;
+ }
+
+ for (Value *removeValueIfDead : removeValuesIfDead) {
+ if (removeValueIfDead->use_empty()) {
+ removeValueIfDead->getDefiningOp()->erase();
+ }
+ }
+}
+
+void InferQuantizedTypesPass::transformResultType(CAGResultAnchor *anchor,
+ Type newType) {
+ Value *origResultValue = anchor->getValue();
+ Operation *op = origResultValue->getDefiningOp();
+ FuncBuilder b(op->getBlock(), ++Block::iterator(op));
+
+ Value *replacedResultValue = nullptr;
+ Value *newResultValue = nullptr;
+ switch (anchor->getTypeTransformRule()) {
+ default:
+ op->emitOpError("unsupported type transform rule");
+ return;
+ case CAGAnchorNode::TypeTransformRule::Direct:
+ origResultValue->setType(newType);
+ replacedResultValue = newResultValue = b.create<DequantizeCastOp>(
+ op->getLoc(), anchor->getOriginalType(), origResultValue);
+ break;
+
+ case CAGAnchorNode::TypeTransformRule::DirectStorage: {
+ Type storageType = QuantizedType::castToStorageType(newType);
+ if (failed(validateTypeConversion(storageType, newType, op)))
+ return;
+ origResultValue->setType(storageType);
+ replacedResultValue =
+ b.create<StorageCastOp>(op->getLoc(), newType, origResultValue);
+ newResultValue = b.create<DequantizeCastOp>(
+ op->getLoc(), anchor->getOriginalType(), replacedResultValue);
+ break;
+ }
+
+ case CAGAnchorNode::TypeTransformRule::ExpressedOnly:
+ // Leave the anchor as-is and just cast in/out after it.
+ replacedResultValue =
+ b.create<QuantizeCastOp>(op->getLoc(), newType, origResultValue);
+ newResultValue = b.create<DequantizeCastOp>(
+ op->getLoc(), anchor->getOriginalType(), replacedResultValue);
+ break;
+ }
+
+ if (replacedResultValue) {
+ // Transform:
+ // origResultValue --> replaceResultValue -> newResultValue
+ // \-> [original uses]
+ // To:
+ // origResultValue -> replaceResultValue ->
+ // newResultValue -> [original uses]
+ // Note that replaceResultValue may equal newResultValue or there may
+ // be operands between the two.
+ origResultValue->replaceAllUsesWith(newResultValue);
+ replacedResultValue->getDefiningOp()->replaceUsesOfWith(newResultValue,
+ origResultValue);
+ }
+}
+
+ModulePassBase *mlir::quantizer::createInferQuantizedTypesPass(
+ SolverContext &solverContext, const TargetConfiguration &config) {
+ return new InferQuantizedTypesPass(solverContext, config);
+}
+
+static PassRegistration<InferQuantizedTypesPass>
+ pass("quantizer-infer-quantized-types",
+ "Infers quantized types for a module");
--- /dev/null
+//===- RemoveInstrumentationPass.cpp - Removes instrumentation ------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines a pass to remove any instrumentation ops. It is often one
+// of the final steps when performing quantization and is run after any
+// decisions requiring instrumentation have been made.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/QuantOps/QuantOps.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Quantizer/Transforms/Passes.h"
+
+using namespace mlir;
+using namespace mlir::quantizer;
+using namespace mlir::quant;
+
+namespace {
+
+class RemoveInstrumentationPass
+ : public FunctionPass<RemoveInstrumentationPass> {
+ void runOnFunction() override;
+};
+
+template <typename OpTy>
+class RemoveIdentityOpRewrite : public RewritePattern {
+public:
+ RemoveIdentityOpRewrite(MLIRContext *context)
+ : RewritePattern(OpTy::getOperationName(), 1, context) {}
+
+ PatternMatchResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ assert(op->getNumOperands() == 1);
+ assert(op->getNumResults() == 1);
+
+ rewriter.replaceOp(op, op->getOperand(0));
+ return matchSuccess();
+ }
+};
+
+} // end anonymous namespace
+
+void RemoveInstrumentationPass::runOnFunction() {
+ OwningRewritePatternList patterns;
+ auto &func = getFunction();
+ auto *context = &getContext();
+ patterns.push_back(
+ llvm::make_unique<RemoveIdentityOpRewrite<StatisticsOp>>(context));
+ patterns.push_back(
+ llvm::make_unique<RemoveIdentityOpRewrite<StatisticsRefOp>>(context));
+ patterns.push_back(
+ llvm::make_unique<RemoveIdentityOpRewrite<CoupledRefOp>>(context));
+ applyPatternsGreedily(func, std::move(patterns));
+}
+
+FunctionPassBase *mlir::quantizer::createRemoveInstrumentationPass() {
+ return new RemoveInstrumentationPass();
+}
+
+static PassRegistration<RemoveInstrumentationPass>
+ pass("quantizer-remove-instrumentation",
+ "Removes instrumentation and hints which have no effect on final "
+ "execution");
--- /dev/null
+// RUN: mlir-opt %s -quantizer-infer-quantized-types -quant-convert-const -quantizer-remove-instrumentation -canonicalize -split-input-file | FileCheck %s
+
+// ----
+// A matmul without fused clamp or bias.
+// CHECK-LABEL: @matmul
+// CHECK: %cst = constant dense<tensor<3x5xi8>
+// CHECK-NEXT: %0 = "quant.qcast"(%arg0) : (tensor<300x3xf32>) -> tensor<300x3x!quant.uniform<i8:f32, 0.037564418067230126:35>>
+// CHECK-NEXT: %1 = "quant.scast"(%cst) : (tensor<3x5xi8>) -> tensor<3x5x!quant.uniform<i8:f32, 0.0062823070315864236:-1>>
+// CHECK-NEXT: %2 = "fxpmath.real_matmul"(%0, %1) : (tensor<300x3x!quant.uniform<i8:f32, 0.037564418067230126:35>>, tensor<3x5x!quant.uniform<i8:f32, 0.0062823070315864236:-1>>) -> tensor<300x5x!quant.uniform<i8:f32, 0.0629921259842528:-1>>
+// CHECK-NEXT: %3 = "quant.dcast"(%2) : (tensor<300x5x!quant.uniform<i8:f32, 0.0629921259842528:-1>>) -> tensor<300x5xf32>
+func @matmul(%arg0: tensor<300x3xf32>) -> tensor<300x5xf32> {
+ %0 = "quant.stats"(%arg0) {layerStats: dense<tensor<2xf32>, [-6.123e+00, 3.45e+00]>} : (tensor<300x3xf32>) -> tensor<300x3xf32>
+ %cst = constant {name: "constant.35"} dense<tensor<3x5xf32>, [[-1.060230e-01, 1.215050e-01, 8.002390e-01, -7.688850e-01, 0.0966112986], [6.890140e-01, -4.070560e-01, -0.797852993, 3.789250e-03, -2.088810e-01], [-6.085290e-01, 2.766170e-02, 2.685570e-01, 5.774010e-01, -4.284370e-01]]>
+ %1 = "fxpmath.real_matmul"(%0, %cst) : (tensor<300x3xf32>, tensor<3x5xf32>) -> tensor<300x5xf32>
+ %2 = "quant.stats"(%1) {layerStats: dense<tensor<2xf32>, [-8.000000e+00, 8.000000e+00]>} : (tensor<300x5xf32>) -> tensor<300x5xf32>
+ return %2 : tensor<300x5xf32>
+}
+
+// ----
+// A matmul with fused clamp which serves as statistics for the result.
+// CHECK-LABEL: @matmul_clamp
+// CHECK: %cst = constant dense<tensor<3x5xi8>
+// CHECK-NEXT: %0 = "quant.qcast"(%arg0) : (tensor<300x3xf32>) -> tensor<300x3x!quant.uniform<i8:f32, 0.037564418067230126:35>>
+// CHECK-NEXT: %1 = "quant.scast"(%cst) : (tensor<3x5xi8>) -> tensor<3x5x!quant.uniform<i8:f32, 0.0062823070315864236:-1>>
+// CHECK-NEXT: %2 = "fxpmath.real_matmul"(%0, %1) {clamp_max: 6.100000e+00 : f64, clamp_min: -1.225000e+01 : f64} : (tensor<300x3x!quant.uniform<i8:f32, 0.037564418067230126:35>>, tensor<3x5x!quant.uniform<i8:f32, 0.0062823070315864236:-1>>) -> tensor<300x5x!quant.uniform<i8:f32, 0.072058823529412216:42>>
+// CHECK-NEXT: %3 = "quant.dcast"(%2) : (tensor<300x5x!quant.uniform<i8:f32, 0.072058823529412216:42>>) -> tensor<300x5xf32>
+func @matmul_clamp(%arg0: tensor<300x3xf32>) -> tensor<300x5xf32> {
+ %0 = "quant.stats"(%arg0) {layerStats: dense<tensor<2xf32>, [-6.123e+00, 3.45e+00]>} : (tensor<300x3xf32>) -> tensor<300x3xf32>
+ %cst = constant {name: "constant.35"} dense<tensor<3x5xf32>, [[-1.060230e-01, 1.215050e-01, 8.002390e-01, -7.688850e-01, 0.0966112986], [6.890140e-01, -4.070560e-01, -0.797852993, 3.789250e-03, -2.088810e-01], [-6.085290e-01, 2.766170e-02, 2.685570e-01, 5.774010e-01, -4.284370e-01]]>
+ %1 = "fxpmath.real_matmul"(%0, %cst) {clamp_max: 6.10, clamp_min: -12.25} : (tensor<300x3xf32>, tensor<3x5xf32>) -> tensor<300x5xf32>
+ return %1 : tensor<300x5xf32>
+}
+
+// ----
+// A matmul with bias and clamp.
+// CHECK-LABEL: @matmul_add_clamp
+// CHECK: %cst = constant dense<tensor<3x5xi8>
+// CHECK-NEXT: %cst_0 = constant dense<tensor<5xi32>, [14, 28, 42, 56, 69]>
+// CHECK-NEXT: %0 = "quant.qcast"(%arg0) : (tensor<300x3xf32>) -> tensor<300x3x!quant.uniform<i8:f32, 0.037564418067230126:35>>
+// CHECK-NEXT: %1 = "quant.scast"(%cst) : (tensor<3x5xi8>) -> tensor<3x5x!quant.uniform<i8:f32, 0.0062823070315864236:-1>>
+// CHECK-NEXT: %2 = "quant.scast"(%cst_0) : (tensor<5xi32>) -> tensor<5x!quant.uniform<i32:f32, 0.072058823529412216>>
+// CHECK-NEXT: %3 = "fxpmath.real_matmul_bias"(%0, %1, %2) {clamp_max: 6.100000e+00 : f64, clamp_min: -1.225000e+01 : f64} : (tensor<300x3x!quant.uniform<i8:f32, 0.037564418067230126:35>>, tensor<3x5x!quant.uniform<i8:f32, 0.0062823070315864236:-1>>, tensor<5x!quant.uniform<i32:f32, 0.072058823529412216>>) -> tensor<300x5x!quant.uniform<i8:f32, 0.072058823529412216:42>>
+// CHECK-NEXT: %4 = "quant.dcast"(%3) : (tensor<300x5x!quant.uniform<i8:f32, 0.072058823529412216:42>>) -> tensor<300x5xf32>
+func @matmul_add_clamp(%arg0: tensor<300x3xf32>) -> tensor<300x5xf32> {
+ %0 = "quant.stats"(%arg0) {layerStats: dense<tensor<2xf32>, [-6.123e+00, 3.45e+00]>} : (tensor<300x3xf32>) -> tensor<300x3xf32>
+ %cst = constant {name: "constant.35"} dense<tensor<3x5xf32>, [[-1.060230e-01, 1.215050e-01, 8.002390e-01, -7.688850e-01, 0.0966112986], [6.890140e-01, -4.070560e-01, -0.797852993, 3.789250e-03, -2.088810e-01], [-6.085290e-01, 2.766170e-02, 2.685570e-01, 5.774010e-01, -4.284370e-01]]>
+ %cst_0 = constant {name: "constant.37"} dense<tensor<5xf32>, [1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00]>
+ %1 = "fxpmath.real_matmul_bias"(%0, %cst, %cst_0) {clamp_max: 6.10, clamp_min: -12.25} : (tensor<300x3xf32>, tensor<3x5xf32>, tensor<5xf32>) -> tensor<300x5xf32>
+ return %1 : tensor<300x5xf32>
+}
+
--- /dev/null
+// RUN: mlir-opt %s -quantizer-remove-instrumentation -verify -split-input-file | FileCheck %s
+
+// -----
+// CHECK-LABEL: remove_ops
+func @remove_ops(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
+ %0 = "quant.stats"(%arg0) {
+ layerStats: dense<tensor<2xf32>, [-1.0, 1.0]>
+ } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
+ %1 = "quant.coupled_ref"(%0) { coupledKey: "foobar" } :
+ (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
+ %2 = "quant.stats_ref"(%1) { statsKey: "foobar" } :
+ (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
+ // CHECK: return %arg0 : tensor<8x4x3xf32>
+ return %2 : tensor<8x4x3xf32>
+}
MLIRNVVMIR
MLIRParser
MLIRPass
+ MLIRQuantizerTransforms
MLIRQuantOps
MLIRStandardOps
MLIRTransforms