Retry: [BPI] Use a safer constructor to calculate branch probabilities
authorVedant Kumar <vsk@apple.com>
Sat, 17 Dec 2016 01:02:08 +0000 (01:02 +0000)
committerVedant Kumar <vsk@apple.com>
Sat, 17 Dec 2016 01:02:08 +0000 (01:02 +0000)
BPI may trigger signed overflow UB while computing branch probabilities for
cold calls or to unreachables. For example, with our current choice of weights,
we'll crash if there are >= 2^12 branches to an unreachable.

Use a safer BranchProbability constructor which is better at handling fractions
with large denominators.

Changes since the initial commit:
  - Use explicit casts to ensure that multiplication operands are 64-bit
    ints.

rdar://problem/29368161

Differential Revision: https://reviews.llvm.org/D27862

llvm-svn: 290022

llvm/lib/Analysis/BranchProbabilityInfo.cpp
llvm/unittests/Analysis/BranchProbabilityInfoTest.cpp [new file with mode: 0644]
llvm/unittests/Analysis/CMakeLists.txt

index a91ac8d..3eabb78 100644 (file)
@@ -162,12 +162,12 @@ bool BranchProbabilityInfo::calcUnreachableHeuristics(const BasicBlock *BB) {
     return true;
   }
 
-  BranchProbability UnreachableProb(UR_TAKEN_WEIGHT,
-                                    (UR_TAKEN_WEIGHT + UR_NONTAKEN_WEIGHT) *
-                                        UnreachableEdges.size());
-  BranchProbability ReachableProb(UR_NONTAKEN_WEIGHT,
-                                  (UR_TAKEN_WEIGHT + UR_NONTAKEN_WEIGHT) *
-                                      ReachableEdges.size());
+  auto UnreachableProb = BranchProbability::getBranchProbability(
+      UR_TAKEN_WEIGHT, (UR_TAKEN_WEIGHT + UR_NONTAKEN_WEIGHT) *
+                           uint64_t(UnreachableEdges.size()));
+  auto ReachableProb = BranchProbability::getBranchProbability(
+      UR_NONTAKEN_WEIGHT,
+      (UR_TAKEN_WEIGHT + UR_NONTAKEN_WEIGHT) * uint64_t(ReachableEdges.size()));
 
   for (unsigned SuccIdx : UnreachableEdges)
     setEdgeProbability(BB, SuccIdx, UnreachableProb);
@@ -300,12 +300,12 @@ bool BranchProbabilityInfo::calcColdCallHeuristics(const BasicBlock *BB) {
     return true;
   }
 
-  BranchProbability ColdProb(CC_TAKEN_WEIGHT,
-                             (CC_TAKEN_WEIGHT + CC_NONTAKEN_WEIGHT) *
-                                 ColdEdges.size());
-  BranchProbability NormalProb(CC_NONTAKEN_WEIGHT,
-                               (CC_TAKEN_WEIGHT + CC_NONTAKEN_WEIGHT) *
-                                   NormalEdges.size());
+  auto ColdProb = BranchProbability::getBranchProbability(
+      CC_TAKEN_WEIGHT,
+      (CC_TAKEN_WEIGHT + CC_NONTAKEN_WEIGHT) * uint64_t(ColdEdges.size()));
+  auto NormalProb = BranchProbability::getBranchProbability(
+      CC_NONTAKEN_WEIGHT,
+      (CC_TAKEN_WEIGHT + CC_NONTAKEN_WEIGHT) * uint64_t(NormalEdges.size()));
 
   for (unsigned SuccIdx : ColdEdges)
     setEdgeProbability(BB, SuccIdx, ColdProb);
diff --git a/llvm/unittests/Analysis/BranchProbabilityInfoTest.cpp b/llvm/unittests/Analysis/BranchProbabilityInfoTest.cpp
new file mode 100644 (file)
index 0000000..cbf8b50
--- /dev/null
@@ -0,0 +1,88 @@
+//===- BranchProbabilityInfoTest.cpp - BranchProbabilityInfo unit tests ---===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/BranchProbabilityInfo.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Support/DataTypes.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/raw_ostream.h"
+#include "gtest/gtest.h"
+
+namespace llvm {
+namespace {
+
+struct BranchProbabilityInfoTest : public testing::Test {
+  std::unique_ptr<BranchProbabilityInfo> BPI;
+  std::unique_ptr<DominatorTree> DT;
+  std::unique_ptr<LoopInfo> LI;
+  LLVMContext C;
+
+  BranchProbabilityInfo &buildBPI(Function &F) {
+    DT.reset(new DominatorTree(F));
+    LI.reset(new LoopInfo(*DT));
+    BPI.reset(new BranchProbabilityInfo(F, *LI));
+    return *BPI;
+  }
+
+  std::unique_ptr<Module> makeLLVMModule() {
+    const char *ModuleString = "define void @f() { exit: ret void }\n";
+    SMDiagnostic Err;
+    return parseAssemblyString(ModuleString, Err, C);
+  }
+};
+
+TEST_F(BranchProbabilityInfoTest, StressUnreachableHeuristic) {
+  auto M = makeLLVMModule();
+  Function *F = M->getFunction("f");
+
+  // define void @f() {
+  // entry:
+  //   switch i32 undef, label %exit, [
+  //      i32 0, label %preexit
+  //      ...                   ;;< Add lots of cases to stress the heuristic.
+  //   ]
+  // preexit:
+  //   unreachable
+  // exit:
+  //   ret void
+  // }
+
+  auto *ExitBB = &F->back();
+  auto *EntryBB = BasicBlock::Create(C, "entry", F, /*insertBefore=*/ExitBB);
+
+  auto *PreExitBB =
+      BasicBlock::Create(C, "preexit", F, /*insertBefore=*/ExitBB);
+  new UnreachableInst(C, PreExitBB);
+
+  unsigned NumCases = 4096;
+  auto *I32 = IntegerType::get(C, 32);
+  auto *Undef = UndefValue::get(I32);
+  auto *Switch = SwitchInst::Create(Undef, ExitBB, NumCases, EntryBB);
+  for (unsigned I = 0; I < NumCases; ++I)
+    Switch->addCase(ConstantInt::get(I32, I), PreExitBB);
+
+  BranchProbabilityInfo &BPI = buildBPI(*F);
+
+  // FIXME: This doesn't seem optimal. Since all of the cases handled by the
+  // switch have the *same* destination block ("preexit"), shouldn't it be the
+  // hot one? I'd expect the results to be reversed here...
+  EXPECT_FALSE(BPI.isEdgeHot(EntryBB, PreExitBB));
+  EXPECT_TRUE(BPI.isEdgeHot(EntryBB, ExitBB));
+}
+
+} // end anonymous namespace
+} // end namespace llvm
index 347c3be..33fc5c7 100644 (file)
@@ -8,6 +8,7 @@ set(LLVM_LINK_COMPONENTS
 add_llvm_unittest(AnalysisTests
   AliasAnalysisTest.cpp
   BlockFrequencyInfoTest.cpp
+  BranchProbabilityInfoTest.cpp
   CallGraphTest.cpp
   CFGTest.cpp
   CGSCCPassManagerTest.cpp