From: Tobias Gysi Date: Thu, 20 Jul 2023 08:13:17 +0000 (+0000) Subject: [mlir][llvm] Add branch weight op interface X-Git-Tag: upstream/17.0.6~1000 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=10fa27704b3165ddc4efbcf7964042b137e7fa7e;p=platform%2Fupstream%2Fllvm.git [mlir][llvm] Add branch weight op interface This revision adds a branch weight op interface for the call / branch operations that support branch weights. It can be used in the LLVM IR import and export to simplify the branch weight conversion. An additional mapping between call operations and instructions ensures the actual conversion can be done in the module translation itself, rather than in the dialect translation interface. It also has the benefit that downstream users can amend custom metadata to the call operation during the export to LLVM IR. Reviewed By: zero9178, definelicht Differential Revision: https://reviews.llvm.org/D155702 --- diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td index 9f230bf..7b33ec8 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td @@ -30,7 +30,7 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> { /*args=*/ (ins), /*methodBody=*/ [{}], /*defaultImpl=*/ [{ - ConcreteOp op = cast(this->getOperation()); + auto op = cast(this->getOperation()); return op.getFastmathFlagsAttr(); }] >, @@ -48,6 +48,42 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> { ]; } +def BranchWeightOpInterface : OpInterface<"BranchWeightOpInterface"> { + let description = [{ + An interface for operations that can carry branch weights metadata. It + provides setters and getters for the operation's branch weights attribute. + The default implementation of the interface methods expect the operation to + have an attribute of type DenseI32ArrayAttr named branch_weights. + }]; + + let cppNamespace = "::mlir::LLVM"; + + let methods = [ + InterfaceMethod< + /*desc=*/ "Returns the branch weights attribute or nullptr", + /*returnType=*/ "DenseI32ArrayAttr", + /*methodName=*/ "getBranchWeightsOrNull", + /*args=*/ (ins), + /*methodBody=*/ [{}], + /*defaultImpl=*/ [{ + auto op = cast(this->getOperation()); + return op.getBranchWeightsAttr(); + }] + >, + InterfaceMethod< + /*desc=*/ "Sets the branch weights attribute", + /*returnType=*/ "void", + /*methodName=*/ "setBranchWeights", + /*args=*/ (ins "DenseI32ArrayAttr":$attr), + /*methodBody=*/ [{}], + /*defaultImpl=*/ [{ + auto op = cast(this->getOperation()); + op.setBranchWeightsAttr(attr); + }] + > + ]; +} + def AccessGroupOpInterface : OpInterface<"AccessGroupOpInterface"> { let description = [{ An interface for memory operations that can carry access groups metadata. @@ -67,7 +103,7 @@ def AccessGroupOpInterface : OpInterface<"AccessGroupOpInterface"> { /*args=*/ (ins), /*methodBody=*/ [{}], /*defaultImpl=*/ [{ - ConcreteOp op = cast(this->getOperation()); + auto op = cast(this->getOperation()); return op.getAccessGroupsAttr(); }] >, @@ -78,7 +114,7 @@ def AccessGroupOpInterface : OpInterface<"AccessGroupOpInterface"> { /*args=*/ (ins "const ArrayAttr":$attr), /*methodBody=*/ [{}], /*defaultImpl=*/ [{ - ConcreteOp op = cast(this->getOperation()); + auto op = cast(this->getOperation()); op.setAccessGroupsAttr(attr); }] > @@ -105,7 +141,7 @@ def AliasAnalysisOpInterface : OpInterface<"AliasAnalysisOpInterface"> { /*args=*/ (ins), /*methodBody=*/ [{}], /*defaultImpl=*/ [{ - ConcreteOp op = cast(this->getOperation()); + auto op = cast(this->getOperation()); return op.getAliasScopesAttr(); }] >, @@ -116,7 +152,7 @@ def AliasAnalysisOpInterface : OpInterface<"AliasAnalysisOpInterface"> { /*args=*/ (ins "const ArrayAttr":$attr), /*methodBody=*/ [{}], /*defaultImpl=*/ [{ - ConcreteOp op = cast(this->getOperation()); + auto op = cast(this->getOperation()); op.setAliasScopesAttr(attr); }] >, @@ -127,7 +163,7 @@ def AliasAnalysisOpInterface : OpInterface<"AliasAnalysisOpInterface"> { /*args=*/ (ins), /*methodBody=*/ [{}], /*defaultImpl=*/ [{ - ConcreteOp op = cast(this->getOperation()); + auto op = cast(this->getOperation()); return op.getNoaliasScopesAttr(); }] >, @@ -138,7 +174,7 @@ def AliasAnalysisOpInterface : OpInterface<"AliasAnalysisOpInterface"> { /*args=*/ (ins "const ArrayAttr":$attr), /*methodBody=*/ [{}], /*defaultImpl=*/ [{ - ConcreteOp op = cast(this->getOperation()); + auto op = cast(this->getOperation()); op.setNoaliasScopesAttr(attr); }] >, @@ -149,7 +185,7 @@ def AliasAnalysisOpInterface : OpInterface<"AliasAnalysisOpInterface"> { /*args=*/ (ins), /*methodBody=*/ [{}], /*defaultImpl=*/ [{ - ConcreteOp op = cast(this->getOperation()); + auto op = cast(this->getOperation()); return op.getTbaaAttr(); }] >, @@ -160,7 +196,7 @@ def AliasAnalysisOpInterface : OpInterface<"AliasAnalysisOpInterface"> { /*args=*/ (ins "const ArrayAttr":$attr), /*methodBody=*/ [{}], /*defaultImpl=*/ [{ - ConcreteOp op = cast(this->getOperation()); + auto op = cast(this->getOperation()); op.setTbaaAttr(attr); }] > diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 01070fe..2d0ca91 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -536,12 +536,14 @@ def LLVM_FPTruncOp : LLVM_CastOp<"fptrunc", "FPTrunc", def LLVM_InvokeOp : LLVM_Op<"invoke", [ AttrSizedOperandSegments, DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods, Terminator]> { + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + Terminator]> { let arguments = (ins OptionalAttr:$callee, Variadic:$callee_operands, Variadic:$normalDestOperands, Variadic:$unwindDestOperands, - OptionalAttr:$branch_weights); + OptionalAttr:$branch_weights); let results = (outs Variadic); let successors = (successor AnySuccessor:$normalDest, AnySuccessor:$unwindDest); @@ -582,7 +584,8 @@ def LLVM_LandingpadOp : LLVM_Op<"landingpad"> { def LLVM_CallOp : LLVM_MemAccessOpBase<"call", [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { let summary = "Call to an LLVM function."; let description = [{ In LLVM IR, functions may return either 0 or 1 value. LLVM IR dialect @@ -616,7 +619,7 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call", Variadic, DefaultValuedAttr:$fastmathFlags, - OptionalAttr:$branch_weights); + OptionalAttr:$branch_weights); // Append the aliasing related attributes defined in LLVM_MemAccessOpBase. let arguments = !con(args, aliasAttrs); let results = (outs Optional:$result); @@ -847,12 +850,14 @@ def LLVM_BrOp : LLVM_TerminatorOp<"br", ]; } def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br", - [AttrSizedOperandSegments, DeclareOpInterfaceMethods, + [AttrSizedOperandSegments, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, Pure]> { let arguments = (ins I1:$condition, Variadic:$trueDestOperands, Variadic:$falseDestOperands, - OptionalAttr:$branch_weights, + OptionalAttr:$branch_weights, OptionalAttr:$loop_annotation); let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest); let assemblyFormat = [{ @@ -874,7 +879,7 @@ def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br", falseOperands); }]>, OpBuilder<(ins "Value":$condition, "ValueRange":$trueOperands, "ValueRange":$falseOperands, - "ElementsAttr":$branchWeights, "Block *":$trueDest, "Block *":$falseDest), + "DenseI32ArrayAttr":$branchWeights, "Block *":$trueDest, "Block *":$falseDest), [{ build($_builder, $_state, condition, trueOperands, falseOperands, branchWeights, {}, trueDest, falseDest); @@ -934,7 +939,9 @@ def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable"> { } def LLVM_SwitchOp : LLVM_TerminatorOp<"switch", - [AttrSizedOperandSegments, DeclareOpInterfaceMethods, + [AttrSizedOperandSegments, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, Pure]> { let arguments = (ins AnyInteger:$value, @@ -942,7 +949,7 @@ def LLVM_SwitchOp : LLVM_TerminatorOp<"switch", VariadicOfVariadic:$caseOperands, OptionalAttr:$case_values, DenseI32ArrayAttr:$case_operand_segments, - OptionalAttr:$branch_weights + OptionalAttr:$branch_weights ); let successors = (successor AnySuccessor:$defaultDestination, diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h index da4d43a..0d296aa 100644 --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -118,6 +118,20 @@ public: return branchMapping.lookup(op); } + /// Stores a mapping between an MLIR call operation and a corresponding LLVM + /// call instruction. + void mapCall(Operation *mlir, llvm::CallInst *llvm) { + auto result = callMapping.try_emplace(mlir, llvm); + (void)result; + assert(result.second && "attempting to map a call that is already mapped"); + } + + /// Finds an LLVM call instruction that corresponds to the given MLIR call + /// operation. + llvm::CallInst *lookupCall(Operation *op) const { + return callMapping.lookup(op); + } + /// Removes the mapping for blocks contained in the region and values defined /// in these blocks. void forgetMapping(Region ®ion); @@ -141,6 +155,9 @@ public: /// Sets LLVM TBAA metadata for memory operations that have TBAA attributes. void setTBAAMetadata(AliasAnalysisOpInterface op, llvm::Instruction *inst); + /// Sets LLVM profiling metadata for operations that have branch weights. + void setBranchWeightsMetadata(BranchWeightOpInterface op); + /// Sets LLVM loop metadata for branch operations that have a loop annotation /// attribute. void setLoopMetadata(Operation *op, llvm::Instruction *inst); @@ -328,6 +345,11 @@ private: /// values after all operations are converted. DenseMap branchMapping; + /// A mapping between MLIR LLVM dialect call operations and LLVM IR call + /// instructions. This allows for adding branch weights after the operations + /// have been converted. + DenseMap callMapping; + /// Mapping from an alias scope metadata operation to its LLVM metadata. /// This map is populated on module entry. DenseMap aliasScopeMetadataMapping; diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp index 28e587a..1d32e6e 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -553,10 +553,12 @@ public: matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // If branch weights exist, map them to 32-bit integer vector. - ElementsAttr branchWeights = nullptr; + DenseI32ArrayAttr branchWeights = nullptr; if (auto weights = op.getBranchWeights()) { - VectorType weightType = VectorType::get(2, rewriter.getI32Type()); - branchWeights = DenseElementsAttr::get(weightType, weights->getValue()); + SmallVector weightValues; + for (auto weight : weights->getAsRange()) + weightValues.push_back(weight.getInt()); + branchWeights = DenseI32ArrayAttr::get(getContext(), weightValues); } rewriter.replaceOpWithNewOp( diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 8eee3b2..f4d9c95 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -310,11 +310,11 @@ void CondBrOp::build(OpBuilder &builder, OperationState &result, Value condition, Block *trueDest, ValueRange trueOperands, Block *falseDest, ValueRange falseOperands, std::optional> weights) { - ElementsAttr weightsAttr; + DenseI32ArrayAttr weightsAttr; if (weights) weightsAttr = - builder.getI32VectorAttr({static_cast(weights->first), - static_cast(weights->second)}); + builder.getDenseI32ArrayAttr({static_cast(weights->first), + static_cast(weights->second)}); build(builder, result, condition, trueOperands, falseOperands, weightsAttr, /*loop_annotation=*/{}, trueDest, falseDest); @@ -330,9 +330,9 @@ void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, BlockRange caseDestinations, ArrayRef caseOperands, ArrayRef branchWeights) { - ElementsAttr weightsAttr; + DenseI32ArrayAttr weightsAttr; if (!branchWeights.empty()) - weightsAttr = builder.getI32VectorAttr(llvm::to_vector<4>(branchWeights)); + weightsAttr = builder.getDenseI32ArrayAttr(branchWeights); build(builder, result, value, defaultOperands, caseOperands, caseValues, weightsAttr, defaultDestination, caseDestinations); diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp index a6f0ebe..40d8253 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp @@ -125,13 +125,11 @@ static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node, branchWeights.push_back(branchWeight->getZExtValue()); } - return TypeSwitch(op) - .Case([&](auto branchWeightOp) { - branchWeightOp.setBranchWeightsAttr( - builder.getI32VectorAttr(branchWeights)); - return success(); - }) - .Default([](auto) { return failure(); }); + if (auto iface = dyn_cast(op)) { + iface.setBranchWeights(builder.getDenseI32ArrayAttr(branchWeights)); + return success(); + } + return failure(); } /// Searches for the attribute that maps to the given TBAA metadata `node` and diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp index a044930..8f7c5d8 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -124,21 +124,6 @@ convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder, return success(); } -/// Constructs branch weights metadata if the provided `weights` hold a value, -/// otherwise returns nullptr. -static llvm::MDNode * -convertBranchWeights(std::optional weights, - LLVM::ModuleTranslation &moduleTranslation) { - if (!weights) - return nullptr; - SmallVector weightValues; - weightValues.reserve(weights->size()); - for (APInt weight : llvm::cast(*weights)) - weightValues.push_back(weight.getLimitedValue()); - return llvm::MDBuilder(moduleTranslation.getLLVMContext()) - .createBranchWeights(weightValues); -} - static LogicalResult convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) { @@ -182,10 +167,6 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, callOp.getArgOperands()), operandsRef.front(), operandsRef.drop_front()); } - llvm::MDNode *branchWeights = - convertBranchWeights(callOp.getBranchWeights(), moduleTranslation); - if (branchWeights) - call->setMetadata(llvm::LLVMContext::MD_prof, branchWeights); moduleTranslation.setAccessGroupsMetadata(callOp, call); moduleTranslation.setAliasScopeMetadata(callOp, call); moduleTranslation.setTBAAMetadata(callOp, call); @@ -196,7 +177,10 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, return success(); } // Check that LLVM call returns void for 0-result functions. - return success(call->getType()->isVoidTy()); + if (!call->getType()->isVoidTy()) + return failure(); + moduleTranslation.mapCall(callOp, call); + return success(); } if (auto inlineAsmOp = dyn_cast(opInst)) { @@ -274,10 +258,6 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, moduleTranslation.lookupBlock(invOp.getSuccessor(1)), operandsRef.drop_front()); } - llvm::MDNode *branchWeights = - convertBranchWeights(invOp.getBranchWeights(), moduleTranslation); - if (branchWeights) - result->setMetadata(llvm::LLVMContext::MD_prof, branchWeights); moduleTranslation.mapBranch(invOp, result); // InvokeOp can only have 0 or 1 result if (invOp->getNumResults() != 0) { @@ -314,23 +294,19 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, return success(); } if (auto condbrOp = dyn_cast(opInst)) { - llvm::MDNode *branchWeights = - convertBranchWeights(condbrOp.getBranchWeights(), moduleTranslation); llvm::BranchInst *branch = builder.CreateCondBr( moduleTranslation.lookupValue(condbrOp.getOperand(0)), moduleTranslation.lookupBlock(condbrOp.getSuccessor(0)), - moduleTranslation.lookupBlock(condbrOp.getSuccessor(1)), branchWeights); + moduleTranslation.lookupBlock(condbrOp.getSuccessor(1))); moduleTranslation.mapBranch(&opInst, branch); moduleTranslation.setLoopMetadata(&opInst, branch); return success(); } if (auto switchOp = dyn_cast(opInst)) { - llvm::MDNode *branchWeights = - convertBranchWeights(switchOp.getBranchWeights(), moduleTranslation); llvm::SwitchInst *switchInst = builder.CreateSwitch( moduleTranslation.lookupValue(switchOp.getValue()), moduleTranslation.lookupBlock(switchOp.getDefaultDestination()), - switchOp.getCaseDestinations().size(), branchWeights); + switchOp.getCaseDestinations().size()); auto *ty = llvm::cast( moduleTranslation.convertType(switchOp.getValue().getType())); diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index d363fb8..cd3a645 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -664,6 +664,10 @@ LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments, if (failed(convertOperation(op, builder))) return failure(); + + // Set the branch weight metadata on the translated instruction. + if (auto iface = dyn_cast(op)) + setBranchWeightsMetadata(iface); } return success(); @@ -1183,6 +1187,19 @@ void ModuleTranslation::setTBAAMetadata(AliasAnalysisOpInterface op, inst->setMetadata(llvm::LLVMContext::MD_tbaa, node); } +void ModuleTranslation::setBranchWeightsMetadata(BranchWeightOpInterface op) { + DenseI32ArrayAttr weightsAttr = op.getBranchWeightsOrNull(); + if (!weightsAttr) + return; + + llvm::Instruction *inst = isa(op) ? lookupCall(op) : lookupBranch(op); + assert(inst && "expected the operation to have a mapping to an instruction"); + SmallVector weights(weightsAttr.asArrayRef()); + inst->setMetadata( + llvm::LLVMContext::MD_prof, + llvm::MDBuilder(getLLVMContext()).createBranchWeights(weights)); +} + LogicalResult ModuleTranslation::createTBAAMetadata() { llvm::LLVMContext &ctx = llvmModule->getContext(); llvm::IntegerType *offsetTy = llvm::IntegerType::get(ctx, 64); diff --git a/mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir index 8c58d59..54ef71f 100644 --- a/mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir @@ -68,7 +68,7 @@ spirv.module Logical GLSL450 { } spirv.func @cond_branch_with_weights(%cond: i1) -> () "None" { - // CHECK: llvm.cond_br %{{.*}} weights(dense<[1, 2]> : vector<2xi32>), ^bb1, ^bb2 + // CHECK: llvm.cond_br %{{.*}} weights([1, 2]), ^bb1, ^bb2 spirv.BranchConditional %cond [1, 2], ^true, ^false // CHECK: ^bb1: ^true: diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index da4799d..09bbc5a 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -874,7 +874,7 @@ func.func @switch_wrong_number_of_weights(%arg0 : i32) { // expected-error@+1 {{expects number of branch weights to match number of successors: 3 vs 2}} llvm.switch %arg0 : i32, ^bb1 [ 42: ^bb2(%arg0, %arg0 : i32, i32) - ] {branch_weights = dense<[13, 17, 19]> : vector<3xi32>} + ] {branch_weights = array} ^bb1: // pred: ^bb0 llvm.return diff --git a/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll b/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll index 688dd10..cc3b47a 100644 --- a/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll +++ b/mlir/test/Target/LLVMIR/Import/metadata-profiling.ll @@ -4,7 +4,7 @@ define i64 @cond_br(i1 %arg1, i64 %arg2) { entry: ; CHECK: llvm.cond_br - ; CHECK-SAME: weights(dense<[0, 3]> : vector<2xi32>) + ; CHECK-SAME: weights([0, 3]) br i1 %arg1, label %bb1, label %bb2, !prof !0 bb1: ret i64 %arg2 @@ -19,7 +19,7 @@ bb2: ; CHECK-LABEL: @simple_switch( define i32 @simple_switch(i32 %arg1) { ; CHECK: llvm.switch - ; CHECK: {branch_weights = dense<[42, 3, 5]> : vector<3xi32>} + ; CHECK: {branch_weights = array} switch i32 %arg1, label %bbd [ i32 0, label %bb1 i32 9, label %bb2 @@ -41,7 +41,7 @@ declare void @fn() ; CHECK-LABEL: @call_branch_weights define void @call_branch_weights() { - ; CHECK: llvm.call @fn() {branch_weights = dense<42> : vector<1xi32>} + ; CHECK: llvm.call @fn() {branch_weights = array} call void @fn(), !prof !0 ret void } @@ -55,7 +55,7 @@ declare i32 @__gxx_personality_v0(...) ; CHECK-LABEL: @invoke_branch_weights define i32 @invoke_branch_weights() personality ptr @__gxx_personality_v0 { - ; CHECK: llvm.invoke @foo() to ^bb2 unwind ^bb1 {branch_weights = dense<[42, 99]> : vector<2xi32>} : () -> () + ; CHECK: llvm.invoke @foo() to ^bb2 unwind ^bb1 {branch_weights = array} : () -> () invoke void @foo() to label %bb2 unwind label %bb1, !prof !0 bb1: %1 = landingpad { ptr, i32 } cleanup diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir index 2500de2..3f97ebd 100644 --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -1802,7 +1802,7 @@ llvm.func @foo() { // Check that branch weight attributes are exported properly as metadata. llvm.func @cond_br_weights(%cond : i1, %arg0 : i32, %arg1 : i32) -> i32 { // CHECK: !prof ![[NODE:[0-9]+]] - llvm.cond_br %cond weights(dense<[5, 10]> : vector<2xi32>), ^bb1, ^bb2 + llvm.cond_br %cond weights([5, 10]), ^bb1, ^bb2 ^bb1: // pred: ^bb0 llvm.return %arg0 : i32 ^bb2: // pred: ^bb0 @@ -1818,7 +1818,7 @@ llvm.func @fn() // CHECK-LABEL: @call_branch_weights llvm.func @call_branch_weights() { // CHECK: !prof ![[NODE:[0-9]+]] - llvm.call @fn() {branch_weights = dense<42> : vector<1xi32>} : () -> () + llvm.call @fn() {branch_weights = array} : () -> () llvm.return } @@ -1833,7 +1833,7 @@ llvm.func @__gxx_personality_v0(...) -> i32 llvm.func @invoke_branch_weights() -> i32 attributes {personality = @__gxx_personality_v0} { %0 = llvm.mlir.constant(1 : i32) : i32 // CHECK: !prof ![[NODE:[0-9]+]] - llvm.invoke @foo() to ^bb2 unwind ^bb1 {branch_weights = dense<[42, 99]> : vector<2xi32>} : () -> () + llvm.invoke @foo() to ^bb2 unwind ^bb1 {branch_weights = array} : () -> () ^bb1: // pred: ^bb0 %1 = llvm.landingpad cleanup : !llvm.struct<(ptr, i32)> llvm.br ^bb2 @@ -2062,7 +2062,7 @@ llvm.func @switch_weights(%arg0: i32) -> i32 { llvm.switch %arg0 : i32, ^bb1(%0 : i32) [ 9: ^bb2(%1, %2 : i32, i32), 99: ^bb3 - ] {branch_weights = dense<[13, 17, 19]> : vector<3xi32>} + ] {branch_weights = array} ^bb1(%3: i32): // pred: ^bb0 llvm.return %3 : i32