Add a printer flag to use local scope when printing IR.
authorRiver Riddle <riverriddle@google.com>
Tue, 12 Nov 2019 17:36:40 +0000 (09:36 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 12 Nov 2019 17:37:11 +0000 (09:37 -0800)
This causes the AsmPrinter to use a local value numbering when printing the IR, allowing for the printer to be used safely in a local context, e.g. to ensure thread-safety when printing the IR. This means that the IR printing instrumentation can also be used during multi-threading when module-scope is disabled. Operation::dump and DiagnosticArgument(Operation*) are also updated to always print local scope, as this is the most common use case when debugging.

PiperOrigin-RevId: 279988203

mlir/g3doc/WritingAPass.md
mlir/include/mlir/IR/OperationSupport.h
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/Diagnostics.cpp
mlir/lib/Pass/IRPrinting.cpp

index 0123709..bd9ef45 100644 (file)
@@ -624,9 +624,6 @@ $ mlir-opt foo.mlir -cse -canonicalize -lower-to-llvm -pass-timing
 
 #### IR Printing
 
-Note: The IR Printing instrumentation should only be used when multi-threading
-is disabled(`-disable-pass-threading`)
-
 When debugging it is often useful to dump the IR at various stages of a pass
 pipeline. This is where the IR printing instrumentation comes into play. This
 instrumentation allows for conditionally printing the IR before and after pass
@@ -641,7 +638,7 @@ this instrumentation:
     *   Print the IR before every pass in the pipeline.
 
 ```shell
-$ mlir-opt foo.mlir -disable-pass-threading -cse -print-ir-before=cse
+$ mlir-opt foo.mlir -cse -print-ir-before=cse
 
 *** IR Dump Before CSE ***
 func @simple_constant() -> (i32, i32) {
@@ -657,7 +654,7 @@ func @simple_constant() -> (i32, i32) {
     *   Print the IR after every pass in the pipeline.
 
 ```shell
-$ mlir-opt foo.mlir -disable-pass-threading -cse -print-ir-after=cse
+$ mlir-opt foo.mlir -cse -print-ir-after=cse
 
 *** IR Dump After CSE ***
 func @simple_constant() -> (i32, i32) {
@@ -669,6 +666,8 @@ func @simple_constant() -> (i32, i32) {
 *   `print-ir-module-scope`
     *   Always print the top-level module operation, regardless of pass type or
         operation nesting level.
+    *   Note: Printing at module scope should only be used when multi-threading
+        is disabled(`-disable-pass-threading`)
 
 ```shell
 $ mlir-opt foo.mlir -disable-pass-threading -cse -print-ir-after=cse -print-ir-module-scope
index 66237e9..7e6ba8c 100644 (file)
@@ -475,6 +475,12 @@ public:
   /// Always print operations in the generic form.
   OpPrintingFlags &printGenericOpForm();
 
+  /// Use local scope when printing the operation. This allows for using the
+  /// printer in a more localized and thread-safe setting, but may not
+  /// necessarily be identical to what the IR will look like when dumping
+  /// the full module.
+  OpPrintingFlags &useLocalScope();
+
   /// Return if the given ElementsAttr should be elided.
   bool shouldElideElementsAttr(ElementsAttr attr) const;
 
@@ -487,6 +493,9 @@ public:
   /// Return if operations should be printed in the generic form.
   bool shouldPrintGenericOpForm() const;
 
+  /// Return if the printer should use local scope when dumping the IR.
+  bool shouldUseLocalScope() const;
+
 private:
   /// Elide large elements attributes if the number of elements is larger than
   /// the upper limit.
@@ -498,6 +507,9 @@ private:
 
   /// Print operations in the generic form.
   bool printGenericOpFormFlag : 1;
+
+  /// Print operations with numberings local to the current operation.
+  bool printLocalScope : 1;
 };
 
 } // end namespace mlir
index 6f77de0..20c49eb 100644 (file)
@@ -92,7 +92,7 @@ OpPrintingFlags::OpPrintingFlags()
               : Optional<int64_t>()),
       printDebugInfoFlag(printDebugInfoOpt),
       printDebugInfoPrettyFormFlag(printPrettyDebugInfoOpt),
-      printGenericOpFormFlag(printGenericOpFormOpt) {}
+      printGenericOpFormFlag(printGenericOpFormOpt), printLocalScope(false) {}
 
 /// Enable the elision of large elements attributes, by printing a '...'
 /// instead of the element data, when the number of elements is greater than
@@ -118,6 +118,14 @@ OpPrintingFlags &OpPrintingFlags::printGenericOpForm() {
   return *this;
 }
 
+/// Use local scope when printing the operation. This allows for using the
+/// printer in a more localized and thread-safe setting, but may not necessarily
+/// be identical of what the IR will look like when dumping the full module.
+OpPrintingFlags &OpPrintingFlags::useLocalScope() {
+  printLocalScope = true;
+  return *this;
+}
+
 /// Return if the given ElementsAttr should be elided.
 bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const {
   return elementsAttrElementLimit.hasValue() &&
@@ -139,6 +147,9 @@ bool OpPrintingFlags::shouldPrintGenericOpForm() const {
   return printGenericOpFormFlag;
 }
 
+/// Return if the printer should use local scope when dumping the IR.
+bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; }
+
 //===----------------------------------------------------------------------===//
 // ModuleState
 //===----------------------------------------------------------------------===//
@@ -1516,8 +1527,10 @@ private:
 
 OperationPrinter::OperationPrinter(Operation *op, ModulePrinter &other)
     : ModulePrinter(other) {
+  llvm::ScopedHashTable<StringRef, char>::ScopeTy usedNamesScope(usedNames);
   if (op->getNumResults() != 0)
     numberValueID(op->getResult(0));
+
   for (auto &region : op->getRegions())
     numberValuesInRegion(region);
 }
@@ -1725,7 +1738,7 @@ void OperationPrinter::printValueIDImpl(Value *value, bool printResultNo,
 
   auto it = valueIDs.find(lookupValue);
   if (it == valueIDs.end()) {
-    stream << "<<INVALID SSA VALUE>>";
+    stream << "<<UNKNOWN SSA VALUE>>";
     return;
   }
 
@@ -1943,9 +1956,10 @@ void Value::dump() {
 }
 
 void Operation::print(raw_ostream &os, OpPrintingFlags flags) {
-  // Handle top-level operations.
-  if (!getParent()) {
-    ModulePrinter modulePrinter(os, flags);
+  // Handle top-level operations or local printing.
+  if (!getParent() || flags.shouldUseLocalScope()) {
+    ModuleState state(getContext());
+    ModulePrinter modulePrinter(os, flags, &state);
     OperationPrinter(this, modulePrinter).print(this);
     return;
   }
@@ -1966,7 +1980,7 @@ void Operation::print(raw_ostream &os, OpPrintingFlags flags) {
 }
 
 void Operation::dump() {
-  print(llvm::errs());
+  print(llvm::errs(), OpPrintingFlags().useLocalScope());
   llvm::errs() << "\n";
 }
 
index 2e15438..b89b44d 100644 (file)
@@ -74,7 +74,7 @@ void DiagnosticArgument::print(raw_ostream &os) const {
     os << getAsInteger();
     break;
   case DiagnosticArgumentKind::Operation:
-    os << getAsOperation();
+    getAsOperation().print(os, OpPrintingFlags().useLocalScope());
     break;
   case DiagnosticArgumentKind::String:
     os << getAsString();
index 032a797..7cf32f8 100644 (file)
@@ -72,7 +72,7 @@ static void printIR(Operation *op, bool printModuleScope, raw_ostream &out,
 
   // Otherwise, check to see if we are not printing at module scope.
   if (!printModuleScope)
-    return op->print(out << "\n", flags);
+    return op->print(out << "\n", flags.useLocalScope());
 
   // Otherwise, we are printing at module scope.
   out << " ('" << op->getName() << "' operation";