From 9b9c647cefea0a81fdf7d2bf6586a13f99d9a2cf Mon Sep 17 00:00:00 2001 From: River Riddle Date: Mon, 11 Nov 2019 18:18:02 -0800 Subject: [PATCH] Add support for nested symbol references. This change allows for adding additional nested references to a SymbolRefAttr to allow for further resolving a symbol if that symbol also defines a SymbolTable. If a referenced symbol also defines a symbol table, a nested reference can be used to refer to a symbol within that table. Nested references are printed after the main reference in the following form: symbol-ref-attribute ::= symbol-ref-id (`::` symbol-ref-id)* Example: module @reference { func @nested_reference() } my_reference_op @reference::@nested_reference Given that SymbolRefAttr is now more general, the existing functionality centered around a single reference is moved to a derived class FlatSymbolRefAttr. Followup commits will add support to lookups, rauw, etc. for scoped references. PiperOrigin-RevId: 279860501 --- mlir/examples/toy/Ch2/include/toy/Ops.td | 2 +- mlir/examples/toy/Ch3/include/toy/Ops.td | 2 +- mlir/examples/toy/Ch4/include/toy/Ops.td | 2 +- mlir/examples/toy/Ch5/include/toy/Ops.td | 2 +- mlir/examples/toy/Ch6/include/toy/Ops.td | 2 +- mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp | 6 +- mlir/examples/toy/Ch7/include/toy/Ops.td | 2 +- mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp | 6 +- mlir/g3doc/LangRef.md | 6 +- mlir/g3doc/Tutorials/Toy/Ch-6.md | 6 +- mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 4 +- .../mlir/Dialect/Linalg/IR/LinalgLibraryOps.td | 6 +- .../mlir/Dialect/SPIRV/SPIRVControlFlowOps.td | 2 +- mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td | 2 +- .../mlir/Dialect/SPIRV/SPIRVStructureOps.td | 8 +- mlir/include/mlir/Dialect/StandardOps/Ops.td | 9 +- mlir/include/mlir/IR/Attributes.h | 102 ++++++++++++++++++++- mlir/include/mlir/IR/Builders.h | 6 +- mlir/include/mlir/IR/Function.h | 2 +- mlir/include/mlir/IR/OpBase.td | 16 +++- mlir/lib/Analysis/CallGraph.cpp | 4 +- mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 3 +- .../lib/Dialect/GPU/Transforms/KernelOutlining.cpp | 3 +- .../Linalg/Transforms/LowerToLLVMDialect.cpp | 6 +- mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | 17 ++-- .../Dialect/SPIRV/Serialization/Deserializer.cpp | 2 +- .../lib/Dialect/SPIRV/Serialization/Serializer.cpp | 2 +- mlir/lib/Dialect/StandardOps/Ops.cpp | 10 +- mlir/lib/IR/AsmPrinter.cpp | 10 +- mlir/lib/IR/AttributeDetail.h | 37 ++++++++ mlir/lib/IR/Attributes.cpp | 23 ++++- mlir/lib/IR/Builders.cpp | 9 +- mlir/lib/IR/FunctionSupport.cpp | 2 +- mlir/lib/Parser/Parser.cpp | 28 +++++- mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 4 +- mlir/test/IR/parser.mlir | 4 + mlir/test/lib/TestDialect/TestOps.td | 6 +- mlir/test/mlir-tblgen/op-attribute.td | 6 +- 38 files changed, 286 insertions(+), 83 deletions(-) diff --git a/mlir/examples/toy/Ch2/include/toy/Ops.td b/mlir/examples/toy/Ch2/include/toy/Ops.td index 799813b..fb41818 100644 --- a/mlir/examples/toy/Ch2/include/toy/Ops.td +++ b/mlir/examples/toy/Ch2/include/toy/Ops.td @@ -123,7 +123,7 @@ def GenericCallOp : Toy_Op<"generic_call"> { // The generic call operation takes a symbol reference attribute as the // callee, and inputs for the call. - let arguments = (ins SymbolRefAttr:$callee, Variadic:$inputs); + let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$inputs); // The generic call operation returns a single value of TensorType. let results = (outs F64Tensor); diff --git a/mlir/examples/toy/Ch3/include/toy/Ops.td b/mlir/examples/toy/Ch3/include/toy/Ops.td index 0be2b66..25bf06c 100644 --- a/mlir/examples/toy/Ch3/include/toy/Ops.td +++ b/mlir/examples/toy/Ch3/include/toy/Ops.td @@ -123,7 +123,7 @@ def GenericCallOp : Toy_Op<"generic_call"> { // The generic call operation takes a symbol reference attribute as the // callee, and inputs for the call. - let arguments = (ins SymbolRefAttr:$callee, Variadic:$inputs); + let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$inputs); // The generic call operation returns a single value of TensorType. let results = (outs F64Tensor); diff --git a/mlir/examples/toy/Ch4/include/toy/Ops.td b/mlir/examples/toy/Ch4/include/toy/Ops.td index 27fdc34..bbbcc48 100644 --- a/mlir/examples/toy/Ch4/include/toy/Ops.td +++ b/mlir/examples/toy/Ch4/include/toy/Ops.td @@ -148,7 +148,7 @@ def GenericCallOp : Toy_Op<"generic_call", // The generic call operation takes a symbol reference attribute as the // callee, and inputs for the call. - let arguments = (ins SymbolRefAttr:$callee, Variadic:$inputs); + let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$inputs); // The generic call operation returns a single value of TensorType. let results = (outs F64Tensor); diff --git a/mlir/examples/toy/Ch5/include/toy/Ops.td b/mlir/examples/toy/Ch5/include/toy/Ops.td index f609a14..d9306f4 100644 --- a/mlir/examples/toy/Ch5/include/toy/Ops.td +++ b/mlir/examples/toy/Ch5/include/toy/Ops.td @@ -148,7 +148,7 @@ def GenericCallOp : Toy_Op<"generic_call", // The generic call operation takes a symbol reference attribute as the // callee, and inputs for the call. - let arguments = (ins SymbolRefAttr:$callee, Variadic:$inputs); + let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$inputs); // The generic call operation returns a single value of TensorType. let results = (outs F64Tensor); diff --git a/mlir/examples/toy/Ch6/include/toy/Ops.td b/mlir/examples/toy/Ch6/include/toy/Ops.td index f609a14..d9306f4 100644 --- a/mlir/examples/toy/Ch6/include/toy/Ops.td +++ b/mlir/examples/toy/Ch6/include/toy/Ops.td @@ -148,7 +148,7 @@ def GenericCallOp : Toy_Op<"generic_call", // The generic call operation takes a symbol reference attribute as the // callee, and inputs for the call. - let arguments = (ins SymbolRefAttr:$callee, Variadic:$inputs); + let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$inputs); // The generic call operation returns a single value of TensorType. let results = (outs F64Tensor); diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp index 7e300fb..091eada 100644 --- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp @@ -107,9 +107,9 @@ public: private: /// Return a symbol reference to the printf function, inserting it into the /// module if necessary. - static SymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter, - ModuleOp module, - LLVM::LLVMDialect *llvmDialect) { + static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter, + ModuleOp module, + LLVM::LLVMDialect *llvmDialect) { auto *context = module.getContext(); if (module.lookupSymbol("printf")) return SymbolRefAttr::get("printf", context); diff --git a/mlir/examples/toy/Ch7/include/toy/Ops.td b/mlir/examples/toy/Ch7/include/toy/Ops.td index 5e932bb..d41406a 100644 --- a/mlir/examples/toy/Ch7/include/toy/Ops.td +++ b/mlir/examples/toy/Ch7/include/toy/Ops.td @@ -160,7 +160,7 @@ def GenericCallOp : Toy_Op<"generic_call", // The generic call operation takes a symbol reference attribute as the // callee, and inputs for the call. - let arguments = (ins SymbolRefAttr:$callee, Variadic:$inputs); + let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$inputs); // The generic call operation returns a single value of TensorType or // StructType. diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp index 7e300fb..091eada 100644 --- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp @@ -107,9 +107,9 @@ public: private: /// Return a symbol reference to the printf function, inserting it into the /// module if necessary. - static SymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter, - ModuleOp module, - LLVM::LLVMDialect *llvmDialect) { + static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter, + ModuleOp module, + LLVM::LLVMDialect *llvmDialect) { auto *context = module.getContext(); if (module.lookupSymbol("printf")) return SymbolRefAttr::get("printf", context); diff --git a/mlir/g3doc/LangRef.md b/mlir/g3doc/LangRef.md index 391f773..3409b9f 100644 --- a/mlir/g3doc/LangRef.md +++ b/mlir/g3doc/LangRef.md @@ -1367,13 +1367,15 @@ A string attribute is an attribute that represents a string literal value. Syntax: ``` {.ebnf} -symbol-ref-attribute ::= symbol-ref-id +symbol-ref-attribute ::= symbol-ref-id (`::` symbol-ref-id)* ``` A symbol reference attribute is a literal attribute that represents a named reference to an operation that is nested within an operation with the `OpTrait::SymbolTable` trait. As such, this reference is given meaning by the -nearest parent operation containing the `OpTrait::SymbolTable` trait. +nearest parent operation containing the `OpTrait::SymbolTable` trait. It may +optionally contain a set of nested references that further resolve to a symbol +nested within a different symbol table. This attribute can only be held internally by [array attributes](#array-attribute) and diff --git a/mlir/g3doc/Tutorials/Toy/Ch-6.md b/mlir/g3doc/Tutorials/Toy/Ch-6.md index b01dfde..49114b4 100644 --- a/mlir/g3doc/Tutorials/Toy/Ch-6.md +++ b/mlir/g3doc/Tutorials/Toy/Ch-6.md @@ -26,9 +26,9 @@ During lowering we can get, or build, the declaration for printf as so: ```c++ /// Return a symbol reference to the printf function, inserting it into the /// module if necessary. -static SymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter, - ModuleOp module, - LLVM::LLVMDialect *llvmDialect) { +static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter, + ModuleOp module, + LLVM::LLVMDialect *llvmDialect) { auto *context = module.getContext(); if (module.lookupSymbol("printf")) return SymbolRefAttr::get("printf", context); diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 2950926..ba337e2 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -341,7 +341,7 @@ def LLVM_FPTruncOp : LLVM_CastOp<"fptrunc", "CreateFPTrunc">; // Call-related operations. def LLVM_CallOp : LLVM_Op<"call">, - Arguments<(ins OptionalAttr:$callee, + Arguments<(ins OptionalAttr:$callee, Variadic)>, Results<(outs Variadic)>, LLVM_TwoBuilders { // to work correctly). def LLVM_AddressOfOp : LLVM_OneResultOp<"mlir.addressof">, - Arguments<(ins SymbolRefAttr:$global_name)> { + Arguments<(ins FlatSymbolRefAttr:$global_name)> { let builders = [ OpBuilder<"Builder *builder, OperationState &result, LLVMType resType, " "StringRef name, ArrayRef attrs = {}", [{ diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td index 467cd08..dd16498 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td @@ -376,7 +376,7 @@ class GenericOpBase : LinalgLibraryBase_Op { I64ArrayAttr:$n_loop_types, I64ArrayAttr:$n_views, OptionalAttr:$doc, - OptionalAttr:$fun, + OptionalAttr:$fun, OptionalAttr:$library_call); let regions = (region AnyRegion:$region); let extraClassDeclaration = [{ @@ -464,7 +464,7 @@ def GenericOp : GenericOpBase<"generic"> { Where #trait_attributes is an alias of a dictionary attribute containing: - doc [optional]: a documentation string - - fun: a SymbolRefAttr that must resolve to an existing function symbol. + - fun: a FlatSymbolRefAttr that must resolve to an existing function symbol. To support inplace updates in a generic fashion, the signature of the function must be: ``` @@ -558,7 +558,7 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> { Where #trait_attributes is an alias of a dictionary attribute containing: - doc [optional]: a documentation string - - fun: a SymbolRefAttr that must resolve to an existing function symbol. + - fun: a FlatSymbolRefAttr that must resolve to an existing function symbol. To support inplace updates in a generic fashion, the signature of the function must be: ``` diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td index 070d6a6..8de2aeb 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td @@ -250,7 +250,7 @@ def SPV_FunctionCallOp : SPV_Op<"FunctionCall", [ }]; let arguments = (ins - SymbolRefAttr:$callee, + FlatSymbolRefAttr:$callee, Variadic:$arguments ); diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td index 0e135c5..9be4898 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td @@ -273,7 +273,7 @@ def SPV_ExecutionModeOp : SPV_Op<"ExecutionMode", [InModuleScope]> { }]; let arguments = (ins - SymbolRefAttr:$fn, + FlatSymbolRefAttr:$fn, SPV_ExecutionModeAttr:$execution_mode, I32ArrayAttr:$values ); diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td index aadd179..fab97bd 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td @@ -55,7 +55,7 @@ def SPV_AddressOfOp : SPV_Op<"_address_of", [InFunctionScope, NoSideEffect]> { }]; let arguments = (ins - SymbolRefAttr:$variable + FlatSymbolRefAttr:$variable ); let results = (outs @@ -174,7 +174,7 @@ def SPV_EntryPointOp : SPV_Op<"EntryPoint", [InModuleScope]> { let arguments = (ins SPV_ExecutionModelAttr:$execution_model, - SymbolRefAttr:$fn, + FlatSymbolRefAttr:$fn, SymbolRefArrayAttr:$interface ); @@ -237,7 +237,7 @@ def SPV_GlobalVariableOp : SPV_Op<"globalVariable", [InModuleScope, Symbol]> { let arguments = (ins TypeAttr:$type, StrAttr:$sym_name, - OptionalAttr:$initializer + OptionalAttr:$initializer ); let builders = [ @@ -394,7 +394,7 @@ def SPV_ReferenceOfOp : SPV_Op<"_reference_of", [NoSideEffect]> { }]; let arguments = (ins - SymbolRefAttr:$spec_const + FlatSymbolRefAttr:$spec_const ); let results = (outs diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td index d7de155..fa72306 100644 --- a/mlir/include/mlir/Dialect/StandardOps/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td @@ -239,14 +239,15 @@ def BranchOp : Std_Op<"br", [Terminator]> { def CallOp : Std_Op<"call", [CallOpInterface]> { let summary = "call operation"; let description = [{ - The "call" operation represents a direct call to a function. The operands - and result types of the call must match the specified function type. The - callee is encoded as a function attribute named "callee". + The "call" operation represents a direct call to a function that is within + the same symbol scope as the call. The operands and result types of the + call must match the specified function type. The callee is encoded as a + function attribute named "callee". %2 = call @my_add(%0, %1) : (f32, f32) -> f32 }]; - let arguments = (ins SymbolRefAttr:$callee, Variadic:$operands); + let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$operands); let results = (outs Variadic); let builders = [OpBuilder< diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h index 5d98f6e..8a5e3b5 100644 --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -44,6 +44,7 @@ struct IntegerSetAttributeStorage; struct FloatAttributeStorage; struct OpaqueAttributeStorage; struct StringAttributeStorage; +struct SymbolRefAttributeStorage; struct TypeAttributeStorage; /// Elements Attributes. @@ -179,6 +180,10 @@ enum Kind { }; } // namespace StandardAttributes +//===----------------------------------------------------------------------===// +// AffineMapAttr +//===----------------------------------------------------------------------===// + class AffineMapAttr : public Attribute::AttrBase { @@ -196,6 +201,10 @@ public: } }; +//===----------------------------------------------------------------------===// +// ArrayAttr +//===----------------------------------------------------------------------===// + /// Array attributes are lists of other attributes. They are not necessarily /// type homogenous given that attributes don't, in general, carry types. class ArrayAttr : public Attribute::AttrBase { public: @@ -234,6 +247,10 @@ public: static bool kindof(unsigned kind) { return kind == StandardAttributes::Bool; } }; +//===----------------------------------------------------------------------===// +// DictionaryAttr +//===----------------------------------------------------------------------===// + /// NamedAttribute is used for dictionary attributes, it holds an identifier for /// the name and a value for the attribute. The attribute pointer should always /// be non-null. @@ -271,6 +288,10 @@ public: } }; +//===----------------------------------------------------------------------===// +// FloatAttr +//===----------------------------------------------------------------------===// + class FloatAttr : public Attribute::AttrBase { public: @@ -308,6 +329,10 @@ public: Type type, const APFloat &value); }; +//===----------------------------------------------------------------------===// +// IntegerAttr +//===----------------------------------------------------------------------===// + class IntegerAttr : public Attribute::AttrBase { @@ -328,6 +353,10 @@ public: } }; +//===----------------------------------------------------------------------===// +// IntegerSetAttr +//===----------------------------------------------------------------------===// + class IntegerSetAttr : public Attribute::AttrBase { @@ -345,6 +374,10 @@ public: } }; +//===----------------------------------------------------------------------===// +// OpaqueAttr +//===----------------------------------------------------------------------===// + /// Opaque attributes represent attributes of non-registered dialects. These are /// attribute represented in their raw string form, and can only usefully be /// tested for attribute equality. @@ -380,6 +413,10 @@ public: } }; +//===----------------------------------------------------------------------===// +// StringAttr +//===----------------------------------------------------------------------===// + class StringAttr : public Attribute::AttrBase { public: @@ -400,19 +437,40 @@ public: } }; +//===----------------------------------------------------------------------===// +// SymbolRefAttr +//===----------------------------------------------------------------------===// + +class FlatSymbolRefAttr; + /// A symbol reference attribute represents a symbolic reference to another /// operation. class SymbolRefAttr : public Attribute::AttrBase { + detail::SymbolRefAttributeStorage> { public: using Base::Base; - using ValueType = StringRef; - static SymbolRefAttr get(StringRef value, MLIRContext *ctx); + /// Construct a symbol reference for the given value name. + static FlatSymbolRefAttr get(StringRef value, MLIRContext *ctx); - /// Returns the name of the held symbol reference. - StringRef getValue() const; + /// Construct a symbol reference for the given value name, and a set of nested + /// references that are further resolve to a nested symbol. + static SymbolRefAttr get(StringRef value, + ArrayRef references, + MLIRContext *ctx); + + /// Returns the name of the top level symbol reference, i.e. the root of the + /// reference path. + StringRef getRootReference() const; + + /// Returns the name of the fully resolved symbol, i.e. the leaf of the + /// reference path. + StringRef getLeafReference() const; + + /// Returns the set of nested references representing the path to the symbol + /// nested under the root reference. + ArrayRef getNestedReferences() const; /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool kindof(unsigned kind) { @@ -420,6 +478,36 @@ public: } }; +/// A symbol reference with a reference path containing a single element. This +/// is used to refer to an operation within the current symbol table. +class FlatSymbolRefAttr : public SymbolRefAttr { +public: + using SymbolRefAttr::SymbolRefAttr; + using ValueType = StringRef; + + /// Construct a symbol reference for the given value name. + static FlatSymbolRefAttr get(StringRef value, MLIRContext *ctx) { + return SymbolRefAttr::get(value, ctx); + } + + /// Returns the name of the held symbol reference. + StringRef getValue() const { return getRootReference(); } + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool classof(Attribute attr) { + SymbolRefAttr refAttr = attr.dyn_cast(); + return refAttr && refAttr.getNestedReferences().empty(); + } + +private: + using SymbolRefAttr::get; + using SymbolRefAttr::getNestedReferences; +}; + +//===----------------------------------------------------------------------===// +// Type +//===----------------------------------------------------------------------===// + class TypeAttr : public Attribute::AttrBase { public: @@ -434,6 +522,10 @@ public: static bool kindof(unsigned kind) { return kind == StandardAttributes::Type; } }; +//===----------------------------------------------------------------------===// +// UnitAttr +//===----------------------------------------------------------------------===// + /// Unit attributes are attributes that hold no specific value and are given /// meaning by their existence. class UnitAttr : public Attribute::AttrBase { diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 0005a39..01ad38c 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -100,8 +100,10 @@ public: FloatAttr getFloatAttr(Type type, const APFloat &value); StringAttr getStringAttr(StringRef bytes); ArrayAttr getArrayAttr(ArrayRef value); - SymbolRefAttr getSymbolRefAttr(Operation *value); - SymbolRefAttr getSymbolRefAttr(StringRef value); + FlatSymbolRefAttr getSymbolRefAttr(Operation *value); + FlatSymbolRefAttr getSymbolRefAttr(StringRef value); + SymbolRefAttr getSymbolRefAttr(StringRef value, + ArrayRef nestedReferences); // Returns a 0-valued attribute of the given `type`. This function only // supports boolean, integer, and 16-/32-/64-bit float types, and vector or diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h index 0f435b2..228b030 100644 --- a/mlir/include/mlir/IR/Function.h +++ b/mlir/include/mlir/IR/Function.h @@ -119,7 +119,7 @@ public: /// to. This may return null in the case of an external callable object, e.g. /// an external function. Region *getCallableRegion(CallInterfaceCallable callable) { - assert(callable.get().getValue() == getName()); + assert(callable.get().getLeafReference() == getName()); return isExternal() ? nullptr : &getBody(); } diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 62d9542..2f67560 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1139,8 +1139,16 @@ class StructAttr()">, "symbol reference attribute"> { let storageType = [{ SymbolRefAttr }]; + let returnType = [{ SymbolRefAttr }]; + let constBuilderCall = "$_builder.getSymbolRefAttr($0)"; + let convertFromStorage = "$_self"; +} +def FlatSymbolRefAttr : Attr()">, + "flat symbol reference attribute"> { + let storageType = [{ FlatSymbolRefAttr }]; let returnType = [{ StringRef }]; let constBuilderCall = "$_builder.getSymbolRefAttr($0)"; + let convertFromStorage = "$_self.getValue()"; } def SymbolRefArrayAttr : @@ -1241,12 +1249,14 @@ class IntArrayNthElemMinValue : AttrConstraint< def IsNullAttr : AttrConstraint< CPred<"!$_self">, "empty attribute (for optional attributes)">; -// An attribute constraint on SymbolRefAttr that requires the SymbolRefAttr -// pointing to an op of `opClass` within the closest parent with a symbol table. +// An attribute constraint on FlatSymbolRefAttr that requires that the +// reference point to an op of `opClass` within the closest parent with a symbol +// table. +// TODO(riverriddle) Add support for nested symbol references. class ReferToOp : AttrConstraint< CPred<"isa_and_nonnull<" # opClass # ">(" "::mlir::SymbolTable::lookupNearestSymbolFrom(" - "&$_op, $_self.cast().getValue()))">, + "&$_op, $_self.cast().getValue()))">, "referencing to a '" # opClass # "' symbol">; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Analysis/CallGraph.cpp b/mlir/lib/Analysis/CallGraph.cpp index 3c02802..2b5894f 100644 --- a/mlir/lib/Analysis/CallGraph.cpp +++ b/mlir/lib/Analysis/CallGraph.cpp @@ -184,7 +184,9 @@ CallGraphNode *CallGraph::resolveCallable(CallInterfaceCallable callable, // Get the callee operation from the callable. Operation *callee; if (auto symbolRef = callable.dyn_cast()) - callee = SymbolTable::lookupNearestSymbolFrom(from, symbolRef.getValue()); + // TODO(riverriddle) Support nested references. + callee = SymbolTable::lookupNearestSymbolFrom(from, + symbolRef.getRootReference()); else callee = callable.get()->getDefiningOp(); diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index d70d51f..bfd094d 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -533,7 +533,8 @@ unsigned LaunchFuncOp::getNumKernelOperands() { } StringRef LaunchFuncOp::getKernelModuleName() { - return getAttrOfType(getKernelModuleAttrName()).getValue(); + return getAttrOfType(getKernelModuleAttrName()) + .getRootReference(); } Value *LaunchFuncOp::getKernelOperand(unsigned i) { diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp index 672beee..420b234 100644 --- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp @@ -189,7 +189,8 @@ private: if (Optional symbolUses = SymbolTable::getSymbolUses(symbolDefWorklist.pop_back_val())) { for (SymbolTable::SymbolUse symbolUse : *symbolUses) { - StringRef symbolName = symbolUse.getSymbolRef().getValue(); + StringRef symbolName = + symbolUse.getSymbolRef().cast().getValue(); if (moduleManager.lookupSymbol(symbolName)) continue; diff --git a/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp index 7a8bc71..2dc46bf 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -359,8 +359,8 @@ public: // Get a SymbolRefAttr containing the library function name for the LinalgOp. // If the library function does not exist, insert a declaration. template -static SymbolRefAttr getLibraryCallSymbolRef(Operation *op, - PatternRewriter &rewriter) { +static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op, + PatternRewriter &rewriter) { auto linalgOp = cast(op); auto fnName = linalgOp.getLibraryCallName(); if (fnName.empty()) { @@ -369,7 +369,7 @@ static SymbolRefAttr getLibraryCallSymbolRef(Operation *op, } // fnName is a dynamic std::String, unique it via a SymbolRefAttr. - SymbolRefAttr fnNameAttr = rewriter.getSymbolRefAttr(fnName); + FlatSymbolRefAttr fnNameAttr = rewriter.getSymbolRefAttr(fnName); auto module = op->getParentOfType(); if (module.lookupSymbol(fnName)) { return fnNameAttr; diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index b3a9f6f..3c1563e 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -658,7 +658,7 @@ void spirv::AddressOfOp::build(Builder *builder, OperationState &state, static ParseResult parseAddressOfOp(OpAsmParser &parser, OperationState &state) { - SymbolRefAttr varRefAttr; + FlatSymbolRefAttr varRefAttr; Type type; if (parser.parseAttribute(varRefAttr, Type(), kVariableAttrName, state.attributes) || @@ -1088,7 +1088,7 @@ static ParseResult parseEntryPointOp(OpAsmParser &parser, SmallVector idTypes; SmallVector interfaceVars; - SymbolRefAttr fn; + FlatSymbolRefAttr fn; if (parseEnumAttribute(execModel, parser, state) || parser.parseAttribute(fn, Type(), kFnNameAttrName, state.attributes)) { return failure(); @@ -1099,7 +1099,7 @@ static ParseResult parseEntryPointOp(OpAsmParser &parser, do { // The name of the interface variable attribute isnt important auto attrName = "var_symbol"; - SymbolRefAttr var; + FlatSymbolRefAttr var; SmallVector attrs; if (parser.parseAttribute(var, Type(), attrName, attrs)) { return failure(); @@ -1186,7 +1186,7 @@ static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter &printer) { static ParseResult parseFunctionCallOp(OpAsmParser &parser, OperationState &state) { - SymbolRefAttr calleeAttr; + FlatSymbolRefAttr calleeAttr; FunctionType type; SmallVector operands; auto loc = parser.getNameLoc(); @@ -1305,7 +1305,7 @@ static ParseResult parseGlobalVariableOp(OpAsmParser &parser, // Parse optional initializer if (succeeded(parser.parseOptionalKeyword(kInitializerAttrName))) { - SymbolRefAttr initSymbol; + FlatSymbolRefAttr initSymbol; if (parser.parseLParen() || parser.parseAttribute(initSymbol, Type(), kInitializerAttrName, state.attributes) || @@ -1361,7 +1361,8 @@ static LogicalResult verify(spirv::GlobalVariableOp varOp) { if (varOp.storageClass() == spirv::StorageClass::Generic) return varOp.emitOpError("storage class cannot be 'Generic'"); - if (auto init = varOp.getAttrOfType(kInitializerAttrName)) { + if (auto init = + varOp.getAttrOfType(kInitializerAttrName)) { auto moduleOp = varOp.getParentOfType(); auto *initOp = moduleOp.lookupSymbol(init.getValue()); // TODO: Currently only variable initialization with specialization @@ -1713,7 +1714,7 @@ static LogicalResult verify(spirv::ModuleOp moduleOp) { } if (auto interface = entryPointOp.interface()) { for (Attribute varRef : interface) { - auto varSymRef = varRef.dyn_cast(); + auto varSymRef = varRef.dyn_cast(); if (!varSymRef) { return entryPointOp.emitError( "expected symbol reference for interface " @@ -1790,7 +1791,7 @@ static LogicalResult verify(spirv::ModuleOp moduleOp) { static ParseResult parseReferenceOfOp(OpAsmParser &parser, OperationState &state) { - SymbolRefAttr constRefAttr; + FlatSymbolRefAttr constRefAttr; Type type; if (parser.parseAttribute(constRefAttr, Type(), kSpecConstAttrName, state.attributes) || diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp index 11660ed..40b5318 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -970,7 +970,7 @@ LogicalResult Deserializer::processGlobalVariable(ArrayRef operands) { wordIndex++; // Initializer. - SymbolRefAttr initializer = nullptr; + FlatSymbolRefAttr initializer = nullptr; if (wordIndex < operands.size()) { auto initializerOp = getGlobalVariable(operands[wordIndex]); if (!initializerOp) { diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp index f92b9ae..805a339 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -1696,7 +1696,7 @@ Serializer::processOp(spirv::EntryPointOp op) { // Add the interface values. if (auto interface = op.interface()) { for (auto var : interface.getValue()) { - auto id = getVariableID(var.cast().getValue()); + auto id = getVariableID(var.cast().getValue()); if (!id) { return op.emitError("referencing undefined global variable." "spv.EntryPoint is at the end of spv.module. All " diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp index 1202924..8c08868 100644 --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -528,7 +528,7 @@ void BranchOp::getCanonicalizationPatterns(OwningRewritePatternList &results, //===----------------------------------------------------------------------===// static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) { - SymbolRefAttr calleeAttr; + FlatSymbolRefAttr calleeAttr; FunctionType calleeType; SmallVector operands; auto calleeLoc = parser.getNameLoc(); @@ -555,7 +555,7 @@ static void print(OpAsmPrinter &p, CallOp op) { static LogicalResult verify(CallOp op) { // Check that the callee attribute was specified. - auto fnAttr = op.getAttrOfType("callee"); + auto fnAttr = op.getAttrOfType("callee"); if (!fnAttr) return op.emitOpError("requires a 'callee' symbol reference attribute"); auto fn = @@ -608,8 +608,8 @@ struct SimplifyIndirectCallWithKnownCallee // Replace with a direct call. SmallVector callResults(indirectCall.getResultTypes()); SmallVector callOperands(indirectCall.getArgOperands()); - rewriter.replaceOpWithNewOp(indirectCall, calledFn.getValue(), - callResults, callOperands); + rewriter.replaceOpWithNewOp(indirectCall, calledFn, callResults, + callOperands); return matchSuccess(); } }; @@ -1206,7 +1206,7 @@ static LogicalResult verify(ConstantOp &op) { } if (type.isa()) { - auto fnAttr = value.dyn_cast(); + auto fnAttr = value.dyn_cast(); if (!fnAttr) return op.emitOpError("requires 'value' to be a function reference"); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 43452b2..6f77de0 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -801,9 +801,15 @@ void ModulePrinter::printAttribute(Attribute attr, bool mayElideType) { case StandardAttributes::Type: printType(attr.cast().getValue()); break; - case StandardAttributes::SymbolRef: - printSymbolReference(attr.cast().getValue(), os); + case StandardAttributes::SymbolRef: { + auto refAttr = attr.dyn_cast(); + printSymbolReference(refAttr.getRootReference(), os); + for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) { + os << "::"; + printSymbolReference(nestedRef.getValue(), os); + } break; + } case StandardAttributes::OpaqueElements: { auto eltsAttr = attr.cast(); os << "opaque<\"" << eltsAttr.getDialect()->getNamespace() << "\", "; diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h index 21f8b68..da4aa69 100644 --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -321,6 +321,43 @@ struct StringAttributeStorage : public AttributeStorage { StringRef value; }; +/// An attribute representing a symbol reference. +struct SymbolRefAttributeStorage final + : public AttributeStorage, + public llvm::TrailingObjects { + using KeyTy = std::pair>; + + SymbolRefAttributeStorage(StringRef value, size_t numNestedRefs) + : value(value), numNestedRefs(numNestedRefs) {} + + /// Key equality function. + bool operator==(const KeyTy &key) const { + return key == KeyTy(value, getNestedRefs()); + } + + /// Construct a new storage instance. + static SymbolRefAttributeStorage * + construct(AttributeStorageAllocator &allocator, const KeyTy &key) { + auto size = SymbolRefAttributeStorage::totalSizeToAlloc( + key.second.size()); + auto rawMem = allocator.allocate(size, alignof(SymbolRefAttributeStorage)); + auto result = ::new (rawMem) SymbolRefAttributeStorage( + allocator.copyInto(key.first), key.second.size()); + std::uninitialized_copy(key.second.begin(), key.second.end(), + result->getTrailingObjects()); + return result; + } + + /// Returns the set of nested references. + ArrayRef getNestedRefs() const { + return {getTrailingObjects(), numNestedRefs}; + } + + StringRef value; + size_t numNestedRefs; +}; + /// An attribute representing a reference to a type. struct TypeAttributeStorage : public AttributeStorage { using KeyTy = Type; diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp index d74cacb..80ac4a5 100644 --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -249,12 +249,27 @@ FloatAttr::verifyConstructionInvariants(llvm::Optional loc, // SymbolRefAttr //===----------------------------------------------------------------------===// -SymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) { - return Base::get(ctx, StandardAttributes::SymbolRef, value, - NoneType::get(ctx)); +FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) { + return Base::get(ctx, StandardAttributes::SymbolRef, value, llvm::None) + .cast(); } -StringRef SymbolRefAttr::getValue() const { return getImpl()->value; } +SymbolRefAttr SymbolRefAttr::get(StringRef value, + ArrayRef nestedReferences, + MLIRContext *ctx) { + return Base::get(ctx, StandardAttributes::SymbolRef, value, nestedReferences); +} + +StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; } + +StringRef SymbolRefAttr::getLeafReference() const { + ArrayRef nestedRefs = getNestedReferences(); + return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue(); +} + +ArrayRef SymbolRefAttr::getNestedReferences() const { + return getImpl()->getNestedRefs(); +} //===----------------------------------------------------------------------===// // IntegerAttr diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 24ae207..afdeefd 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -150,15 +150,20 @@ ArrayAttr Builder::getArrayAttr(ArrayRef value) { return ArrayAttr::get(value, context); } -SymbolRefAttr Builder::getSymbolRefAttr(Operation *value) { +FlatSymbolRefAttr Builder::getSymbolRefAttr(Operation *value) { auto symName = value->getAttrOfType(SymbolTable::getSymbolAttrName()); assert(symName && "value does not have a valid symbol name"); return getSymbolRefAttr(symName.getValue()); } -SymbolRefAttr Builder::getSymbolRefAttr(StringRef value) { +FlatSymbolRefAttr Builder::getSymbolRefAttr(StringRef value) { return SymbolRefAttr::get(value, getContext()); } +SymbolRefAttr +Builder::getSymbolRefAttr(StringRef value, + ArrayRef nestedReferences) { + return SymbolRefAttr::get(value, nestedReferences, getContext()); +} ArrayAttr Builder::getI32ArrayAttr(ArrayRef values) { auto attrs = functional::map( diff --git a/mlir/lib/IR/FunctionSupport.cpp b/mlir/lib/IR/FunctionSupport.cpp index d1ba2d3..29cae17 100644 --- a/mlir/lib/IR/FunctionSupport.cpp +++ b/mlir/lib/IR/FunctionSupport.cpp @@ -159,7 +159,7 @@ mlir::impl::parseFunctionLikeOp(OpAsmParser &parser, OperationState &result, auto &builder = parser.getBuilder(); // Parse the name as a symbol reference attribute. - SymbolRefAttr nameAttr; + FlatSymbolRefAttr nameAttr; if (parser.parseAttribute(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), result.attributes)) return failure(); diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 35c694b..2843aae 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -1400,7 +1400,7 @@ static std::string extractSymbolReference(Token tok) { /// | type /// | `[` (attribute-value (`,` attribute-value)*)? `]` /// | `{` (attribute-entry (`,` attribute-entry)*)? `}` -/// | symbol-ref-id +/// | symbol-ref-id (`::` symbol-ref-id)* /// | `dense` `<` attribute-value `>` `:` /// (tensor-type | vector-type) /// | `sparse` `<` attribute-value `,` attribute-value `>` @@ -1509,7 +1509,31 @@ Attribute Parser::parseAttribute(Type type) { case Token::at_identifier: { std::string nameStr = extractSymbolReference(getToken()); consumeToken(Token::at_identifier); - return builder.getSymbolRefAttr(nameStr); + + // Parse any nested references. + std::vector nestedRefs; + while (getToken().is(Token::colon)) { + // Check for the '::' prefix. + const char *curPointer = getToken().getLoc().getPointer(); + consumeToken(Token::colon); + if (!consumeIf(Token::colon)) { + state.lex.resetPointer(curPointer); + consumeToken(); + break; + } + // Parse the reference itself. + auto curLoc = getToken().getLoc(); + if (getToken().isNot(Token::at_identifier)) { + emitError(curLoc, "expected nested symbol reference identifier"); + return Attribute(); + } + + std::string nameStr = extractSymbolReference(getToken()); + consumeToken(Token::at_identifier); + nestedRefs.push_back(SymbolRefAttr::get(nameStr, getContext())); + } + + return builder.getSymbolRefAttr(nameStr, nestedRefs); } // Parse a 'unit' attribute. diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index 69f7e93..7f3ce5a 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -52,7 +52,7 @@ llvm::Constant *ModuleTranslation::getLLVMConstant(llvm::Type *llvmType, return llvm::ConstantInt::get(llvmType, intAttr.getValue()); if (auto floatAttr = attr.dyn_cast()) return llvm::ConstantFP::get(llvmType, floatAttr.getValue()); - if (auto funcAttr = attr.dyn_cast()) + if (auto funcAttr = attr.dyn_cast()) return functionMapping.lookup(funcAttr.getValue()); if (auto splatAttr = attr.dyn_cast()) { auto *sequentialType = cast(llvmType); @@ -194,7 +194,7 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst, auto convertCall = [this, &builder](Operation &op) -> llvm::Value * { auto operands = lookupValues(op.getOperands()); ArrayRef operandsRef(operands); - if (auto attr = op.getAttrOfType("callee")) { + if (auto attr = op.getAttrOfType("callee")) { return builder.CreateCall(functionMapping.lookup(attr.getValue()), operandsRef); } else { diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index 37f85e7..dc85fbb 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -1112,3 +1112,7 @@ func @"\"_string_symbol_reference\""() { "foo.symbol_reference"() {ref = @"\"_string_symbol_reference\""} : () -> () return } + +// CHECK-LABEL: func @nested_reference +// CHECK-NEXT: ref = @some_symbol::@some_nested_symbol +func @nested_reference() attributes {test.ref = @some_symbol::@some_nested_symbol } diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td index 4071f7e..2972793 100644 --- a/mlir/test/lib/TestDialect/TestOps.td +++ b/mlir/test/lib/TestDialect/TestOps.td @@ -206,7 +206,7 @@ def UpdateFloatElementsAttr : Pat< def SymbolRefOp : TEST_Op<"symbol_ref_attr"> { let arguments = (ins - Confined]>:$symbol + Confined]>:$symbol ); } @@ -232,7 +232,7 @@ def SizedRegionOp : TEST_Op<"sized_region_op", []> { def ConversionCallOp : TEST_Op<"conversion_call_op", [CallOpInterface]> { - let arguments = (ins Variadic:$inputs, SymbolRefAttr:$callee); + let arguments = (ins Variadic:$inputs, FlatSymbolRefAttr:$callee); let results = (outs Variadic); let extraClassDeclaration = [{ @@ -241,7 +241,7 @@ def ConversionCallOp : TEST_Op<"conversion_call_op", /// Return the callee of this operation. CallInterfaceCallable getCallableForCallee() { - return getAttrOfType("callee"); + return getAttrOfType("callee"); } }]; } diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td index 82702bf..7fe249b 100644 --- a/mlir/test/mlir-tblgen/op-attribute.td +++ b/mlir/test/mlir-tblgen/op-attribute.td @@ -72,7 +72,7 @@ def AOp : NS_Op<"a_op", []> { // CHECK: auto tblgen_cAttr = this->getAttr("cAttr"); // CHECK-NEXT: if (tblgen_cAttr) { // CHECK-NEXT: if (!((some-condition))) return emitOpError("attribute 'cAttr' failed to satisfy constraint: some attribute kind"); - + def SomeTypeAttr : TypeAttrBase<"SomeType", "some type attribute">; def BOp : NS_Op<"b_op", []> { @@ -85,7 +85,7 @@ def BOp : NS_Op<"b_op", []> { F64Attr:$f64_attr, StrAttr:$str_attr, ElementsAttr:$elements_attr, - SymbolRefAttr:$function_attr, + FlatSymbolRefAttr:$function_attr, SomeTypeAttr:$type_attr, ArrayAttr:$array_attr, TypedArrayAttrBase:$some_attr_array, @@ -122,7 +122,7 @@ def BOp : NS_Op<"b_op", []> { // CHECK: if (!(((tblgen_f64_attr.isa())) && ((tblgen_f64_attr.cast().getType().isF64())))) // CHECK: if (!((tblgen_str_attr.isa()))) // CHECK: if (!((tblgen_elements_attr.isa()))) -// CHECK: if (!((tblgen_function_attr.isa()))) +// CHECK: if (!((tblgen_function_attr.isa()))) // CHECK: if (!(((tblgen_type_attr.isa())) && ((tblgen_type_attr.cast().getValue().isa())))) // CHECK: if (!((tblgen_array_attr.isa()))) // CHECK: if (!(((tblgen_some_attr_array.isa())) && (llvm::all_of(tblgen_some_attr_array.cast(), [](Attribute attr) { return (some-condition); })))) -- 2.7.4