[DemandedBits] Improve accuracy of Add propagator
authorSimon Pilgrim <llvm-dev@redking.me.uk>
Mon, 17 Aug 2020 11:53:52 +0000 (12:53 +0100)
committerSimon Pilgrim <llvm-dev@redking.me.uk>
Mon, 17 Aug 2020 11:54:09 +0000 (12:54 +0100)
The current demand propagator for addition will mark all input bits at and right of the alive output bit as alive. But carry won't propagate beyond a bit for which both operands are zero (or one/zero in the case of subtraction) so a more accurate answer is possible given known bits.

I derived a propagator by working through truth tables and using a bit-reversed addition to make demand ripple to the right, but I'm not sure how to make a convincing argument for its correctness in the comments yet. Nevertheless, here's a minimal implementation and test to get feedback.

This would help in a situation where, for example, four bytes (<128) packed into an int are added with four others SIMD-style but only one of the four results is actually read.

Known A:     0_______0_______0_______0_______
Known B:     0_______0_______0_______0_______
AOut:        00000000001000000000000000000000
AB, current: 00000000001111111111111111111111
AB, patch:   00000000001111111000000000000000

Committed on behalf of: @rrika (Erika)

Differential Revision: https://reviews.llvm.org/D72423

llvm/include/llvm/Analysis/DemandedBits.h
llvm/lib/Analysis/DemandedBits.cpp
llvm/test/Analysis/DemandedBits/add.ll
llvm/unittests/IR/CMakeLists.txt
llvm/unittests/IR/DemandedBitsTest.cpp [new file with mode: 0644]
llvm/unittests/Support/KnownBitsTest.cpp
llvm/unittests/Support/KnownBitsTest.h [new file with mode: 0644]

index 04db3eb..7a8618a 100644 (file)
@@ -61,6 +61,20 @@ public:
 
   void print(raw_ostream &OS);
 
+  /// Compute alive bits of one addition operand from alive output and known
+  /// operand bits
+  static APInt determineLiveOperandBitsAdd(unsigned OperandNo,
+                                           const APInt &AOut,
+                                           const KnownBits &LHS,
+                                           const KnownBits &RHS);
+
+  /// Compute alive bits of one subtraction operand from alive output and known
+  /// operand bits
+  static APInt determineLiveOperandBitsSub(unsigned OperandNo,
+                                           const APInt &AOut,
+                                           const KnownBits &LHS,
+                                           const KnownBits &RHS);
+
 private:
   void performAnalysis();
   void determineLiveOperandBits(const Instruction *UserI,
index aaee8c2..62e08f3 100644 (file)
@@ -173,7 +173,21 @@ void DemandedBits::determineLiveOperandBits(
       }
     break;
   case Instruction::Add:
+    if (AOut.isMask()) {
+      AB = AOut;
+    } else {
+      ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1));
+      AB = determineLiveOperandBitsAdd(OperandNo, AOut, Known, Known2);
+    }
+    break;
   case Instruction::Sub:
+    if (AOut.isMask()) {
+      AB = AOut;
+    } else {
+      ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1));
+      AB = determineLiveOperandBitsSub(OperandNo, AOut, Known, Known2);
+    }
+    break;
   case Instruction::Mul:
     // Find the highest live output bit. We don't need any more input
     // bits than that (adds, and thus subtracts, ripple only to the
@@ -469,6 +483,86 @@ void DemandedBits::print(raw_ostream &OS) {
   }
 }
 
