From c1f6ce0c7322d47f1bb90169585fa54232231ede Mon Sep 17 00:00:00 2001 From: Simon Pilgrim Date: Mon, 17 Aug 2020 12:53:52 +0100 Subject: [PATCH] [DemandedBits] Improve accuracy of Add propagator The current demand propagator for addition will mark all input bits at and right of the alive output bit as alive. But carry won't propagate beyond a bit for which both operands are zero (or one/zero in the case of subtraction) so a more accurate answer is possible given known bits. I derived a propagator by working through truth tables and using a bit-reversed addition to make demand ripple to the right, but I'm not sure how to make a convincing argument for its correctness in the comments yet. Nevertheless, here's a minimal implementation and test to get feedback. This would help in a situation where, for example, four bytes (<128) packed into an int are added with four others SIMD-style but only one of the four results is actually read. Known A: 0_______0_______0_______0_______ Known B: 0_______0_______0_______0_______ AOut: 00000000001000000000000000000000 AB, current: 00000000001111111111111111111111 AB, patch: 00000000001111111000000000000000 Committed on behalf of: @rrika (Erika) Differential Revision: https://reviews.llvm.org/D72423 --- llvm/include/llvm/Analysis/DemandedBits.h | 14 +++++ llvm/lib/Analysis/DemandedBits.cpp | 94 +++++++++++++++++++++++++++++++ llvm/test/Analysis/DemandedBits/add.ll | 40 ++++++------- llvm/unittests/IR/CMakeLists.txt | 1 + llvm/unittests/IR/DemandedBitsTest.cpp | 66 ++++++++++++++++++++++ llvm/unittests/Support/KnownBitsTest.cpp | 30 +--------- llvm/unittests/Support/KnownBitsTest.h | 52 +++++++++++++++++ 7 files changed, 248 insertions(+), 49 deletions(-) create mode 100644 llvm/unittests/IR/DemandedBitsTest.cpp create mode 100644 llvm/unittests/Support/KnownBitsTest.h diff --git a/llvm/include/llvm/Analysis/DemandedBits.h b/llvm/include/llvm/Analysis/DemandedBits.h index 04db3eb..7a8618a 100644 --- a/llvm/include/llvm/Analysis/DemandedBits.h +++ b/llvm/include/llvm/Analysis/DemandedBits.h @@ -61,6 +61,20 @@ public: void print(raw_ostream &OS); + /// Compute alive bits of one addition operand from alive output and known + /// operand bits + static APInt determineLiveOperandBitsAdd(unsigned OperandNo, + const APInt &AOut, + const KnownBits &LHS, + const KnownBits &RHS); + + /// Compute alive bits of one subtraction operand from alive output and known + /// operand bits + static APInt determineLiveOperandBitsSub(unsigned OperandNo, + const APInt &AOut, + const KnownBits &LHS, + const KnownBits &RHS); + private: void performAnalysis(); void determineLiveOperandBits(const Instruction *UserI, diff --git a/llvm/lib/Analysis/DemandedBits.cpp b/llvm/lib/Analysis/DemandedBits.cpp index aaee8c2..62e08f3 100644 --- a/llvm/lib/Analysis/DemandedBits.cpp +++ b/llvm/lib/Analysis/DemandedBits.cpp @@ -173,7 +173,21 @@ void DemandedBits::determineLiveOperandBits( } break; case Instruction::Add: + if (AOut.isMask()) { + AB = AOut; + } else { + ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1)); + AB = determineLiveOperandBitsAdd(OperandNo, AOut, Known, Known2); + } + break; case Instruction::Sub: + if (AOut.isMask()) { + AB = AOut; + } else { + ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1)); + AB = determineLiveOperandBitsSub(OperandNo, AOut, Known, Known2); + } + break; case Instruction::Mul: // Find the highest live output bit. We don't need any more input // bits than that (adds, and thus subtracts, ripple only to the @@ -469,6 +483,86 @@ void DemandedBits::print(raw_ostream &OS) { } } +static APInt determineLiveOperandBitsAddCarry(unsigned OperandNo, + const APInt &AOut, + const KnownBits &LHS, + const KnownBits &RHS, + bool CarryZero, bool CarryOne) { + assert(!(CarryZero && CarryOne) && + "Carry can't be zero and one at the same time"); + + // The following check should be done by the caller, as it also indicates + // that LHS and RHS don't need to be computed. + // + // if (AOut.isMask()) + // return AOut; + + // Boundary bits' carry out is unaffected by their carry in. + APInt Bound = (LHS.Zero & RHS.Zero) | (LHS.One & RHS.One); + + // First, the alive carry bits are determined from the alive output bits: + // Let demand ripple to the right but only up to any set bit in Bound. + // AOut = -1---- + // Bound = ----1- + // ACarry&~AOut = --111- + APInt RBound = Bound.reverseBits(); + APInt RAOut = AOut.reverseBits(); + APInt RProp = RAOut + (RAOut | ~RBound); + APInt RACarry = RProp ^ ~RBound; + APInt ACarry = RACarry.reverseBits(); + + // Then, the alive input bits are determined from the alive carry bits: + APInt NeededToMaintainCarryZero; + APInt NeededToMaintainCarryOne; + if (OperandNo == 0) { + NeededToMaintainCarryZero = LHS.Zero | ~RHS.Zero; + NeededToMaintainCarryOne = LHS.One | ~RHS.One; + } else { + NeededToMaintainCarryZero = RHS.Zero | ~LHS.Zero; + NeededToMaintainCarryOne = RHS.One | ~LHS.One; + } + + // As in computeForAddCarry + APInt PossibleSumZero = ~LHS.Zero + ~RHS.Zero + !CarryZero; + APInt PossibleSumOne = LHS.One + RHS.One + CarryOne; + + // The below is simplified from + // + // APInt CarryKnownZero = ~(PossibleSumZero ^ LHS.Zero ^ RHS.Zero); + // APInt CarryKnownOne = PossibleSumOne ^ LHS.One ^ RHS.One; + // APInt CarryUnknown = ~(CarryKnownZero | CarryKnownOne); + // + // APInt NeededToMaintainCarry = + // (CarryKnownZero & NeededToMaintainCarryZero) | + // (CarryKnownOne & NeededToMaintainCarryOne) | + // CarryUnknown; + + APInt NeededToMaintainCarry = (~PossibleSumZero | NeededToMaintainCarryZero) & + (PossibleSumOne | NeededToMaintainCarryOne); + + APInt AB = AOut | (ACarry & NeededToMaintainCarry); + return AB; +} + +APInt DemandedBits::determineLiveOperandBitsAdd(unsigned OperandNo, + const APInt &AOut, + const KnownBits &LHS, + const KnownBits &RHS) { + return determineLiveOperandBitsAddCarry(OperandNo, AOut, LHS, RHS, true, + false); +} + +APInt DemandedBits::determineLiveOperandBitsSub(unsigned OperandNo, + const APInt &AOut, + const KnownBits &LHS, + const KnownBits &RHS) { + KnownBits NRHS; + NRHS.Zero = RHS.One; + NRHS.One = RHS.Zero; + return determineLiveOperandBitsAddCarry(OperandNo, AOut, LHS, NRHS, false, + true); +} + FunctionPass *llvm::createDemandedBitsWrapperPass() { return new DemandedBitsWrapperPass(); } diff --git a/llvm/test/Analysis/DemandedBits/add.ll b/llvm/test/Analysis/DemandedBits/add.ll index 9203ed1..01673f8 100644 --- a/llvm/test/Analysis/DemandedBits/add.ll +++ b/llvm/test/Analysis/DemandedBits/add.ll @@ -1,22 +1,22 @@ -; RUN: opt -S -demanded-bits -analyze < %s | FileCheck %s -; RUN: opt -S -disable-output -passes="print" < %s 2>&1 | FileCheck %s - -; CHECK-DAG: DemandedBits: 0x1f for %1 = and i32 %a, 9 -; CHECK-DAG: DemandedBits: 0x1f for %2 = and i32 %b, 9 -; CHECK-DAG: DemandedBits: 0x1f for %3 = and i32 %c, 13 -; CHECK-DAG: DemandedBits: 0x1f for %4 = and i32 %d, 4 -; CHECK-DAG: DemandedBits: 0x1f for %5 = or i32 %2, %3 -; CHECK-DAG: DemandedBits: 0x1f for %6 = or i32 %4, %5 +; RUN: opt -S -demanded-bits -analyze < %s | FileCheck %s +; RUN: opt -S -disable-output -passes="print" < %s 2>&1 | FileCheck %s + +; CHECK-DAG: DemandedBits: 0x1e for %1 = and i32 %a, 9 +; CHECK-DAG: DemandedBits: 0x1a for %2 = and i32 %b, 9 +; CHECK-DAG: DemandedBits: 0x1a for %3 = and i32 %c, 13 +; CHECK-DAG: DemandedBits: 0x1a for %4 = and i32 %d, 4 +; CHECK-DAG: DemandedBits: 0x1a for %5 = or i32 %2, %3 +; CHECK-DAG: DemandedBits: 0x1a for %6 = or i32 %4, %5 ; CHECK-DAG: DemandedBits: 0x10 for %7 = add i32 %1, %6 ; CHECK-DAG: DemandedBits: 0xffffffff for %8 = and i32 %7, 16 -define i32 @test_add(i32 %a, i32 %b, i32 %c, i32 %d) { - %1 = and i32 %a, 9 - %2 = and i32 %b, 9 - %3 = and i32 %c, 13 - %4 = and i32 %d, 4 ; no bit of %d alive, %4 simplifies to zero - %5 = or i32 %2, %3 - %6 = or i32 %4, %5 - %7 = add i32 %1, %6 - %8 = and i32 %7, 16 - ret i32 %8 -} \ No newline at end of file +define i32 @test_add(i32 %a, i32 %b, i32 %c, i32 %d) { + %1 = and i32 %a, 9 + %2 = and i32 %b, 9 + %3 = and i32 %c, 13 + %4 = and i32 %d, 4 ; no bit of %d alive, %4 simplifies to zero + %5 = or i32 %2, %3 + %6 = or i32 %4, %5 + %7 = add i32 %1, %6 + %8 = and i32 %7, 16 + ret i32 %8 +} diff --git a/llvm/unittests/IR/CMakeLists.txt b/llvm/unittests/IR/CMakeLists.txt index 4634bf8..c4386fe 100644 --- a/llvm/unittests/IR/CMakeLists.txt +++ b/llvm/unittests/IR/CMakeLists.txt @@ -18,6 +18,7 @@ add_llvm_unittest(IRTests DataLayoutTest.cpp DebugInfoTest.cpp DebugTypeODRUniquingTest.cpp + DemandedBitsTest.cpp DominatorTreeTest.cpp DominatorTreeBatchUpdatesTest.cpp FunctionTest.cpp diff --git a/llvm/unittests/IR/DemandedBitsTest.cpp b/llvm/unittests/IR/DemandedBitsTest.cpp new file mode 100644 index 0000000..4d15e81 --- /dev/null +++ b/llvm/unittests/IR/DemandedBitsTest.cpp @@ -0,0 +1,66 @@ +//===- DemandedBitsTest.cpp - DemandedBits tests --------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/DemandedBits.h" +#include "../Support/KnownBitsTest.h" +#include "llvm/Support/KnownBits.h" +#include "gtest/gtest.h" + +using namespace llvm; + +namespace { + +template +static void TestBinOpExhaustive(Fn1 PropagateFn, Fn2 EvalFn) { + unsigned Bits = 4; + unsigned Max = 1 << Bits; + ForeachKnownBits(Bits, [&](const KnownBits &Known1) { + ForeachKnownBits(Bits, [&](const KnownBits &Known2) { + for (unsigned AOut_ = 0; AOut_ < Max; AOut_++) { + APInt AOut(Bits, AOut_); + APInt AB1 = PropagateFn(0, AOut, Known1, Known2); + APInt AB2 = PropagateFn(1, AOut, Known1, Known2); + { + // If the propagator claims that certain known bits + // didn't matter, check it doesn't change its mind + // when they become unknown. + KnownBits Known1Redacted; + KnownBits Known2Redacted; + Known1Redacted.Zero = Known1.Zero & AB1; + Known1Redacted.One = Known1.One & AB1; + Known2Redacted.Zero = Known2.Zero & AB2; + Known2Redacted.One = Known2.One & AB2; + + APInt AB1R = PropagateFn(0, AOut, Known1Redacted, Known2Redacted); + APInt AB2R = PropagateFn(1, AOut, Known1Redacted, Known2Redacted); + EXPECT_EQ(AB1, AB1R); + EXPECT_EQ(AB2, AB2R); + } + ForeachNumInKnownBits(Known1, [&](APInt Value1) { + ForeachNumInKnownBits(Known2, [&](APInt Value2) { + APInt ReferenceResult = EvalFn((Value1 & AB1), (Value2 & AB2)); + APInt Result = EvalFn(Value1, Value2); + EXPECT_EQ(Result & AOut, ReferenceResult & AOut); + }); + }); + } + }); + }); +} + +TEST(DemandedBitsTest, Add) { + TestBinOpExhaustive(DemandedBits::determineLiveOperandBitsAdd, + [](APInt N1, APInt N2) -> APInt { return N1 + N2; }); +} + +TEST(DemandedBitsTest, Sub) { + TestBinOpExhaustive(DemandedBits::determineLiveOperandBitsSub, + [](APInt N1, APInt N2) -> APInt { return N1 - N2; }); +} + +} // anonymous namespace diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp index bfd8eb2..694e5c4 100644 --- a/llvm/unittests/Support/KnownBitsTest.cpp +++ b/llvm/unittests/Support/KnownBitsTest.cpp @@ -11,41 +11,13 @@ //===----------------------------------------------------------------------===// #include "llvm/Support/KnownBits.h" +#include "KnownBitsTest.h" #include "gtest/gtest.h" using namespace llvm; namespace { -template -void ForeachKnownBits(unsigned Bits, FnTy Fn) { - unsigned Max = 1 << Bits; - KnownBits Known(Bits); - for (unsigned Zero = 0; Zero < Max; ++Zero) { - for (unsigned One = 0; One < Max; ++One) { - Known.Zero = Zero; - Known.One = One; - if (Known.hasConflict()) - continue; - - Fn(Known); - } - } -} - -template -void ForeachNumInKnownBits(const KnownBits &Known, FnTy Fn) { - unsigned Bits = Known.getBitWidth(); - unsigned Max = 1 << Bits; - for (unsigned N = 0; N < Max; ++N) { - APInt Num(Bits, N); - if ((Num & Known.Zero) != 0 || (~Num & Known.One) != 0) - continue; - - Fn(Num); - } -} - TEST(KnownBitsTest, AddCarryExhaustive) { unsigned Bits = 4; ForeachKnownBits(Bits, [&](const KnownBits &Known1) { diff --git a/llvm/unittests/Support/KnownBitsTest.h b/llvm/unittests/Support/KnownBitsTest.h new file mode 100644 index 0000000..bc29189 --- /dev/null +++ b/llvm/unittests/Support/KnownBitsTest.h @@ -0,0 +1,52 @@ +//===- llvm/unittest/Support/KnownBitsTest.h - KnownBits tests ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements helpers for KnownBits and DemandedBits unit tests. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_UNITTESTS_SUPPORT_KNOWNBITSTEST_H +#define LLVM_UNITTESTS_SUPPORT_KNOWNBITSTEST_H + +#include "llvm/Support/KnownBits.h" + +namespace { + +using namespace llvm; + +template void ForeachKnownBits(unsigned Bits, FnTy Fn) { + unsigned Max = 1 << Bits; + KnownBits Known(Bits); + for (unsigned Zero = 0; Zero < Max; ++Zero) { + for (unsigned One = 0; One < Max; ++One) { + Known.Zero = Zero; + Known.One = One; + if (Known.hasConflict()) + continue; + + Fn(Known); + } + } +} + +template +void ForeachNumInKnownBits(const KnownBits &Known, FnTy Fn) { + unsigned Bits = Known.getBitWidth(); + unsigned Max = 1 << Bits; + for (unsigned N = 0; N < Max; ++N) { + APInt Num(Bits, N); + if ((Num & Known.Zero) != 0 || (~Num & Known.One) != 0) + continue; + + Fn(Num); + } +} + +} // end anonymous namespace + +#endif -- 2.7.4