[mlir][llvm] Add branch weight op interface
authorTobias Gysi <tobias.gysi@nextsilicon.com>
Thu, 20 Jul 2023 08:13:17 +0000 (08:13 +0000)
committerTobias Gysi <tobias.gysi@nextsilicon.com>
Thu, 20 Jul 2023 10:46:04 +0000 (10:46 +0000)
This revision adds a branch weight op interface for the call / branch
operations that support branch weights. It can be used in the LLVM IR
import and export to simplify the branch weight conversion. An
additional mapping between call operations and instructions ensures
the actual conversion can be done in the module translation itself,
rather than in the dialect translation interface. It also has the
benefit that downstream users can amend custom metadata to the call
operation during the export to LLVM IR.

Reviewed By: zero9178, definelicht

Differential Revision: https://reviews.llvm.org/D155702

12 files changed:
mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp
mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir
mlir/test/Dialect/LLVMIR/invalid.mlir
mlir/test/Target/LLVMIR/Import/metadata-profiling.ll
mlir/test/Target/LLVMIR/llvmir.mlir

index 9f230bf..7b33ec8 100644 (file)
@@ -30,7 +30,7 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {
       /*args=*/        (ins),
       /*methodBody=*/  [{}],
       /*defaultImpl=*/ [{
-        ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+        auto op = cast<ConcreteOp>(this->getOperation());
         return op.getFastmathFlagsAttr();
       }]
       >,
@@ -48,6 +48,42 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {
   ];
 }
 
