Add a flag to the AsmPrinter for eliding large ElementsAttrs.
authorRiver Riddle <riverriddle@google.com>
Tue, 8 Oct 2019 00:18:54 +0000 (17:18 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 8 Oct 2019 00:19:20 +0000 (17:19 -0700)
Some modules may have extremely large ElementsAttrs, which makes debugging involving IR dumping extremely slow and painful. This change adds a flag that will elide ElementsAttrs with a "large"(as defined by the user) number of elements by printing "..." instead of the element data.

PiperOrigin-RevId: 273413100

mlir/include/mlir/IR/OperationSupport.h
mlir/lib/IR/AsmPrinter.cpp
mlir/test/IR/pretty-attributes.mlir [new file with mode: 0644]

index 5567af7..c932541 100644 (file)
@@ -460,6 +460,13 @@ public:
   OpPrintingFlags();
   OpPrintingFlags(llvm::NoneType) : OpPrintingFlags() {}
 
+  /// Enable the elision of large elements attributes, by printing a '...'
+  /// instead of the element data. Note: The IR generated with this option is
+  /// not parsable. `largeElementLimit` is used to configure what is considered
+  /// to be a "large" ElementsAttr by providing an upper limit to the number of
+  /// elements.
+  OpPrintingFlags &elideLargeElementsAttrs(int64_t largeElementLimit = 16);
+
   /// Enable printing of debug information. If 'prettyForm' is set to true,
   /// debug information is printed in a more readable 'pretty' form. Note: The
   /// IR generated with 'prettyForm' is not parsable.
@@ -468,6 +475,9 @@ public:
   /// Always print operations in the generic form.
   OpPrintingFlags &printGenericOpForm();
 
+  /// Return if the given ElementsAttr should be elided.
+  bool shouldElideElementsAttr(ElementsAttr attr) const;
+
   /// Return if debug information should be printed.
   bool shouldPrintDebugInfo() const;
 
@@ -478,6 +488,10 @@ public:
   bool shouldPrintGenericOpForm() const;
 
 private:
+  /// Elide large elements attributes if the number of elements is larger than
+  /// the upper limit.
+  llvm::Optional<int64_t> elementsAttrElementLimit;
+
   /// Print debug information.
   bool printDebugInfoFlag : 1;
   bool printDebugInfoPrettyFormFlag : 1;
index a1cd863..ea58f6c 100644 (file)
@@ -59,6 +59,12 @@ OpAsmPrinter::~OpAsmPrinter() {}
 // OpPrintingFlags
 //===----------------------------------------------------------------------===//
 
+static llvm::cl::opt<unsigned> elideElementsAttrIfLarger(
+    "mlir-elide-elementsattrs-if-larger",
+    llvm::cl::desc("Elide ElementsAttrs with \"...\" that have "
+                   "more elements than the given upper limit"),
+    llvm::cl::init(16));
+
 static llvm::cl::opt<bool>
     printDebugInfoOpt("mlir-print-debuginfo",
                       llvm::cl::desc("Print debug info in MLIR output"),
@@ -78,10 +84,24 @@ static llvm::cl::opt<bool>
 
 /// Initialize the printing flags with default supplied by the cl::opts above.
 OpPrintingFlags::OpPrintingFlags()
-    : printDebugInfoFlag(printDebugInfoOpt),
+    : elementsAttrElementLimit(
+          elideElementsAttrIfLarger.getNumOccurrences()
+              ? Optional<int64_t>(elideElementsAttrIfLarger)
+              : Optional<int64_t>()),
+      printDebugInfoFlag(printDebugInfoOpt),
       printDebugInfoPrettyFormFlag(printPrettyDebugInfoOpt),
       printGenericOpFormFlag(printGenericOpFormOpt) {}
 
+/// Enable the elision of large elements attributes, by printing a '...'
+/// instead of the element data, when the number of elements is greater than
+/// `largeElementLimit`. Note: The IR generated with this option is not
+/// parsable.
+OpPrintingFlags &
+OpPrintingFlags::elideLargeElementsAttrs(int64_t largeElementLimit) {
+  elementsAttrElementLimit = largeElementLimit;
+  return *this;
+}
+
 /// Enable printing of debug information. If 'prettyForm' is set to true,
 /// debug information is printed in a more readable 'pretty' form.
 OpPrintingFlags &OpPrintingFlags::enableDebugInfo(bool prettyForm) {
@@ -96,6 +116,12 @@ OpPrintingFlags &OpPrintingFlags::printGenericOpForm() {
   return *this;
 }
 
+/// Return if the given ElementsAttr should be elided.
+bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const {
+  return elementsAttrElementLimit.hasValue() &&
+         *elementsAttrElementLimit < int64_t(attr.getNumElements());
+}
+
 /// Return if debug information should be printed.
 bool OpPrintingFlags::shouldPrintDebugInfo() const {
   return printDebugInfoFlag;
@@ -742,7 +768,14 @@ void ModulePrinter::printAttribute(Attribute attr, bool mayElideType) {
   case StandardAttributes::OpaqueElements: {
     auto eltsAttr = attr.cast<OpaqueElementsAttr>();
     os << "opaque<\"" << eltsAttr.getDialect()->getNamespace() << "\", ";
-    os << '"' << "0x" << llvm::toHex(eltsAttr.getValue()) << "\">";
+    os << '"' << "0x";
+
+    // Check for large ElementsAttr elision.
+    if (printerFlags.shouldElideElementsAttr(eltsAttr))
+      os << "...";
+    else
+      os << llvm::toHex(eltsAttr.getValue());
+    os << "\">";
     break;
   }
   case StandardAttributes::DenseElements: {
@@ -814,6 +847,13 @@ void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) {
     return;
   }
 
+  // Check for large elements attr elision. We explicitly check *after* splat,
+  // as the splat printing is already elided.
+  if (printerFlags.shouldElideElementsAttr(attr)) {
+    os << "...";
+    return;
+  }
+
   // Special case for degenerate tensors.
   auto numElements = type.getNumElements();
   if (numElements == 0) {
diff --git a/mlir/test/IR/pretty-attributes.mlir b/mlir/test/IR/pretty-attributes.mlir
new file mode 100644 (file)
index 0000000..cadb2da
--- /dev/null
@@ -0,0 +1,10 @@
+// RUN: mlir-opt %s -mlir-elide-elementsattrs-if-larger=2 | FileCheck %s
+
+// CHECK: dense<...> : tensor<3xi32>
+"test.dense_attr"() {foo.dense_attr = dense<[1, 2, 3]> : tensor<3xi32>} : () -> ()
+
+// CHECK: dense<[1, 2]> : tensor<2xi32>
+"test.non_elided_dense_attr"() {foo.dense_attr = dense<[1, 2]> : tensor<2xi32>} : () -> ()
+
+// CHECK: sparse<..., -2.{{0+}}e+00> : vector<1x1x1xf16>
+"test.sparse_attr"() {foo.sparse_attr = sparse<[[1, 2, 3]],  -2.0> : vector<1x1x1xf16>} : () -> ()