Merge ext/cfg/ml function printing logic in the AsmPrinter (shrinking it
authorChris Lattner <clattner@google.com>
Fri, 28 Dec 2018 19:41:56 +0000 (11:41 -0800)
committerjpienaar <jpienaar@google.com>
Fri, 29 Mar 2019 21:43:29 +0000 (14:43 -0700)
by about 100 LOC), without changing any existing behavior.

This is step 20/n towards merging instructions and statements, NFC.

PiperOrigin-RevId: 227155000

mlir/lib/IR/AsmPrinter.cpp

index 19943573bc344fafce023e79e4956fa6a6bbf3f1..2ff7220f8eee6e6e3793b497ea34a40a9a9785ba 100644 (file)
@@ -276,9 +276,6 @@ public:
   void printAttribute(Attribute attr);
   void printType(Type type);
   void print(const Function *fn);
-  void printExt(const Function *fn);
-  void printCFG(const Function *fn);
-  void printML(const Function *fn);
 
   void printAffineMap(AffineMap map);
   void printAffineExpr(AffineExpr expr);
@@ -289,8 +286,6 @@ protected:
   raw_ostream &os;
   ModuleState &state;
 
-  void printFunctionSignature(const Function *fn);
-  void printFunctionAttributes(const Function *fn);
   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
                              ArrayRef<const char *> elidedAttrs = {});
   void printFunctionResultType(FunctionType type);
@@ -312,18 +307,6 @@ protected:
 };
 } // end anonymous namespace
 
-// Prints function with initialized module state.
-void ModulePrinter::print(const Function *fn) {
-  switch (fn->getKind()) {
-  case Function::Kind::ExtFunc:
-    return printExt(fn);
-  case Function::Kind::CFGFunc:
-    return printCFG(fn);
-  case Function::Kind::MLFunc:
-    return printML(fn);
-  }
-}
-
 // Prints affine map identifier.
 void ModulePrinter::printAffineMapId(int affineMapId) const {
   os << "#map" << affineMapId;
@@ -872,24 +855,6 @@ void ModulePrinter::printFunctionResultType(FunctionType type) {
   }
 }
 
-void ModulePrinter::printFunctionAttributes(const Function *fn) {
-  auto attrs = fn->getAttrs();
-  if (attrs.empty())
-    return;
-  os << "\n  attributes ";
-  printOptionalAttrDict(attrs);
-}
-
-void ModulePrinter::printFunctionSignature(const Function *fn) {
-  auto type = fn->getType();
-
-  os << "@" << fn->getName() << '(';
-  interleaveComma(type.getInputs(), [&](Type eltType) { printType(eltType); });
-  os << ')';
-
-  printFunctionResultType(type);
-}
-
 void ModulePrinter::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
                                           ArrayRef<const char *> elidedAttrs) {
   // If there are no attributes, then there is nothing to be done.
@@ -929,20 +894,27 @@ void ModulePrinter::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
   os << '}';
 }
 
