Make LLVM Linkage a first class attribute instead of using an integer attribute
authorMehdi Amini <joker.eph@gmail.com>
Fri, 3 Sep 2021 21:18:39 +0000 (21:18 +0000)
committerMehdi Amini <joker.eph@gmail.com>
Fri, 3 Sep 2021 21:21:46 +0000 (21:21 +0000)
This makes the IR more readable, in particular when this will be used on
the builtin func outside of the LLVM dialect.

Reviewed By: wsmoses

Differential Revision: https://reviews.llvm.org/D109209

mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/test/Dialect/LLVMIR/func.mlir
mlir/test/Dialect/LLVMIR/global.mlir
mlir/test/Target/LLVMIR/llvmir.mlir

index 46b01e3..1338b82 100644 (file)
@@ -20,12 +20,19 @@ def FastmathFlagsAttr : LLVM_Attr<"FMF"> {
   let mnemonic = "fastmath";
 
   // List of type parameters.
-  let parameters = (
-    ins
+  let parameters = (ins
     "FastmathFlags":$flags
   );
 }
 
+// Attribute definition for the LLVM Linkage enum.
+def LinkageAttr : LLVM_Attr<"Linkage"> {
+  let mnemonic = "linkage";
+  let parameters = (ins
+    "linkage::Linkage":$linkage
+  );
+}
+
 def LoopOptionsAttr : LLVM_Attr<"LoopOptions"> {
   let mnemonic = "loopopts";
 
@@ -39,8 +46,7 @@ def LoopOptionsAttr : LLVM_Attr<"LoopOptions"> {
   }];
 
   // List of type parameters.
-  let parameters = (
-    ins
+  let parameters = (ins
     ArrayRefParameter<"std::pair<LoopOptionCase, int64_t>", "">:$options
   );
 
index 7fce572..e68cea8 100644 (file)
 #include "llvm/IR/Type.h"
 
 #include "mlir/Dialect/LLVMIR/LLVMOpsEnums.h.inc"
+
+namespace mlir {
+namespace LLVM {
+// Inline the LLVM generated Linkage enum and utility.
+// This is only necessary to isolate the "enum generated code" from the
+// attribute definition itself.
+// TODO: this shouldn't be needed after we unify the attribute generation, i.e.
+// --gen-attr-* and --gen-attrdef-*.
+using linkage::Linkage;
+} // namespace LLVM
+} // namespace mlir
+
 #include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.h.inc"
 
 namespace llvm {
index 13c463a..8176c3f 100644 (file)
@@ -28,7 +28,7 @@ def FMFafn      : BitEnumAttrCase<"afn", 0x20>;
 def FMFreassoc  : BitEnumAttrCase<"reassoc", 0x40>;
 def FMFfast     : BitEnumAttrCase<"fast", 0x80>;
 
-def FastmathFlags : BitEnumAttr<
+def FastmathFlags_DoNotUse : BitEnumAttr<
     "FastmathFlags",
     "LLVM fastmath flags",
     [FMFnnan, FMFninf, FMFnsz, FMFarcp, FMFcontract, FMFafn, FMFreassoc, FMFfast
@@ -803,16 +803,28 @@ def LinkageWeakODR
 def LinkageExternal
     : LLVM_EnumAttrCase<"External", "external", "ExternalLinkage", 10>;
 
-def Linkage : LLVM_EnumAttr<
+def LinkageEnum : LLVM_EnumAttr<
     "Linkage",
     "::llvm::GlobalValue::LinkageTypes",
     "LLVM linkage types",
     [LinkagePrivate, LinkageInternal, LinkageAvailableExternally,
      LinkageLinkonce, LinkageWeak, LinkageCommon, LinkageAppending,
      LinkageExternWeak, LinkageLinkonceODR, LinkageWeakODR, LinkageExternal]> {
-  let cppNamespace = "::mlir::LLVM";
+  let cppNamespace = "::mlir::LLVM::linkage";
+}
+
+def Linkage : DialectAttr<
+    LLVM_Dialect,
+    CPred<"$_self.isa<::mlir::LLVM::LinkageAttr>()">,
+    "LLVM Linkage specification"> {
+  let storageType = "::mlir::LLVM::LinkageAttr";
+  let returnType = "::mlir::LLVM::Linkage";
+  let convertFromStorage = "$_self.getLinkage()";
+  let constBuilderCall =
+          "::mlir::LLVM::LinkageAttr::get($_builder.getContext(), $0)";
 }
 
+
 def UnnamedAddrNone : LLVM_EnumAttrCase<"None", "", "None", 0>;
 def UnnamedAddrLocal : LLVM_EnumAttrCase<"Local", "local_unnamed_addr", "Local", 1>;
 def UnnamedAddrGlobal : LLVM_EnumAttrCase<"Global", "unnamed_addr", "Global", 2>;
index 3d0cca7..28cebbd 100644 (file)
@@ -36,6 +36,7 @@
 
 using namespace mlir;
 using namespace mlir::LLVM;
+using mlir::LLVM::linkage::getMaxEnumValForLinkage;
 
 #include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc"
 
@@ -1400,7 +1401,7 @@ void GlobalOp::build(OpBuilder &builder, OperationState &result, Type type,
     result.addAttribute("alignment", builder.getI64IntegerAttr(alignment));
 
   result.addAttribute(getLinkageAttrName(),
-                      builder.getI64IntegerAttr(static_cast<int64_t>(linkage)));
+                      LinkageAttr::get(builder.getContext(), linkage));
   if (addrSpace != 0)
     result.addAttribute("addr_space", builder.getI32IntegerAttr(addrSpace));
   result.attributes.append(attrs.begin(), attrs.end());
@@ -1463,19 +1464,21 @@ REGISTER_ENUM_TYPE(Linkage);
 REGISTER_ENUM_TYPE(UnnamedAddr);
 } // end namespace
 
-template <typename EnumTy>
-static ParseResult parseOptionalLLVMKeyword(OpAsmParser &parser,
-                                            OperationState &result,
-                                            StringRef name) {
+/// Parse an enum from the keyword, or default to the provided default value.
+/// The return type is the enum type by default, unless overriden with the
+/// second template argument.
+template <typename EnumTy, typename RetTy = EnumTy>
+static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser,
+                                      OperationState &result,
+                                      EnumTy defaultValue) {
   SmallVector<StringRef, 10> names;
-  for (unsigned i = 0, e = getMaxEnumValForLinkage(); i <= e; ++i)
+  for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
     names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
 
   int index = parseOptionalKeywordAlternative(parser, names);
   if (index == -1)
-    return failure();
-  result.addAttribute(name, parser.getBuilder().getI64IntegerAttr(index));
-  return success();
+    return static_cast<RetTy>(defaultValue);
+  return static_cast<RetTy>(index);
 }
 
 // operation ::= `llvm.mlir.global` linkage? `constant`? `@` identifier
@@ -1485,17 +1488,17 @@ static ParseResult parseOptionalLLVMKeyword(OpAsmParser &parser,
 // The type can be omitted for string attributes, in which case it will be
 // inferred from the value of the string as [strlen(value) x i8].
 static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) {
-  if (failed(parseOptionalLLVMKeyword<Linkage>(parser, result,
-                                               getLinkageAttrName())))
-    result.addAttribute(getLinkageAttrName(),
-                        parser.getBuilder().getI64IntegerAttr(
-                            static_cast<int64_t>(LLVM::Linkage::External)));
-
-  if (failed(parseOptionalLLVMKeyword<UnnamedAddr>(parser, result,
-                                                   getUnnamedAddrAttrName())))
-    result.addAttribute(getUnnamedAddrAttrName(),
-                        parser.getBuilder().getI64IntegerAttr(
-                            static_cast<int64_t>(LLVM::UnnamedAddr::None)));
+  MLIRContext *ctx = parser.getBuilder().getContext();
+  // Parse optional linkage, default to External.
+  result.addAttribute(getLinkageAttrName(),
+                      LLVM::LinkageAttr::get(
+                          ctx, parseOptionalLLVMKeyword<Linkage>(
+                                   parser, result, LLVM::Linkage::External)));
+  // Parse optional UnnamedAddr, default to None.
+  result.addAttribute(getUnnamedAddrAttrName(),
+                      parser.getBuilder().getI64IntegerAttr(
+                          parseOptionalLLVMKeyword<UnnamedAddr, int64_t>(
+                              parser, result, LLVM::UnnamedAddr::None)));
 
   if (succeeded(parser.parseOptionalKeyword("constant")))
     result.addAttribute("constant", parser.getBuilder().getUnitAttr());
@@ -1692,7 +1695,7 @@ void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
                       builder.getStringAttr(name));
   result.addAttribute("type", TypeAttr::get(type));
   result.addAttribute(getLinkageAttrName(),
-                      builder.getI64IntegerAttr(static_cast<int64_t>(linkage)));
+                      LinkageAttr::get(builder.getContext(), linkage));
   result.attributes.append(attrs.begin(), attrs.end());
   if (dsoLocal)
     result.addAttribute("dso_local", builder.getUnitAttr());
@@ -1751,11 +1754,11 @@ buildLLVMFunctionType(OpAsmParser &parser, llvm::SMLoc loc,
 static ParseResult parseLLVMFuncOp(OpAsmParser &parser,
                                    OperationState &result) {
   // Default to external linkage if no keyword is provided.
-  if (failed(parseOptionalLLVMKeyword<Linkage>(parser, result,
-                                               getLinkageAttrName())))
-    result.addAttribute(getLinkageAttrName(),
-                        parser.getBuilder().getI64IntegerAttr(
-                            static_cast<int64_t>(LLVM::Linkage::External)));
+  result.addAttribute(
+      getLinkageAttrName(),
+      LinkageAttr::get(parser.getBuilder().getContext(),
+                       parseOptionalLLVMKeyword<Linkage>(
+                           parser, result, LLVM::Linkage::External)));
 
   StringAttr nameAttr;
   SmallVector<OpAsmParser::OperandType, 8> entryArgs;
@@ -2175,7 +2178,7 @@ static LogicalResult verify(FenceOp &op) {
 //===----------------------------------------------------------------------===//
 
 void LLVMDialect::initialize() {
-  addAttributes<FMFAttr, LoopOptionsAttr>();
+  addAttributes<FMFAttr, LinkageAttr, LoopOptionsAttr>();
 
   // clang-format off
   addTypes<LLVMVoidType,
@@ -2397,6 +2400,30 @@ Attribute FMFAttr::parse(MLIRContext *context, DialectAsmParser &parser,
   return FMFAttr::get(parser.getBuilder().getContext(), flags);
 }
 
+void LinkageAttr::print(DialectAsmPrinter &printer) const {
+  printer << "linkage<";
+  if (static_cast<uint64_t>(getLinkage()) <= getMaxEnumValForLinkage())
+    printer << stringifyEnum(getLinkage());
+  else
+    printer << static_cast<uint64_t>(getLinkage());
+  printer << ">";
+}
+
+Attribute LinkageAttr::parse(MLIRContext *context, DialectAsmParser &parser,
+                             Type type) {
+  StringRef elemName;
+  if (parser.parseLess() || parser.parseKeyword(&elemName) ||
+      parser.parseGreater())
+    return {};
+  auto elem = linkage::symbolizeLinkage(elemName);
+  if (!elem) {
+    parser.emitError(parser.getNameLoc(), "Unknown linkage: ") << elemName;
+    return {};
+  }
+  Linkage linkage = *elem;
+  return LinkageAttr::get(context, linkage);
+}
+
 LoopOptionsAttrBuilder::LoopOptionsAttrBuilder(LoopOptionsAttr attr)
     : options(attr.getOptions().begin(), attr.getOptions().end()) {}
 
index 48a6609..7ebc4c5 100644 (file)
@@ -121,7 +121,7 @@ module {
   // Check that it is present in the generic format using its numeric value.
   //
   // CHECK: llvm.func @external_func
-  // GENERIC: linkage = 10
+  // GENERIC: linkage = #llvm.linkage<external>
   llvm.func external @external_func()
 }
 
index efce9a4..f9109fa 100644 (file)
@@ -96,12 +96,12 @@ llvm.mlir.global internal constant @constant(37.0) : !llvm.label
 // -----
 
 // expected-error @+1 {{'addr_space' failed to satisfy constraint: 32-bit signless integer attribute whose value is non-negative}}
-"llvm.mlir.global"() ({}) {sym_name = "foo", type = i64, value = 42 : i64, addr_space = -1 : i32, linkage = 0} : () -> ()
+"llvm.mlir.global"() ({}) {sym_name = "foo", type = i64, value = 42 : i64, addr_space = -1 : i32, linkage = #llvm.linkage<private>} : () -> ()
 
 // -----
 
 // expected-error @+1 {{'addr_space' failed to satisfy constraint: 32-bit signless integer attribute whose value is non-negative}}
-"llvm.mlir.global"() ({}) {sym_name = "foo", type = i64, value = 42 : i64, addr_space = 1.0 : f32, linkage = 0} : () -> ()
+"llvm.mlir.global"() ({}) {sym_name = "foo", type = i64, value = 42 : i64, addr_space = 1.0 : f32, linkage = #llvm.linkage<private>} : () -> ()
 
 // -----
 
index 2455175..8905f31 100644 (file)
@@ -1,7 +1,7 @@
 // RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
 
 // CHECK: @global_aligned32 = private global i64 42, align 32
-"llvm.mlir.global"() ({}) {sym_name = "global_aligned32", type = i64, value = 42 : i64, linkage = 0, alignment = 32} : () -> ()
+"llvm.mlir.global"() ({}) {sym_name = "global_aligned32", type = i64, value = 42 : i64, linkage = #llvm.linkage<private>, alignment = 32} : () -> ()
 
 // CHECK: @global_aligned64 = private global i64 42, align 64
 llvm.mlir.global private @global_aligned64(42 : i64) {alignment = 64 : i64} : i64