Upstream the Quantizer tool (part 4).
authorStella Laurenzo <laurenzo@google.com>
Tue, 21 May 2019 01:27:38 +0000 (18:27 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 2 Jun 2019 02:53:12 +0000 (19:53 -0700)
    This adds the basic passes needed and ties them into mlir-opt. Also adds two specific unit tests that exercise them.

    Next step is a standalone quantizer tool and additional cleanup.
    Tested:
      ninja check-mlir

--

PiperOrigin-RevId: 249167690

13 files changed:
mlir/include/mlir/Quantizer/Support/ConstraintAnalysisGraphTraits.h
mlir/include/mlir/Quantizer/Support/Statistics.h
mlir/include/mlir/Quantizer/Support/UniformConstraints.h
mlir/include/mlir/Quantizer/Transforms/Passes.h [new file with mode: 0644]
mlir/lib/CMakeLists.txt
mlir/lib/Quantizer/CMakeLists.txt [new file with mode: 0644]
mlir/lib/Quantizer/Support/Statistics.cpp
mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp [new file with mode: 0644]
mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp [new file with mode: 0644]
mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp [new file with mode: 0644]
mlir/test/Quantizer/matmul.mlir [new file with mode: 0644]
mlir/test/Quantizer/remove-instrumentation.mlir [new file with mode: 0644]
mlir/tools/mlir-opt/CMakeLists.txt

index 7a907b0..7e2b61d 100644 (file)
@@ -28,7 +28,7 @@
 namespace llvm {
 
 template <>
-struct llvm::GraphTraits<const mlir::quantizer::CAGNode *> {
+struct GraphTraits<const mlir::quantizer::CAGNode *> {
   using NodeRef = const mlir::quantizer::CAGNode *;
 
   static NodeRef getEntryNode(NodeRef node) { return node; }
@@ -40,7 +40,7 @@ struct llvm::GraphTraits<const mlir::quantizer::CAGNode *> {
 };
 
 template <>
-struct llvm::GraphTraits<const mlir::quantizer::CAGSlice *>
+struct GraphTraits<const mlir::quantizer::CAGSlice *>
     : public llvm::GraphTraits<const mlir::quantizer::CAGNode *> {
   using nodes_iterator = mlir::quantizer::CAGSlice::const_iterator;
   static mlir::quantizer::CAGSlice::const_iterator
index 1c26ad0..d4641d6 100644 (file)
@@ -36,6 +36,11 @@ struct TensorAxisStatistics {
   double mean = 0;
   double variance = 0;
 
+  TensorAxisStatistics() {}
+  TensorAxisStatistics(int64_t sampleSize, double minValue, double maxValue,
+                       double mean, double variance)
+      : sampleSize(sampleSize), minValue(minValue), maxValue(maxValue),
+        mean(mean), variance(variance) {}
   void clear() { *this = TensorAxisStatistics(); }
 };
 
index beae3ed..90b5fe1 100644 (file)
@@ -34,7 +34,7 @@ class CAGSlice;
 /// Factory methods for adding CAG constraints of various kinds suitable
 /// for solving for uniform quantization.
 class UniformConstraintsBuilder {
- public:
+public:
   UniformConstraintsBuilder(CAGSlice &slice) : slice(slice) {}
 
   /// Adds a coupling constraint between two nodes, effectively treating
@@ -59,11 +59,11 @@ class UniformConstraintsBuilder {
   /// as the originating node.
   void propagateExplicitScale(CAGAnchorNode *from, CAGAnchorNode *to);
 
- private:
+private:
   CAGSlice &slice;
 };
 
-}  // namespace quantizer
-}  // namespace mlir
+} // namespace quantizer
+} // namespace mlir
 
-#endif  // MLIR_QUANTIZER_SUPPORT_UNIFORMCONSTRAINTS_H
+#endif // MLIR_QUANTIZER_SUPPORT_UNIFORMCONSTRAINTS_H
diff --git a/mlir/include/mlir/Quantizer/Transforms/Passes.h b/mlir/include/mlir/Quantizer/Transforms/Passes.h
new file mode 100644 (file)
index 0000000..0d7b4cb
--- /dev/null
@@ -0,0 +1,51 @@
+//===- Passes.h - Quantizer passes  -----------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines entry points to create passes to perform various kinds
+// of quantization related transforms.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_QUANTIZER_TRANSFORMS_PASSES_H
+#define MLIR_QUANTIZER_TRANSFORMS_PASSES_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace quantizer {
+
+class SolverContext;
+class TargetConfiguration;
+
+/// Creates a pass that infers quantized types based on metadata discovered
+/// in the computation.
+ModulePassBase *
+createInferQuantizedTypesPass(SolverContext &solverContext,
+                              const TargetConfiguration &config);
+
+/// Creates a pass which removes any instrumentation and hint ops which have
+/// no effect on final runtime.
+FunctionPassBase *createRemoveInstrumentationPass();
+
+/// Adds default (dummy) statistics to ops that can benefit from runtime stats.
+/// Meant for testing.
+FunctionPassBase *createAddDefaultStatsPass();
+
+} // namespace quantizer
+} // namespace mlir
+
+#endif // MLIR_QUANTIZER_TRANSFORMS_PASSES_H
index 659cba9..995c886 100644 (file)
@@ -9,6 +9,7 @@ add_subdirectory(LLVMIR)
 add_subdirectory(Linalg)
 add_subdirectory(Parser)
 add_subdirectory(Pass)
