From: Frank Laub Date: Fri, 20 Dec 2019 06:15:31 +0000 (-0800) Subject: Allow dialect to create friendly names for region arguments X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=7811ad3c2b312fb5eda5ed5f3a1d15b8e6085b24;p=platform%2Fupstream%2Fllvm.git Allow dialect to create friendly names for region arguments This is the block argument equivalent of the existing `getAsmResultNames` hook. Closes tensorflow/mlir#329 COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/329 from plaidml:flaub-region-arg-names fc7876f2d1335024e441083cd25263fd6247eb7d PiperOrigin-RevId: 286523299 --- diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index 97569cc..7dd11d0 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -661,6 +661,11 @@ public: /// OpAsmInterface.td#getAsmResultNames for usage details and documentation. virtual void getAsmResultNames(Operation *op, OpAsmSetValueNameFn setNameFn) const {} + + /// Get a special name to use when printing the entry block arguments of the + /// region contained by an operation in this dialect. + virtual void getAsmBlockArgumentNames(Block *block, + OpAsmSetValueNameFn setNameFn) const {} }; //===--------------------------------------------------------------------===// diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index e1903d5..f3c92ad 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1619,13 +1619,28 @@ void OperationPrinter::numberValuesInRegion(Region ®ion) { } void OperationPrinter::numberValuesInBlock(Block &block) { + auto setArgNameFn = [&](Value *arg, StringRef name) { + assert(!valueIDs.count(arg) && "arg numbered multiple times"); + assert(cast(arg)->getOwner() == &block && + "arg not defined in 'block'"); + setValueName(arg, name); + }; + bool isEntryBlock = block.isEntryBlock(); + if (isEntryBlock && state) { + if (auto *op = block.getParentOp()) { + if (auto dialectAsmInterface = state->getOpAsmInterface(op->getDialect())) + dialectAsmInterface->getAsmBlockArgumentNames(&block, setArgNameFn); + } + } // Number the block arguments. We give entry block arguments a special name // 'arg'. SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : ""); llvm::raw_svector_ostream specialName(specialNameBuffer); for (auto *arg : block.getArguments()) { + if (valueIDs.count(arg)) + continue; if (isEntryBlock) { specialNameBuffer.resize(strlen("arg")); specialName << nextArgumentID++; diff --git a/mlir/test/IR/pretty-region-args.mlir b/mlir/test/IR/pretty-region-args.mlir new file mode 100644 index 0000000..59a9ebc --- /dev/null +++ b/mlir/test/IR/pretty-region-args.mlir @@ -0,0 +1,12 @@ +// RUN: mlir-opt %s | FileCheck %s + +// CHECK-LABEL: func @custom_region_names +func @custom_region_names() -> () { + "test.polyfor"() ( { + ^bb0(%arg0: index, %arg1: index, %arg2: index): + "foo"() : () -> () + }) { arg_names = ["i", "j", "k"] } : () -> () + // CHECK: test.polyfor + // CHECK-NEXT: ^bb{{.*}}(%i: index, %j: index, %k: index): + return +} diff --git a/mlir/test/lib/TestDialect/TestDialect.cpp b/mlir/test/lib/TestDialect/TestDialect.cpp index 059cfb3..7462db4 100644 --- a/mlir/test/lib/TestDialect/TestDialect.cpp +++ b/mlir/test/lib/TestDialect/TestDialect.cpp @@ -41,6 +41,20 @@ struct TestOpAsmInterface : public OpAsmDialectInterface { if (auto asmOp = dyn_cast(op)) setNameFn(asmOp, "result"); } + + void getAsmBlockArgumentNames(Block *block, + OpAsmSetValueNameFn setNameFn) const final { + auto op = block->getParentOp(); + auto arrayAttr = op->getAttrOfType("arg_names"); + if (!arrayAttr) + return; + auto args = block->getArguments(); + auto e = std::min(arrayAttr.size(), args.size()); + for (unsigned i = 0; i < e; ++i) { + if (auto strAttr = arrayAttr.getValue()[i].dyn_cast()) + setNameFn(args[i], strAttr.getValue()); + } + } }; struct TestOpFolderDialectInterface : public OpFolderDialectInterface {