+def BranchWeightOpInterface : OpInterface<"BranchWeightOpInterface"> {
+  let description = [{
+    An interface for operations that can carry branch weights metadata. It
+    provides setters and getters for the operation's branch weights attribute.
+    The default implementation of the interface methods expect the operation to
+    have an attribute of type DenseI32ArrayAttr named branch_weights.
+  }];
+
+  let cppNamespace = "::mlir::LLVM";
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/        "Returns the branch weights attribute or nullptr",
+      /*returnType=*/  "DenseI32ArrayAttr",
+      /*methodName=*/  "getBranchWeightsOrNull",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        return op.getBranchWeightsAttr();
+      }]
+      >,
+    InterfaceMethod<
+      /*desc=*/        "Sets the branch weights attribute",
+      /*returnType=*/  "void",
+      /*methodName=*/  "setBranchWeights",
+      /*args=*/        (ins "DenseI32ArrayAttr":$attr),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        op.setBranchWeightsAttr(attr);
+      }]
+      >
+  ];
+}
+
 def AccessGroupOpInterface : OpInterface<"AccessGroupOpInterface"> {
   let description = [{
     An interface for memory operations that can carry access groups metadata.
@@ -67,7 +103,7 @@ def AccessGroupOpInterface : OpInterface<"AccessGroupOpInterface"> {
       /*args=*/        (ins),
       /*methodBody=*/  [{}],
       /*defaultImpl=*/ [{
-        ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+        auto op = cast<ConcreteOp>(this->getOperation());
         return op.getAccessGroupsAttr();
       }]
       >,
@@ -78,7 +114,7 @@ def AccessGroupOpInterface : OpInterface<"AccessGroupOpInterface"> {
       /*args=*/        (ins "const ArrayAttr":$attr),
       /*methodBody=*/  [{}],
       /*defaultImpl=*/ [{
-        ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+        auto op = cast<ConcreteOp>(this->getOperation());
         op.setAccessGroupsAttr(attr);
       }]
       >
@@ -105,7 +141,7 @@ def AliasAnalysisOpInterface : OpInterface<"AliasAnalysisOpInterface"> {
       /*args=*/        (ins),
       /*methodBody=*/  [{}],
       /*defaultImpl=*/ [{
-        ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+        auto op = cast<ConcreteOp>(this->getOperation());
         return op.getAliasScopesAttr();
       }]
       >,
@@ -116,7 +152,7 @@ def AliasAnalysisOpInterface : OpInterface<"AliasAnalysisOpInterface"> {
       /*args=*/        (ins "const ArrayAttr":$attr),
       /*methodBody=*/  [{}],
       /*defaultImpl=*/ [{
-        ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+        auto op = cast<ConcreteOp>(this->getOperation());
         op.setAliasScopesAttr(attr);
       }]
       >,
@@ -127,7 +163,7 @@ def AliasAnalysisOpInterface : OpInterface<"AliasAnalysisOpInterface"> {
       /*args=*/        (ins),
       /*methodBody=*/  [{}],
       /*defaultImpl=*/ [{
-        ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+        auto op = cast<ConcreteOp>(this->getOperation());
         return op.getNoaliasScopesAttr();
       }]
       >,
@@ -138,7 +174,7 @@ def AliasAnalysisOpInterface : OpInterface<"AliasAnalysisOpInterface"> {
       /*args=*/        (ins "const ArrayAttr":$attr),
       /*methodBody=*/  [{}],
       /*defaultImpl=*/ [{
-        ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+        auto op = cast<ConcreteOp>(this->getOperation());
         op.setNoaliasScopesAttr(attr);
       }]
       >,
@@ -149,7 +185,7 @@ def AliasAnalysisOpInterface : OpInterface<"AliasAnalysisOpInterface"> {
       /*args=*/        (ins),
       /*methodBody=*/  [{}],
       /*defaultImpl=*/ [{
-        ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+        auto op = cast<ConcreteOp>(this->getOperation());
         return op.getTbaaAttr();
       }]
       >,
@@ -160,7 +196,7 @@ def AliasAnalysisOpInterface : OpInterface<"AliasAnalysisOpInterface"> {
       /*args=*/        (ins "const ArrayAttr":$attr),
       /*methodBody=*/  [{}],
       /*defaultImpl=*/ [{
-        ConcreteOp op = cast<ConcreteOp>(this->getOperation());
+        auto op = cast<ConcreteOp>(this->getOperation());
         op.setTbaaAttr(attr);
       }]
       >
index 01070fe..2d0ca91 100644 (file)
@@ -536,12 +536,14 @@ def LLVM_FPTruncOp : LLVM_CastOp<"fptrunc", "FPTrunc",
 def LLVM_InvokeOp : LLVM_Op<"invoke", [
                       AttrSizedOperandSegments,
                       DeclareOpInterfaceMethods<BranchOpInterface>,
-                      DeclareOpInterfaceMethods<CallOpInterface>, Terminator]> {
+                      DeclareOpInterfaceMethods<CallOpInterface>,
+                      DeclareOpInterfaceMethods<BranchWeightOpInterface>,
+                      Terminator]> {
   let arguments = (ins OptionalAttr<FlatSymbolRefAttr>:$callee,
                    Variadic<LLVM_Type>:$callee_operands,
                    Variadic<LLVM_Type>:$normalDestOperands,
                    Variadic<LLVM_Type>:$unwindDestOperands,
-                   OptionalAttr<ElementsAttr>:$branch_weights);
+                   OptionalAttr<DenseI32ArrayAttr>:$branch_weights);
   let results = (outs Variadic<LLVM_Type>);
   let successors = (successor AnySuccessor:$normalDest,
                               AnySuccessor:$unwindDest);
@@ -582,7 +584,8 @@ def LLVM_LandingpadOp : LLVM_Op<"landingpad"> {
 def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
                     [DeclareOpInterfaceMethods<FastmathFlagsInterface>,
                      DeclareOpInterfaceMethods<CallOpInterface>,
-                     DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+                     DeclareOpInterfaceMethods<SymbolUserOpInterface>,
+                     DeclareOpInterfaceMethods<BranchWeightOpInterface>]> {
   let summary = "Call to an LLVM function.";
   let description = [{
     In LLVM IR, functions may return either 0 or 1 value. LLVM IR dialect
@@ -616,7 +619,7 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
                   Variadic<LLVM_Type>,
                   DefaultValuedAttr<LLVM_FastmathFlagsAttr,
                                    "{}">:$fastmathFlags,
-                  OptionalAttr<ElementsAttr>:$branch_weights);
+                  OptionalAttr<DenseI32ArrayAttr>:$branch_weights);
   // Append the aliasing related attributes defined in LLVM_MemAccessOpBase.
   let arguments = !con(args, aliasAttrs);
   let results = (outs Optional<LLVM_Type>:$result);
@@ -847,12 +850,14 @@ def LLVM_BrOp : LLVM_TerminatorOp<"br",
   ];
 }
 def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br",
-    [AttrSizedOperandSegments, DeclareOpInterfaceMethods<BranchOpInterface>,
+    [AttrSizedOperandSegments,
+     DeclareOpInterfaceMethods<BranchOpInterface>,
+     DeclareOpInterfaceMethods<BranchWeightOpInterface>,
      Pure]> {
   let arguments = (ins I1:$condition,
                    Variadic<LLVM_Type>:$trueDestOperands,
                    Variadic<LLVM_Type>:$falseDestOperands,
-                   OptionalAttr<ElementsAttr>:$branch_weights,
+                   OptionalAttr<DenseI32ArrayAttr>:$branch_weights,
                    OptionalAttr<LoopAnnotationAttr>:$loop_annotation);
   let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);
   let assemblyFormat = [{
@@ -874,7 +879,7 @@ def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br",
             falseOperands);
   }]>,
   OpBuilder<(ins "Value":$condition, "ValueRange":$trueOperands, "ValueRange":$falseOperands,
-    "ElementsAttr":$branchWeights, "Block *":$trueDest, "Block *":$falseDest),
+    "DenseI32ArrayAttr":$branchWeights, "Block *":$trueDest, "Block *":$falseDest),
   [{
       build($_builder, $_state, condition, trueOperands, falseOperands, branchWeights,
       {}, trueDest, falseDest);
@@ -934,7 +939,9 @@ def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable"> {
 }
 
 def LLVM_SwitchOp : LLVM_TerminatorOp<"switch",
-    [AttrSizedOperandSegments, DeclareOpInterfaceMethods<BranchOpInterface>,
+    [AttrSizedOperandSegments,
+     DeclareOpInterfaceMethods<BranchOpInterface>,
+     DeclareOpInterfaceMethods<BranchWeightOpInterface>,
      Pure]> {
   let arguments = (ins
     AnyInteger:$value,
@@ -942,7 +949,7 @@ def LLVM_SwitchOp : LLVM_TerminatorOp<"switch",
     VariadicOfVariadic<AnyType, "case_operand_segments">:$caseOperands,
     OptionalAttr<AnyIntElementsAttr>:$case_values,
     DenseI32ArrayAttr:$case_operand_segments,
-    OptionalAttr<ElementsAttr>:$branch_weights
+    OptionalAttr<DenseI32ArrayAttr>:$branch_weights
   );
   let successors = (successor
     AnySuccessor:$defaultDestination,
index da4d43a..0d296aa 100644 (file)
@@ -118,6 +118,20 @@ public:
     return branchMapping.lookup(op);
   }
 
+  /// Stores a mapping between an MLIR call operation and a corresponding LLVM
+  /// call instruction.
+  void mapCall(Operation *mlir, llvm::CallInst *llvm) {
+    auto result = callMapping.try_emplace(mlir, llvm);
+    (void)result;
+    assert(result.second && "attempting to map a call that is already mapped");
+  }
+
+  /// Finds an LLVM call instruction that corresponds to the given MLIR call
+  /// operation.
+  llvm::CallInst *lookupCall(Operation *op) const {
+    return callMapping.lookup(op);
+  }
+
   /// Removes the mapping for blocks contained in the region and values defined
   /// in these blocks.
   void forgetMapping(Region &region);
@@ -141,6 +155,9 @@ public:
   /// Sets LLVM TBAA metadata for memory operations that have TBAA attributes.
   void setTBAAMetadata(AliasAnalysisOpInterface op, llvm::Instruction *inst);
 
+  /// Sets LLVM profiling metadata for operations that have branch weights.
+  void setBranchWeightsMetadata(BranchWeightOpInterface op);
+
   /// Sets LLVM loop metadata for branch operations that have a loop annotation
   /// attribute.
   void setLoopMetadata(Operation *op, llvm::Instruction *inst);
@@ -328,6 +345,11 @@ private:
   /// values after all operations are converted.
   DenseMap<Operation *, llvm::Instruction *> branchMapping;
 
+  /// A mapping between MLIR LLVM dialect call operations and LLVM IR call
+  /// instructions. This allows for adding branch weights after the operations
+  /// have been converted.
+  DenseMap<Operation *, llvm::CallInst *> callMapping;
+
   /// Mapping from an alias scope metadata operation to its LLVM metadata.
   /// This map is populated on module entry.
   DenseMap<Attribute, llvm::MDNode *> aliasScopeMetadataMapping;
index 28e587a..1d32e6e 100644 (file)
@@ -553,10 +553,12 @@ public:
   matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     // If branch weights exist, map them to 32-bit integer vector.
-    ElementsAttr branchWeights = nullptr;
+    DenseI32ArrayAttr branchWeights = nullptr;
     if (auto weights = op.getBranchWeights()) {
-      VectorType weightType = VectorType::get(2, rewriter.getI32Type());
-      branchWeights = DenseElementsAttr::get(weightType, weights->getValue());
+      SmallVector<int32_t> weightValues;
+      for (auto weight : weights->getAsRange<IntegerAttr>())
+        weightValues.push_back(weight.getInt());
+      branchWeights = DenseI32ArrayAttr::get(getContext(), weightValues);
     }
 
     rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
index 8eee3b2..f4d9c95 100644 (file)
@@ -310,11 +310,11 @@ void CondBrOp::build(OpBuilder &builder, OperationState &result,
                      Value condition, Block *trueDest, ValueRange trueOperands,
                      Block *falseDest, ValueRange falseOperands,
                      std::optional<std::pair<uint32_t, uint32_t>> weights) {
-  ElementsAttr weightsAttr;
+  DenseI32ArrayAttr weightsAttr;
   if (weights)
     weightsAttr =
-        builder.getI32VectorAttr({static_cast<int32_t>(weights->first),
-                                  static_cast<int32_t>(weights->second)});
+        builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights->first),
+                                      static_cast<int32_t>(weights->second)});
 
   build(builder, result, condition, trueOperands, falseOperands, weightsAttr,
         /*loop_annotation=*/{}, trueDest, falseDest);
@@ -330,9 +330,9 @@ void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
                      BlockRange caseDestinations,
                      ArrayRef<ValueRange> caseOperands,
                      ArrayRef<int32_t> branchWeights) {
-  ElementsAttr weightsAttr;
+  DenseI32ArrayAttr weightsAttr;
   if (!branchWeights.empty())
-    weightsAttr = builder.getI32VectorAttr(llvm::to_vector<4>(branchWeights));
+    weightsAttr = builder.getDenseI32ArrayAttr(branchWeights);
 
   build(builder, result, value, defaultOperands, caseOperands, caseValues,
         weightsAttr, defaultDestination, caseDestinations);
index a6f0ebe..40d8253 100644 (file)
@@ -125,13 +125,11 @@ static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node,
     branchWeights.push_back(branchWeight->getZExtValue());
   }
 
-  return TypeSwitch<Operation *, LogicalResult>(op)
-      .Case<CondBrOp, SwitchOp, CallOp, InvokeOp>([&](auto branchWeightOp) {
-        branchWeightOp.setBranchWeightsAttr(
-            builder.getI32VectorAttr(branchWeights));
-        return success();
-      })
-      .Default([](auto) { return failure(); });
+  if (auto iface = dyn_cast<BranchWeightOpInterface>(op)) {
+    iface.setBranchWeights(builder.getDenseI32ArrayAttr(branchWeights));
+    return success();
+  }
+  return failure();
 }
 
 /// Searches for the attribute that maps to the given TBAA metadata `node` and
index a044930..8f7c5d8 100644 (file)
@@ -124,21 +124,6 @@ convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder,
   return success();
 }
 
-/// Constructs branch weights metadata if the provided `weights` hold a value,
-/// otherwise returns nullptr.
-static llvm::MDNode *
-convertBranchWeights(std::optional<ElementsAttr> weights,
-                     LLVM::ModuleTranslation &moduleTranslation) {
-  if (!weights)
-    return nullptr;
-  SmallVector<uint32_t> weightValues;
-  weightValues.reserve(weights->size());
-  for (APInt weight : llvm::cast<DenseIntElementsAttr>(*weights))
-    weightValues.push_back(weight.getLimitedValue());
-  return llvm::MDBuilder(moduleTranslation.getLLVMContext())
-      .createBranchWeights(weightValues);
-}
-
 static LogicalResult
 convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
                      LLVM::ModuleTranslation &moduleTranslation) {
@@ -182,10 +167,6 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
                                                       callOp.getArgOperands()),
                                 operandsRef.front(), operandsRef.drop_front());
     }
-    llvm::MDNode *branchWeights =
-        convertBranchWeights(callOp.getBranchWeights(), moduleTranslation);
-    if (branchWeights)
-      call->setMetadata(llvm::LLVMContext::MD_prof, branchWeights);
     moduleTranslation.setAccessGroupsMetadata(callOp, call);
     moduleTranslation.setAliasScopeMetadata(callOp, call);
     moduleTranslation.setTBAAMetadata(callOp, call);
@@ -196,7 +177,10 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
       return success();
     }
     // Check that LLVM call returns void for 0-result functions.
-    return success(call->getType()->isVoidTy());
+    if (!call->getType()->isVoidTy())
+      return failure();
+    moduleTranslation.mapCall(callOp, call);
+    return success();
   }
 
   if (auto inlineAsmOp = dyn_cast<LLVM::InlineAsmOp>(opInst)) {
@@ -274,10 +258,6 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
           moduleTranslation.lookupBlock(invOp.getSuccessor(1)),
           operandsRef.drop_front());
     }
-    llvm::MDNode *branchWeights =
-        convertBranchWeights(invOp.getBranchWeights(), moduleTranslation);
-    if (branchWeights)
-      result->setMetadata(llvm::LLVMContext::MD_prof, branchWeights);
     moduleTranslation.mapBranch(invOp, result);
     // InvokeOp can only have 0 or 1 result
     if (invOp->getNumResults() != 0) {
@@ -314,23 +294,19 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
     return success();
   }
   if (auto condbrOp = dyn_cast<LLVM::CondBrOp>(opInst)) {
-    llvm::MDNode *branchWeights =
-        convertBranchWeights(condbrOp.getBranchWeights(), moduleTranslation);
     llvm::BranchInst *branch = builder.CreateCondBr(
         moduleTranslation.lookupValue(condbrOp.getOperand(0)),
         moduleTranslation.lookupBlock(condbrOp.getSuccessor(0)),
-        moduleTranslation.lookupBlock(condbrOp.getSuccessor(1)), branchWeights);
+        moduleTranslation.lookupBlock(condbrOp.getSuccessor(1)));
     moduleTranslation.mapBranch(&opInst, branch);
     moduleTranslation.setLoopMetadata(&opInst, branch);
     return success();
   }
   if (auto switchOp = dyn_cast<LLVM::SwitchOp>(opInst)) {
-    llvm::MDNode *branchWeights =
-        convertBranchWeights(switchOp.getBranchWeights(), moduleTranslation);
     llvm::SwitchInst *switchInst = builder.CreateSwitch(
         moduleTranslation.lookupValue(switchOp.getValue()),
         moduleTranslation.lookupBlock(switchOp.getDefaultDestination()),
-        switchOp.getCaseDestinations().size(), branchWeights);
+        switchOp.getCaseDestinations().size());
 
     auto *ty = llvm::cast<llvm::IntegerType>(
         moduleTranslation.convertType(switchOp.getValue().getType()));
index d363fb8..cd3a645 100644 (file)
@@ -664,6 +664,10 @@ LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments,
 
     if (failed(convertOperation(op, builder)))
       return failure();
+
+    // Set the branch weight metadata on the translated instruction.
+    if (auto iface = dyn_cast<BranchWeightOpInterface>(op))
+      setBranchWeightsMetadata(iface);
   }
 
   return success();