+add_subdirectory(Quantizer)
 add_subdirectory(StandardOps)
 add_subdirectory(Support)
 add_subdirectory(TableGen)
diff --git a/mlir/lib/Quantizer/CMakeLists.txt b/mlir/lib/Quantizer/CMakeLists.txt
new file mode 100644 (file)
index 0000000..4071263
--- /dev/null
@@ -0,0 +1,43 @@
+# Support.
+add_llvm_library(MLIRQuantizerSupport
+  Support/Configuration.cpp
+  Support/ConstraintAnalysisGraph.cpp
+  Support/Metadata.cpp
+  Support/Statistics.cpp
+  Support/TypeUtils.cpp
+  Support/UniformConstraints.cpp
+  Support/UniformSolvers.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  )
+add_dependencies(MLIRQuantizerSupport
+                 MLIRIR
+                 MLIRQuantOps
+                 MLIRSupport
+                 MLIRStandardOps)
+
+# Configurations.
+add_llvm_library(MLIRQuantizerFxpMathConfig
+  Configurations/FxpMathConfig.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  )
+add_dependencies(MLIRQuantizerFxpMathConfig
+                 MLIRQuantizerSupport)
+
+# Transforms.
+add_llvm_library(MLIRQuantizerTransforms
+  Transforms/AddDefaultStatsTestPass.cpp
+  Transforms/InferQuantizedTypesPass.cpp
+  Transforms/RemoveInstrumentationPass.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  )
+add_dependencies(MLIRQuantizerTransforms
+  MLIRQuantizerFxpMathConfig
+  MLIRQuantizerSupport
+  MLIRPass)
+target_link_libraries(MLIRQuantizerTransforms
+  MLIRQuantizerFxpMathConfig
+  MLIRQuantizerSupport
+  MLIRPass)
index 3ec07e4..058d31f 100644 (file)
@@ -88,7 +88,7 @@ static bool getElementsStatistics(ElementsAttr attr,
 bool AttributeTensorStatistics::get(TensorAxisStatistics &stats) const {
   if (FloatAttr floatAttr = attr.dyn_cast<FloatAttr>()) {
     double value = floatAttr.getValueAsDouble();
-    stats = TensorAxisStatistics{1, value, value, value, 0};
+    stats = TensorAxisStatistics(1, value, value, value, 0);
     return true;
   } else if (auto eltAttr = attr.dyn_cast<ElementsAttr>()) {
     return getElementsStatistics(eltAttr, stats);
diff --git a/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp b/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp
new file mode 100644 (file)
index 0000000..75c082f
--- /dev/null
@@ -0,0 +1,128 @@
+//===- AddDefaultStatsTestPass.cpp - Testing pass to add default stats ----===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines a testing pass to add default statistics nodes to every
+// quantization eligible op. Useful for unit testing.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/QuantOps/QuantOps.h"
+#include "mlir/Dialect/QuantOps/QuantTypes.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Quantizer/Configurations/FxpMathConfig.h"
+#include "mlir/Quantizer/Support/Configuration.h"
+#include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h"
+#include "mlir/Quantizer/Support/ConstraintAnalysisGraphTraits.h"
+#include "mlir/Quantizer/Transforms/Passes.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/Support/GraphWriter.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+using namespace mlir::quantizer;
+using namespace mlir::quant;
+
+namespace {
+
+class AddDefaultStatsPass : public FunctionPass<AddDefaultStatsPass> {
+public:
+  AddDefaultStatsPass() = default;
+  AddDefaultStatsPass(SolverContext &solverContext,
+                      const TargetConfiguration &config)
+      : explicitSolverContext(&solverContext), explicitConfig(&config) {}
+
+  void runOnFunction() override;
+  void runWithConfig(SolverContext &solverContext,
+                     const TargetConfiguration &config);
+
+private:
+  SolverContext *explicitSolverContext = nullptr;
+  const TargetConfiguration *explicitConfig = nullptr;
+};
+
+} // end anonymous namespace
+
+void AddDefaultStatsPass::runOnFunction() {
+  if (explicitSolverContext && explicitConfig) {
+    // If explicitly constructed with a config and context.
+    runWithConfig(*explicitSolverContext, *explicitConfig);
+    return;
+  }
+  // For global pass registration, use defaults.
+  SolverContext solverContext(*getFunction().getContext());
+  auto config = FxpMathTargetConfig::create(solverContext);
+  runWithConfig(solverContext, *config);
+}
+
+void AddDefaultStatsPass::runWithConfig(SolverContext &solverContext,
+                                        const TargetConfiguration &config) {
+  auto &func = getFunction();
+
+  // Insert stats for each argument.
+  for (auto *arg : func.getArguments()) {
+    if (!config.isHandledType(arg->getType()))
+      continue;
+    FuncBuilder b(func);
+    APFloat minValue(-1.0f);
+    APFloat maxValue(1.0f);
+    ElementsAttr layerStats = DenseFPElementsAttr::get(
+        b.getTensorType({2}, b.getF32Type()), {minValue, maxValue});
+    auto statsOp =
+        b.create<StatisticsOp>(func.getLoc(), arg, layerStats, nullptr);
+    arg->replaceAllUsesWith(statsOp);
+
+    // StatsOp contained a use to 'arg' so make sure to reset it after replacing
+    // all of the uses of 'arg'.
+    statsOp.getOperation()->replaceUsesOfWith(statsOp, arg);
+  }
+
+  // Walk the ops and insert stats.
+  func.walk([&](Operation *op) {
+    if (!config.isRequireStatsOp(op)) {
+      return;
+    }
+    assert(op->getNumResults() == 1);
+
+    auto originalResult = op->getResult(0);
+    if (!config.isHandledType(originalResult->getType()))
+      return;
+
+    FuncBuilder b(op->getBlock(), ++op->getIterator());
+
+    APFloat minValue(-1.0f);
+    APFloat maxValue(1.0f);
+    ElementsAttr layerStats = DenseFPElementsAttr::get(
+        b.getTensorType({2}, b.getF32Type()), {minValue, maxValue});
+    auto statsOp = b.create<StatisticsOp>(op->getLoc(), op->getResult(0),
+                                          layerStats, nullptr);
+    originalResult->replaceAllUsesWith(statsOp);
+
+    // StatsOp contained a use to 'op' so make sure to reset it after replacing
+    // all of the uses of 'op'.
+    statsOp.getOperation()->replaceUsesOfWith(statsOp, originalResult);
+  });
+}
+
+FunctionPassBase *mlir::quantizer::createAddDefaultStatsPass() {
+  return new AddDefaultStatsPass();
+}
+
+static PassRegistration<AddDefaultStatsPass> pass(
+    "quantizer-add-default-stats-test",
+    "Adds default (dummy) statistics to all ops that can benefit from "
+    "runtime statistics. This is meant to help in early stage bootstrapping.");
diff --git a/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp b/mlir/lib/Quantizer/Transforms/InferQuantizedTypesPass.cpp
new file mode 100644 (file)
index 0000000..a2cfe72
--- /dev/null
@@ -0,0 +1,303 @@
+//===- InferQuantizedTypesPass.cpp - Infers quantized types ---------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines the primary pass for instantiating a CAG, running it to
+// convergence on a module to determine eligible quantized type transforms, and
+// applying those transforms to the IR.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/QuantOps/QuantOps.h"
+#include "mlir/Dialect/QuantOps/QuantTypes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Quantizer/Configurations/FxpMathConfig.h"
+#include "mlir/Quantizer/Support/Configuration.h"
+#include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h"
+#include "mlir/Quantizer/Support/ConstraintAnalysisGraphTraits.h"
+#include "mlir/Quantizer/Transforms/Passes.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/Support/DOTGraphTraits.h"
+#include "llvm/Support/GraphWriter.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+using namespace mlir::quantizer;
+using namespace mlir::quant;
+
+namespace llvm {
+
+template <>
+struct DOTGraphTraits<const CAGSlice *>
+    : public DOTGraphTraits<const CAGNode *> {
+  DOTGraphTraits(bool isSimple = false)
+      : DOTGraphTraits<const CAGNode *>(isSimple) {}
+
+  std::string getNodeLabel(const CAGNode *node, const CAGSlice *graph) {
+    std::string s;
+    llvm::raw_string_ostream out(s);
+    node->printLabel(out);
+    return out.str();
+  }
+
+  static std::string getGraphProperties(const CAGSlice *) {
+    return "rankdir=LR;";
+  }
+
+  static bool isNodeHidden(const CAGNode *node) {
+    // Filter constraint nodes with no incoming or outgoing connections.
+    // These orphans are often created as part of graph merging operations.
+    return llvm::isa<CAGConstraintNode>(node) && node->isOrphan();
+  }
+
+  std::string getNodeAttributes(const CAGNode *node, const CAGSlice *graph) {
+    switch (node->getKind()) {
+    default:
+      return std::string();
+    case CAGNode::Kind::OperandAnchor:
+      return "shape=record,color=yellow,style=filled";
+    case CAGNode::Kind::ResultAnchor:
+      return "shape=record,color=lightblue,style=filled";
+    case CAGNode::Kind::Constraint:
+      return "shape=record,style=dotted";
+    }
+  }
+};
+
+} // end namespace llvm
+
+namespace {
+
+class InferQuantizedTypesPass : public ModulePass<InferQuantizedTypesPass> {
+public:
+  InferQuantizedTypesPass() = default;
+  InferQuantizedTypesPass(SolverContext &solverContext,
+                          const TargetConfiguration &config)
+      : explicitSolverContext(&solverContext), explicitConfig(&config) {}
+  void runOnModule() override;
+  void runWithConfig(SolverContext &solverContext,
+                     const TargetConfiguration &config);
+
+  void transformOperandType(CAGOperandAnchor *anchor, Type newType);
+  void transformResultType(CAGResultAnchor *anchor, Type newType);
+
+private:
+  SolverContext *explicitSolverContext = nullptr;
+  const TargetConfiguration *explicitConfig = nullptr;
+};
+
+} // end anonymous namespace
+
+/// Maximum number of propagation rounds to run to converge the CAG before
+/// signalling an error.
+static const int kMaximumPropagationRounds = 1000;
+
+static LogicalResult validateTypeConversion(Type newType, Type origType,
+                                            Operation *op) {
+  if (!newType) {
+    return op->emitOpError() << "unsupported type conversion from " << newType;
+  }
+  return success();
+}
+
+void InferQuantizedTypesPass::runOnModule() {
+  if (explicitSolverContext && explicitConfig) {
+    // If explicitly constructed with a config and context.
+    runWithConfig(*explicitSolverContext, *explicitConfig);
+    return;
+  }
+
+  // For global pass registration, use defaults.
+  SolverContext solverContext(*getModule().getContext());
+  auto config = FxpMathTargetConfig::create(solverContext);
+  runWithConfig(solverContext, *config);
+}
+
+void InferQuantizedTypesPass::runWithConfig(SolverContext &solverContext,
+                                            const TargetConfiguration &config) {
+  CAGSlice cag(solverContext);
+  for (auto &f : getModule()) {
+    f.walk([&cag, &config](Operation *op) { config.handleOp(op, cag); });
+  }
+  config.finalizeAnchors(cag);
+
+  // Propagate.
+  int propRound;
+  for (propRound = kMaximumPropagationRounds; propRound > 0; --propRound) {
+    auto propCount = cag.propagate(config);
+    if (propCount == 0)
+      break;
+  }
+  if (propRound == 0) {
+    getContext().emitError(
+        UnknownLoc::get(&getContext()),
+        "exceeded maximum number of solver iterations (infinite loop?)");
+    return;
+  }
+
+  // TODO: Only dump the GraphViz if a flag is set and move to a utility.
+  // GraphViz.
+  if (!solverContext.getDebugCAGDotPath().empty()) {
+    auto actFileName =
+        llvm::WriteGraph(const_cast<const CAGSlice *>(&cag), "CAG",
+                         /*ShortNames=*/false,
+                         /*Title=*/"CAG",
+                         /*Filename=*/solverContext.getDebugCAGDotPath());
+    llvm::errs() << "Wrote graphviz file: " << actFileName << "\n";
+  }
+
+  // Start transforming the types in order of anchor type (results, then
+  // operands).
+  // Apply result types.
+  for (auto *node : cag) {
+    auto anchorNode = llvm::dyn_cast<CAGResultAnchor>(node);
+    if (!anchorNode)
+      continue;
+    if (Type newType = anchorNode->getTransformedType())
+      transformResultType(anchorNode, newType);
+  }
+
+  // Apply operand types.
+  for (auto *node : cag) {
+    auto anchorNode = llvm::dyn_cast<CAGOperandAnchor>(node);
+    if (!anchorNode)
+      continue;
+    if (Type newType = anchorNode->getTransformedType())
+      transformOperandType(anchorNode, newType);
+  }
+}
+
+void InferQuantizedTypesPass::transformOperandType(CAGOperandAnchor *anchor,
+                                                   Type newType) {
+  Value *inputValue = anchor->getValue();
+  Operation *op = anchor->getOp();
+  FuncBuilder b(op->getBlock(), Block::iterator(op));
+
+  SmallVector<Value *, 1> removeValuesIfDead;
+
+  // Because we've already run the result transforms at this phase, it is
+  // very likely that inputValue points to a dcast op whose input matches
+  // our type. We detect that situation and route around just to save some
+  // bulk in the IR.
+  Value *newTypedInputValue = inputValue;
+  auto inputDcastOp =
+      dyn_cast_or_null<DequantizeCastOp>(inputValue->getDefiningOp());
+  if (inputDcastOp && inputDcastOp.arg()->getType() == newType) {
+    // Can just use the dcast's input value.
+    newTypedInputValue = inputDcastOp.arg();
+    removeValuesIfDead.push_back(inputDcastOp);
+  } else {
+    // Need to synthesize a qcast.
+    newTypedInputValue =
+        b.create<QuantizeCastOp>(op->getLoc(), newType, inputValue);
+  }
+
+  switch (anchor->getTypeTransformRule()) {
+  default:
+    op->emitOpError("unsupported type transform rule");
+    break;
+  case CAGAnchorNode::TypeTransformRule::Direct:
+    anchor->getOp()->setOperand(anchor->getOperandIdx(), newTypedInputValue);
+    break;
+
+  case CAGAnchorNode::TypeTransformRule::DirectStorage: {
+    Type storageType = QuantizedType::castToStorageType(newType);
+    if (failed(validateTypeConversion(storageType, newType, op)))
+      return;
+    anchor->getOp()->setOperand(
+        anchor->getOperandIdx(),
+        b.create<StorageCastOp>(op->getLoc(), storageType, newTypedInputValue));
+    break;
+  }
+
+  case CAGAnchorNode::TypeTransformRule::ExpressedOnly:
+    // Leave the anchor as-is and just cast in/out after it.
+    anchor->getOp()->setOperand(
+        anchor->getOperandIdx(),
+        b.create<DequantizeCastOp>(op->getLoc(), anchor->getOriginalType(),
+                                   newTypedInputValue));
+    break;
+  }
+
+  for (Value *removeValueIfDead : removeValuesIfDead) {
+    if (removeValueIfDead->use_empty()) {
+      removeValueIfDead->getDefiningOp()->erase();
+    }
+  }
+}
+
+void InferQuantizedTypesPass::transformResultType(CAGResultAnchor *anchor,
+                                                  Type newType) {
+  Value *origResultValue = anchor->getValue();
+  Operation *op = origResultValue->getDefiningOp();
+  FuncBuilder b(op->getBlock(), ++Block::iterator(op));
+
+  Value *replacedResultValue = nullptr;
+  Value *newResultValue = nullptr;
+  switch (anchor->getTypeTransformRule()) {
+  default:
+    op->emitOpError("unsupported type transform rule");
+    return;
+  case CAGAnchorNode::TypeTransformRule::Direct:
+    origResultValue->setType(newType);
+    replacedResultValue = newResultValue = b.create<DequantizeCastOp>(
+        op->getLoc(), anchor->getOriginalType(), origResultValue);
+    break;
+
+  case CAGAnchorNode::TypeTransformRule::DirectStorage: {
+    Type storageType = QuantizedType::castToStorageType(newType);
+    if (failed(validateTypeConversion(storageType, newType, op)))
+      return;
+    origResultValue->setType(storageType);
+    replacedResultValue =
+        b.create<StorageCastOp>(op->getLoc(), newType, origResultValue);
+    newResultValue = b.create<DequantizeCastOp>(
+        op->getLoc(), anchor->getOriginalType(), replacedResultValue);
+    break;
+  }
+
+  case CAGAnchorNode::TypeTransformRule::ExpressedOnly:
+    // Leave the anchor as-is and just cast in/out after it.
+    replacedResultValue =
+        b.create<QuantizeCastOp>(op->getLoc(), newType, origResultValue);
+    newResultValue = b.create<DequantizeCastOp>(
+        op->getLoc(), anchor->getOriginalType(), replacedResultValue);
+    break;
+  }
+
+  if (replacedResultValue) {
+    // Transform:
+    //   origResultValue -->  replaceResultValue -> newResultValue
+    //                   \->  [original uses]
+    // To:
+    //   origResultValue -> replaceResultValue ->
+    //                      newResultValue -> [original uses]
+    // Note that replaceResultValue may equal newResultValue or there may
+    // be operands between the two.
+    origResultValue->replaceAllUsesWith(newResultValue);
+    replacedResultValue->getDefiningOp()->replaceUsesOfWith(newResultValue,
+                                                            origResultValue);
+  }
+}
+
+ModulePassBase *mlir::quantizer::createInferQuantizedTypesPass(
+    SolverContext &solverContext, const TargetConfiguration &config) {
+  return new InferQuantizedTypesPass(solverContext, config);
+}
+
+static PassRegistration<InferQuantizedTypesPass>
+    pass("quantizer-infer-quantized-types",
+         "Infers quantized types for a module");
diff --git a/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp b/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp
new file mode 100644 (file)
index 0000000..ed3b095
--- /dev/null
@@ -0,0 +1,79 @@
+//===- RemoveInstrumentationPass.cpp - Removes instrumentation ------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines a pass to remove any instrumentation ops. It is often one
+// of the final steps when performing quantization and is run after any
+// decisions requiring instrumentation have been made.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/QuantOps/QuantOps.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Quantizer/Transforms/Passes.h"
+
+using namespace mlir;
+using namespace mlir::quantizer;
+using namespace mlir::quant;
+
+namespace {
+
+class RemoveInstrumentationPass
+    : public FunctionPass<RemoveInstrumentationPass> {
+  void runOnFunction() override;
+};
+
+template <typename OpTy>
+class RemoveIdentityOpRewrite : public RewritePattern {
+public:
+  RemoveIdentityOpRewrite(MLIRContext *context)
+      : RewritePattern(OpTy::getOperationName(), 1, context) {}
+
+  PatternMatchResult matchAndRewrite(Operation *op,
+                                     PatternRewriter &rewriter) const override {
+    assert(op->getNumOperands() == 1);
+    assert(op->getNumResults() == 1);
+
+    rewriter.replaceOp(op, op->getOperand(0));
+    return matchSuccess();
+  }
+};
+
+} // end anonymous namespace
+
+void RemoveInstrumentationPass::runOnFunction() {
+  OwningRewritePatternList patterns;
+  auto &func = getFunction();
+  auto *context = &getContext();
+  patterns.push_back(
+      llvm::make_unique<RemoveIdentityOpRewrite<StatisticsOp>>(context));
+  patterns.push_back(
+      llvm::make_unique<RemoveIdentityOpRewrite<StatisticsRefOp>>(context));
+  patterns.push_back(
+      llvm::make_unique<RemoveIdentityOpRewrite<CoupledRefOp>>(context));
+  applyPatternsGreedily(func, std::move(patterns));
+}
+
+FunctionPassBase *mlir::quantizer::createRemoveInstrumentationPass() {
+  return new RemoveInstrumentationPass();
+}
+
+static PassRegistration<RemoveInstrumentationPass>
+    pass("quantizer-remove-instrumentation",
+         "Removes instrumentation and hints which have no effect on final "
+         "execution");
diff --git a/mlir/test/Quantizer/matmul.mlir b/mlir/test/Quantizer/matmul.mlir
new file mode 100644 (file)
index 0000000..c207162
--- /dev/null
@@ -0,0 +1,51 @@
+// RUN: mlir-opt %s -quantizer-infer-quantized-types -quant-convert-const -quantizer-remove-instrumentation -canonicalize -split-input-file | FileCheck %s
+
+// ----
+// A matmul without fused clamp or bias.
+// CHECK-LABEL: @matmul
+// CHECK: %cst = constant dense<tensor<3x5xi8>
+// CHECK-NEXT: %0 = "quant.qcast"(%arg0) : (tensor<300x3xf32>) -> tensor<300x3x!quant.uniform<i8:f32, 0.037564418067230126:35>>
+// CHECK-NEXT: %1 = "quant.scast"(%cst) : (tensor<3x5xi8>) -> tensor<3x5x!quant.uniform<i8:f32, 0.0062823070315864236:-1>>
+// CHECK-NEXT: %2 = "fxpmath.real_matmul"(%0, %1) : (tensor<300x3x!quant.uniform<i8:f32, 0.037564418067230126:35>>, tensor<3x5x!quant.uniform<i8:f32, 0.0062823070315864236:-1>>) -> tensor<300x5x!quant.uniform<i8:f32, 0.0629921259842528:-1>>
+// CHECK-NEXT: %3 = "quant.dcast"(%2) : (tensor<300x5x!quant.uniform<i8:f32, 0.0629921259842528:-1>>) -> tensor<300x5xf32>
+func @matmul(%arg0: tensor<300x3xf32>) -> tensor<300x5xf32> {
+  %0 = "quant.stats"(%arg0) {layerStats: dense<tensor<2xf32>, [-6.123e+00, 3.45e+00]>} : (tensor<300x3xf32>) -> tensor<300x3xf32>
+  %cst = constant  {name: "constant.35"} dense<tensor<3x5xf32>, [[-1.060230e-01, 1.215050e-01, 8.002390e-01, -7.688850e-01, 0.0966112986], [6.890140e-01, -4.070560e-01, -0.797852993, 3.789250e-03, -2.088810e-01], [-6.085290e-01, 2.766170e-02, 2.685570e-01, 5.774010e-01, -4.284370e-01]]>
+  %1 = "fxpmath.real_matmul"(%0, %cst) : (tensor<300x3xf32>, tensor<3x5xf32>) -> tensor<300x5xf32>
+  %2 = "quant.stats"(%1) {layerStats: dense<tensor<2xf32>, [-8.000000e+00, 8.000000e+00]>} : (tensor<300x5xf32>) -> tensor<300x5xf32>
+  return %2 : tensor<300x5xf32>
+}
+
+// ----
+// A matmul with fused clamp which serves as statistics for the result.
+// CHECK-LABEL: @matmul_clamp
+// CHECK: %cst = constant dense<tensor<3x5xi8>
+// CHECK-NEXT: %0 = "quant.qcast"(%arg0) : (tensor<300x3xf32>) -> tensor<300x3x!quant.uniform<i8:f32, 0.037564418067230126:35>>
+// CHECK-NEXT: %1 = "quant.scast"(%cst) : (tensor<3x5xi8>) -> tensor<3x5x!quant.uniform<i8:f32, 0.0062823070315864236:-1>>
+// CHECK-NEXT: %2 = "fxpmath.real_matmul"(%0, %1) {clamp_max: 6.100000e+00 : f64, clamp_min: -1.225000e+01 : f64} : (tensor<300x3x!quant.uniform<i8:f32, 0.037564418067230126:35>>, tensor<3x5x!quant.uniform<i8:f32, 0.0062823070315864236:-1>>) -> tensor<300x5x!quant.uniform<i8:f32, 0.072058823529412216:42>>
+// CHECK-NEXT: %3 = "quant.dcast"(%2) : (tensor<300x5x!quant.uniform<i8:f32, 0.072058823529412216:42>>) -> tensor<300x5xf32>
+func @matmul_clamp(%arg0: tensor<300x3xf32>) -> tensor<300x5xf32> {
+  %0 = "quant.stats"(%arg0) {layerStats: dense<tensor<2xf32>, [-6.123e+00, 3.45e+00]>} : (tensor<300x3xf32>) -> tensor<300x3xf32>
+  %cst = constant  {name: "constant.35"} dense<tensor<3x5xf32>, [[-1.060230e-01, 1.215050e-01, 8.002390e-01, -7.688850e-01, 0.0966112986], [6.890140e-01, -4.070560e-01, -0.797852993, 3.789250e-03, -2.088810e-01], [-6.085290e-01, 2.766170e-02, 2.685570e-01, 5.774010e-01, -4.284370e-01]]>
+  %1 = "fxpmath.real_matmul"(%0, %cst) {clamp_max: 6.10, clamp_min: -12.25} : (tensor<300x3xf32>, tensor<3x5xf32>) -> tensor<300x5xf32>
+  return %1 : tensor<300x5xf32>
+}
+
+// ----
+// A matmul with bias and clamp.
+// CHECK-LABEL: @matmul_add_clamp
+// CHECK: %cst = constant dense<tensor<3x5xi8>
+// CHECK-NEXT: %cst_0 = constant dense<tensor<5xi32>, [14, 28, 42, 56, 69]>
+// CHECK-NEXT: %0 = "quant.qcast"(%arg0) : (tensor<300x3xf32>) -> tensor<300x3x!quant.uniform<i8:f32, 0.037564418067230126:35>>
+// CHECK-NEXT: %1 = "quant.scast"(%cst) : (tensor<3x5xi8>) -> tensor<3x5x!quant.uniform<i8:f32, 0.0062823070315864236:-1>>
+// CHECK-NEXT: %2 = "quant.scast"(%cst_0) : (tensor<5xi32>) -> tensor<5x!quant.uniform<i32:f32, 0.072058823529412216>>
+// CHECK-NEXT: %3 = "fxpmath.real_matmul_bias"(%0, %1, %2) {clamp_max: 6.100000e+00 : f64, clamp_min: -1.225000e+01 : f64} : (tensor<300x3x!quant.uniform<i8:f32, 0.037564418067230126:35>>, tensor<3x5x!quant.uniform<i8:f32, 0.0062823070315864236:-1>>, tensor<5x!quant.uniform<i32:f32, 0.072058823529412216>>) -> tensor<300x5x!quant.uniform<i8:f32, 0.072058823529412216:42>>
+// CHECK-NEXT: %4 = "quant.dcast"(%3) : (tensor<300x5x!quant.uniform<i8:f32, 0.072058823529412216:42>>) -> tensor<300x5xf32>
+func @matmul_add_clamp(%arg0: tensor<300x3xf32>) -> tensor<300x5xf32> {
+  %0 = "quant.stats"(%arg0) {layerStats: dense<tensor<2xf32>, [-6.123e+00, 3.45e+00]>} : (tensor<300x3xf32>) -> tensor<300x3xf32>
+  %cst = constant  {name: "constant.35"} dense<tensor<3x5xf32>, [[-1.060230e-01, 1.215050e-01, 8.002390e-01, -7.688850e-01, 0.0966112986], [6.890140e-01, -4.070560e-01, -0.797852993, 3.789250e-03, -2.088810e-01], [-6.085290e-01, 2.766170e-02, 2.685570e-01, 5.774010e-01, -4.284370e-01]]>
+  %cst_0 = constant  {name: "constant.37"} dense<tensor<5xf32>, [1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00]>
+  %1 = "fxpmath.real_matmul_bias"(%0, %cst, %cst_0) {clamp_max: 6.10, clamp_min: -12.25} : (tensor<300x3xf32>, tensor<3x5xf32>, tensor<5xf32>) -> tensor<300x5xf32>
+  return %1 : tensor<300x5xf32>
+}
+
diff --git a/mlir/test/Quantizer/remove-instrumentation.mlir b/mlir/test/Quantizer/remove-instrumentation.mlir
new file mode 100644 (file)
index 0000000..659104b
--- /dev/null
@@ -0,0 +1,15 @@
+// RUN: mlir-opt %s -quantizer-remove-instrumentation -verify -split-input-file | FileCheck %s
+
+// -----
+// CHECK-LABEL: remove_ops
+func @remove_ops(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
+  %0 = "quant.stats"(%arg0) {
+    layerStats: dense<tensor<2xf32>, [-1.0, 1.0]>
+  } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
+  %1 = "quant.coupled_ref"(%0) { coupledKey: "foobar" } :
+      (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
+  %2 = "quant.stats_ref"(%1) { statsKey: "foobar" } :
+      (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
+  // CHECK: return %arg0 : tensor<8x4x3xf32>
+  return %2 : tensor<8x4x3xf32>
+}
index 0bc5d4f..636d186 100644 (file)
@@ -26,6 +26,7 @@ set(LIBS
   MLIRNVVMIR
   MLIRParser
   MLIRPass
+  MLIRQuantizerTransforms
   MLIRQuantOps
   MLIRStandardOps
   MLIRTransforms