Add support for nested symbol references.
authorRiver Riddle <riverriddle@google.com>
Tue, 12 Nov 2019 02:18:02 +0000 (18:18 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 12 Nov 2019 02:18:31 +0000 (18:18 -0800)
This change allows for adding additional nested references to a SymbolRefAttr to allow for further resolving a symbol if that symbol also defines a SymbolTable. If a referenced symbol also defines a symbol table, a nested reference can be used to refer to a symbol within that table. Nested references are printed after the main reference in the following form:

  symbol-ref-attribute ::= symbol-ref-id (`::` symbol-ref-id)*

Example:

  module @reference {
    func @nested_reference()
  }

  my_reference_op @reference::@nested_reference

Given that SymbolRefAttr is now more general, the existing functionality centered around a single reference is moved to a derived class FlatSymbolRefAttr. Followup commits will add support to lookups, rauw, etc. for scoped references.

PiperOrigin-RevId: 279860501

38 files changed:
mlir/examples/toy/Ch2/include/toy/Ops.td
mlir/examples/toy/Ch3/include/toy/Ops.td
mlir/examples/toy/Ch4/include/toy/Ops.td
mlir/examples/toy/Ch5/include/toy/Ops.td
mlir/examples/toy/Ch6/include/toy/Ops.td
mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
mlir/examples/toy/Ch7/include/toy/Ops.td
mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
mlir/g3doc/LangRef.md
mlir/g3doc/Tutorials/Toy/Ch-6.md
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td
mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
mlir/include/mlir/Dialect/StandardOps/Ops.td
mlir/include/mlir/IR/Attributes.h
mlir/include/mlir/IR/Builders.h
mlir/include/mlir/IR/Function.h
mlir/include/mlir/IR/OpBase.td
mlir/lib/Analysis/CallGraph.cpp
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
mlir/lib/Dialect/StandardOps/Ops.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/AttributeDetail.h
mlir/lib/IR/Attributes.cpp
mlir/lib/IR/Builders.cpp
mlir/lib/IR/FunctionSupport.cpp
mlir/lib/Parser/Parser.cpp
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/test/IR/parser.mlir
mlir/test/lib/TestDialect/TestOps.td
mlir/test/mlir-tblgen/op-attribute.td

index 799813b..fb41818 100644 (file)
@@ -123,7 +123,7 @@ def GenericCallOp : Toy_Op<"generic_call"> {
 
   // The generic call operation takes a symbol reference attribute as the
   // callee, and inputs for the call.
-  let arguments = (ins SymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);
+  let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);
 
   // The generic call operation returns a single value of TensorType.
   let results = (outs F64Tensor);
index 0be2b66..25bf06c 100644 (file)
@@ -123,7 +123,7 @@ def GenericCallOp : Toy_Op<"generic_call"> {
 
   // The generic call operation takes a symbol reference attribute as the
   // callee, and inputs for the call.
-  let arguments = (ins SymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);
+  let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);
 
   // The generic call operation returns a single value of TensorType.
   let results = (outs F64Tensor);
index 27fdc34..bbbcc48 100644 (file)
@@ -148,7 +148,7 @@ def GenericCallOp : Toy_Op<"generic_call",
 
   // The generic call operation takes a symbol reference attribute as the
   // callee, and inputs for the call.
-  let arguments = (ins SymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);
+  let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);
 
   // The generic call operation returns a single value of TensorType.
   let results = (outs F64Tensor);
index f609a14..d9306f4 100644 (file)
@@ -148,7 +148,7 @@ def GenericCallOp : Toy_Op<"generic_call",
 
   // The generic call operation takes a symbol reference attribute as the
   // callee, and inputs for the call.
-  let arguments = (ins SymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);
+  let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);
 
   // The generic call operation returns a single value of TensorType.
   let results = (outs F64Tensor);
index f609a14..d9306f4 100644 (file)
@@ -148,7 +148,7 @@ def GenericCallOp : Toy_Op<"generic_call",
 
   // The generic call operation takes a symbol reference attribute as the
   // callee, and inputs for the call.
-  let arguments = (ins SymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);
+  let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);
 
   // The generic call operation returns a single value of TensorType.
   let results = (outs F64Tensor);
index 7e300fb..091eada 100644 (file)
@@ -107,9 +107,9 @@ public:
 private:
   /// Return a symbol reference to the printf function, inserting it into the
   /// module if necessary.
-  static SymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter,
-                                         ModuleOp module,
-                                         LLVM::LLVMDialect *llvmDialect) {
+  static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter,
+                                             ModuleOp module,
+                                             LLVM::LLVMDialect *llvmDialect) {
     auto *context = module.getContext();
     if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf"))
       return SymbolRefAttr::get("printf", context);
index 5e932bb..d41406a 100644 (file)
@@ -160,7 +160,7 @@ def GenericCallOp : Toy_Op<"generic_call",
 
   // The generic call operation takes a symbol reference attribute as the
   // callee, and inputs for the call.
-  let arguments = (ins SymbolRefAttr:$callee, Variadic<Toy_Type>:$inputs);
+  let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<Toy_Type>:$inputs);
 
   // The generic call operation returns a single value of TensorType or
   // StructType.
index 7e300fb..091eada 100644 (file)
@@ -107,9 +107,9 @@ public:
 private:
   /// Return a symbol reference to the printf function, inserting it into the
   /// module if necessary.