@@ -1183,6 +1187,19 @@ void ModuleTranslation::setTBAAMetadata(AliasAnalysisOpInterface op,
   inst->setMetadata(llvm::LLVMContext::MD_tbaa, node);
 }
 
+void ModuleTranslation::setBranchWeightsMetadata(BranchWeightOpInterface op) {
+  DenseI32ArrayAttr weightsAttr = op.getBranchWeightsOrNull();
+  if (!weightsAttr)
+    return;
+
+  llvm::Instruction *inst = isa<CallOp>(op) ? lookupCall(op) : lookupBranch(op);
+  assert(inst && "expected the operation to have a mapping to an instruction");
+  SmallVector<uint32_t> weights(weightsAttr.asArrayRef());
+  inst->setMetadata(
+      llvm::LLVMContext::MD_prof,
+      llvm::MDBuilder(getLLVMContext()).createBranchWeights(weights));
+}
+
 LogicalResult ModuleTranslation::createTBAAMetadata() {
   llvm::LLVMContext &ctx = llvmModule->getContext();
   llvm::IntegerType *offsetTy = llvm::IntegerType::get(ctx, 64);
index 8c58d59..54ef71f 100644 (file)
@@ -68,7 +68,7 @@ spirv.module Logical GLSL450 {
   }
 
   spirv.func @cond_branch_with_weights(%cond: i1) -> () "None" {
-    // CHECK: llvm.cond_br %{{.*}} weights(dense<[1, 2]> : vector<2xi32>), ^bb1, ^bb2
+    // CHECK: llvm.cond_br %{{.*}} weights([1, 2]), ^bb1, ^bb2
     spirv.BranchConditional %cond [1, 2], ^true, ^false
   // CHECK: ^bb1:
   ^true:
index da4799d..09bbc5a 100644 (file)
@@ -874,7 +874,7 @@ func.func @switch_wrong_number_of_weights(%arg0 : i32) {
   // expected-error@+1 {{expects number of branch weights to match number of successors: 3 vs 2}}
   llvm.switch %arg0 : i32, ^bb1 [
     42: ^bb2(%arg0, %arg0 : i32, i32)
-  ] {branch_weights = dense<[13, 17, 19]> : vector<3xi32>}
+  ] {branch_weights = array<i32: 13, 17, 19>}
 
 ^bb1: // pred: ^bb0
   llvm.return
index 688dd10..cc3b47a 100644 (file)
@@ -4,7 +4,7 @@
 define i64 @cond_br(i1 %arg1, i64 %arg2) {
 entry:
   ; CHECK: llvm.cond_br
-  ; CHECK-SAME: weights(dense<[0, 3]> : vector<2xi32>)
+  ; CHECK-SAME: weights([0, 3])
   br i1 %arg1, label %bb1, label %bb2, !prof !0
 bb1:
   ret i64 %arg2
@@ -19,7 +19,7 @@ bb2:
 ; CHECK-LABEL: @simple_switch(
 define i32 @simple_switch(i32 %arg1) {
   ; CHECK: llvm.switch
-  ; CHECK: {branch_weights = dense<[42, 3, 5]> : vector<3xi32>}
+  ; CHECK: {branch_weights = array<i32: 42, 3, 5>}
   switch i32 %arg1, label %bbd [
     i32 0, label %bb1
     i32 9, label %bb2
@@ -41,7 +41,7 @@ declare void @fn()
 
 ; CHECK-LABEL: @call_branch_weights
 define void @call_branch_weights() {
-  ; CHECK:  llvm.call @fn() {branch_weights = dense<42> : vector<1xi32>}
+  ; CHECK:  llvm.call @fn() {branch_weights = array<i32: 42>}
   call void @fn(), !prof !0
   ret void
 }
@@ -55,7 +55,7 @@ declare i32 @__gxx_personality_v0(...)
 
 ; CHECK-LABEL: @invoke_branch_weights
 define i32 @invoke_branch_weights() personality ptr @__gxx_personality_v0 {
-  ; CHECK: llvm.invoke @foo() to ^bb2 unwind ^bb1 {branch_weights = dense<[42, 99]> : vector<2xi32>} : () -> ()
+  ; CHECK: llvm.invoke @foo() to ^bb2 unwind ^bb1 {branch_weights = array<i32: 42, 99>} : () -> ()
   invoke void @foo() to label %bb2 unwind label %bb1, !prof !0
 bb1:
   %1 = landingpad { ptr, i32 } cleanup
index 2500de2..3f97ebd 100644 (file)
@@ -1802,7 +1802,7 @@ llvm.func @foo() {
 // Check that branch weight attributes are exported properly as metadata.
 llvm.func @cond_br_weights(%cond : i1, %arg0 : i32,  %arg1 : i32) -> i32 {
   // CHECK: !prof ![[NODE:[0-9]+]]
-  llvm.cond_br %cond weights(dense<[5, 10]> : vector<2xi32>), ^bb1, ^bb2
+  llvm.cond_br %cond weights([5, 10]), ^bb1, ^bb2
 ^bb1:  // pred: ^bb0
   llvm.return %arg0 : i32
 ^bb2:  // pred: ^bb0
@@ -1818,7 +1818,7 @@ llvm.func @fn()
 // CHECK-LABEL: @call_branch_weights
 llvm.func @call_branch_weights() {
   // CHECK: !prof ![[NODE:[0-9]+]]
-  llvm.call @fn() {branch_weights = dense<42> : vector<1xi32>} : () -> ()
+  llvm.call @fn() {branch_weights = array<i32 : 42>} : () -> ()
   llvm.return
 }
 
@@ -1833,7 +1833,7 @@ llvm.func @__gxx_personality_v0(...) -> i32
 llvm.func @invoke_branch_weights() -> i32 attributes {personality = @__gxx_personality_v0} {
   %0 = llvm.mlir.constant(1 : i32) : i32
   // CHECK: !prof ![[NODE:[0-9]+]]
-  llvm.invoke @foo() to ^bb2 unwind ^bb1 {branch_weights = dense<[42, 99]> : vector<2xi32>} : () -> ()
+  llvm.invoke @foo() to ^bb2 unwind ^bb1 {branch_weights = array<i32 : 42, 99>} : () -> ()
 ^bb1:  // pred: ^bb0
   %1 = llvm.landingpad cleanup : !llvm.struct<(ptr<i8>, i32)>
   llvm.br ^bb2
@@ -2062,7 +2062,7 @@ llvm.func @switch_weights(%arg0: i32) -> i32 {
   llvm.switch %arg0 : i32, ^bb1(%0 : i32) [
     9: ^bb2(%1, %2 : i32, i32),
     99: ^bb3
-  ] {branch_weights = dense<[13, 17, 19]> : vector<3xi32>}
+  ] {branch_weights = array<i32 : 13, 17, 19>}
 
 ^bb1(%3: i32):  // pred: ^bb0
   llvm.return %3 : i32