+static APInt determineLiveOperandBitsAddCarry(unsigned OperandNo,
+                                              const APInt &AOut,
+                                              const KnownBits &LHS,
+                                              const KnownBits &RHS,
+                                              bool CarryZero, bool CarryOne) {
+  assert(!(CarryZero && CarryOne) &&
+         "Carry can't be zero and one at the same time");
+
+  // The following check should be done by the caller, as it also indicates
+  // that LHS and RHS don't need to be computed.
+  //
+  // if (AOut.isMask())
+  //   return AOut;
+
+  // Boundary bits' carry out is unaffected by their carry in.
+  APInt Bound = (LHS.Zero & RHS.Zero) | (LHS.One & RHS.One);
+
+  // First, the alive carry bits are determined from the alive output bits:
+  // Let demand ripple to the right but only up to any set bit in Bound.
+  //   AOut         = -1----
+  //   Bound        = ----1-
+  //   ACarry&~AOut = --111-
+  APInt RBound = Bound.reverseBits();
+  APInt RAOut = AOut.reverseBits();
+  APInt RProp = RAOut + (RAOut | ~RBound);
+  APInt RACarry = RProp ^ ~RBound;
+  APInt ACarry = RACarry.reverseBits();
+
+  // Then, the alive input bits are determined from the alive carry bits:
+  APInt NeededToMaintainCarryZero;
+  APInt NeededToMaintainCarryOne;
+  if (OperandNo == 0) {
+    NeededToMaintainCarryZero = LHS.Zero | ~RHS.Zero;
+    NeededToMaintainCarryOne = LHS.One | ~RHS.One;
+  } else {
+    NeededToMaintainCarryZero = RHS.Zero | ~LHS.Zero;
+    NeededToMaintainCarryOne = RHS.One | ~LHS.One;
+  }
+
+  // As in computeForAddCarry
+  APInt PossibleSumZero = ~LHS.Zero + ~RHS.Zero + !CarryZero;
+  APInt PossibleSumOne = LHS.One + RHS.One + CarryOne;
+
+  // The below is simplified from
+  //
+  // APInt CarryKnownZero = ~(PossibleSumZero ^ LHS.Zero ^ RHS.Zero);
+  // APInt CarryKnownOne = PossibleSumOne ^ LHS.One ^ RHS.One;
+  // APInt CarryUnknown = ~(CarryKnownZero | CarryKnownOne);
+  //
+  // APInt NeededToMaintainCarry =
+  //   (CarryKnownZero & NeededToMaintainCarryZero) |
+  //   (CarryKnownOne  & NeededToMaintainCarryOne) |
+  //   CarryUnknown;
+
+  APInt NeededToMaintainCarry = (~PossibleSumZero | NeededToMaintainCarryZero) &
+                                (PossibleSumOne | NeededToMaintainCarryOne);
+
+  APInt AB = AOut | (ACarry & NeededToMaintainCarry);
+  return AB;
+}
+
+APInt DemandedBits::determineLiveOperandBitsAdd(unsigned OperandNo,
+                                                const APInt &AOut,
+                                                const KnownBits &LHS,
+                                                const KnownBits &RHS) {
+  return determineLiveOperandBitsAddCarry(OperandNo, AOut, LHS, RHS, true,
+                                          false);
+}
+
+APInt DemandedBits::determineLiveOperandBitsSub(unsigned OperandNo,
+                                                const APInt &AOut,
+                                                const KnownBits &LHS,
+                                                const KnownBits &RHS) {
+  KnownBits NRHS;
+  NRHS.Zero = RHS.One;
+  NRHS.One = RHS.Zero;
+  return determineLiveOperandBitsAddCarry(OperandNo, AOut, LHS, NRHS, false,
+                                          true);
+}
+
 FunctionPass *llvm::createDemandedBitsWrapperPass() {
   return new DemandedBitsWrapperPass();
 }
index 9203ed1..01673f8 100644 (file)
@@ -1,22 +1,22 @@
-; RUN: opt -S -demanded-bits -analyze < %s | FileCheck %s\r
-; RUN: opt -S -disable-output -passes="print<demanded-bits>" < %s 2>&1 | FileCheck %s\r
-\r
-; CHECK-DAG: DemandedBits: 0x1f for   %1 = and i32 %a, 9\r
-; CHECK-DAG: DemandedBits: 0x1f for   %2 = and i32 %b, 9
-; CHECK-DAG: DemandedBits: 0x1f for   %3 = and i32 %c, 13\r
-; CHECK-DAG: DemandedBits: 0x1f for   %4 = and i32 %d, 4
-; CHECK-DAG: DemandedBits: 0x1f for   %5 = or i32 %2, %3
-; CHECK-DAG: DemandedBits: 0x1f for   %6 = or i32 %4, %5\r
+; RUN: opt -S -demanded-bits -analyze < %s | FileCheck %s
+; RUN: opt -S -disable-output -passes="print<demanded-bits>" < %s 2>&1 | FileCheck %s
+
+; CHECK-DAG: DemandedBits: 0x1e for   %1 = and i32 %a, 9
+; CHECK-DAG: DemandedBits: 0x1a for   %2 = and i32 %b, 9
+; CHECK-DAG: DemandedBits: 0x1a for   %3 = and i32 %c, 13
+; CHECK-DAG: DemandedBits: 0x1a for   %4 = and i32 %d, 4
+; CHECK-DAG: DemandedBits: 0x1a for   %5 = or i32 %2, %3
+; CHECK-DAG: DemandedBits: 0x1a for   %6 = or i32 %4, %5
 ; CHECK-DAG: DemandedBits: 0x10 for   %7 = add i32 %1, %6
 ; CHECK-DAG: DemandedBits: 0xffffffff for   %8 = and i32 %7, 16
