[mlir][llvm] Add branch weights to call and invoke
authorChristian Ulmann <christian.ulmann@nextsilicon.com>
Mon, 9 Jan 2023 09:14:21 +0000 (10:14 +0100)
committerTobias Gysi <tobias.gysi@nextsilicon.com>
Mon, 9 Jan 2023 09:25:07 +0000 (10:25 +0100)
This commit introduces branch weight attributes to the LLVM::CallOp and
LLVM::InvokeOp and adds both import and export of them.

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D141122

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
mlir/test/Target/LLVMIR/Import/profiling-metadata.ll
mlir/test/Target/LLVMIR/llvmir.mlir

index ca4cc48..bfb2226 100644 (file)
@@ -483,7 +483,8 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [
   let arguments = (ins OptionalAttr<FlatSymbolRefAttr>:$callee,
                    Variadic<LLVM_Type>:$callee_operands,
                    Variadic<LLVM_Type>:$normalDestOperands,
-                   Variadic<LLVM_Type>:$unwindDestOperands);
+                   Variadic<LLVM_Type>:$unwindDestOperands,
+                   OptionalAttr<ElementsAttr>:$branch_weights);
   let results = (outs Variadic<LLVM_Type>);
   let successors = (successor AnySuccessor:$normalDest,
                               AnySuccessor:$unwindDest);
@@ -500,7 +501,7 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [
       "ValueRange":$normalOps, "Block*":$unwind, "ValueRange":$unwindOps),
     [{
       build($_builder, $_state, tys, /*callee=*/FlatSymbolRefAttr(), ops, normalOps,
-            unwindOps, normal, unwind);
+            unwindOps, nullptr, normal, unwind);
     }]>];
   let hasCustomAssemblyFormat = 1;
   let hasVerifier = 1;
@@ -553,13 +554,16 @@ def LLVM_CallOp : LLVM_Op<"call",
   let arguments = (ins OptionalAttr<FlatSymbolRefAttr>:$callee,
                        Variadic<LLVM_Type>,
                        DefaultValuedAttr<LLVM_FastmathFlagsAttr,
-                                         "{}">:$fastmathFlags);
+                                         "{}">:$fastmathFlags,
+                       OptionalAttr<ElementsAttr>:$branch_weights);
   let results = (outs Optional<LLVM_Type>:$result);
 
   let builders = [
     OpBuilder<(ins "LLVMFuncOp":$func, "ValueRange":$args)>,
     OpBuilder<(ins "TypeRange":$results, "StringAttr":$callee,
                    CArg<"ValueRange", "{}">:$args)>,
+    OpBuilder<(ins "TypeRange":$results, "FlatSymbolRefAttr":$callee,
+                   CArg<"ValueRange", "{}">:$args)>,
     OpBuilder<(ins "TypeRange":$results, "StringRef":$callee,
                    CArg<"ValueRange", "{}">:$args)>
   ];
index 58abf02..c32ca2b 100644 (file)
@@ -1156,7 +1156,13 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
 
 void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
                    StringAttr callee, ValueRange args) {
-  build(builder, state, results, SymbolRefAttr::get(callee), args, nullptr);
+  build(builder, state, results, SymbolRefAttr::get(callee), args, nullptr,
+        nullptr);
+}
+
+void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
+                   FlatSymbolRefAttr callee, ValueRange args) {
+  build(builder, state, results, callee, args, nullptr, nullptr);
 }
 
 void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
@@ -1165,7 +1171,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
   Type resultType = func.getFunctionType().getReturnType();
   if (!resultType.isa<LLVM::LLVMVoidType>())
     results.push_back(resultType);
-  build(builder, state, results, SymbolRefAttr::get(func), args, nullptr);
+  build(builder, state, results, SymbolRefAttr::get(func), args, nullptr,
+        nullptr);
 }
 
 CallInterfaceCallable CallOp::getCallableForCallee() {
index aed8321..24e5ab3 100644 (file)
@@ -18,6 +18,7 @@
 #include "llvm/ADT/PostOrderIterator.h"
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/StringSet.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/InlineAsm.h"
 #include "llvm/IR/Instructions.h"
@@ -116,15 +117,13 @@ static LogicalResult setProfilingAttrs(OpBuilder &builder, llvm::MDNode *node,
   }
 
   // Attach the branch weights to the operations that support it.
-  if (auto condBrOp = dyn_cast<CondBrOp>(op)) {
-    condBrOp.setBranchWeightsAttr(builder.getI32VectorAttr(branchWeights));
-    return success();
-  }
-  if (auto switchOp = dyn_cast<SwitchOp>(op)) {
-    switchOp.setBranchWeightsAttr(builder.getI32VectorAttr(branchWeights));
-    return success();
-  }
-  return failure();
+  return llvm::TypeSwitch<Operation *, LogicalResult>(op)
+      .Case<CondBrOp, SwitchOp, CallOp, InvokeOp>([&](auto branchWeightOp) {
+        branchWeightOp.setBranchWeightsAttr(
+            builder.getI32VectorAttr(branchWeights));
+        return success();
+      })
+      .Default([](auto) { return failure(); });
 }
 
 /// Attaches the given TBAA metadata `node` to the imported operation.
