Add a new class, OpPrintingFlags, to enable programmatic control of Operation::print...
authorRiver Riddle <riverriddle@google.com>
Mon, 7 Oct 2019 20:54:16 +0000 (13:54 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 7 Oct 2019 20:54:49 +0000 (13:54 -0700)
This allows for controlling the behavior of the AsmPrinter programmatically, instead of relying exclusively on cl::opt flags. This will also allow for more fine-tuned control of printing behavior per callsite, instead of being applied globally.

PiperOrigin-RevId: 273368361

mlir/include/mlir/IR/Module.h
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/IR/Operation.h
mlir/include/mlir/IR/OperationSupport.h
mlir/lib/IR/AsmPrinter.cpp

index e019572..cf09494 100644 (file)
@@ -65,7 +65,7 @@ public:
   Optional<StringRef> getName();
 
   /// Print the this module in the custom top-level form.
-  void print(raw_ostream &os);
+  void print(raw_ostream &os, OpPrintingFlags flags = llvm::None);
   void dump();
 
   //===--------------------------------------------------------------------===//
index dd82e7b..c500e73 100644 (file)
@@ -105,7 +105,9 @@ public:
   MLIRContext *getContext() { return getOperation()->getContext(); }
 
   /// Print the operation to the given stream.
-  void print(raw_ostream &os) { state->print(os); }
+  void print(raw_ostream &os, OpPrintingFlags flags = llvm::None) {
+    state->print(os, flags);
+  }
 
   /// Dump this operation.
   void dump() { state->dump(); }
index 5444d6c..ff23f6a 100644 (file)
@@ -199,7 +199,7 @@ public:
   /// take O(N) where N is the number of operations within the parent block.
   bool isBeforeInBlock(Operation *other);
 
-  void print(raw_ostream &os);
+  void print(raw_ostream &os, OpPrintingFlags flags = llvm::None);
   void dump();
 
   //===--------------------------------------------------------------------===//
index 70d5476..5567af7 100644 (file)
@@ -452,6 +452,40 @@ private:
   }
 };
 } // end namespace detail
