/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
- ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+ auto op = cast<ConcreteOp>(this->getOperation());
return op.getFastmathFlagsAttr();
}]
>,
];
}
+def BranchWeightOpInterface : OpInterface<"BranchWeightOpInterface"> {
+ let description = [{
+ An interface for operations that can carry branch weights metadata. It
+ provides setters and getters for the operation's branch weights attribute.
+ The default implementation of the interface methods expect the operation to
+ have an attribute of type DenseI32ArrayAttr named branch_weights.
+ }];
+
+ let cppNamespace = "::mlir::LLVM";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/ "Returns the branch weights attribute or nullptr",
+ /*returnType=*/ "DenseI32ArrayAttr",
+ /*methodName=*/ "getBranchWeightsOrNull",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ auto op = cast<ConcreteOp>(this->getOperation());
+ return op.getBranchWeightsAttr();
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/ "Sets the branch weights attribute",
+ /*returnType=*/ "void",
+ /*methodName=*/ "setBranchWeights",
+ /*args=*/ (ins "DenseI32ArrayAttr":$attr),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ auto op = cast<ConcreteOp>(this->getOperation());
+ op.setBranchWeightsAttr(attr);
+ }]
+ >
+ ];
+}
+
def AccessGroupOpInterface : OpInterface<"AccessGroupOpInterface"> {
let description = [{
An interface for memory operations that can carry access groups metadata.
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
- ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+ auto op = cast<ConcreteOp>(this->getOperation());
return op.getAccessGroupsAttr();
}]
>,
/*args=*/ (ins "const ArrayAttr":$attr),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
- ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+ auto op = cast<ConcreteOp>(this->getOperation());
op.setAccessGroupsAttr(attr);
}]
>
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
- ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+ auto op = cast<ConcreteOp>(this->getOperation());
return op.getAliasScopesAttr();
}]
>,
/*args=*/ (ins "const ArrayAttr":$attr),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
- ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+ auto op = cast<ConcreteOp>(this->getOperation());
op.setAliasScopesAttr(attr);
}]
>,
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
- ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+ auto op = cast<ConcreteOp>(this->getOperation());
return op.getNoaliasScopesAttr();
}]
>,
/*args=*/ (ins "const ArrayAttr":$attr),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
- ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+ auto op = cast<ConcreteOp>(this->getOperation());
op.setNoaliasScopesAttr(attr);
}]
>,
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
- ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+ auto op = cast<ConcreteOp>(this->getOperation());
return op.getTbaaAttr();
}]
>,
/*args=*/ (ins "const ArrayAttr":$attr),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
- ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+ auto op = cast<ConcreteOp>(this->getOperation());
op.setTbaaAttr(attr);
}]
>
def LLVM_InvokeOp : LLVM_Op<"invoke", [
AttrSizedOperandSegments,
DeclareOpInterfaceMethods<BranchOpInterface>,
- DeclareOpInterfaceMethods<CallOpInterface>, Terminator]> {
+ DeclareOpInterfaceMethods<CallOpInterface>,
+ DeclareOpInterfaceMethods<BranchWeightOpInterface>,
+ Terminator]> {
let arguments = (ins OptionalAttr<FlatSymbolRefAttr>:$callee,
Variadic<LLVM_Type>:$callee_operands,
Variadic<LLVM_Type>:$normalDestOperands,
Variadic<LLVM_Type>:$unwindDestOperands,
- OptionalAttr<ElementsAttr>:$branch_weights);
+ OptionalAttr<DenseI32ArrayAttr>:$branch_weights);
let results = (outs Variadic<LLVM_Type>);
let successors = (successor AnySuccessor:$normalDest,
AnySuccessor:$unwindDest);
def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
[DeclareOpInterfaceMethods<FastmathFlagsInterface>,
DeclareOpInterfaceMethods<CallOpInterface>,
- DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+ DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+ DeclareOpInterfaceMethods<BranchWeightOpInterface>]> {
let summary = "Call to an LLVM function.";
let description = [{
In LLVM IR, functions may return either 0 or 1 value. LLVM IR dialect
Variadic<LLVM_Type>,
DefaultValuedAttr<LLVM_FastmathFlagsAttr,
"{}">:$fastmathFlags,
- OptionalAttr<ElementsAttr>:$branch_weights);
+ OptionalAttr<DenseI32ArrayAttr>:$branch_weights);
// Append the aliasing related attributes defined in LLVM_MemAccessOpBase.
let arguments = !con(args, aliasAttrs);
let results = (outs Optional<LLVM_Type>:$result);
];
}
def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br",
- [AttrSizedOperandSegments, DeclareOpInterfaceMethods<BranchOpInterface>,
+ [AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<BranchOpInterface>,
+ DeclareOpInterfaceMethods<BranchWeightOpInterface>,
Pure]> {
let arguments = (ins I1:$condition,
Variadic<LLVM_Type>:$trueDestOperands,
Variadic<LLVM_Type>:$falseDestOperands,
- OptionalAttr<ElementsAttr>:$branch_weights,
+ OptionalAttr<DenseI32ArrayAttr>:$branch_weights,
OptionalAttr<LoopAnnotationAttr>:$loop_annotation);
let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);
let assemblyFormat = [{
falseOperands);
}]>,
OpBuilder<(ins "Value":$condition, "ValueRange":$trueOperands, "ValueRange":$falseOperands,
- "ElementsAttr":$branchWeights, "Block *":$trueDest, "Block *":$falseDest),
+ "DenseI32ArrayAttr":$branchWeights, "Block *":$trueDest, "Block *":$falseDest),
[{
build($_builder, $_state, condition, trueOperands, falseOperands, branchWeights,
{}, trueDest, falseDest);
}
def LLVM_SwitchOp : LLVM_TerminatorOp<"switch",
- [AttrSizedOperandSegments, DeclareOpInterfaceMethods<BranchOpInterface>,
+ [AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<BranchOpInterface>,
+ DeclareOpInterfaceMethods<BranchWeightOpInterface>,
Pure]> {
let arguments = (ins
AnyInteger:$value,
VariadicOfVariadic<AnyType, "case_operand_segments">:$caseOperands,
OptionalAttr<AnyIntElementsAttr>:$case_values,
DenseI32ArrayAttr:$case_operand_segments,
- OptionalAttr<ElementsAttr>:$branch_weights
+ OptionalAttr<DenseI32ArrayAttr>:$branch_weights
);
let successors = (successor
AnySuccessor:$defaultDestination,
return branchMapping.lookup(op);
}
+ /// Stores a mapping between an MLIR call operation and a corresponding LLVM
+ /// call instruction.
+ void mapCall(Operation *mlir, llvm::CallInst *llvm) {
+ auto result = callMapping.try_emplace(mlir, llvm);
+ (void)result;
+ assert(result.second && "attempting to map a call that is already mapped");
+ }
+
+ /// Finds an LLVM call instruction that corresponds to the given MLIR call
+ /// operation.
+ llvm::CallInst *lookupCall(Operation *op) const {
+ return callMapping.lookup(op);
+ }
+
/// Removes the mapping for blocks contained in the region and values defined
/// in these blocks.
void forgetMapping(Region ®ion);
/// Sets LLVM TBAA metadata for memory operations that have TBAA attributes.
void setTBAAMetadata(AliasAnalysisOpInterface op, llvm::Instruction *inst);
+ /// Sets LLVM profiling metadata for operations that have branch weights.
+ void setBranchWeightsMetadata(BranchWeightOpInterface op);
+
/// Sets LLVM loop metadata for branch operations that have a loop annotation
/// attribute.
void setLoopMetadata(Operation *op, llvm::Instruction *inst);
/// values after all operations are converted.
DenseMap<Operation *, llvm::Instruction *> branchMapping;
+ /// A mapping between MLIR LLVM dialect call operations and LLVM IR call
+ /// instructions. This allows for adding branch weights after the operations
+ /// have been converted.
+ DenseMap<Operation *, llvm::CallInst *> callMapping;
+
/// Mapping from an alias scope metadata operation to its LLVM metadata.
/// This map is populated on module entry.
DenseMap<Attribute, llvm::MDNode *> aliasScopeMetadataMapping;
matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// If branch weights exist, map them to 32-bit integer vector.
- ElementsAttr branchWeights = nullptr;
+ DenseI32ArrayAttr branchWeights = nullptr;
if (auto weights = op.getBranchWeights()) {
- VectorType weightType = VectorType::get(2, rewriter.getI32Type());
- branchWeights = DenseElementsAttr::get(weightType, weights->getValue());
+ SmallVector<int32_t> weightValues;
+ for (auto weight : weights->getAsRange<IntegerAttr>())
+ weightValues.push_back(weight.getInt());
+ branchWeights = DenseI32ArrayAttr::get(getContext(), weightValues);
}
rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
Value condition, Block *trueDest, ValueRange trueOperands,
Block *falseDest, ValueRange falseOperands,
std::optional<std::pair<uint32_t, uint32_t>> weights) {
- ElementsAttr weightsAttr;
+ DenseI32ArrayAttr weightsAttr;
if (weights)
weightsAttr =
- builder.getI32VectorAttr({static_cast<int32_t>(weights->first),
- static_cast<int32_t>(weights->second)});
+ builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights->first),
+ static_cast<int32_t>(weights->second)});
build(builder, result, condition, trueOperands, falseOperands, weightsAttr,
/*loop_annotation=*/{}, trueDest, falseDest);
BlockRange caseDestinations,
ArrayRef<ValueRange> caseOperands,
ArrayRef<int32_t> branchWeights) {
- ElementsAttr weightsAttr;
+ DenseI32ArrayAttr weightsAttr;
if (!branchWeights.empty())
- weightsAttr = builder.getI32VectorAttr(llvm::to_vector<4>(branchWeights));
+ weightsAttr = builder.getDenseI32ArrayAttr(branchWeights);
build(builder, result, value, defaultOperands, caseOperands, caseValues,
weightsAttr, defaultDestination, caseDestinations);
branchWeights.push_back(branchWeight->getZExtValue());
}
- return TypeSwitch<Operation *, LogicalResult>(op)
- .Case<CondBrOp, SwitchOp, CallOp, InvokeOp>([&](auto branchWeightOp) {
- branchWeightOp.setBranchWeightsAttr(
- builder.getI32VectorAttr(branchWeights));
- return success();
- })
- .Default([](auto) { return failure(); });
+ if (auto iface = dyn_cast<BranchWeightOpInterface>(op)) {
+ iface.setBranchWeights(builder.getDenseI32ArrayAttr(branchWeights));
+ return success();
+ }
+ return failure();
}
/// Searches for the attribute that maps to the given TBAA metadata `node` and
return success();
}
-/// Constructs branch weights metadata if the provided `weights` hold a value,
-/// otherwise returns nullptr.
-static llvm::MDNode *
-convertBranchWeights(std::optional<ElementsAttr> weights,
- LLVM::ModuleTranslation &moduleTranslation) {
- if (!weights)
- return nullptr;
- SmallVector<uint32_t> weightValues;
- weightValues.reserve(weights->size());
- for (APInt weight : llvm::cast<DenseIntElementsAttr>(*weights))
- weightValues.push_back(weight.getLimitedValue());
- return llvm::MDBuilder(moduleTranslation.getLLVMContext())
- .createBranchWeights(weightValues);
-}
-
static LogicalResult
convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
callOp.getArgOperands()),
operandsRef.front(), operandsRef.drop_front());
}
- llvm::MDNode *branchWeights =
- convertBranchWeights(callOp.getBranchWeights(), moduleTranslation);
- if (branchWeights)
- call->setMetadata(llvm::LLVMContext::MD_prof, branchWeights);
moduleTranslation.setAccessGroupsMetadata(callOp, call);
moduleTranslation.setAliasScopeMetadata(callOp, call);
moduleTranslation.setTBAAMetadata(callOp, call);
return success();
}
// Check that LLVM call returns void for 0-result functions.
- return success(call->getType()->isVoidTy());
+ if (!call->getType()->isVoidTy())
+ return failure();
+ moduleTranslation.mapCall(callOp, call);
+ return success();
}
if (auto inlineAsmOp = dyn_cast<LLVM::InlineAsmOp>(opInst)) {
moduleTranslation.lookupBlock(invOp.getSuccessor(1)),
operandsRef.drop_front());
}
- llvm::MDNode *branchWeights =
- convertBranchWeights(invOp.getBranchWeights(), moduleTranslation);
- if (branchWeights)
- result->setMetadata(llvm::LLVMContext::MD_prof, branchWeights);
moduleTranslation.mapBranch(invOp, result);
// InvokeOp can only have 0 or 1 result
if (invOp->getNumResults() != 0) {
return success();
}
if (auto condbrOp = dyn_cast<LLVM::CondBrOp>(opInst)) {
- llvm::MDNode *branchWeights =
- convertBranchWeights(condbrOp.getBranchWeights(), moduleTranslation);
llvm::BranchInst *branch = builder.CreateCondBr(
moduleTranslation.lookupValue(condbrOp.getOperand(0)),
moduleTranslation.lookupBlock(condbrOp.getSuccessor(0)),
- moduleTranslation.lookupBlock(condbrOp.getSuccessor(1)), branchWeights);
+ moduleTranslation.lookupBlock(condbrOp.getSuccessor(1)));
moduleTranslation.mapBranch(&opInst, branch);
moduleTranslation.setLoopMetadata(&opInst, branch);
return success();
}
if (auto switchOp = dyn_cast<LLVM::SwitchOp>(opInst)) {
- llvm::MDNode *branchWeights =
- convertBranchWeights(switchOp.getBranchWeights(), moduleTranslation);
llvm::SwitchInst *switchInst = builder.CreateSwitch(
moduleTranslation.lookupValue(switchOp.getValue()),
moduleTranslation.lookupBlock(switchOp.getDefaultDestination()),
- switchOp.getCaseDestinations().size(), branchWeights);
+ switchOp.getCaseDestinations().size());
auto *ty = llvm::cast<llvm::IntegerType>(
moduleTranslation.convertType(switchOp.getValue().getType()));
if (failed(convertOperation(op, builder)))
return failure();
+
+ // Set the branch weight metadata on the translated instruction.
+ if (auto iface = dyn_cast<BranchWeightOpInterface>(op))
+ setBranchWeightsMetadata(iface);
}
return success();
inst->setMetadata(llvm::LLVMContext::MD_tbaa, node);
}
+void ModuleTranslation::setBranchWeightsMetadata(BranchWeightOpInterface op) {
+ DenseI32ArrayAttr weightsAttr = op.getBranchWeightsOrNull();
+ if (!weightsAttr)
+ return;
+
+ llvm::Instruction *inst = isa<CallOp>(op) ? lookupCall(op) : lookupBranch(op);
+ assert(inst && "expected the operation to have a mapping to an instruction");
+ SmallVector<uint32_t> weights(weightsAttr.asArrayRef());
+ inst->setMetadata(
+ llvm::LLVMContext::MD_prof,
+ llvm::MDBuilder(getLLVMContext()).createBranchWeights(weights));
+}
+
LogicalResult ModuleTranslation::createTBAAMetadata() {
llvm::LLVMContext &ctx = llvmModule->getContext();
llvm::IntegerType *offsetTy = llvm::IntegerType::get(ctx, 64);
}
spirv.func @cond_branch_with_weights(%cond: i1) -> () "None" {
- // CHECK: llvm.cond_br %{{.*}} weights(dense<[1, 2]> : vector<2xi32>), ^bb1, ^bb2
+ // CHECK: llvm.cond_br %{{.*}} weights([1, 2]), ^bb1, ^bb2
spirv.BranchConditional %cond [1, 2], ^true, ^false
// CHECK: ^bb1:
^true:
// expected-error@+1 {{expects number of branch weights to match number of successors: 3 vs 2}}
llvm.switch %arg0 : i32, ^bb1 [
42: ^bb2(%arg0, %arg0 : i32, i32)
- ] {branch_weights = dense<[13, 17, 19]> : vector<3xi32>}
+ ] {branch_weights = array<i32: 13, 17, 19>}
^bb1: // pred: ^bb0
llvm.return
define i64 @cond_br(i1 %arg1, i64 %arg2) {
entry:
; CHECK: llvm.cond_br
- ; CHECK-SAME: weights(dense<[0, 3]> : vector<2xi32>)
+ ; CHECK-SAME: weights([0, 3])
br i1 %arg1, label %bb1, label %bb2, !prof !0
bb1:
ret i64 %arg2
; CHECK-LABEL: @simple_switch(
define i32 @simple_switch(i32 %arg1) {
; CHECK: llvm.switch
- ; CHECK: {branch_weights = dense<[42, 3, 5]> : vector<3xi32>}
+ ; CHECK: {branch_weights = array<i32: 42, 3, 5>}
switch i32 %arg1, label %bbd [
i32 0, label %bb1
i32 9, label %bb2
; CHECK-LABEL: @call_branch_weights
define void @call_branch_weights() {
- ; CHECK: llvm.call @fn() {branch_weights = dense<42> : vector<1xi32>}
+ ; CHECK: llvm.call @fn() {branch_weights = array<i32: 42>}
call void @fn(), !prof !0
ret void
}
; CHECK-LABEL: @invoke_branch_weights
define i32 @invoke_branch_weights() personality ptr @__gxx_personality_v0 {
- ; CHECK: llvm.invoke @foo() to ^bb2 unwind ^bb1 {branch_weights = dense<[42, 99]> : vector<2xi32>} : () -> ()
+ ; CHECK: llvm.invoke @foo() to ^bb2 unwind ^bb1 {branch_weights = array<i32: 42, 99>} : () -> ()
invoke void @foo() to label %bb2 unwind label %bb1, !prof !0
bb1:
%1 = landingpad { ptr, i32 } cleanup
// Check that branch weight attributes are exported properly as metadata.
llvm.func @cond_br_weights(%cond : i1, %arg0 : i32, %arg1 : i32) -> i32 {
// CHECK: !prof ![[NODE:[0-9]+]]
- llvm.cond_br %cond weights(dense<[5, 10]> : vector<2xi32>), ^bb1, ^bb2
+ llvm.cond_br %cond weights([5, 10]), ^bb1, ^bb2
^bb1: // pred: ^bb0
llvm.return %arg0 : i32
^bb2: // pred: ^bb0
// CHECK-LABEL: @call_branch_weights
llvm.func @call_branch_weights() {
// CHECK: !prof ![[NODE:[0-9]+]]
- llvm.call @fn() {branch_weights = dense<42> : vector<1xi32>} : () -> ()
+ llvm.call @fn() {branch_weights = array<i32 : 42>} : () -> ()
llvm.return
}
llvm.func @invoke_branch_weights() -> i32 attributes {personality = @__gxx_personality_v0} {
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK: !prof ![[NODE:[0-9]+]]
- llvm.invoke @foo() to ^bb2 unwind ^bb1 {branch_weights = dense<[42, 99]> : vector<2xi32>} : () -> ()
+ llvm.invoke @foo() to ^bb2 unwind ^bb1 {branch_weights = array<i32 : 42, 99>} : () -> ()
^bb1: // pred: ^bb0
%1 = llvm.landingpad cleanup : !llvm.struct<(ptr<i8>, i32)>
llvm.br ^bb2
llvm.switch %arg0 : i32, ^bb1(%0 : i32) [
9: ^bb2(%1, %2 : i32, i32),
99: ^bb3
- ] {branch_weights = dense<[13, 17, 19]> : vector<3xi32>}
+ ] {branch_weights = array<i32 : 13, 17, 19>}
^bb1(%3: i32): // pred: ^bb0
llvm.return %3 : i32