-  static SymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter,
-                                         ModuleOp module,
-                                         LLVM::LLVMDialect *llvmDialect) {
+  static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter,
+                                             ModuleOp module,
+                                             LLVM::LLVMDialect *llvmDialect) {
     auto *context = module.getContext();
     if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf"))
       return SymbolRefAttr::get("printf", context);
index 391f773..3409b9f 100644 (file)
@@ -1367,13 +1367,15 @@ A string attribute is an attribute that represents a string literal value.
 Syntax:
 
 ``` {.ebnf}
-symbol-ref-attribute ::= symbol-ref-id
+symbol-ref-attribute ::= symbol-ref-id (`::` symbol-ref-id)*
 ```
 
 A symbol reference attribute is a literal attribute that represents a named
 reference to an operation that is nested within an operation with the
 `OpTrait::SymbolTable` trait. As such, this reference is given meaning by the
-nearest parent operation containing the `OpTrait::SymbolTable` trait.
+nearest parent operation containing the `OpTrait::SymbolTable` trait. It may
+optionally contain a set of nested references that further resolve to a symbol
+nested within a different symbol table.
 
 This attribute can only be held internally by
 [array attributes](#array-attribute) and
index b01dfde..49114b4 100644 (file)
@@ -26,9 +26,9 @@ During lowering we can get, or build, the declaration for printf as so:
 ```c++
 /// Return a symbol reference to the printf function, inserting it into the
 /// module if necessary.
-static SymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter,
-                                       ModuleOp module,
-                                       LLVM::LLVMDialect *llvmDialect) {
+static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter,
+                                           ModuleOp module,
+                                           LLVM::LLVMDialect *llvmDialect) {
   auto *context = module.getContext();
   if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf"))
     return SymbolRefAttr::get("printf", context);
index 2950926..ba337e2 100644 (file)
@@ -341,7 +341,7 @@ def LLVM_FPTruncOp : LLVM_CastOp<"fptrunc", "CreateFPTrunc">;
 
 // Call-related operations.
 def LLVM_CallOp : LLVM_Op<"call">,
-                  Arguments<(ins OptionalAttr<SymbolRefAttr>:$callee,
+                  Arguments<(ins OptionalAttr<FlatSymbolRefAttr>:$callee,
                              Variadic<LLVM_Type>)>,
                   Results<(outs Variadic<LLVM_Type>)>,
                   LLVM_TwoBuilders<LLVM_OneResultOpBuilder,
@@ -479,7 +479,7 @@ def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable", []> {
 // to work correctly).
 def LLVM_AddressOfOp
     : LLVM_OneResultOp<"mlir.addressof">,
-      Arguments<(ins SymbolRefAttr:$global_name)> {
+      Arguments<(ins FlatSymbolRefAttr:$global_name)> {
   let builders = [
     OpBuilder<"Builder *builder, OperationState &result, LLVMType resType, "
               "StringRef name, ArrayRef<NamedAttribute> attrs = {}", [{
index 467cd08..dd16498 100644 (file)
@@ -376,7 +376,7 @@ class GenericOpBase<string mnemonic> : LinalgLibraryBase_Op<mnemonic, []> {
                    I64ArrayAttr:$n_loop_types,
                    I64ArrayAttr:$n_views,
                    OptionalAttr<StrAttr>:$doc,
-                   OptionalAttr<SymbolRefAttr>:$fun,
+                   OptionalAttr<FlatSymbolRefAttr>:$fun,
                    OptionalAttr<StrAttr>:$library_call);
   let regions = (region AnyRegion:$region);
   let extraClassDeclaration = [{
@@ -464,7 +464,7 @@ def GenericOp : GenericOpBase<"generic"> {
 
     Where #trait_attributes is an alias of a dictionary attribute containing:
       - doc [optional]: a documentation string
-      - fun: a SymbolRefAttr that must resolve to an existing function symbol.
+      - fun: a FlatSymbolRefAttr that must resolve to an existing function symbol.
         To support inplace updates in a generic fashion, the signature of the
         function must be:
         ```
@@ -558,7 +558,7 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
 
     Where #trait_attributes is an alias of a dictionary attribute containing:
       - doc [optional]: a documentation string
-      - fun: a SymbolRefAttr that must resolve to an existing function symbol.
+      - fun: a FlatSymbolRefAttr that must resolve to an existing function symbol.
         To support inplace updates in a generic fashion, the signature of the
         function must be:
         ```
index 070d6a6..8de2aeb 100644 (file)
@@ -250,7 +250,7 @@ def SPV_FunctionCallOp : SPV_Op<"FunctionCall", [
   }];
 
   let arguments = (ins
-    SymbolRefAttr:$callee,
+    FlatSymbolRefAttr:$callee,
     Variadic<SPV_Type>:$arguments
   );
 
index 0e135c5..9be4898 100644 (file)
@@ -273,7 +273,7 @@ def SPV_ExecutionModeOp : SPV_Op<"ExecutionMode", [InModuleScope]> {
   }];
 
   let arguments = (ins
-    SymbolRefAttr:$fn,
+    FlatSymbolRefAttr:$fn,
     SPV_ExecutionModeAttr:$execution_mode,
     I32ArrayAttr:$values
   );
index aadd179..fab97bd 100644 (file)
@@ -55,7 +55,7 @@ def SPV_AddressOfOp : SPV_Op<"_address_of", [InFunctionScope, NoSideEffect]> {
   }];
 
   let arguments = (ins
-    SymbolRefAttr:$variable
+    FlatSymbolRefAttr:$variable
   );
 
   let results = (outs
@@ -174,7 +174,7 @@ def SPV_EntryPointOp : SPV_Op<"EntryPoint", [InModuleScope]> {
 
   let arguments = (ins
     SPV_ExecutionModelAttr:$execution_model,
-    SymbolRefAttr:$fn,
+    FlatSymbolRefAttr:$fn,
     SymbolRefArrayAttr:$interface
   );
 
@@ -237,7 +237,7 @@ def SPV_GlobalVariableOp : SPV_Op<"globalVariable", [InModuleScope, Symbol]> {
   let arguments = (ins
     TypeAttr:$type,
     StrAttr:$sym_name,
-    OptionalAttr<SymbolRefAttr>:$initializer
+    OptionalAttr<FlatSymbolRefAttr>:$initializer
   );
 
   let builders = [
@@ -394,7 +394,7 @@ def SPV_ReferenceOfOp : SPV_Op<"_reference_of", [NoSideEffect]> {
   }];
 
   let arguments = (ins
-    SymbolRefAttr:$spec_const
+    FlatSymbolRefAttr:$spec_const
   );
 
   let results = (outs
index d7de155..fa72306 100644 (file)
@@ -239,14 +239,15 @@ def BranchOp : Std_Op<"br", [Terminator]> {
 def CallOp : Std_Op<"call", [CallOpInterface]> {
   let summary = "call operation";
   let description = [{
-    The "call" operation represents a direct call to a function.  The operands
-    and result types of the call must match the specified function type.  The
-    callee is encoded as a function attribute named "callee".
+    The "call" operation represents a direct call to a function that is within
+    the same symbol scope as the call.  The operands and result types of the
+    call must match the specified function type. The callee is encoded as a
+    function attribute named "callee".
 
       %2 = call @my_add(%0, %1) : (f32, f32) -> f32
   }];
 
-  let arguments = (ins SymbolRefAttr:$callee, Variadic<AnyType>:$operands);
+  let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<AnyType>:$operands);
   let results = (outs Variadic<AnyType>);
 
   let builders = [OpBuilder<
index 5d98f6e..8a5e3b5 100644 (file)
@@ -44,6 +44,7 @@ struct IntegerSetAttributeStorage;
 struct FloatAttributeStorage;
 struct OpaqueAttributeStorage;
 struct StringAttributeStorage;
+struct SymbolRefAttributeStorage;
 struct TypeAttributeStorage;
 
 /// Elements Attributes.
@@ -179,6 +180,10 @@ enum Kind {
 };
 } // namespace StandardAttributes
 
+//===----------------------------------------------------------------------===//
+// AffineMapAttr
+//===----------------------------------------------------------------------===//
+
 class AffineMapAttr
     : public Attribute::AttrBase<AffineMapAttr, Attribute,
                                  detail::AffineMapAttributeStorage> {
@@ -196,6 +201,10 @@ public:
   }
 };
 
+//===----------------------------------------------------------------------===//
+// ArrayAttr
+//===----------------------------------------------------------------------===//
+
 /// Array attributes are lists of other attributes.  They are not necessarily
 /// type homogenous given that attributes don't, in general, carry types.
 class ArrayAttr : public Attribute::AttrBase<ArrayAttr, Attribute,
@@ -220,6 +229,10 @@ public:
   }
 };
 
+//===----------------------------------------------------------------------===//
+// BoolAttr
+//===----------------------------------------------------------------------===//
+
 class BoolAttr : public Attribute::AttrBase<BoolAttr, Attribute,
                                             detail::BoolAttributeStorage> {
 public:
@@ -234,6 +247,10 @@ public:
   static bool kindof(unsigned kind) { return kind == StandardAttributes::Bool; }
 };
 
+//===----------------------------------------------------------------------===//
+// DictionaryAttr
+//===----------------------------------------------------------------------===//
+
 /// NamedAttribute is used for dictionary attributes, it holds an identifier for
 /// the name and a value for the attribute. The attribute pointer should always
 /// be non-null.
@@ -271,6 +288,10 @@ public:
   }
 };
 
+//===----------------------------------------------------------------------===//
+// FloatAttr
+//===----------------------------------------------------------------------===//
+
 class FloatAttr : public Attribute::AttrBase<FloatAttr, Attribute,
                                              detail::FloatAttributeStorage> {
 public:
@@ -308,6 +329,10 @@ public:
                                Type type, const APFloat &value);
 };
 
+//===----------------------------------------------------------------------===//
+// IntegerAttr
+//===----------------------------------------------------------------------===//
+
 class IntegerAttr
     : public Attribute::AttrBase<IntegerAttr, Attribute,
                                  detail::IntegerAttributeStorage> {
@@ -328,6 +353,10 @@ public:
   }
 };
 
+//===----------------------------------------------------------------------===//
+// IntegerSetAttr
+//===----------------------------------------------------------------------===//
+
 class IntegerSetAttr
     : public Attribute::AttrBase<IntegerSetAttr, Attribute,
                                  detail::IntegerSetAttributeStorage> {
@@ -345,6 +374,10 @@ public:
   }
 };
 
+//===----------------------------------------------------------------------===//
+// OpaqueAttr
+//===----------------------------------------------------------------------===//
+
 /// Opaque attributes represent attributes of non-registered dialects. These are
 /// attribute represented in their raw string form, and can only usefully be
 /// tested for attribute equality.
@@ -380,6 +413,10 @@ public:
   }
 };
 
+//===----------------------------------------------------------------------===//
+// StringAttr
+//===----------------------------------------------------------------------===//
+
 class StringAttr : public Attribute::AttrBase<StringAttr, Attribute,
                                               detail::StringAttributeStorage> {
 public:
@@ -400,19 +437,40 @@ public:
   }
 };
 
+//===----------------------------------------------------------------------===//
+// SymbolRefAttr
+//===----------------------------------------------------------------------===//
+
+class FlatSymbolRefAttr;
+
 /// A symbol reference attribute represents a symbolic reference to another
 /// operation.
 class SymbolRefAttr
     : public Attribute::AttrBase<SymbolRefAttr, Attribute,
-                                 detail::StringAttributeStorage> {
+                                 detail::SymbolRefAttributeStorage> {
 public:
   using Base::Base;
-  using ValueType = StringRef;
 
-  static SymbolRefAttr get(StringRef value, MLIRContext *ctx);
+  /// Construct a symbol reference for the given value name.
+  static FlatSymbolRefAttr get(StringRef value, MLIRContext *ctx);
 
-  /// Returns the name of the held symbol reference.
-  StringRef getValue() const;
+  /// Construct a symbol reference for the given value name, and a set of nested
+  /// references that are further resolve to a nested symbol.
+  static SymbolRefAttr get(StringRef value,
+                           ArrayRef<FlatSymbolRefAttr> references,
+                           MLIRContext *ctx);
+
+  /// Returns the name of the top level symbol reference, i.e. the root of the
+  /// reference path.
+  StringRef getRootReference() const;
+
+  /// Returns the name of the fully resolved symbol, i.e. the leaf of the
+  /// reference path.
+  StringRef getLeafReference() const;
+
+  /// Returns the set of nested references representing the path to the symbol
+  /// nested under the root reference.
+  ArrayRef<FlatSymbolRefAttr> getNestedReferences() const;
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
   static bool kindof(unsigned kind) {
@@ -420,6 +478,36 @@ public:
   }
 };
 
+/// A symbol reference with a reference path containing a single element. This
+/// is used to refer to an operation within the current symbol table.
+class FlatSymbolRefAttr : public SymbolRefAttr {
+public:
+  using SymbolRefAttr::SymbolRefAttr;
+  using ValueType = StringRef;
+
+  /// Construct a symbol reference for the given value name.
+  static FlatSymbolRefAttr get(StringRef value, MLIRContext *ctx) {
+    return SymbolRefAttr::get(value, ctx);
+  }
+
+  /// Returns the name of the held symbol reference.
+  StringRef getValue() const { return getRootReference(); }
+
+  /// Methods for support type inquiry through isa, cast, and dyn_cast.
+  static bool classof(Attribute attr) {
+    SymbolRefAttr refAttr = attr.dyn_cast<SymbolRefAttr>();
+    return refAttr && refAttr.getNestedReferences().empty();
+  }
+
+private:
+  using SymbolRefAttr::get;
+  using SymbolRefAttr::getNestedReferences;
+};
+
+//===----------------------------------------------------------------------===//
+// Type
+//===----------------------------------------------------------------------===//
+
 class TypeAttr : public Attribute::AttrBase<TypeAttr, Attribute,
                                             detail::TypeAttributeStorage> {
 public:
@@ -434,6 +522,10 @@ public:
   static bool kindof(unsigned kind) { return kind == StandardAttributes::Type; }
 };
 
+//===----------------------------------------------------------------------===//
+// UnitAttr
+//===----------------------------------------------------------------------===//
+
 /// Unit attributes are attributes that hold no specific value and are given
 /// meaning by their existence.
 class UnitAttr : public Attribute::AttrBase<UnitAttr> {
index 0005a39..01ad38c 100644 (file)
@@ -100,8 +100,10 @@ public:
   FloatAttr getFloatAttr(Type type, const APFloat &value);
   StringAttr getStringAttr(StringRef bytes);
   ArrayAttr getArrayAttr(ArrayRef<Attribute> value);
-  SymbolRefAttr getSymbolRefAttr(Operation *value);
-  SymbolRefAttr getSymbolRefAttr(StringRef value);
+  FlatSymbolRefAttr getSymbolRefAttr(Operation *value);
+  FlatSymbolRefAttr getSymbolRefAttr(StringRef value);
+  SymbolRefAttr getSymbolRefAttr(StringRef value,
+                                 ArrayRef<FlatSymbolRefAttr> nestedReferences);
 
   // Returns a 0-valued attribute of the given `type`. This function only
   // supports boolean, integer, and 16-/32-/64-bit float types, and vector or
index 0f435b2..228b030 100644 (file)
@@ -119,7 +119,7 @@ public:
   /// to. This may return null in the case of an external callable object, e.g.
   /// an external function.
   Region *getCallableRegion(CallInterfaceCallable callable) {
-    assert(callable.get<SymbolRefAttr>().getValue() == getName());
+    assert(callable.get<SymbolRefAttr>().getLeafReference() == getName());
     return isExternal() ? nullptr : &getBody();
   }
 
index 62d9542..2f67560 100644 (file)
@@ -1139,8 +1139,16 @@ class StructAttr<string name, Dialect dialect,
 def SymbolRefAttr : Attr<CPred<"$_self.isa<SymbolRefAttr>()">,
                         "symbol reference attribute"> {
   let storageType = [{ SymbolRefAttr }];
+  let returnType = [{ SymbolRefAttr }];
+  let constBuilderCall = "$_builder.getSymbolRefAttr($0)";
+  let convertFromStorage = "$_self";
+}
+def FlatSymbolRefAttr : Attr<CPred<"$_self.isa<FlatSymbolRefAttr>()">,
+                                   "flat symbol reference attribute"> {
+  let storageType = [{ FlatSymbolRefAttr }];
   let returnType = [{ StringRef }];
   let constBuilderCall = "$_builder.getSymbolRefAttr($0)";
+  let convertFromStorage = "$_self.getValue()";
 }
 
 def SymbolRefArrayAttr :
@@ -1241,12 +1249,14 @@ class IntArrayNthElemMinValue<int index, int min> : AttrConstraint<
 def IsNullAttr : AttrConstraint<
     CPred<"!$_self">, "empty attribute (for optional attributes)">;
 
-// An attribute constraint on SymbolRefAttr that requires the SymbolRefAttr
-// pointing to an op of `opClass` within the closest parent with a symbol table.
+// An attribute constraint on FlatSymbolRefAttr that requires that the
+// reference point to an op of `opClass` within the closest parent with a symbol
+// table.
+// TODO(riverriddle) Add support for nested symbol references.
 class ReferToOp<string opClass> : AttrConstraint<
     CPred<"isa_and_nonnull<" # opClass # ">("
             "::mlir::SymbolTable::lookupNearestSymbolFrom("
-              "&$_op, $_self.cast<SymbolRefAttr>().getValue()))">,
+              "&$_op, $_self.cast<FlatSymbolRefAttr>().getValue()))">,
     "referencing to a '" # opClass # "' symbol">;
 
 //===----------------------------------------------------------------------===//
index 3c02802..2b5894f 100644 (file)
@@ -184,7 +184,9 @@ CallGraphNode *CallGraph::resolveCallable(CallInterfaceCallable callable,
   // Get the callee operation from the callable.
   Operation *callee;
   if (auto symbolRef = callable.dyn_cast<SymbolRefAttr>())
-    callee = SymbolTable::lookupNearestSymbolFrom(from, symbolRef.getValue());
+    // TODO(riverriddle) Support nested references.
+    callee = SymbolTable::lookupNearestSymbolFrom(from,
+                                                  symbolRef.getRootReference());
   else
     callee = callable.get<Value *>()->getDefiningOp();
 
index d70d51f..bfd094d 100644 (file)
@@ -533,7 +533,8 @@ unsigned LaunchFuncOp::getNumKernelOperands() {
 }
 
 StringRef LaunchFuncOp::getKernelModuleName() {
-  return getAttrOfType<SymbolRefAttr>(getKernelModuleAttrName()).getValue();
+  return getAttrOfType<SymbolRefAttr>(getKernelModuleAttrName())
+      .getRootReference();
 }
 
 Value *LaunchFuncOp::getKernelOperand(unsigned i) {
index 672beee..420b234 100644 (file)
@@ -189,7 +189,8 @@ private:
       if (Optional<SymbolTable::UseRange> symbolUses =
               SymbolTable::getSymbolUses(symbolDefWorklist.pop_back_val())) {
         for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
-          StringRef symbolName = symbolUse.getSymbolRef().getValue();
+          StringRef symbolName =
+              symbolUse.getSymbolRef().cast<FlatSymbolRefAttr>().getValue();
           if (moduleManager.lookupSymbol(symbolName))
             continue;
 
index 7a8bc71..2dc46bf 100644 (file)
@@ -359,8 +359,8 @@ public:
 // Get a SymbolRefAttr containing the library function name for the LinalgOp.
 // If the library function does not exist, insert a declaration.
 template <typename LinalgOp>
-static SymbolRefAttr getLibraryCallSymbolRef(Operation *op,
-                                             PatternRewriter &rewriter) {
+static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op,
+                                                 PatternRewriter &rewriter) {
   auto linalgOp = cast<LinalgOp>(op);
   auto fnName = linalgOp.getLibraryCallName();
   if (fnName.empty()) {
@@ -369,7 +369,7 @@ static SymbolRefAttr getLibraryCallSymbolRef(Operation *op,
   }
 
   // fnName is a dynamic std::String, unique it via a SymbolRefAttr.
-  SymbolRefAttr fnNameAttr = rewriter.getSymbolRefAttr(fnName);
+  FlatSymbolRefAttr fnNameAttr = rewriter.getSymbolRefAttr(fnName);
   auto module = op->getParentOfType<ModuleOp>();
   if (module.lookupSymbol(fnName)) {
     return fnNameAttr;
index b3a9f6f..3c1563e 100644 (file)
@@ -658,7 +658,7 @@ void spirv::AddressOfOp::build(Builder *builder, OperationState &state,
 
 static ParseResult parseAddressOfOp(OpAsmParser &parser,
                                     OperationState &state) {
-  SymbolRefAttr varRefAttr;
+  FlatSymbolRefAttr varRefAttr;
   Type type;
   if (parser.parseAttribute(varRefAttr, Type(), kVariableAttrName,
                             state.attributes) ||
@@ -1088,7 +1088,7 @@ static ParseResult parseEntryPointOp(OpAsmParser &parser,
   SmallVector<Type, 0> idTypes;
   SmallVector<Attribute, 4> interfaceVars;
 
-  SymbolRefAttr fn;
+  FlatSymbolRefAttr fn;
   if (parseEnumAttribute(execModel, parser, state) ||
       parser.parseAttribute(fn, Type(), kFnNameAttrName, state.attributes)) {
     return failure();
@@ -1099,7 +1099,7 @@ static ParseResult parseEntryPointOp(OpAsmParser &parser,
     do {
       // The name of the interface variable attribute isnt important
       auto attrName = "var_symbol";
-      SymbolRefAttr var;
+      FlatSymbolRefAttr var;
       SmallVector<NamedAttribute, 1> attrs;
       if (parser.parseAttribute(var, Type(), attrName, attrs)) {
         return failure();
@@ -1186,7 +1186,7 @@ static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter &printer) {
 
 static ParseResult parseFunctionCallOp(OpAsmParser &parser,
                                        OperationState &state) {
-  SymbolRefAttr calleeAttr;
+  FlatSymbolRefAttr calleeAttr;
   FunctionType type;
   SmallVector<OpAsmParser::OperandType, 4> operands;
   auto loc = parser.getNameLoc();
@@ -1305,7 +1305,7 @@ static ParseResult parseGlobalVariableOp(OpAsmParser &parser,
 
   // Parse optional initializer
   if (succeeded(parser.parseOptionalKeyword(kInitializerAttrName))) {
-    SymbolRefAttr initSymbol;
+    FlatSymbolRefAttr initSymbol;
     if (parser.parseLParen() ||
         parser.parseAttribute(initSymbol, Type(), kInitializerAttrName,
                               state.attributes) ||
@@ -1361,7 +1361,8 @@ static LogicalResult verify(spirv::GlobalVariableOp varOp) {
   if (varOp.storageClass() == spirv::StorageClass::Generic)
     return varOp.emitOpError("storage class cannot be 'Generic'");
 
-  if (auto init = varOp.getAttrOfType<SymbolRefAttr>(kInitializerAttrName)) {
+  if (auto init =
+          varOp.getAttrOfType<FlatSymbolRefAttr>(kInitializerAttrName)) {
     auto moduleOp = varOp.getParentOfType<spirv::ModuleOp>();
     auto *initOp = moduleOp.lookupSymbol(init.getValue());
     // TODO: Currently only variable initialization with specialization
@@ -1713,7 +1714,7 @@ static LogicalResult verify(spirv::ModuleOp moduleOp) {
         }
         if (auto interface = entryPointOp.interface()) {
           for (Attribute varRef : interface) {
-            auto varSymRef = varRef.dyn_cast<SymbolRefAttr>();
+            auto varSymRef = varRef.dyn_cast<FlatSymbolRefAttr>();
             if (!varSymRef) {
               return entryPointOp.emitError(
                          "expected symbol reference for interface "
@@ -1790,7 +1791,7 @@ static LogicalResult verify(spirv::ModuleOp moduleOp) {
 
 static ParseResult parseReferenceOfOp(OpAsmParser &parser,
                                       OperationState &state) {
-  SymbolRefAttr constRefAttr;
+  FlatSymbolRefAttr constRefAttr;
   Type type;
   if (parser.parseAttribute(constRefAttr, Type(), kSpecConstAttrName,
                             state.attributes) ||
index 11660ed..40b5318 100644 (file)
@@ -970,7 +970,7 @@ LogicalResult Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
   wordIndex++;
 
   // Initializer.
-  SymbolRefAttr initializer = nullptr;
+  FlatSymbolRefAttr initializer = nullptr;
   if (wordIndex < operands.size()) {
     auto initializerOp = getGlobalVariable(operands[wordIndex]);
     if (!initializerOp) {
index f92b9ae..805a339 100644 (file)
@@ -1696,7 +1696,7 @@ Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
   // Add the interface values.
   if (auto interface = op.interface()) {
     for (auto var : interface.getValue()) {
-      auto id = getVariableID(var.cast<SymbolRefAttr>().getValue());
+      auto id = getVariableID(var.cast<FlatSymbolRefAttr>().getValue());
       if (!id) {
         return op.emitError("referencing undefined global variable."
                             "spv.EntryPoint is at the end of spv.module. All "
index 1202924..8c08868 100644 (file)
@@ -528,7 +528,7 @@ void BranchOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
 //===----------------------------------------------------------------------===//
 
 static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) {
-  SymbolRefAttr calleeAttr;
+  FlatSymbolRefAttr calleeAttr;
   FunctionType calleeType;
   SmallVector<OpAsmParser::OperandType, 4> operands;
   auto calleeLoc = parser.getNameLoc();
@@ -555,7 +555,7 @@ static void print(OpAsmPrinter &p, CallOp op) {
 
 static LogicalResult verify(CallOp op) {
   // Check that the callee attribute was specified.
-  auto fnAttr = op.getAttrOfType<SymbolRefAttr>("callee");
+  auto fnAttr = op.getAttrOfType<FlatSymbolRefAttr>("callee");
   if (!fnAttr)
     return op.emitOpError("requires a 'callee' symbol reference attribute");
   auto fn =
@@ -608,8 +608,8 @@ struct SimplifyIndirectCallWithKnownCallee
     // Replace with a direct call.
     SmallVector<Type, 8> callResults(indirectCall.getResultTypes());
     SmallVector<Value *, 8> callOperands(indirectCall.getArgOperands());
-    rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn.getValue(),
-                                        callResults, callOperands);
+    rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn, callResults,
+                                        callOperands);
     return matchSuccess();
   }
 };
@@ -1206,7 +1206,7 @@ static LogicalResult verify(ConstantOp &op) {
   }
 
   if (type.isa<FunctionType>()) {
-    auto fnAttr = value.dyn_cast<SymbolRefAttr>();
+    auto fnAttr = value.dyn_cast<FlatSymbolRefAttr>();
     if (!fnAttr)
       return op.emitOpError("requires 'value' to be a function reference");
 
index 43452b2..6f77de0 100644 (file)
@@ -801,9 +801,15 @@ void ModulePrinter::printAttribute(Attribute attr, bool mayElideType) {
   case StandardAttributes::Type:
     printType(attr.cast<TypeAttr>().getValue());
     break;
-  case StandardAttributes::SymbolRef:
-    printSymbolReference(attr.cast<SymbolRefAttr>().getValue(), os);
+  case StandardAttributes::SymbolRef: {
+    auto refAttr = attr.dyn_cast<SymbolRefAttr>();
+    printSymbolReference(refAttr.getRootReference(), os);
+    for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) {
+      os << "::";
+      printSymbolReference(nestedRef.getValue(), os);
+    }
     break;
+  }
   case StandardAttributes::OpaqueElements: {
     auto eltsAttr = attr.cast<OpaqueElementsAttr>();
     os << "opaque<\"" << eltsAttr.getDialect()->getNamespace() << "\", ";
index 21f8b68..da4aa69 100644 (file)
@@ -321,6 +321,43 @@ struct StringAttributeStorage : public AttributeStorage {
   StringRef value;
 };
 
+/// An attribute representing a symbol reference.
+struct SymbolRefAttributeStorage final
+    : public AttributeStorage,
+      public llvm::TrailingObjects<SymbolRefAttributeStorage,
+                                   FlatSymbolRefAttr> {
+  using KeyTy = std::pair<StringRef, ArrayRef<FlatSymbolRefAttr>>;
+
+  SymbolRefAttributeStorage(StringRef value, size_t numNestedRefs)
+      : value(value), numNestedRefs(numNestedRefs) {}
+
+  /// Key equality function.
+  bool operator==(const KeyTy &key) const {
+    return key == KeyTy(value, getNestedRefs());
+  }
+
+  /// Construct a new storage instance.
+  static SymbolRefAttributeStorage *
+  construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
+    auto size = SymbolRefAttributeStorage::totalSizeToAlloc<FlatSymbolRefAttr>(
+        key.second.size());
+    auto rawMem = allocator.allocate(size, alignof(SymbolRefAttributeStorage));
+    auto result = ::new (rawMem) SymbolRefAttributeStorage(
+        allocator.copyInto(key.first), key.second.size());
+    std::uninitialized_copy(key.second.begin(), key.second.end(),
+                            result->getTrailingObjects<FlatSymbolRefAttr>());
+    return result;
+  }
+
+  /// Returns the set of nested references.
+  ArrayRef<FlatSymbolRefAttr> getNestedRefs() const {
+    return {getTrailingObjects<FlatSymbolRefAttr>(), numNestedRefs};
+  }
+
+  StringRef value;
+  size_t numNestedRefs;
+};
+
 /// An attribute representing a reference to a type.
 struct TypeAttributeStorage : public AttributeStorage {
   using KeyTy = Type;
index d74cacb..80ac4a5 100644 (file)
@@ -249,12 +249,27 @@ FloatAttr::verifyConstructionInvariants(llvm::Optional<Location> loc,
 // SymbolRefAttr
 //===----------------------------------------------------------------------===//
 
-SymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) {
-  return Base::get(ctx, StandardAttributes::SymbolRef, value,
-                   NoneType::get(ctx));
+FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) {
+  return Base::get(ctx, StandardAttributes::SymbolRef, value, llvm::None)
+      .cast<FlatSymbolRefAttr>();
 }
 
-StringRef SymbolRefAttr::getValue() const { return getImpl()->value; }
+SymbolRefAttr SymbolRefAttr::get(StringRef value,
+                                 ArrayRef<FlatSymbolRefAttr> nestedReferences,
+                                 MLIRContext *ctx) {
+  return Base::get(ctx, StandardAttributes::SymbolRef, value, nestedReferences);
+}
+
+StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; }
+
+StringRef SymbolRefAttr::getLeafReference() const {
+  ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences();
+  return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue();
+}
+
+ArrayRef<FlatSymbolRefAttr> SymbolRefAttr::getNestedReferences() const {
+  return getImpl()->getNestedRefs();
+}
 
 //===----------------------------------------------------------------------===//
 // IntegerAttr
index 24ae207..afdeefd 100644 (file)
@@ -150,15 +150,20 @@ ArrayAttr Builder::getArrayAttr(ArrayRef<Attribute> value) {
   return ArrayAttr::get(value, context);
 }
 
-SymbolRefAttr Builder::getSymbolRefAttr(Operation *value) {
+FlatSymbolRefAttr Builder::getSymbolRefAttr(Operation *value) {
   auto symName =
       value->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
   assert(symName && "value does not have a valid symbol name");
   return getSymbolRefAttr(symName.getValue());
 }
-SymbolRefAttr Builder::getSymbolRefAttr(StringRef value) {
+FlatSymbolRefAttr Builder::getSymbolRefAttr(StringRef value) {
   return SymbolRefAttr::get(value, getContext());
 }
+SymbolRefAttr
+Builder::getSymbolRefAttr(StringRef value,
+                          ArrayRef<FlatSymbolRefAttr> nestedReferences) {
+  return SymbolRefAttr::get(value, nestedReferences, getContext());
+}
 
 ArrayAttr Builder::getI32ArrayAttr(ArrayRef<int32_t> values) {
   auto attrs = functional::map(
index d1ba2d3..29cae17 100644 (file)
@@ -159,7 +159,7 @@ mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result,
   auto &builder = parser.getBuilder();
 
   // Parse the name as a symbol reference attribute.
-  SymbolRefAttr nameAttr;
+  FlatSymbolRefAttr nameAttr;
   if (parser.parseAttribute(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(),
                             result.attributes))
     return failure();
index 35c694b..2843aae 100644 (file)
@@ -1400,7 +1400,7 @@ static std::string extractSymbolReference(Token tok) {
 ///                    | type
 ///                    | `[` (attribute-value (`,` attribute-value)*)? `]`
 ///                    | `{` (attribute-entry (`,` attribute-entry)*)? `}`
-///                    | symbol-ref-id
+///                    | symbol-ref-id (`::` symbol-ref-id)*
 ///                    | `dense` `<` attribute-value `>` `:`
 ///                      (tensor-type | vector-type)
 ///                    | `sparse` `<` attribute-value `,` attribute-value `>`
@@ -1509,7 +1509,31 @@ Attribute Parser::parseAttribute(Type type) {
   case Token::at_identifier: {
     std::string nameStr = extractSymbolReference(getToken());
     consumeToken(Token::at_identifier);
-    return builder.getSymbolRefAttr(nameStr);
+
+    // Parse any nested references.
+    std::vector<FlatSymbolRefAttr> nestedRefs;
+    while (getToken().is(Token::colon)) {
+      // Check for the '::' prefix.
+      const char *curPointer = getToken().getLoc().getPointer();
+      consumeToken(Token::colon);
+      if (!consumeIf(Token::colon)) {
+        state.lex.resetPointer(curPointer);
+        consumeToken();
+        break;
+      }
+      // Parse the reference itself.
+      auto curLoc = getToken().getLoc();
+      if (getToken().isNot(Token::at_identifier)) {
+        emitError(curLoc, "expected nested symbol reference identifier");
+        return Attribute();
+      }
+
+      std::string nameStr = extractSymbolReference(getToken());
+      consumeToken(Token::at_identifier);
+      nestedRefs.push_back(SymbolRefAttr::get(nameStr, getContext()));
+    }
+
+    return builder.getSymbolRefAttr(nameStr, nestedRefs);
   }
 
   // Parse a 'unit' attribute.
index 69f7e93..7f3ce5a 100644 (file)
@@ -52,7 +52,7 @@ llvm::Constant *ModuleTranslation::getLLVMConstant(llvm::Type *llvmType,
     return llvm::ConstantInt::get(llvmType, intAttr.getValue());
   if (auto floatAttr = attr.dyn_cast<FloatAttr>())
     return llvm::ConstantFP::get(llvmType, floatAttr.getValue());
-  if (auto funcAttr = attr.dyn_cast<SymbolRefAttr>())
+  if (auto funcAttr = attr.dyn_cast<FlatSymbolRefAttr>())
     return functionMapping.lookup(funcAttr.getValue());
   if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
     auto *sequentialType = cast<llvm::SequentialType>(llvmType);
@@ -194,7 +194,7 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
   auto convertCall = [this, &builder](Operation &op) -> llvm::Value * {
     auto operands = lookupValues(op.getOperands());
     ArrayRef<llvm::Value *> operandsRef(operands);
-    if (auto attr = op.getAttrOfType<SymbolRefAttr>("callee")) {
+    if (auto attr = op.getAttrOfType<FlatSymbolRefAttr>("callee")) {
       return builder.CreateCall(functionMapping.lookup(attr.getValue()),
                                 operandsRef);
     } else {
index 37f85e7..dc85fbb 100644 (file)
@@ -1112,3 +1112,7 @@ func @"\"_string_symbol_reference\""() {
   "foo.symbol_reference"() {ref = @"\"_string_symbol_reference\""} : () -> ()
   return
 }
+
+// CHECK-LABEL: func @nested_reference
+// CHECK-NEXT: ref = @some_symbol::@some_nested_symbol
+func @nested_reference() attributes {test.ref = @some_symbol::@some_nested_symbol }
index 4071f7e..2972793 100644 (file)
@@ -206,7 +206,7 @@ def UpdateFloatElementsAttr : Pat<
 
 def SymbolRefOp : TEST_Op<"symbol_ref_attr"> {
   let arguments = (ins
-    Confined<SymbolRefAttr, [ReferToOp<"FuncOp">]>:$symbol
+    Confined<FlatSymbolRefAttr, [ReferToOp<"FuncOp">]>:$symbol
   );
 }
 
@@ -232,7 +232,7 @@ def SizedRegionOp : TEST_Op<"sized_region_op", []> {
 
 def ConversionCallOp : TEST_Op<"conversion_call_op",
     [CallOpInterface]> {
-  let arguments = (ins Variadic<AnyType>:$inputs, SymbolRefAttr:$callee);
+  let arguments = (ins Variadic<AnyType>:$inputs, FlatSymbolRefAttr:$callee);
   let results = (outs Variadic<AnyType>);
 
   let extraClassDeclaration = [{
@@ -241,7 +241,7 @@ def ConversionCallOp : TEST_Op<"conversion_call_op",
 
     /// Return the callee of this operation.
     CallInterfaceCallable getCallableForCallee() {
-      return getAttrOfType<SymbolRefAttr>("callee");
+      return getAttrOfType<FlatSymbolRefAttr>("callee");
     }
   }];
 }
index 82702bf..7fe249b 100644 (file)
@@ -72,7 +72,7 @@ def AOp : NS_Op<"a_op", []> {
 // CHECK:      auto tblgen_cAttr = this->getAttr("cAttr");
 // CHECK-NEXT: if (tblgen_cAttr) {
 // CHECK-NEXT:   if (!((some-condition))) return emitOpError("attribute 'cAttr' failed to satisfy constraint: some attribute kind");
+
 def SomeTypeAttr : TypeAttrBase<"SomeType", "some type attribute">;
 
 def BOp : NS_Op<"b_op", []> {
@@ -85,7 +85,7 @@ def BOp : NS_Op<"b_op", []> {
     F64Attr:$f64_attr,
     StrAttr:$str_attr,
     ElementsAttr:$elements_attr,
-    SymbolRefAttr:$function_attr,
+    FlatSymbolRefAttr:$function_attr,
     SomeTypeAttr:$type_attr,
     ArrayAttr:$array_attr,
     TypedArrayAttrBase<SomeAttr, "SomeAttr array">:$some_attr_array,
@@ -122,7 +122,7 @@ def BOp : NS_Op<"b_op", []> {
 // CHECK: if (!(((tblgen_f64_attr.isa<FloatAttr>())) && ((tblgen_f64_attr.cast<FloatAttr>().getType().isF64()))))
 // CHECK: if (!((tblgen_str_attr.isa<StringAttr>())))
 // CHECK: if (!((tblgen_elements_attr.isa<ElementsAttr>())))
-// CHECK: if (!((tblgen_function_attr.isa<SymbolRefAttr>())))
+// CHECK: if (!((tblgen_function_attr.isa<FlatSymbolRefAttr>())))
 // CHECK: if (!(((tblgen_type_attr.isa<TypeAttr>())) && ((tblgen_type_attr.cast<TypeAttr>().getValue().isa<SomeType>()))))
 // CHECK: if (!((tblgen_array_attr.isa<ArrayAttr>())))
 // CHECK: if (!(((tblgen_some_attr_array.isa<ArrayAttr>())) && (llvm::all_of(tblgen_some_attr_array.cast<ArrayAttr>(), [](Attribute attr) { return (some-condition); }))))