Prevent unsigned overflow.
authorRichard Trieu <rtrieu@google.com>
Wed, 5 Sep 2018 04:19:15 +0000 (04:19 +0000)
committerRichard Trieu <rtrieu@google.com>
Wed, 5 Sep 2018 04:19:15 +0000 (04:19 +0000)
The sum of the weights is caculated in an APInt, which has a width smaller than
64.  In certain cases, the sum of the widths would overflow when calculations
are done inside an APInt, but would not if done with uint64_t.  Since the
values will be passed as uint64_t in the function call anyways, do all the math
in 64 bits.  Also added an assert in case the probabilities overflow 64 bits.

llvm-svn: 341444

llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp

index 4da8a71..80e16e7 100644 (file)
@@ -614,13 +614,15 @@ static bool CheckMDProf(MDNode *MD, BranchProbability &TrueProb,
   ConstantInt *FalseWeight = mdconst::extract<ConstantInt>(MD->getOperand(2));
   if (!TrueWeight || !FalseWeight)
     return false;
-  APInt TrueWt = TrueWeight->getValue();
-  APInt FalseWt = FalseWeight->getValue();
-  APInt SumWt = TrueWt + FalseWt;
-  TrueProb = BranchProbability::getBranchProbability(TrueWt.getZExtValue(),
-                                                     SumWt.getZExtValue());
-  FalseProb = BranchProbability::getBranchProbability(FalseWt.getZExtValue(),
-                                                      SumWt.getZExtValue());
+  uint64_t TrueWt = TrueWeight->getValue().getZExtValue();
+  uint64_t FalseWt = FalseWeight->getValue().getZExtValue();
+  uint64_t SumWt = TrueWt + FalseWt;
+
+  assert(SumWt >= TrueWt && SumWt >= FalseWt &&
+         "Overflow calculating branch probabilities.");
+
+  TrueProb = BranchProbability::getBranchProbability(TrueWt, SumWt);
+  FalseProb = BranchProbability::getBranchProbability(FalseWt, SumWt);
   return true;
 }