Allow dialect to create friendly names for region arguments
authorFrank Laub <frank.laub@intel.com>
Fri, 20 Dec 2019 06:15:31 +0000 (22:15 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 20 Dec 2019 06:16:07 +0000 (22:16 -0800)
This is the block argument equivalent of the existing `getAsmResultNames` hook.

Closes tensorflow/mlir#329

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/329 from plaidml:flaub-region-arg-names fc7876f2d1335024e441083cd25263fd6247eb7d
PiperOrigin-RevId: 286523299

mlir/include/mlir/IR/OpImplementation.h
mlir/lib/IR/AsmPrinter.cpp
mlir/test/IR/pretty-region-args.mlir [new file with mode: 0644]
mlir/test/lib/TestDialect/TestDialect.cpp

index 97569cc..7dd11d0 100644 (file)
@@ -661,6 +661,11 @@ public:
   /// OpAsmInterface.td#getAsmResultNames for usage details and documentation.
   virtual void getAsmResultNames(Operation *op,
                                  OpAsmSetValueNameFn setNameFn) const {}
+
+  /// Get a special name to use when printing the entry block arguments of the
+  /// region contained by an operation in this dialect.
+  virtual void getAsmBlockArgumentNames(Block *block,
+                                        OpAsmSetValueNameFn setNameFn) const {}
 };
 
 //===--------------------------------------------------------------------===//
index e1903d5..f3c92ad 100644 (file)
@@ -1619,13 +1619,28 @@ void OperationPrinter::numberValuesInRegion(Region &region) {
 }
 
 void OperationPrinter::numberValuesInBlock(Block &block) {
+  auto setArgNameFn = [&](Value *arg, StringRef name) {
+    assert(!valueIDs.count(arg) && "arg numbered multiple times");
+    assert(cast<BlockArgument>(arg)->getOwner() == &block &&
+           "arg not defined in 'block'");
+    setValueName(arg, name);
+  };
+
   bool isEntryBlock = block.isEntryBlock();
+  if (isEntryBlock && state) {
+    if (auto *op = block.getParentOp()) {
+      if (auto dialectAsmInterface = state->getOpAsmInterface(op->getDialect()))
+        dialectAsmInterface->getAsmBlockArgumentNames(&block, setArgNameFn);
+    }
+  }
 
   // Number the block arguments. We give entry block arguments a special name
   // 'arg'.
   SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : "");
   llvm::raw_svector_ostream specialName(specialNameBuffer);
   for (auto *arg : block.getArguments()) {
+    if (valueIDs.count(arg))
+      continue;
     if (isEntryBlock) {
       specialNameBuffer.resize(strlen("arg"));
       specialName << nextArgumentID++;
diff --git a/mlir/test/IR/pretty-region-args.mlir b/mlir/test/IR/pretty-region-args.mlir
new file mode 100644 (file)
index 0000000..59a9ebc
--- /dev/null
@@ -0,0 +1,12 @@
+// RUN: mlir-opt %s | FileCheck %s
+
+// CHECK-LABEL: func @custom_region_names
+func @custom_region_names() -> () {
+  "test.polyfor"() ( {
+  ^bb0(%arg0: index, %arg1: index, %arg2: index):
+    "foo"() : () -> ()
+  }) { arg_names = ["i", "j", "k"] } : () -> ()
+  // CHECK: test.polyfor
+  // CHECK-NEXT: ^bb{{.*}}(%i: index, %j: index, %k: index):
+  return
+}
index 059cfb3..7462db4 100644 (file)
@@ -41,6 +41,20 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
     if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
       setNameFn(asmOp, "result");
   }
+
+  void getAsmBlockArgumentNames(Block *block,
+                                OpAsmSetValueNameFn setNameFn) const final {
+    auto op = block->getParentOp();
+    auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
+    if (!arrayAttr)
+      return;
+    auto args = block->getArguments();
+    auto e = std::min(arrayAttr.size(), args.size());
+    for (unsigned i = 0; i < e; ++i) {
+      if (auto strAttr = arrayAttr.getValue()[i].dyn_cast<StringAttr>())
+        setNameFn(args[i], strAttr.getValue());
+    }
+  }
 };
 
 struct TestOpFolderDialectInterface : public OpFolderDialectInterface {