-void ModulePrinter::printExt(const Function *fn) {
-  os << "extfunc ";
-  printFunctionSignature(fn);
-  printFunctionAttributes(fn);
-  os << '\n';
-}
-
 namespace {
 
 // FunctionPrinter contains common functionality for printing
 // CFG and ML functions.
 class FunctionPrinter : public ModulePrinter, private OpAsmPrinter {
 public:
-  FunctionPrinter(const ModulePrinter &other) : ModulePrinter(other) {}
+  FunctionPrinter(const Function *function, const ModulePrinter &other);
+
+  // Prints the function as a whole.
+  void print();
+
+  // Print the function signature.
+  void printMLFunctionSignature();
+  void printOtherFunctionSignature();
+
+  // Methods to print statements.
+  void print(const Statement *stmt);
+  void print(const OperationInst *inst);
+  void print(const ForStmt *stmt);
+  void print(const IfStmt *stmt);
+  void print(const StmtBlock *block);
 
   void printOperation(const OperationInst *op);
   void printDefaultOp(const OperationInst *op);
@@ -963,9 +935,6 @@ public:
   void printFunctionReference(const Function *func) {
     return ModulePrinter::printFunctionReference(func);
   }
-  void printFunctionAttributes(const Function *func) {
-    return ModulePrinter::printFunctionAttributes(func);
-  }
   void printOperand(const Value *value) { printValueID(value); }
 
   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
@@ -975,428 +944,207 @@ public:
 
   enum { nameSentinel = ~0U };
 
-protected:
-  void numberValueID(const Value *value) {
-    assert(!valueIDs.count(value) && "Value numbered multiple times");
-
-    SmallString<32> specialNameBuffer;
-    llvm::raw_svector_ostream specialName(specialNameBuffer);
-
-    // Give constant integers special names.
-    if (auto *op = value->getDefiningInst()) {
-      if (auto intOp = op->dyn_cast<ConstantIntOp>()) {
-        // i1 constants get special names.
-        if (intOp->getType().isInteger(1)) {
-          specialName << (intOp->getValue() ? "true" : "false");
-        } else {
-          specialName << 'c' << intOp->getValue() << '_' << intOp->getType();
-        }
-      } else if (auto intOp = op->dyn_cast<ConstantIndexOp>()) {
-        specialName << 'c' << intOp->getValue();
-      } else if (auto constant = op->dyn_cast<ConstantOp>()) {
-        if (constant->getValue().isa<FunctionAttr>())
-          specialName << 'f';
-        else
-          specialName << "cst";
-      }
-    }
-
-    if (specialNameBuffer.empty()) {
-      switch (value->getKind()) {
-      case Value::Kind::BlockArgument:
-        // If this is an argument to the function, give it an 'arg' name.
-        if (auto *block = cast<BlockArgument>(value)->getOwner())
-          if (auto *fn = block->getFunction())
-            if (&fn->getBlockList().front() == block) {
-              specialName << "arg" << nextArgumentID++;
-              break;
-            }
-        // Otherwise number it normally.
-        valueIDs[value] = nextValueID++;
-        return;
-      case Value::Kind::InstResult:
-        // This is an uninteresting result, give it a boring number and be
-        // done with it.
-        valueIDs[value] = nextValueID++;
-        return;
-      case Value::Kind::ForStmt:
-        specialName << 'i' << nextLoopID++;
-        break;
-      }
-    }
-
-    // Ok, this value had an interesting name.  Remember it with a sentinel.
-    valueIDs[value] = nameSentinel;
-
-    // Remember that we've used this name, checking to see if we had a conflict.
-    auto insertRes = usedNames.insert(specialName.str());
-    if (insertRes.second) {
-      // If this is the first use of the name, then we're successful!
-      valueNames[value] = insertRes.first->first();
-      return;
-    }
+  void printBBName(const BasicBlock *block) { os << "bb" << getBBID(block); }
 
-    // Otherwise, we had a conflict - probe until we find a unique name.  This
-    // is guaranteed to terminate (and usually in a single iteration) because it
-    // generates new names by incrementing nextConflictID.
-    while (1) {
-      std::string probeName =
-          specialName.str().str() + "_" + llvm::utostr(nextConflictID++);
-      insertRes = usedNames.insert(probeName);
-      if (insertRes.second) {
-        // If this is the first use of the name, then we're successful!
-        valueNames[value] = insertRes.first->first();
-        return;
-      }
-    }
+  unsigned getBBID(const BasicBlock *block) {
+    auto it = basicBlockIDs.find(block);
+    assert(it != basicBlockIDs.end() && "Block not in this function?");
+    return it->second;
   }
 
-  void printValueID(const Value *value, bool printResultNo = true) const {
-    int resultNo = -1;
-    auto lookupValue = value;
-
-    // If this is a reference to the result of a multi-result instruction or
-    // statement, print out the # identifier and make sure to map our lookup
-    // to the first result of the instruction.
-    if (auto *result = dyn_cast<InstResult>(value)) {
-      if (result->getOwner()->getNumResults() != 1) {
-        resultNo = result->getResultNumber();
-        lookupValue = result->getOwner()->getResult(0);
-      }
-    } else if (auto *result = dyn_cast<InstResult>(value)) {
-      if (result->getOwner()->getNumResults() != 1) {
-        resultNo = result->getResultNumber();
-        lookupValue = result->getOwner()->getResult(0);
-      }
-    }
+  void printSuccessorAndUseList(const OperationInst *term,
+                                unsigned index) override;
 
-    auto it = valueIDs.find(lookupValue);
-    if (it == valueIDs.end()) {
-      os << "<<INVALID SSA VALUE>>";
-      return;
-    }
+  // Print if and loop bounds.
+  void printDimAndSymbolList(ArrayRef<InstOperand> ops, unsigned numDims);
+  void printBound(AffineBound bound, const char *prefix);
 
-    os << '%';
-    if (it->second != nameSentinel) {
-      os << it->second;
-    } else {
-      auto nameIt = valueNames.find(lookupValue);
-      assert(nameIt != valueNames.end() && "Didn't have a name entry?");
-      os << nameIt->second;
-    }
+  // Number of spaces used for indenting nested statements.
+  const static unsigned indentWidth = 2;
 
-    if (resultNo != -1 && printResultNo)
-      os << '#' << resultNo;
-  }
+protected:
+  void numberValueID(const Value *value);
+  void numberValuesInBlock(const StmtBlock &block);
+  void printValueID(const Value *value, bool printResultNo = true) const;
 
 private:
+  const Function *function;
+
   /// This is the value ID for each SSA value in the current function.  If this
   /// returns ~0, then the valueID has an entry in valueNames.
   DenseMap<const Value *, unsigned> valueIDs;
   DenseMap<const Value *, StringRef> valueNames;
 
+  /// This is the block ID for each  block in the current function.
+  DenseMap<const BasicBlock *, unsigned> basicBlockIDs;
+
   /// This keeps track of all of the non-numeric names that are in flight,
   /// allowing us to check for duplicates.
   llvm::StringSet<> usedNames;
 
+  // This is the current indentation level for nested structures.
+  unsigned currentIndent = 0;
+
   /// This is the next value ID to assign in numbering.
   unsigned nextValueID = 0;
   /// This is the ID to assign to the next induction variable.
   unsigned nextLoopID = 0;
   /// This is the next ID to assign to a Function argument.
   unsigned nextArgumentID = 0;
-
   /// This is the next ID to assign when a name conflict is detected.
   unsigned nextConflictID = 0;
+  /// This is the next block ID to assign in numbering.
+  unsigned nextBlockID = 0;
 };
 } // end anonymous namespace
 
-void FunctionPrinter::printOperation(const OperationInst *op) {
-  if (op->getNumResults()) {
-    printValueID(op->getResult(0), /*printResultNo=*/false);
-    os << " = ";
-  }
-
-  // Check to see if this is a known operation.  If so, use the registered
-  // custom printer hook.
-  if (auto *opInfo = op->getAbstractOperation()) {
-    opInfo->printAssembly(op, this);
-    return;
-  }
+FunctionPrinter::FunctionPrinter(const Function *function,
+                                 const ModulePrinter &other)
+    : ModulePrinter(other), function(function) {
 
-  // Otherwise use the standard verbose printing approach.
-  printDefaultOp(op);
-}
-
-void FunctionPrinter::printDefaultOp(const OperationInst *op) {
-  os << '"';
-  printEscapedString(op->getName().getStringRef(), os);
-  os << "\"(";
-
-  interleaveComma(op->getOperands(),
-                  [&](const Value *value) { printValueID(value); });
-
-  os << ')';
-  auto attrs = op->getAttrs();
-  printOptionalAttrDict(attrs);
-
-  // Print the type signature of the operation.
-  os << " : (";
-  interleaveComma(op->getOperands(),
-                  [&](const Value *value) { printType(value->getType()); });
-  os << ") -> ";
-
-  if (op->getNumResults() == 1) {
-    printType(op->getResult(0)->getType());
-  } else {
-    os << '(';
-    interleaveComma(op->getResults(),
-                    [&](const Value *result) { printType(result->getType()); });
-    os << ')';
-  }
+  for (auto &block : *function)
+    numberValuesInBlock(block);
 }
 
-//===----------------------------------------------------------------------===//
-// CFG Function printing
-//===----------------------------------------------------------------------===//
-
-namespace {
-class CFGFunctionPrinter : public FunctionPrinter {
-public:
-  CFGFunctionPrinter(const Function *function, const ModulePrinter &other);
-
-  const Function *getFunction() const { return function; }
-
-  void print();
-  void print(const BasicBlock *block);
-
-  void print(const Instruction *inst);
-
-  void printSuccessorAndUseList(const OperationInst *term, unsigned index);
-
-  void printBBName(const BasicBlock *block) { os << "bb" << getBBID(block); }
-
-  unsigned getBBID(const BasicBlock *block) {
-    auto it = basicBlockIDs.find(block);
-    assert(it != basicBlockIDs.end() && "Block not in this function?");
-    return it->second;
-  }
-
-private:
-  const Function *function;
-  DenseMap<const BasicBlock *, unsigned> basicBlockIDs;
-
-  void numberValuesInBlock(const BasicBlock *block);
-
-  template <typename Range> void printBranchOperands(const Range &range);
-};
-} // end anonymous namespace
-
-CFGFunctionPrinter::CFGFunctionPrinter(const Function *function,
-                                       const ModulePrinter &other)
-    : FunctionPrinter(other), function(function) {
-  // Each basic block gets a unique ID per function.
-  unsigned blockID = 0;
-  for (auto &block : *function) {
-    basicBlockIDs[&block] = blockID++;
-    numberValuesInBlock(&block);
-  }
-}
+/// Number all of the SSA values in the specified block list.
+void FunctionPrinter::numberValuesInBlock(const StmtBlock &block) {
+  // Each block gets a unique ID, and all of the instructions within it get
+  // numbered as well.
+  basicBlockIDs[&block] = nextBlockID++;
 
-/// Number all of the SSA values in the specified basic block.
-void CFGFunctionPrinter::numberValuesInBlock(const BasicBlock *block) {
-  for (auto *arg : block->getArguments()) {
+  for (auto *arg : block.getArguments())
     numberValueID(arg);
-  }
-  for (auto &op : *block) {
+
+  for (auto &inst : block) {
     // We number instruction that have results, and we only number the first
     // result.
-    if (auto *opInst = dyn_cast<OperationInst>(&op))
+    switch (inst.getKind()) {
+    case Statement::Kind::OperationInst: {
+      auto *opInst = cast<OperationInst>(&inst);
       if (opInst->getNumResults() != 0)
         numberValueID(opInst->getResult(0));
+      break;
+    }
+    case Statement::Kind::For: {
+      auto *forInst = cast<ForStmt>(&inst);
+      // Number the induction variable.
+      numberValueID(forInst);
+      // Recursively number the stuff in the body.
+      numberValuesInBlock(*forInst->getBody());
+      break;
+    }
+    case Statement::Kind::If: {
+      auto *ifInst = cast<IfStmt>(&inst);
+      numberValuesInBlock(*ifInst->getThen());
+      if (auto *elseBlock = ifInst->getElse())
+        numberValuesInBlock(*elseBlock);
+    }
+    }
   }
-
-  // Terminators do not define values.
 }
 
-void CFGFunctionPrinter::print() {
-  os << "cfgfunc ";
-  printFunctionSignature(getFunction());
-  printFunctionAttributes(getFunction());
-  os << " {\n";
-
-  for (auto &block : *function)
-    print(&block);
-  os << "}\n\n";
-}
+void FunctionPrinter::numberValueID(const Value *value) {
+  assert(!valueIDs.count(value) && "Value numbered multiple times");
 
-void CFGFunctionPrinter::print(const BasicBlock *block) {
-  printBBName(block);
+  SmallString<32> specialNameBuffer;
+  llvm::raw_svector_ostream specialName(specialNameBuffer);
 
-  if (!block->args_empty()) {
-    os << '(';
-    interleaveComma(block->getArguments(), [&](const BlockArgument *arg) {
-      printValueID(arg);
-      os << ": ";
-      printType(arg->getType());
-    });
-    os << ')';
+  // Give constant integers special names.
+  if (auto *op = value->getDefiningInst()) {
+    if (auto intOp = op->dyn_cast<ConstantIntOp>()) {
+      // i1 constants get special names.
+      if (intOp->getType().isInteger(1)) {
+        specialName << (intOp->getValue() ? "true" : "false");
+      } else {
+        specialName << 'c' << intOp->getValue() << '_' << intOp->getType();
+      }
+    } else if (auto intOp = op->dyn_cast<ConstantIndexOp>()) {
+      specialName << 'c' << intOp->getValue();
+    } else if (auto constant = op->dyn_cast<ConstantOp>()) {
+      if (constant->getValue().isa<FunctionAttr>())
+        specialName << 'f';
+      else
+        specialName << "cst";
+    }
   }
-  os << ':';
-
-  // Print out some context information about the predecessors of this block.
-  if (!block->getFunction()) {
-    os << "\t// block is not in a function!";
-  } else if (block->hasNoPredecessors()) {
-    // Don't print "no predecessors" for the entry block.
-    if (block != &block->getFunction()->front())
-      os << "\t// no predecessors";
-  } else if (auto *pred = block->getSinglePredecessor()) {
-    os << "\t// pred: ";
-    printBBName(pred);
-  } else {
-    // We want to print the predecessors in increasing numeric order, not in
-    // whatever order the use-list is in, so gather and sort them.
-    SmallVector<unsigned, 4> predIDs;
-    for (auto *pred : block->getPredecessors())
-      predIDs.push_back(getBBID(pred));
-    llvm::array_pod_sort(predIDs.begin(), predIDs.end());
 
-    os << "\t// " << predIDs.size() << " preds: ";
-
-    interleaveComma(predIDs, [&](unsigned predID) { os << "bb" << predID; });
+  if (specialNameBuffer.empty()) {
+    switch (value->getKind()) {
+    case Value::Kind::BlockArgument:
+      // If this is an argument to the function, give it an 'arg' name.
+      if (auto *block = cast<BlockArgument>(value)->getOwner())
+        if (auto *fn = block->getFunction())
+          if (&fn->getBlockList().front() == block) {
+            specialName << "arg" << nextArgumentID++;
+            break;
+          }
+      // Otherwise number it normally.
+      valueIDs[value] = nextValueID++;
+      return;
+    case Value::Kind::InstResult:
+      // This is an uninteresting result, give it a boring number and be
+      // done with it.
+      valueIDs[value] = nextValueID++;
+      return;
+    case Value::Kind::ForStmt:
+      specialName << 'i' << nextLoopID++;
+      break;
+    }
   }
-  os << '\n';
 
-  for (auto &inst : block->getStatements()) {
-    os << "  ";
-    print(&inst);
-    os << '\n';
-  }
-}
+  // Ok, this value had an interesting name.  Remember it with a sentinel.
+  valueIDs[value] = nameSentinel;
 
-void CFGFunctionPrinter::print(const Instruction *inst) {
-  if (!inst) {
-    os << "<<null instruction>>\n";
+  // Remember that we've used this name, checking to see if we had a conflict.
+  auto insertRes = usedNames.insert(specialName.str());
+  if (insertRes.second) {
+    // If this is the first use of the name, then we're successful!
+    valueNames[value] = insertRes.first->first();
     return;
   }
-  auto *opInst = dyn_cast<OperationInst>(inst);
-  assert(opInst && "IfStmt/ForStmt aren't supported in CFG functions yet");
-  printOperation(opInst);
-}
 
-// Print the operands from "container" to "os", followed by a colon and their
-// respective types, everything in parentheses.  Do nothing if the container is
-// empty.
-template <typename Range>
-void CFGFunctionPrinter::printBranchOperands(const Range &range) {
-  if (llvm::empty(range))
-    return;
-
-  os << '(';
-  interleaveComma(range,
-                  [this](const Value *operand) { printValueID(operand); });
-  os << " : ";
-  interleaveComma(
-      range, [this](const Value *operand) { printType(operand->getType()); });
-  os << ')';
-}
-
-void CFGFunctionPrinter::printSuccessorAndUseList(const OperationInst *term,
-                                                  unsigned index) {
-  printBBName(term->getSuccessor(index));
-  printBranchOperands(term->getSuccessorOperands(index));
-}
-
-void ModulePrinter::printCFG(const Function *fn) {
-  CFGFunctionPrinter(fn, *this).print();
-}
-
-//===----------------------------------------------------------------------===//
-// ML Function printing
-//===----------------------------------------------------------------------===//
-
-namespace {
-class MLFunctionPrinter : public FunctionPrinter {
-public:
-  MLFunctionPrinter(const Function *function, const ModulePrinter &other);
-
-  const Function *getFunction() const { return function; }
-
-  // Prints ML function.
-  void print();
-
-  // Prints ML function signature.
-  void printFunctionSignature();
-
-  // Methods to print ML function statements.
-  void print(const Statement *stmt);
-  void print(const OperationInst *stmt);
-  void print(const ForStmt *stmt);
-  void print(const IfStmt *stmt);
-  void print(const StmtBlock *block);
-  void printSuccessorAndUseList(const OperationInst *term, unsigned index) {
-    assert(false && "MLFunctions do not have terminators with successors.");
+  // Otherwise, we had a conflict - probe until we find a unique name.  This
+  // is guaranteed to terminate (and usually in a single iteration) because it
+  // generates new names by incrementing nextConflictID.
+  while (1) {
+    std::string probeName =
+        specialName.str().str() + "_" + llvm::utostr(nextConflictID++);
+    insertRes = usedNames.insert(probeName);
+    if (insertRes.second) {
+      // If this is the first use of the name, then we're successful!
+      valueNames[value] = insertRes.first->first();
+      return;
+    }
   }
-
-  // Print loop bounds.
-  void printDimAndSymbolList(ArrayRef<InstOperand> ops, unsigned numDims);
-  void printBound(AffineBound bound, const char *prefix);
-
-  // Number of spaces used for indenting nested statements.
-  const static unsigned indentWidth = 2;
-
-private:
-  void numberValues();
-
-  const Function *function;
-  int numSpaces;
-};
-} // end anonymous namespace
-
-MLFunctionPrinter::MLFunctionPrinter(const Function *function,
-                                     const ModulePrinter &other)
-    : FunctionPrinter(other), function(function), numSpaces(0) {
-  assert(function && "Cannot print nullptr function");
-  numberValues();
 }
 
-/// Number all of the SSA values in this ML function.
-void MLFunctionPrinter::numberValues() {
-  // Numbers ML function arguments.
-  for (auto *arg : function->getArguments())
-    numberValueID(arg);
-
-  // Walks ML function statements and numbers for statements and
-  // the first result of the operation statements.
-  struct NumberValuesPass : public StmtWalker<NumberValuesPass> {
-    NumberValuesPass(MLFunctionPrinter *printer) : printer(printer) {}
-    void visitOperationInst(OperationInst *stmt) {
-      if (stmt->getNumResults() != 0)
-        printer->numberValueID(stmt->getResult(0));
-    }
-    void visitForStmt(ForStmt *stmt) { printer->numberValueID(stmt); }
-    MLFunctionPrinter *printer;
-  };
+void FunctionPrinter::print() {
+  // TODO(clattner): merge the syntax of functions.
+  if (function->isML())
+    printMLFunctionSignature();
+  else
+    printOtherFunctionSignature();
 
-  NumberValuesPass pass(this);
-  // TODO: it'd be cleaner to have constant visitor instead of using const_cast.
-  pass.walk(const_cast<Function *>(function));
-}
+  // Print out function attributes, if present.
+  auto attrs = function->getAttrs();
+  if (!attrs.empty()) {
+    os << "\n  attributes ";
+    printOptionalAttrDict(attrs);
+  }
 
-void MLFunctionPrinter::print() {
-  os << "mlfunc ";
-  printFunctionSignature();
-  printFunctionAttributes(getFunction());
-  os << " {\n";
-  print(function->getBody());
-  os << "}\n\n";
+  if (!function->empty()) {
+    os << " {\n";
+    for (const auto &block : *function)
+      print(&block);
+    os << "}\n";
+  }
+  os << '\n';
 }
 
-void MLFunctionPrinter::printFunctionSignature() {
+void FunctionPrinter::printMLFunctionSignature() {
   auto type = function->getType();
 
-  os << "@" << function->getName() << '(';
+  os << "mlfunc @" << function->getName() << '(';
 
   for (unsigned i = 0, e = function->getNumArguments(); i != e; ++i) {
     if (i > 0)
@@ -1406,20 +1154,78 @@ void MLFunctionPrinter::printFunctionSignature() {
     os << " : ";
     printType(arg->getType());
   }
-  os << ")";
+  os << ')';
+  printFunctionResultType(type);
+}
+
+// This prints the signature for CFG and External functions.
+void FunctionPrinter::printOtherFunctionSignature() {
+  auto type = function->getType();
+
+  if (function->isCFG())
+    os << "cfgfunc ";
+  else
+    os << "extfunc ";
+
+  os << '@' << function->getName() << '(';
+  interleaveComma(type.getInputs(), [&](Type eltType) { printType(eltType); });
+  os << ')';
+
   printFunctionResultType(type);
 }
 
-void MLFunctionPrinter::print(const StmtBlock *block) {
-  numSpaces += indentWidth;
+void FunctionPrinter::print(const StmtBlock *block) {
+  // Print the block label and argument list, unless we are in an ML function.
+  if (!block->getFunction()->isML()) {
+    os.indent(currentIndent);
+    printBBName(block);
+
+    // Print the argument list if non-empty.
+    if (!block->args_empty()) {
+      os << '(';
+      interleaveComma(block->getArguments(), [&](const BlockArgument *arg) {
+        printValueID(arg);
+        os << ": ";
+        printType(arg->getType());
+      });
+      os << ')';
+    }
+    os << ':';
+
+    // Print out some context information about the predecessors of this block.
+    if (!block->getFunction()) {
+      os << "\t// block is not in a function!";
+    } else if (block->hasNoPredecessors()) {
+      // Don't print "no predecessors" for the entry block.
+      if (block != &block->getFunction()->front())
+        os << "\t// no predecessors";
+    } else if (auto *pred = block->getSinglePredecessor()) {
+      os << "\t// pred: ";
+      printBBName(pred);
+    } else {
+      // We want to print the predecessors in increasing numeric order, not in
+      // whatever order the use-list is in, so gather and sort them.
+      SmallVector<unsigned, 4> predIDs;
+      for (auto *pred : block->getPredecessors())
+        predIDs.push_back(getBBID(pred));
+      llvm::array_pod_sort(predIDs.begin(), predIDs.end());
+
+      os << "\t// " << predIDs.size() << " preds: ";
+
+      interleaveComma(predIDs, [&](unsigned predID) { os << "bb" << predID; });
+    }
+    os << '\n';
+  }
+
+  currentIndent += indentWidth;
   for (auto &stmt : block->getStatements()) {
     print(&stmt);
-    os << "\n";
+    os << '\n';
   }
-  numSpaces -= indentWidth;
+  currentIndent -= indentWidth;
 }
 
-void MLFunctionPrinter::print(const Statement *stmt) {
+void FunctionPrinter::print(const Statement *stmt) {
   switch (stmt->getKind()) {
   case Statement::Kind::OperationInst:
     return print(cast<OperationInst>(stmt));
@@ -1430,13 +1236,13 @@ void MLFunctionPrinter::print(const Statement *stmt) {
   }
 }
 
-void MLFunctionPrinter::print(const OperationInst *stmt) {
-  os.indent(numSpaces);
-  printOperation(stmt);
+void FunctionPrinter::print(const OperationInst *inst) {
+  os.indent(currentIndent);
+  printOperation(inst);
 }
 
-void MLFunctionPrinter::print(const ForStmt *stmt) {
-  os.indent(numSpaces) << "for ";
+void FunctionPrinter::print(const ForStmt *stmt) {
+  os.indent(currentIndent) << "for ";
   printOperand(stmt);
   os << " = ";
   printBound(stmt->getLowerBound(), "max");
@@ -1448,11 +1254,129 @@ void MLFunctionPrinter::print(const ForStmt *stmt) {
 
   os << " {\n";
   print(stmt->getBody());
-  os.indent(numSpaces) << "}";
+  os.indent(currentIndent) << "}";
+}
+
+void FunctionPrinter::print(const IfStmt *stmt) {
+  os.indent(currentIndent) << "if ";
+  IntegerSet set = stmt->getIntegerSet();
+  printIntegerSetReference(set);
+  printDimAndSymbolList(stmt->getInstOperands(), set.getNumDims());
+  os << " {\n";
+  print(stmt->getThen());
+  os.indent(currentIndent) << "}";
+  if (stmt->hasElse()) {
+    os << " else {\n";
+    print(stmt->getElse());
+    os.indent(currentIndent) << "}";
+  }
+}
+
+void FunctionPrinter::printValueID(const Value *value,
+                                   bool printResultNo) const {
+  int resultNo = -1;
+  auto lookupValue = value;
+
+  // If this is a reference to the result of a multi-result instruction or
+  // statement, print out the # identifier and make sure to map our lookup
+  // to the first result of the instruction.
+  if (auto *result = dyn_cast<InstResult>(value)) {
+    if (result->getOwner()->getNumResults() != 1) {
+      resultNo = result->getResultNumber();
+      lookupValue = result->getOwner()->getResult(0);
+    }
+  } else if (auto *result = dyn_cast<InstResult>(value)) {
+    if (result->getOwner()->getNumResults() != 1) {
+      resultNo = result->getResultNumber();
+      lookupValue = result->getOwner()->getResult(0);
+    }
+  }
+
+  auto it = valueIDs.find(lookupValue);
+  if (it == valueIDs.end()) {
+    os << "<<INVALID SSA VALUE>>";
+    return;
+  }
+
+  os << '%';
+  if (it->second != nameSentinel) {
+    os << it->second;
+  } else {
+    auto nameIt = valueNames.find(lookupValue);
+    assert(nameIt != valueNames.end() && "Didn't have a name entry?");
+    os << nameIt->second;
+  }
+
+  if (resultNo != -1 && printResultNo)
+    os << '#' << resultNo;
+}
+
+void FunctionPrinter::printOperation(const OperationInst *op) {
+  if (op->getNumResults()) {
+    printValueID(op->getResult(0), /*printResultNo=*/false);
+    os << " = ";
+  }
+
+  // Check to see if this is a known operation.  If so, use the registered
+  // custom printer hook.
+  if (auto *opInfo = op->getAbstractOperation()) {
+    opInfo->printAssembly(op, this);
+    return;
+  }
+
+  // Otherwise use the standard verbose printing approach.
+  printDefaultOp(op);
+}
+
+void FunctionPrinter::printDefaultOp(const OperationInst *op) {
+  os << '"';
+  printEscapedString(op->getName().getStringRef(), os);
+  os << "\"(";
+
+  interleaveComma(op->getOperands(),
+                  [&](const Value *value) { printValueID(value); });
+
+  os << ')';
+  auto attrs = op->getAttrs();
+  printOptionalAttrDict(attrs);
+
+  // Print the type signature of the operation.
+  os << " : (";
+  interleaveComma(op->getOperands(),
+                  [&](const Value *value) { printType(value->getType()); });
+  os << ") -> ";
+
+  if (op->getNumResults() == 1) {
+    printType(op->getResult(0)->getType());
+  } else {
+    os << '(';
+    interleaveComma(op->getResults(),
+                    [&](const Value *result) { printType(result->getType()); });
+    os << ')';
+  }
+}
+
+void FunctionPrinter::printSuccessorAndUseList(const OperationInst *term,
+                                               unsigned index) {
+  printBBName(term->getSuccessor(index));
+
+  auto succOperands = term->getSuccessorOperands(index);
+
+  if (succOperands.begin() == succOperands.end())
+    return;
+
+  os << '(';
+  interleaveComma(succOperands,
+                  [this](const Value *operand) { printValueID(operand); });
+  os << " : ";
+  interleaveComma(succOperands, [this](const Value *operand) {
+    printType(operand->getType());
+  });
+  os << ')';
 }
 
-void MLFunctionPrinter::printDimAndSymbolList(ArrayRef<InstOperand> ops,
-                                              unsigned numDims) {
+void FunctionPrinter::printDimAndSymbolList(ArrayRef<InstOperand> ops,
+                                            unsigned numDims) {
   auto printComma = [&]() { os << ", "; };
   os << '(';
   interleave(
@@ -1469,7 +1393,7 @@ void MLFunctionPrinter::printDimAndSymbolList(ArrayRef<InstOperand> ops,
   }
 }
 
-void MLFunctionPrinter::printBound(AffineBound bound, const char *prefix) {
+void FunctionPrinter::printBound(AffineBound bound, const char *prefix) {
   AffineMap map = bound.getMap();
 
   // Check if this bound should be printed using short-hand notation.
@@ -1507,23 +1431,9 @@ void MLFunctionPrinter::printBound(AffineBound bound, const char *prefix) {
   printDimAndSymbolList(bound.getInstOperands(), map.getNumDims());
 }
 
-void MLFunctionPrinter::print(const IfStmt *stmt) {
-  os.indent(numSpaces) << "if ";
-  IntegerSet set = stmt->getIntegerSet();
-  printIntegerSetReference(set);
-  printDimAndSymbolList(stmt->getInstOperands(), set.getNumDims());
-  os << " {\n";
-  print(stmt->getThen());
-  os.indent(numSpaces) << "}";
-  if (stmt->hasElse()) {
-    os << " else {\n";
-    print(stmt->getElse());
-    os.indent(numSpaces) << "}";
-  }
-}
-
-void ModulePrinter::printML(const Function *fn) {
-  MLFunctionPrinter(fn, *this).print();
+// Prints function with initialized module state.
+void ModulePrinter::print(const Function *fn) {
+  FunctionPrinter(fn, *this).print();
 }
 
 //===----------------------------------------------------------------------===//
@@ -1595,15 +1505,10 @@ void Instruction::print(raw_ostream &os) const {
     os << "<<UNLINKED INSTRUCTION>>\n";
     return;
   }
-  if (function->isCFG()) {
-    ModuleState state(function->getContext());
-    ModulePrinter modulePrinter(os, state);
-    CFGFunctionPrinter(function, modulePrinter).print(this);
-  } else {
-    ModuleState state(function->getContext());
-    ModulePrinter modulePrinter(os, state);
-    MLFunctionPrinter(function, modulePrinter).print(this);
-  }
+
+  ModuleState state(function->getContext());
+  ModulePrinter modulePrinter(os, state);
+  FunctionPrinter(function, modulePrinter).print(this);
 }
 
 void Instruction::dump() const {
@@ -1618,15 +1523,9 @@ void BasicBlock::print(raw_ostream &os) const {
     return;
   }
 
-  if (function->isCFG()) {
-    ModuleState state(function->getContext());
-    ModulePrinter modulePrinter(os, state);
-    CFGFunctionPrinter(function, modulePrinter).print(this);
-  } else {
-    ModuleState state(function->getContext());
-    ModulePrinter modulePrinter(os, state);
-    MLFunctionPrinter(function, modulePrinter).print(this);
-  }
+  ModuleState state(function->getContext());
+  ModulePrinter modulePrinter(os, state);
+  FunctionPrinter(function, modulePrinter).print(this);
 }
 
 void BasicBlock::dump() const { print(llvm::errs()); }
@@ -1639,7 +1538,7 @@ void StmtBlock::printAsOperand(raw_ostream &os, bool printType) {
   }
   ModuleState state(getFunction()->getContext());
   ModulePrinter modulePrinter(os, state);
-  CFGFunctionPrinter(getFunction(), modulePrinter).printBBName(this);
+  FunctionPrinter(getFunction(), modulePrinter).printBBName(this);
 }
 
 void Function::print(raw_ostream &os) const {