Give modules a name
authorAlex Zinenko <zinenko@google.com>
Thu, 3 Oct 2019 15:56:12 +0000 (08:56 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 3 Oct 2019 15:56:38 +0000 (08:56 -0700)
Modules are now Ops and, as such, can be nested. They do not produce an SSA
value so there is no possibility to refer to them in the IR. Introduce support
for symbol names attached to the module Op so that it can be referred to using
SymbolRefAttrs. The name is optional, for example the implicit top-level module
does not have a name.

PiperOrigin-RevId: 272671600

mlir/include/mlir/IR/Module.h
mlir/lib/IR/Module.cpp
mlir/test/IR/module-op.mlir

index 803de4a..40f27d5 100644 (file)
@@ -46,10 +46,11 @@ public:
 
   static StringRef getOperationName() { return "module"; }
 
-  static void build(Builder *builder, OperationState &result);
+  static void build(Builder *builder, OperationState &result,
+                    StringRef name = {});
 
-  /// Construct a module from the given location.
-  static ModuleOp create(Location loc);
+  /// Construct a module from the given location with an optional name.
+  static ModuleOp create(Location loc, StringRef name = {});
 
   /// Operation hooks.
   static ParseResult parse(OpAsmParser &parser, OperationState &result);
@@ -60,6 +61,9 @@ public:
   Region &getBodyRegion();
   Block *getBody();
 
+  /// Return the name of this module if present.
+  StringRef getName();
+
   /// Print the this module in the custom top-level form.
   void print(raw_ostream &os);
   void dump();
index 6960962..990d4af 100644 (file)
@@ -25,19 +25,28 @@ using namespace mlir;
 // Module Operation.
 //===----------------------------------------------------------------------===//
 
-void ModuleOp::build(Builder *builder, OperationState &result) {
+void ModuleOp::build(Builder *builder, OperationState &result, StringRef name) {
   ensureTerminator(*result.addRegion(), *builder, result.location);
+  if (!name.empty())
+    result.attributes.push_back(
+        builder->getNamedAttr(mlir::SymbolTable::getSymbolAttrName(),
+                              builder->getSymbolRefAttr(name)));
 }
 
 /// Construct a module from the given context.
-ModuleOp ModuleOp::create(Location loc) {
+ModuleOp ModuleOp::create(Location loc, StringRef name) {
   OperationState state(loc, "module");
   Builder builder(loc->getContext());
-  ModuleOp::build(&builder, state);
+  ModuleOp::build(&builder, state, name);
   return llvm::cast<ModuleOp>(Operation::create(state));
 }
 
 ParseResult ModuleOp::parse(OpAsmParser &parser, OperationState &result) {
+  // If the name is present, parse it.
+  StringAttr nameAttr;
+  (void)parser.parseSymbolName(nameAttr, mlir::SymbolTable::getSymbolAttrName(),
+                               result.attributes);
+
   // If module attributes are present, parse them.
   if (succeeded(parser.parseOptionalKeyword("attributes")))
     if (parser.parseOptionalAttributeDict(result.attributes))
@@ -56,11 +65,17 @@ ParseResult ModuleOp::parse(OpAsmParser &parser, OperationState &result) {
 void ModuleOp::print(OpAsmPrinter &p) {
   p << "module";
 
+  StringRef name = getName();
+  if (!name.empty())
+    p << " @" << name;
+
   // Print the module attributes.
   auto attrs = getAttrs();
-  if (!attrs.empty()) {
+  if (!attrs.empty() &&
+      !(attrs.size() == 1 && attrs.front().first.strref() ==
+                                 mlir::SymbolTable::getSymbolAttrName())) {
     p << " attributes";
-    p.printOptionalAttrDict(attrs, {});
+    p.printOptionalAttrDict(attrs, {mlir::SymbolTable::getSymbolAttrName()});
   }
 
   // Print the region.
@@ -80,9 +95,11 @@ LogicalResult ModuleOp::verify() {
   if (body->getNumArguments() != 0)
     return emitOpError("expected body to have no arguments");
 
-  // Check that none of the attributes are non-dialect attributes.
+  // Check that none of the attributes are non-dialect attributes, except for
+  // the symbol name attribute.
   for (auto attr : getOperation()->getAttrList().getAttrs()) {
-    if (!attr.first.strref().contains('.'))
+    if (!attr.first.strref().contains('.') &&
+        attr.first.strref() != mlir::SymbolTable::getSymbolAttrName())
       return emitOpError(
                  "can only contain dialect-specific attributes, found: '")
              << attr.first << "'";
@@ -94,3 +111,10 @@ LogicalResult ModuleOp::verify() {
 /// Return body of this module.
 Region &ModuleOp::getBodyRegion() { return getOperation()->getRegion(0); }
 Block *ModuleOp::getBody() { return &getBodyRegion().front(); }
+
+StringRef ModuleOp::getName() {
+  if (auto nameAttr =
+          getAttrOfType<StringAttr>(mlir::SymbolTable::getSymbolAttrName()))
+    return nameAttr.getValue();
+  return {};
+}
index a4ca57d..84a404b 100644 (file)
@@ -43,3 +43,15 @@ module {
 // CHECK: }
 %0 = "op"() : () -> i32
 
+// -----
+
+// CHECK-LABEL: module @foo
+// CHECK-NOT: attributes
+module @foo {
+  // CHECK: module
+  module {
+    // CHECK: module @bar attributes
+    module @bar attributes {foo.bar} {
+    }
+  }
+}