-define i32 @test_add(i32 %a, i32 %b, i32 %c, i32 %d) {\r
-  %1 = and i32 %a, 9\r
-  %2 = and i32 %b, 9\r
-  %3 = and i32 %c, 13\r
-  %4 = and i32 %d, 4 ; no bit of %d alive, %4 simplifies to zero\r
-  %5 = or i32 %2, %3\r
-  %6 = or i32 %4, %5\r
-  %7 = add i32 %1, %6\r
-  %8 = and i32 %7, 16\r
-  ret i32 %8\r
-}
\ No newline at end of file
+define i32 @test_add(i32 %a, i32 %b, i32 %c, i32 %d) {
+  %1 = and i32 %a, 9
+  %2 = and i32 %b, 9
+  %3 = and i32 %c, 13
+  %4 = and i32 %d, 4 ; no bit of %d alive, %4 simplifies to zero
+  %5 = or i32 %2, %3
+  %6 = or i32 %4, %5
+  %7 = add i32 %1, %6
+  %8 = and i32 %7, 16
+  ret i32 %8
+}
index 4634bf8..c4386fe 100644 (file)
@@ -18,6 +18,7 @@ add_llvm_unittest(IRTests
   DataLayoutTest.cpp
   DebugInfoTest.cpp
   DebugTypeODRUniquingTest.cpp
+  DemandedBitsTest.cpp
   DominatorTreeTest.cpp
   DominatorTreeBatchUpdatesTest.cpp
   FunctionTest.cpp
diff --git a/llvm/unittests/IR/DemandedBitsTest.cpp b/llvm/unittests/IR/DemandedBitsTest.cpp
new file mode 100644 (file)
index 0000000..4d15e81
--- /dev/null
@@ -0,0 +1,66 @@
+//===- DemandedBitsTest.cpp - DemandedBits tests --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/DemandedBits.h"
+#include "../Support/KnownBitsTest.h"
+#include "llvm/Support/KnownBits.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+
+namespace {
+
+template <typename Fn1, typename Fn2>
+static void TestBinOpExhaustive(Fn1 PropagateFn, Fn2 EvalFn) {
+  unsigned Bits = 4;
+  unsigned Max = 1 << Bits;
+  ForeachKnownBits(Bits, [&](const KnownBits &Known1) {
+    ForeachKnownBits(Bits, [&](const KnownBits &Known2) {
+      for (unsigned AOut_ = 0; AOut_ < Max; AOut_++) {
+        APInt AOut(Bits, AOut_);
+        APInt AB1 = PropagateFn(0, AOut, Known1, Known2);
+        APInt AB2 = PropagateFn(1, AOut, Known1, Known2);
+        {
+          // If the propagator claims that certain known bits
+          // didn't matter, check it doesn't change its mind
+          // when they become unknown.
+          KnownBits Known1Redacted;
+          KnownBits Known2Redacted;
+          Known1Redacted.Zero = Known1.Zero & AB1;
+          Known1Redacted.One = Known1.One & AB1;
+          Known2Redacted.Zero = Known2.Zero & AB2;
+          Known2Redacted.One = Known2.One & AB2;
+
+          APInt AB1R = PropagateFn(0, AOut, Known1Redacted, Known2Redacted);
+          APInt AB2R = PropagateFn(1, AOut, Known1Redacted, Known2Redacted);
+          EXPECT_EQ(AB1, AB1R);
+          EXPECT_EQ(AB2, AB2R);
+        }
+        ForeachNumInKnownBits(Known1, [&](APInt Value1) {
+          ForeachNumInKnownBits(Known2, [&](APInt Value2) {
+            APInt ReferenceResult = EvalFn((Value1 & AB1), (Value2 & AB2));
+            APInt Result = EvalFn(Value1, Value2);
+            EXPECT_EQ(Result & AOut, ReferenceResult & AOut);
+          });
+        });
+      }
+    });
+  });
+}
+
+TEST(DemandedBitsTest, Add) {
+  TestBinOpExhaustive(DemandedBits::determineLiveOperandBitsAdd,
+                      [](APInt N1, APInt N2) -> APInt { return N1 + N2; });
+}
+
+TEST(DemandedBitsTest, Sub) {
+  TestBinOpExhaustive(DemandedBits::determineLiveOperandBitsSub,
+                      [](APInt N1, APInt N2) -> APInt { return N1 - N2; });
+}
+
+} // anonymous namespace
index bfd8eb2..694e5c4 100644 (file)
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Support/KnownBits.h"
+#include "KnownBitsTest.h"
 #include "gtest/gtest.h"
 
 using namespace llvm;
 
 namespace {
 
-template<typename FnTy>
-void ForeachKnownBits(unsigned Bits, FnTy Fn) {
-  unsigned Max = 1 << Bits;
-  KnownBits Known(Bits);
-  for (unsigned Zero = 0; Zero < Max; ++Zero) {
-    for (unsigned One = 0; One < Max; ++One) {
-      Known.Zero = Zero;
-      Known.One = One;
-      if (Known.hasConflict())
-        continue;
-
-      Fn(Known);
-    }
-  }
-}
-
-template<typename FnTy>
-void ForeachNumInKnownBits(const KnownBits &Known, FnTy Fn) {
-  unsigned Bits = Known.getBitWidth();
-  unsigned Max = 1 << Bits;
-  for (unsigned N = 0; N < Max; ++N) {
-    APInt Num(Bits, N);
-    if ((Num & Known.Zero) != 0 || (~Num & Known.One) != 0)
-      continue;
-
-    Fn(Num);
-  }
-}
-
 TEST(KnownBitsTest, AddCarryExhaustive) {
   unsigned Bits = 4;
   ForeachKnownBits(Bits, [&](const KnownBits &Known1) {
diff --git a/llvm/unittests/Support/KnownBitsTest.h b/llvm/unittests/Support/KnownBitsTest.h
new file mode 100644 (file)
index 0000000..bc29189
--- /dev/null
@@ -0,0 +1,52 @@
+//===- llvm/unittest/Support/KnownBitsTest.h - KnownBits tests ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements helpers for KnownBits and DemandedBits unit tests.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_UNITTESTS_SUPPORT_KNOWNBITSTEST_H
+#define LLVM_UNITTESTS_SUPPORT_KNOWNBITSTEST_H
+
+#include "llvm/Support/KnownBits.h"
+
+namespace {
+
+using namespace llvm;
+
+template <typename FnTy> void ForeachKnownBits(unsigned Bits, FnTy Fn) {
+  unsigned Max = 1 << Bits;
+  KnownBits Known(Bits);
+  for (unsigned Zero = 0; Zero < Max; ++Zero) {
+    for (unsigned One = 0; One < Max; ++One) {
+      Known.Zero = Zero;
+      Known.One = One;
+      if (Known.hasConflict())
+        continue;
+
+      Fn(Known);
+    }
+  }
+}
+
+template <typename FnTy>
+void ForeachNumInKnownBits(const KnownBits &Known, FnTy Fn) {
+  unsigned Bits = Known.getBitWidth();
+  unsigned Max = 1 << Bits;
+  for (unsigned N = 0; N < Max; ++N) {
+    APInt Num(Bits, N);
+    if ((Num & Known.Zero) != 0 || (~Num & Known.One) != 0)
+      continue;
+
+    Fn(Num);
+  }
+}
+
+} // end anonymous namespace
+
+#endif