[mlir] Add function_entry_count to LLVMFuncOp
authorChristian Ulmann <christian.ulmann@nextsilicon.com>
Thu, 5 Jan 2023 12:21:57 +0000 (13:21 +0100)
committerTobias Gysi <tobias.gysi@nextsilicon.com>
Thu, 5 Jan 2023 12:40:56 +0000 (13:40 +0100)
This commit introduces the function_entry_count metadata field to the
LLVMFuncOp and adds both the corresponding import and export
funtionalities.
The import of the function metadata uses the same infrastructure as the
instruction metadata, i.e., it dispatches through a dialect interface.

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D141001

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
mlir/lib/Target/LLVMIR/ModuleImport.cpp
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/test/Target/LLVMIR/Import/function-attributes.ll
mlir/test/Target/LLVMIR/llvmir.mlir

index 714209c..94e5d9b 100644 (file)
@@ -1330,7 +1330,8 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
     OptionalAttr<StrAttr>:$garbageCollector,
     OptionalAttr<ArrayAttr>:$passthrough,
     OptionalAttr<DictArrayAttr>:$arg_attrs,
-    OptionalAttr<DictArrayAttr>:$res_attrs
+    OptionalAttr<DictArrayAttr>:$res_attrs,
+    OptionalAttr<I64Attr>:$function_entry_count
   );
 
   let regions = (region AnyRegion:$body);
@@ -1343,7 +1344,8 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
       CArg<"bool", "false">:$dsoLocal,
       CArg<"CConv", "CConv::C">:$cconv,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs,
-      CArg<"ArrayRef<DictionaryAttr>", "{}">:$argAttrs)>
+      CArg<"ArrayRef<DictionaryAttr>", "{}">:$argAttrs,
+      CArg<"Optional<uint64_t>", "{}">:$functionEntryCount)>
   ];
 
   let extraClassDeclaration = [{
index 211080e..f77251b 100644 (file)
@@ -1992,7 +1992,8 @@ void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
                        StringRef name, Type type, LLVM::Linkage linkage,
                        bool dsoLocal, CConv cconv,
                        ArrayRef<NamedAttribute> attrs,
-                       ArrayRef<DictionaryAttr> argAttrs) {
+                       ArrayRef<DictionaryAttr> argAttrs,
+                       Optional<uint64_t> functionEntryCount) {
   result.addRegion();
   result.addAttribute(SymbolTable::getSymbolAttrName(),
                       builder.getStringAttr(name));
@@ -2004,7 +2005,11 @@ void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
                       CConvAttr::get(builder.getContext(), cconv));
   result.attributes.append(attrs.begin(), attrs.end());
   if (dsoLocal)
-    result.addAttribute("dso_local", builder.getUnitAttr());
+    result.addAttribute(getDsoLocalAttrName(result.name),
+                        builder.getUnitAttr());
+  if (functionEntryCount)
+    result.addAttribute(getFunctionEntryCountAttrName(result.name),
+                        builder.getI64IntegerAttr(functionEntryCount.value()));
   if (argAttrs.empty())
     return;
 
index 493be27..faa580c 100644 (file)
@@ -84,12 +84,29 @@ static LogicalResult setProfilingAttrs(OpBuilder &builder, llvm::MDNode *node,
   if (!node->getNumOperands())
     return success();
 
-  // Return failure for non-"branch_weights" metadata.
   auto *name = dyn_cast<llvm::MDString>(node->getOperand(0));
-  if (!name || !name->getString().equals("branch_weights"))
+  if (!name)
     return failure();
 
-  // Copy the branch weights to an array.
+  // Handle function entry count metadata.
+  if (name->getString().equals("function_entry_count")) {
+    // TODO support function entry count metadata with GUID fields.
+    if (node->getNumOperands() != 2)
+      return failure();
+
+    llvm::ConstantInt *entryCount =
+        llvm::mdconst::extract<llvm::ConstantInt>(node->getOperand(1));
+    if (auto funcOp = dyn_cast<LLVMFuncOp>(op)) {
+      funcOp.setFunctionEntryCount(entryCount->getZExtValue());
+      return success();
+    }
+    return failure();
+  }
+
+  if (!name->getString().equals("branch_weights"))
+    return failure();
+
+  // Handle branch weights metadata.
   SmallVector<int32_t> branchWeights;
   branchWeights.reserve(node->getNumOperands() - 1);
   for (unsigned i = 1, e = node->getNumOperands(); i != e; ++i) {
index b7c3457..fb58e34 100644 (file)
@@ -1184,6 +1184,18 @@ LogicalResult ModuleImport::processFunction(llvm::Function *func) {
   // Handle Function attributes.
   processFunctionAttributes(func, funcOp);
 
+  // Convert non-debug metadata by using the dialect interface.
+  SmallVector<std::pair<unsigned, llvm::MDNode *>> allMetadata;
+  func->getAllMetadata(allMetadata);
+  for (auto &[kind, node] : allMetadata) {
+    if (!iface.isConvertibleMetadata(kind))
+      continue;
+    if (failed(iface.setMetadataAttrs(builder, kind, node, funcOp, *this))) {
+      emitWarning(funcOp->getLoc())
+          << "unhandled function metadata (" << kind << ") " << diag(*func);
+    }
+  }
+
   if (func->isDeclaration())
     return success();
 
index c5af256..a214ee5 100644 (file)
@@ -889,6 +889,10 @@ LogicalResult ModuleTranslation::convertFunctionSignatures() {
     if (function->getAttrOfType<UnitAttr>(LLVMDialect::getReadnoneAttrName()))
       llvmFunc->setDoesNotAccessMemory();
 
+    // Convert function_entry_count attribute to metadata.
+    if (std::optional<uint64_t> entryCount = function.getFunctionEntryCount())
+      llvmFunc->setEntryCount(entryCount.value());
+
     // Convert result attributes.
     if (ArrayAttr allResultAttrs = function.getAllResultAttrs()) {
       llvm::AttrBuilder retAttrs(llvmFunc->getContext());
index 622cc20..3fac482 100644 (file)
@@ -38,3 +38,13 @@ define void @func_arg_attrs(
     ptr inalloca(i64) %arg3) {
   ret void
 }
+
+; // -----
+
+; CHECK-LABEL: @entry_count
+; CHECK-SAME:  attributes {function_entry_count = 4242 : i64}
+define void @entry_count() !prof !1 {
+  ret void
+}
+
+!1 = !{!"function_entry_count", i64 4242}
index fd2bd27..7e65b87 100644 (file)
@@ -1545,6 +1545,16 @@ llvm.func @passthrough() attributes {passthrough = ["noinline", ["alignstack", "
 
 // -----
 
+// CHECK-LABEL: @functionEntryCount
+// CHECK-SAME: !prof ![[PROF_ID:[0-9]*]]
+llvm.func @functionEntryCount() attributes {function_entry_count = 4242 : i64} {
+  llvm.return
+}
+
+// CHECK: ![[PROF_ID]] = !{!"function_entry_count", i64 4242}
+
+// -----
+
 // CHECK-LABEL: @constant_bf16
 llvm.func @constant_bf16() -> bf16 {
   %0 = llvm.mlir.constant(1.000000e+01 : bf16) : bf16