Add linkage support to LLVMFuncOp
authorAlex Zinenko <zinenko@google.com>
Tue, 3 Dec 2019 08:26:13 +0000 (00:26 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 3 Dec 2019 08:26:44 +0000 (00:26 -0800)
A recent commit introduced the Linkage attribute to the LLVM dialect and used
it in the Global Op. Also use it in LLVMFuncOp. As per LLVM Language Reference,
if the linkage attribute is omitted, the function is assumed to have external
linkage.

PiperOrigin-RevId: 283493299

mlir/g3doc/Dialects/LLVM.md
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/IR/FunctionImplementation.cpp
mlir/test/Dialect/LLVMIR/func.mlir

index ed0cad2..9791352 100644 (file)
@@ -72,6 +72,11 @@ llvm.func @foo(%arg0: !llvm.i64) {
   llvm.return
 }
 
+// A function with `internal` linkage.
+llvm.func internal @internal_func() {
+  llvm.return
+}
+
 ```
 
 ### LLVM IR operations
index 324937a..573542b 100644 (file)
@@ -583,9 +583,12 @@ def LLVM_GlobalOp
   let verifier = "return ::verify(*this);";
 }
 
-def LLVM_LLVMFuncOp : LLVM_ZeroResultOp<"func",
-      [NativeOpTrait<"IsIsolatedFromAbove">, NativeOpTrait<"FunctionLike">,
-       Symbol]> {
+def LLVM_LLVMFuncOp
+    : LLVM_ZeroResultOp<"func",
+                        [NativeOpTrait<"IsIsolatedFromAbove">,
+                         NativeOpTrait<"FunctionLike">, Symbol]>,
+      Arguments<(ins DefaultValuedAttr<Linkage,
+                                       "Linkage::External">:$linkage)> {
   let summary = "LLVM dialect function, has wrapped LLVM IR function type";
 
   let regions = (region AnyRegion:$body);
@@ -594,7 +597,8 @@ def LLVM_LLVMFuncOp : LLVM_ZeroResultOp<"func",
 
   let builders = [
     OpBuilder<"Builder *builder, OperationState &result, StringRef name, "
-              "LLVMType type, ArrayRef<NamedAttribute> attrs = {}, "
+              "LLVMType type, LLVM::Linkage linkage = LLVM::Linkage::External, "
+              "ArrayRef<NamedAttribute> attrs = {}, "
               "ArrayRef<NamedAttributeList> argAttrs = {}">
   ];
 
@@ -627,10 +631,7 @@ def LLVM_LLVMFuncOp : LLVM_ZeroResultOp<"func",
 
   let verifier = [{ return ::verify(*this); }];
   let printer = [{ printLLVMFuncOp(p, *this); }];
-  let parser = [{
-    return impl::parseFunctionLikeOp(parser, result, /*allowVariadic=*/true,
-                                     buildLLVMFunctionType);
-  }];
+  let parser = [{ return parseLLVMFuncOp(parser, result); }];
 }
 
 def LLVM_NullOp
index d681e4c..e06e88b 100644 (file)
@@ -97,7 +97,7 @@ private:
       return llvm::cast<LLVMFuncOp>(*funcOp);
 
     mlir::OpBuilder b(op->getParentOfType<LLVMFuncOp>());
-    return b.create<LLVMFuncOp>(op->getLoc(), funcName, funcType, llvm::None);
+    return b.create<LLVMFuncOp>(op->getLoc(), funcName, funcType);
   }
 
   const std::string f32Func;
index d226766..2db02db 100644 (file)
@@ -443,9 +443,11 @@ struct FuncOpConversion : public LLVMLegalizationPattern<FuncOp> {
       attributes.push_back(attr);
     }
 
-    // Create an LLVM funcion.
+    // Create an LLVM funcion, use external linkage by default until MLIR
+    // functions have linkage.
     auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
-        op->getLoc(), funcOp.getName(), llvmType, attributes);
+        op->getLoc(), funcOp.getName(), llvmType, LLVM::Linkage::External,
+        attributes);
     rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
                                 newFuncOp.end());
 
index a8c676f..0091101 100644 (file)
@@ -862,6 +862,10 @@ static ParseResult parseConstantOp(OpAsmParser &parser,
 // Builder, printer and verifier for LLVM::GlobalOp.
 //===----------------------------------------------------------------------===//
 
+/// Returns the name used for the linkge attribute. This *must* correspond to
+/// the name of the attribute in ODS.
+static StringRef getLinkageAttrName() { return "linkage"; }
+
 void GlobalOp::build(Builder *builder, OperationState &result, LLVMType type,
                      bool isConstant, Linkage linkage, StringRef name,
                      Attribute value, ArrayRef<NamedAttribute> attrs) {
@@ -872,52 +876,46 @@ void GlobalOp::build(Builder *builder, OperationState &result, LLVMType type,
     result.addAttribute("constant", builder->getUnitAttr());
   if (value)
     result.addAttribute("value", value);
-  result.addAttribute(
-      "linkage", builder->getI64IntegerAttr(static_cast<int64_t>(linkage)));
+  result.addAttribute(getLinkageAttrName(), builder->getI64IntegerAttr(
+                                                static_cast<int64_t>(linkage)));
   result.attributes.append(attrs.begin(), attrs.end());
   result.addRegion();
 }
 
-// Prints the keyword for the linkage type using the printer.
-static void printLinkage(OpAsmPrinter &p, LLVM::Linkage linkage) {
+// Returns the textual representation of the given linkage.
+static StringRef linkageToStr(LLVM::Linkage linkage) {
   switch (linkage) {
   case LLVM::Linkage::Private:
-    p << "private";
-    return;
+    return "private";
   case LLVM::Linkage::Internal:
-    p << "internal";
-    return;
+    return "internal";
   case LLVM::Linkage::AvailableExternally:
-    p << "available_externally";
-    return;
+    return "available_externally";
   case LLVM::Linkage::Linkonce:
-    p << "linkonce";
-    return;
+    return "linkonce";
   case LLVM::Linkage::Weak:
-    p << "weak";
-    return;
+    return "weak";
   case LLVM::Linkage::Common:
-    p << "common";
-    return;
+    return "common";
   case LLVM::Linkage::Appending:
-    p << "appending";
-    return;
+    return "appending";
   case LLVM::Linkage::ExternWeak:
-    p << "extern_weak";
-    return;
+    return "extern_weak";
   case LLVM::Linkage::LinkonceODR:
-    p << "linkonce_odr";
-    return;
+    return "linkonce_odr";
   case LLVM::Linkage::WeakODR:
-    p << "weak_odr";
-    return;
+    return "weak_odr";
   case LLVM::Linkage::External:
-    p << "external";
-    return;
+    return "external";
   }
   llvm_unreachable("unknown linkage type");
 }
 
+// Prints the keyword for the linkage type using the printer.
+static void printLinkage(OpAsmPrinter &p, LLVM::Linkage linkage) {
+  p << linkageToStr(linkage);
+}
+
 static void printGlobalOp(OpAsmPrinter &p, GlobalOp op) {
   p << op.getOperationName() << ' ';
   printLinkage(p, op.linkage());
@@ -931,7 +929,7 @@ static void printGlobalOp(OpAsmPrinter &p, GlobalOp op) {
   p << ')';
   p.printOptionalAttrDict(op.getAttrs(),
                           {SymbolTable::getSymbolAttrName(), "type", "constant",
-                           "value", "linkage"});
+                           "value", getLinkageAttrName()});
 
   // Print the trailing type unless it's a string global.
   if (op.getValueOrNull().dyn_cast_or_null<StringAttr>())
@@ -970,7 +968,8 @@ static ParseResult parseOptionalLinkageKeyword(OpAsmParser &parser,
                "weak_odr", "external"});
   if (index == -1)
     return failure();
-  result.addAttribute("linkage", parser.getBuilder().getI64IntegerAttr(index));
+  result.addAttribute(getLinkageAttrName(),
+                      parser.getBuilder().getI64IntegerAttr(index));
   return success();
 }
 
@@ -1118,12 +1117,15 @@ static ParseResult parseShuffleVectorOp(OpAsmParser &parser,
 //===----------------------------------------------------------------------===//
 
 void LLVMFuncOp::build(Builder *builder, OperationState &result, StringRef name,
-                       LLVMType type, ArrayRef<NamedAttribute> attrs,
+                       LLVMType type, LLVM::Linkage linkage,
+                       ArrayRef<NamedAttribute> attrs,
                        ArrayRef<NamedAttributeList> argAttrs) {
   result.addRegion();
   result.addAttribute(SymbolTable::getSymbolAttrName(),
                       builder->getStringAttr(name));
   result.addAttribute("type", TypeAttr::get(type));
+  result.addAttribute(getLinkageAttrName(), builder->getI64IntegerAttr(
+                                                static_cast<int64_t>(linkage)));
   result.attributes.append(attrs.begin(), attrs.end());
   if (argAttrs.empty())
     return;
@@ -1137,15 +1139,16 @@ void LLVMFuncOp::build(Builder *builder, OperationState &result, StringRef name,
       result.addAttribute(getArgAttrName(i, argAttrName), argDict);
 }
 
-// Build an LLVM function type from the given lists of input and output types.
+// Builds an LLVM function type from the given lists of input and output types.
 // Returns a null type if any of the types provided are non-LLVM types, or if
 // there is more than one output type.
-static Type buildLLVMFunctionType(Builder &b, ArrayRef<Type> inputs,
-                                  ArrayRef<Type> outputs,
-                                  impl::VariadicFlag variadicFlag,
-                                  std::string &errorMessage) {
+static Type buildLLVMFunctionType(OpAsmParser &parser, llvm::SMLoc loc,
+                                  ArrayRef<Type> inputs, ArrayRef<Type> outputs,
+                                  impl::VariadicFlag variadicFlag) {
+  Builder &b = parser.getBuilder();
   if (outputs.size() > 1) {
-    errorMessage = "expected zero or one function result";
+    parser.emitError(loc, "failed to construct function type: expected zero or "
+                          "one function result");
     return {};
   }
 
@@ -1154,7 +1157,8 @@ static Type buildLLVMFunctionType(Builder &b, ArrayRef<Type> inputs,
   for (auto t : inputs) {
     auto llvmTy = t.dyn_cast<LLVMType>();
     if (!llvmTy) {
-      errorMessage = "expected LLVM type for function arguments";
+      parser.emitError(loc, "failed to construct function type: expected LLVM "
+                            "type for function arguments");
       return {};
     }
     llvmInputs.push_back(llvmTy);
@@ -1170,16 +1174,71 @@ static Type buildLLVMFunctionType(Builder &b, ArrayRef<Type> inputs,
   LLVMType llvmOutput = outputs.empty() ? LLVMType::getVoidTy(dialect)
                                         : outputs.front().dyn_cast<LLVMType>();
   if (!llvmOutput) {
-    errorMessage = "expected LLVM type for function results";
+    parser.emitError(loc, "failed to construct function type: expected LLVM "
+                          "type for function results");
     return {};
   }
   return LLVMType::getFunctionTy(llvmOutput, llvmInputs,
                                  variadicFlag.isVariadic());
 }
 
-// Print the LLVMFuncOp.  Collects argument and result types and passes them
-// to the trait printer.  Drops "void" result since it cannot be parsed back.
+// Parses an LLVM function.
+//
+// operation ::= `llvm.func` linkage? function-signature function-attributes?
+//               function-body
+//
+static ParseResult parseLLVMFuncOp(OpAsmParser &parser,
+                                   OperationState &result) {
+  // Default to external linkage if no keyword is provided.
+  if (failed(parseOptionalLinkageKeyword(parser, result)))
+    result.addAttribute(getLinkageAttrName(),
+                        parser.getBuilder().getI64IntegerAttr(
+                            static_cast<int64_t>(LLVM::Linkage::External)));
+
+  StringAttr nameAttr;
+  SmallVector<OpAsmParser::OperandType, 8> entryArgs;
+  SmallVector<SmallVector<NamedAttribute, 2>, 1> argAttrs;
+  SmallVector<SmallVector<NamedAttribute, 2>, 1> resultAttrs;
+  SmallVector<Type, 8> argTypes;
+  SmallVector<Type, 4> resultTypes;
+  bool isVariadic;
+
+  auto signatureLocation = parser.getCurrentLocation();
+  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
+                             result.attributes) ||
+      impl::parseFunctionSignature(parser, /*allowVariadic=*/true, entryArgs,
+                                   argTypes, argAttrs, isVariadic, resultTypes,
+                                   resultAttrs))
+    return failure();
+
+  auto type =
+      buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes,
+                            impl::VariadicFlag(isVariadic));
+  if (!type)
+    return failure();
+  result.addAttribute(impl::getTypeAttrName(), TypeAttr::get(type));
+
+  if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
+    return failure();
+  impl::addArgAndResultAttrs(parser.getBuilder(), result, argAttrs,
+                             resultAttrs);
+
+  auto *body = result.addRegion();
+  return parser.parseOptionalRegion(
+      *body, entryArgs, entryArgs.empty() ? llvm::ArrayRef<Type>() : argTypes);
+}
+
+// Print the LLVMFuncOp. Collects argument and result types and passes them to
+// helper functions. Drops "void" result since it cannot be parsed back. Skips
+// the external linkage since it is the default value.
 static void printLLVMFuncOp(OpAsmPrinter &p, LLVMFuncOp op) {
+  p << op.getOperationName() << ' ';
+  if (op.linkage() != LLVM::Linkage::External) {
+    printLinkage(p, op.linkage());
+    p << ' ';
+  }
+  p.printSymbolName(op.getName());
+
   LLVMType fnType = op.getType();
   SmallVector<Type, 8> argTypes;
   SmallVector<Type, 1> resTypes;
@@ -1191,7 +1250,15 @@ static void printLLVMFuncOp(OpAsmPrinter &p, LLVMFuncOp op) {
   if (!returnType.getUnderlyingType()->isVoidTy())
     resTypes.push_back(returnType);
 
-  impl::printFunctionLikeOp(p, op, argTypes, op.isVarArg(), resTypes);
+  impl::printFunctionSignature(p, op, argTypes, op.isVarArg(), resTypes);
+  impl::printFunctionAttributes(p, op, argTypes.size(), resTypes.size(),
+                                {getLinkageAttrName()});
+
+  // Print the body if this is not an external function.
+  Region &body = op.body();
+  if (!body.empty())
+    p.printRegion(body, /*printEntryBlockArgs=*/false,
+                  /*printBlockTerminators=*/true);
 }
 
 // Hook for OpTrait::FunctionLike, called after verifying that the 'type'
@@ -1227,9 +1294,26 @@ unsigned LLVMFuncOp::getNumFuncResults() {
   return 1;
 }
 
+// Verifies LLVM- and implementation-specific properties of the LLVM func Op:
+// - functions don't have 'common' linkage
+// - external functions have 'external' or 'extern_weak' linkage;
+// - vararg is (currently) only supported for external functions;
+// - entry block arguments are of LLVM types and match the function signature.
 static LogicalResult verify(LLVMFuncOp op) {
-  if (op.isExternal())
+  if (op.linkage() == LLVM::Linkage::Common)
+    return op.emitOpError()
+           << "functions cannot have '" << linkageToStr(LLVM::Linkage::Common)
+           << "' linkage";
+
+  if (op.isExternal()) {
+    if (op.linkage() != LLVM::Linkage::External &&
+        op.linkage() != LLVM::Linkage::ExternWeak)
+      return op.emitOpError()
+             << "external functions must have '"
+             << linkageToStr(LLVM::Linkage::External) << "' or '"
+             << linkageToStr(LLVM::Linkage::ExternWeak) << "' linkage";
     return success();
+  }
 
   if (op.isVarArg())
     return op.emitOpError("only external functions can be variadic");
index a1fc21e..66c0d8a 100644 (file)
@@ -71,7 +71,8 @@ parseArgumentList(OpAsmParser &parser, bool allowVariadic,
   };
 
   // Parse the function arguments.
-  if (parser.parseOptionalRParen()) {
+  isVariadic = false;
+  if (failed(parser.parseOptionalRParen())) {
     do {
       unsigned numTypedArguments = argTypes.size();
       if (parseArgument())
index 6955cd0..2db5d35 100644 (file)
@@ -92,6 +92,27 @@ module {
 
   // CHECK: llvm.func @variadic_args(!llvm.i32, !llvm.i32, ...)
   llvm.func @variadic_args(!llvm.i32, !llvm.i32, ...)
+
+  //
+  // Check that functions can have linkage attributes.
+  //
+
+  // CHECK: llvm.func internal
+  llvm.func internal @internal_func() {
+    llvm.return
+  }
+
+  // CHECK: llvm.func weak
+  llvm.func weak @weak_linkage() {
+    llvm.return
+  }
+
+  // Omit the `external` linkage, which is the default, in the custom format.
+  // Check that it is present in the generic format using its numeric value.
+  //
+  // CHECK: llvm.func @external_func
+  // GENERIC: linkage = 10
+  llvm.func external @external_func()
 }
 
 // -----
@@ -188,3 +209,17 @@ module {
   // expected-error@+1 {{variadic arguments must be in the end of the argument list}}
   llvm.func @variadic_inside(%arg0: !llvm.i32, ..., %arg1: !llvm.i32)
 }
+
+// -----
+
+module {
+  // expected-error@+1 {{external functions must have 'external' or 'extern_weak' linkage}}
+  llvm.func internal @internal_external_func()
+}
+
+// -----
+
+module {
+  // expected-error@+1 {{functions cannot have 'common' linkage}}
+  llvm.func common @common_linkage_func()
+}