+
+/// Set of flags used to control the behavior of the various IR print methods
+/// (e.g. Operation::Print).
+class OpPrintingFlags {
+public:
+  OpPrintingFlags();
+  OpPrintingFlags(llvm::NoneType) : OpPrintingFlags() {}
+
+  /// Enable printing of debug information. If 'prettyForm' is set to true,
+  /// debug information is printed in a more readable 'pretty' form. Note: The
+  /// IR generated with 'prettyForm' is not parsable.
+  OpPrintingFlags &enableDebugInfo(bool prettyForm = false);
+
+  /// Always print operations in the generic form.
+  OpPrintingFlags &printGenericOpForm();
+
+  /// Return if debug information should be printed.
+  bool shouldPrintDebugInfo() const;
+
+  /// Return if debug information should be printed in the pretty form.
+  bool shouldPrintDebugInfoPrettyForm() const;
+
+  /// Return if operations should be printed in the generic form.
+  bool shouldPrintGenericOpForm() const;
+
+private:
+  /// Print debug information.
+  bool printDebugInfoFlag : 1;
+  bool printDebugInfoPrettyFormFlag : 1;
+
+  /// Print operations in the generic form.
+  bool printGenericOpFormFlag : 1;
+};
+
 } // end namespace mlir
 
 namespace llvm {
index ce79db0..a1cd863 100644 (file)
@@ -56,17 +56,15 @@ void OperationName::dump() const { print(llvm::errs()); }
 OpAsmPrinter::~OpAsmPrinter() {}
 
 //===----------------------------------------------------------------------===//
-// ModuleState
+// OpPrintingFlags
 //===----------------------------------------------------------------------===//
 
-// TODO(riverriddle) Rethink this flag when we have a pass that can remove debug
-// info or when we have a system for printer flags.
 static llvm::cl::opt<bool>
-    shouldPrintDebugInfoOpt("mlir-print-debuginfo",
-                            llvm::cl::desc("Print debug info in MLIR output"),
-                            llvm::cl::init(false));
+    printDebugInfoOpt("mlir-print-debuginfo",
+                      llvm::cl::desc("Print debug info in MLIR output"),
+                      llvm::cl::init(false));
 
-static llvm::cl::opt<bool> printPrettyDebugInfo(
+static llvm::cl::opt<bool> printPrettyDebugInfoOpt(
     "mlir-pretty-debuginfo",
     llvm::cl::desc("Print pretty debug info in MLIR output"),
     llvm::cl::init(false));
@@ -74,9 +72,48 @@ static llvm::cl::opt<bool> printPrettyDebugInfo(
 // Use the generic op output form in the operation printer even if the custom
 // form is defined.
 static llvm::cl::opt<bool>
-    printGenericOpForm("mlir-print-op-generic",
-                       llvm::cl::desc("Print the generic op form"),
-                       llvm::cl::init(false), llvm::cl::Hidden);
+    printGenericOpFormOpt("mlir-print-op-generic",
+                          llvm::cl::desc("Print the generic op form"),
+                          llvm::cl::init(false), llvm::cl::Hidden);
+
+/// Initialize the printing flags with default supplied by the cl::opts above.
+OpPrintingFlags::OpPrintingFlags()
+    : printDebugInfoFlag(printDebugInfoOpt),
+      printDebugInfoPrettyFormFlag(printPrettyDebugInfoOpt),
+      printGenericOpFormFlag(printGenericOpFormOpt) {}
+
+/// Enable printing of debug information. If 'prettyForm' is set to true,
+/// debug information is printed in a more readable 'pretty' form.
+OpPrintingFlags &OpPrintingFlags::enableDebugInfo(bool prettyForm) {
+  printDebugInfoFlag = true;
+  printDebugInfoPrettyFormFlag = prettyForm;
+  return *this;
+}
+
+/// Always print operations in the generic form.
+OpPrintingFlags &OpPrintingFlags::printGenericOpForm() {
+  printGenericOpFormFlag = true;
+  return *this;
+}
+
+/// Return if debug information should be printed.
+bool OpPrintingFlags::shouldPrintDebugInfo() const {
+  return printDebugInfoFlag;
+}
+
+/// Return if debug information should be printed in the pretty form.
+bool OpPrintingFlags::shouldPrintDebugInfoPrettyForm() const {
+  return printDebugInfoPrettyFormFlag;
+}
+
+/// Return if operations should be printed in the generic form.
+bool OpPrintingFlags::shouldPrintGenericOpForm() const {
+  return printGenericOpFormFlag;
+}
+
+//===----------------------------------------------------------------------===//
+// ModuleState
+//===----------------------------------------------------------------------===//
 
 namespace {
 /// A special index constant used for non-kind attribute aliases.
@@ -322,10 +359,12 @@ void ModuleState::initialize(Operation *op) {
 namespace {
 class ModulePrinter {
 public:
-  ModulePrinter(raw_ostream &os, ModuleState *state = nullptr)
-      : os(os), state(state) {}
+  ModulePrinter(raw_ostream &os, OpPrintingFlags flags = llvm::None,
+                ModuleState *state = nullptr)
+      : os(os), printerFlags(flags), state(state) {}
   explicit ModulePrinter(ModulePrinter &printer)
-      : os(printer.os), state(printer.state) {}
+      : os(printer.os), printerFlags(printer.printerFlags),
+        state(printer.state) {}
 
   template <typename Container, typename UnaryFunctor>
   inline void interleaveComma(const Container &c, UnaryFunctor each_fn) const {
@@ -370,6 +409,9 @@ protected:
   /// The output stream for the printer.
   raw_ostream &os;
 
+  /// A set of flags to control the printer's behavior.
+  OpPrintingFlags printerFlags;
+
   /// An optional printer state for the module.
   ModuleState *state;
 };
@@ -377,7 +419,7 @@ protected:
 
 void ModulePrinter::printTrailingLocation(Location loc) {
   // Check to see if we are printing debug information.
-  if (!shouldPrintDebugInfoOpt)
+  if (!printerFlags.shouldPrintDebugInfo())
     return;
 
   os << " ";
@@ -499,7 +541,7 @@ static void printFloatValue(const APFloat &apValue, raw_ostream &os) {
 }
 
 void ModulePrinter::printLocation(LocationAttr loc) {
-  if (printPrettyDebugInfo) {
+  if (printerFlags.shouldPrintDebugInfoPrettyForm()) {
     printLocationInternal(loc, /*pretty=*/true);
   } else {
     os << "loc(";
@@ -1597,7 +1639,7 @@ void OperationPrinter::printOperation(Operation *op) {
 
   // TODO(riverriddle): FuncOp cannot be round-tripped currently, as
   // FunctionType cannot be used in a TypeAttr.
-  if (printGenericOpForm && !isa<FuncOp>(op))
+  if (printerFlags.shouldPrintGenericOpForm() && !isa<FuncOp>(op))
     return printGenericOp(op);
 
   // Check to see if this is a known operation.  If so, use the registered
@@ -1755,10 +1797,10 @@ void Value::dump() {
   llvm::errs() << "\n";
 }
 
-void Operation::print(raw_ostream &os) {
+void Operation::print(raw_ostream &os, OpPrintingFlags flags) {
   // Handle top-level operations.
   if (!getParent()) {
-    ModulePrinter modulePrinter(os);
+    ModulePrinter modulePrinter(os, flags);
     OperationPrinter(this, modulePrinter).print(this);
     return;
   }
@@ -1774,7 +1816,7 @@ void Operation::print(raw_ostream &os) {
     region = nextRegion;
 
   ModuleState state(getContext());
-  ModulePrinter modulePrinter(os, &state);
+  ModulePrinter modulePrinter(os, flags, &state);
   OperationPrinter(region, modulePrinter).print(this);
 }
 
@@ -1795,7 +1837,7 @@ void Block::print(raw_ostream &os) {
     region = nextRegion;
 
   ModuleState state(region->getContext());
-  ModulePrinter modulePrinter(os, &state);
+  ModulePrinter modulePrinter(os, /*flags=*/llvm::None, &state);
   OperationPrinter(region, modulePrinter).print(this);
 }
 
@@ -1817,10 +1859,10 @@ void Block::printAsOperand(raw_ostream &os, bool printType) {
   OperationPrinter(region, modulePrinter).printBlockName(this);
 }
 
-void ModuleOp::print(raw_ostream &os) {
+void ModuleOp::print(raw_ostream &os, OpPrintingFlags flags) {
   ModuleState state(getContext());
   state.initialize(*this);
-  ModulePrinter(os, &state).print(*this);
+  ModulePrinter(os, flags, &state).print(*this);
 }
 
 void ModuleOp::dump() { print(llvm::errs()); }