index 697121e..7f44db5 100644 (file)
@@ -322,6 +322,21 @@ convertCallLLVMIntrinsicOp(CallIntrinsicOp &op, llvm::IRBuilderBase &builder,
   return success();
 }
 
+/// Constructs branch weights metadata if the provided `weights` hold a value,
+/// otherwise returns nullptr.
+static llvm::MDNode *
+convertBranchWeights(std::optional<ElementsAttr> weights,
+                     LLVM::ModuleTranslation &moduleTranslation) {
+  if (!weights)
+    return nullptr;
+  SmallVector<uint32_t> weightValues;
+  weightValues.reserve(weights->size());
+  for (APInt weight : weights->cast<DenseIntElementsAttr>())
+    weightValues.push_back(weight.getLimitedValue());
+  return llvm::MDBuilder(moduleTranslation.getLLVMContext())
+      .createBranchWeights(weightValues);
+}
+
 static LogicalResult
 convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
                      LLVM::ModuleTranslation &moduleTranslation) {
@@ -336,32 +351,34 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
   // Emit function calls.  If the "callee" attribute is present, this is a
   // direct function call and we also need to look up the remapped function
   // itself.  Otherwise, this is an indirect call and the callee is the first
-  // operand, look it up as a normal value.  Return the llvm::Value
-  // representing the function result, which may be of llvm::VoidTy type.
-  auto convertCall = [&](Operation &op) -> llvm::Value * {
-    auto operands = moduleTranslation.lookupValues(op.getOperands());
+  // operand, look it up as a normal value.
+  if (auto callOp = dyn_cast<LLVM::CallOp>(opInst)) {
+    auto operands = moduleTranslation.lookupValues(callOp.getOperands());
     ArrayRef<llvm::Value *> operandsRef(operands);
-    if (auto attr = op.getAttrOfType<FlatSymbolRefAttr>("callee"))
-      return builder.CreateCall(
+    llvm::CallInst *call;
+    if (auto attr = callOp.getCalleeAttr()) {
+      call = builder.CreateCall(
           moduleTranslation.lookupFunction(attr.getValue()), operandsRef);
-    auto calleeType =
-        op.getOperands().front().getType().cast<LLVMPointerType>();
-    auto *calleeFunctionType = cast<llvm::FunctionType>(
-        moduleTranslation.convertType(calleeType.getElementType()));
-    return builder.CreateCall(calleeFunctionType, operandsRef.front(),
-                              operandsRef.drop_front());
-  };
-
-  // Emit calls.  If the called function has a result, remap the corresponding
-  // value.  Note that LLVM IR dialect CallOp has either 0 or 1 result.
-  if (isa<LLVM::CallOp>(opInst)) {
-    llvm::Value *result = convertCall(opInst);
+    } else {
+      auto calleeType =
+          callOp->getOperands().front().getType().cast<LLVMPointerType>();
+      auto *calleeFunctionType = cast<llvm::FunctionType>(
+          moduleTranslation.convertType(calleeType.getElementType()));
+      call = builder.CreateCall(calleeFunctionType, operandsRef.front(),
+                                operandsRef.drop_front());
+    }
+    llvm::MDNode *branchWeights =
+        convertBranchWeights(callOp.getBranchWeights(), moduleTranslation);
+    if (branchWeights)
+      call->setMetadata(llvm::LLVMContext::MD_prof, branchWeights);
+    // If the called function has a result, remap the corresponding value.  Note
+    // that LLVM IR dialect CallOp has either 0 or 1 result.
     if (opInst.getNumResults() != 0) {
-      moduleTranslation.mapValue(opInst.getResult(0), result);
+      moduleTranslation.mapValue(opInst.getResult(0), call);
       return success();
     }
     // Check that LLVM call returns void for 0-result functions.
-    return success(result->getType()->isVoidTy());
+    return success(call->getType()->isVoidTy());
   }
 
   if (auto inlineAsmOp = dyn_cast<LLVM::InlineAsmOp>(opInst)) {
@@ -442,6 +459,10 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
           moduleTranslation.lookupBlock(invOp.getSuccessor(1)),
           operandsRef.drop_front());
     }
+    llvm::MDNode *branchWeights =
+        convertBranchWeights(invOp.getBranchWeights(), moduleTranslation);
+    if (branchWeights)
+      result->setMetadata(llvm::LLVMContext::MD_prof, branchWeights);
     moduleTranslation.mapBranch(invOp, result);
     // InvokeOp can only have 0 or 1 result
     if (invOp->getNumResults() != 0) {
@@ -478,17 +499,8 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
     return success();
   }
   if (auto condbrOp = dyn_cast<LLVM::CondBrOp>(opInst)) {
-    llvm::MDNode *branchWeights = nullptr;
-    if (auto weights = condbrOp.getBranchWeights()) {
-      // Map weight attributes to LLVM metadata.
-      auto weightValues = weights->getValues<APInt>();
-      auto trueWeight = weightValues[0].getSExtValue();
-      auto falseWeight = weightValues[1].getSExtValue();
-      branchWeights =
-          llvm::MDBuilder(moduleTranslation.getLLVMContext())
-              .createBranchWeights(static_cast<uint32_t>(trueWeight),
-                                   static_cast<uint32_t>(falseWeight));
-    }
+    llvm::MDNode *branchWeights =
+        convertBranchWeights(condbrOp.getBranchWeights(), moduleTranslation);
     llvm::BranchInst *branch = builder.CreateCondBr(
         moduleTranslation.lookupValue(condbrOp.getOperand(0)),
         moduleTranslation.lookupBlock(condbrOp.getSuccessor(0)),
@@ -498,16 +510,8 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
     return success();
   }
   if (auto switchOp = dyn_cast<LLVM::SwitchOp>(opInst)) {
-    llvm::MDNode *branchWeights = nullptr;
-    if (auto weights = switchOp.getBranchWeights()) {
-      llvm::SmallVector<uint32_t> weightValues;
-      weightValues.reserve(weights->size());
-      for (llvm::APInt weight : weights->cast<DenseIntElementsAttr>())
-        weightValues.push_back(weight.getLimitedValue());
-      branchWeights = llvm::MDBuilder(moduleTranslation.getLLVMContext())
-                          .createBranchWeights(weightValues);
-    }
-
+    llvm::MDNode *branchWeights =
+        convertBranchWeights(switchOp.getBranchWeights(), moduleTranslation);
     llvm::SwitchInst *switchInst = builder.CreateSwitch(
         moduleTranslation.lookupValue(switchOp.getValue()),
         moduleTranslation.lookupBlock(switchOp.getDefaultDestination()),
index 402271a..70a66b6 100644 (file)
@@ -33,3 +33,36 @@ bbd:
 }
 
 !0 = !{!"branch_weights", i32 42, i32 3, i32 5}
+
+; // -----
+
+; CHECK: llvm.func @fn()
+declare void @fn()
+
+; CHECK-LABEL: @call_branch_weights
+define void @call_branch_weights() {
+  ; CHECK:  llvm.call @fn() {branch_weights = dense<42> : vector<1xi32>}
+  call void @fn(), !prof !0
+  ret void
+}
+
+!0 = !{!"branch_weights", i32 42}
+
+; // -----
+
+declare void @foo()
+declare i32 @__gxx_personality_v0(...)
+
+; CHECK-LABEL: @invoke_branch_weights
+define i32 @invoke_branch_weights() personality i8* bitcast (i32 (...)* @__gxx_personality_v0 to i8*) {
+  ; CHECK: llvm.invoke @foo() to ^bb2 unwind ^bb1 {branch_weights = dense<[42, 99]> : vector<2xi32>} : () -> ()
+  invoke void @foo() to label %bb2 unwind label %bb1, !prof !0
+bb1:
+  %1 = landingpad { i8*, i32 } cleanup
+  br label %bb2
+bb2:
+  ret i32 1
+
+}
+
+!0 = !{!"branch_weights", i32 42, i32 99}
index 7e65b87..eb6738e 100644 (file)
@@ -1631,6 +1631,38 @@ llvm.func @cond_br_weights(%cond : i1, %arg0 : i32,  %arg1 : i32) -> i32 {
 
 // -----
 
+llvm.func @fn()
+
+// CHECK-LABEL: @call_branch_weights
+llvm.func @call_branch_weights() {
+  // CHECK: !prof ![[NODE:[0-9]+]]
+  llvm.call @fn() {branch_weights = dense<42> : vector<1xi32>} : () -> ()
+  llvm.return
+}
+
+// CHECK: ![[NODE]] = !{!"branch_weights", i32 42}
+
+// -----
+
+llvm.func @foo()
+llvm.func @__gxx_personality_v0(...) -> i32
+
+// CHECK-LABEL: @invoke_branch_weights
+llvm.func @invoke_branch_weights() -> i32 attributes {personality = @__gxx_personality_v0} {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: !prof ![[NODE:[0-9]+]]
+  llvm.invoke @foo() to ^bb2 unwind ^bb1 {branch_weights = dense<[42, 99]> : vector<2xi32>} : () -> ()
+^bb1:  // pred: ^bb0
+  %1 = llvm.landingpad cleanup : !llvm.struct<(ptr<i8>, i32)>
+  llvm.br ^bb2
+^bb2:  // 2 preds: ^bb0, ^bb1
+  llvm.return %0 : i32
+}
+
+// CHECK: ![[NODE]] = !{!"branch_weights", i32 42, i32 99}
+
+// -----
+
 llvm.func @volatile_store_and_load() {
   %val = llvm.mlir.constant(5 : i32) : i32
   %size = llvm.mlir.constant(1 : i64) : i64