From 1977c53b2ae425541a0ef329ca10cc8d5cacd0cd Mon Sep 17 00:00:00 2001 From: Juneyoung Lee Date: Wed, 28 Apr 2021 11:34:45 +0900 Subject: [PATCH] [InstCombine] Fold overflow bit of [u|s]mul.with.overflow in a poison-safe way As discussed in D101191, this patch adds a poison-safe folding of overflow bit check: ``` %Op0 = icmp ne i4 %X, 0 %Agg = call { i4, i1 } @llvm.[us]mul.with.overflow.i4(i4 %X, i4 %Y) %Op1 = extractvalue { i4, i1 } %Agg, 1 %ret = select i1 %Op0, i1 %Op1, i1 false => %Y.fr = freeze %Y %Agg = call { i4, i1 } @llvm.[us]mul.with.overflow.i4(i4 %X, i4 %Y.fr) %Op1 = extractvalue { i4, i1 } %Agg, 1 %ret = %Op1 ``` https://alive2.llvm.org/ce/z/zgPUGT https://alive2.llvm.org/ce/z/h2gZ_6 Note that there are cases where inserting freeze is not necessary: e.g. %Y is `noundef`. In this case, LLVM is already good because `%ret` is already successfully folded into `and`, triggering the pre-existing optimization in InstSimplify: https://godbolt.org/z/v6qena15K Differential Revision: https://reviews.llvm.org/D101423 --- llvm/include/llvm/Analysis/OverflowInstAnalysis.h | 45 +++++++++++ llvm/lib/Analysis/CMakeLists.txt | 1 + llvm/lib/Analysis/InstructionSimplify.cpp | 88 +++------------------- llvm/lib/Analysis/OverflowInstAnalysis.cpp | 71 +++++++++++++++++ .../Transforms/InstCombine/InstCombineSelect.cpp | 13 ++++ .../div-by-0-guard-before-smul_ov-not.ll | 7 +- .../InstCombine/div-by-0-guard-before-smul_ov.ll | 7 +- .../div-by-0-guard-before-umul_ov-not.ll | 7 +- .../InstCombine/div-by-0-guard-before-umul_ov.ll | 7 +- 9 files changed, 151 insertions(+), 95 deletions(-) create mode 100644 llvm/include/llvm/Analysis/OverflowInstAnalysis.h create mode 100644 llvm/lib/Analysis/OverflowInstAnalysis.cpp diff --git a/llvm/include/llvm/Analysis/OverflowInstAnalysis.h b/llvm/include/llvm/Analysis/OverflowInstAnalysis.h new file mode 100644 index 0000000..7523fb9 --- /dev/null +++ b/llvm/include/llvm/Analysis/OverflowInstAnalysis.h @@ -0,0 +1,45 @@ +//===-- OverflowInstAnalysis.h - Utils to fold overflow insts ----*- C++ -*-==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file holds routines to help analyse overflow instructions +// and fold them into constants or other overflow instructions +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_ANALYSIS_OVERFLOWINSTANALYSIS_H +#define LLVM_ANALYSIS_OVERFLOWINSTANALYSIS_H + +#include "llvm/IR/InstrTypes.h" + +namespace llvm { +class Value; +class Use; + +/// Match one of the patterns up to the select/logic op: +/// %Op0 = icmp ne i4 %X, 0 +/// %Agg = call { i4, i1 } @llvm.[us]mul.with.overflow.i4(i4 %X, i4 %Y) +/// %Op1 = extractvalue { i4, i1 } %Agg, 1 +/// %ret = select i1 %Op0, i1 %Op1, i1 false / %ret = and i1 %Op0, %Op1 +/// +/// %Op0 = icmp eq i4 %X, 0 +/// %Agg = call { i4, i1 } @llvm.[us]mul.with.overflow.i4(i4 %X, i4 %Y) +/// %NotOp1 = extractvalue { i4, i1 } %Agg, 1 +/// %Op1 = xor i1 %NotOp1, true +/// %ret = select i1 %Op0, i1 true, i1 %Op1 / %ret = or i1 %Op0, %Op1 +/// +/// Callers are expected to align that with the operands of the select/logic. +/// IsAnd is set to true if the Op0 and Op1 are used as the first pattern. +/// If Op0 and Op1 match one of the patterns above, return true and fill Y's +/// use. + +bool isCheckForZeroAndMulWithOverflow(Value *Op0, Value *Op1, bool IsAnd, + Use *&Y); +bool isCheckForZeroAndMulWithOverflow(Value *Op0, Value *Op1, bool IsAnd); +} // end namespace llvm + +#endif diff --git a/llvm/lib/Analysis/CMakeLists.txt b/llvm/lib/Analysis/CMakeLists.txt index 114542d..559947c 100644 --- a/llvm/lib/Analysis/CMakeLists.txt +++ b/llvm/lib/Analysis/CMakeLists.txt @@ -101,6 +101,7 @@ add_llvm_component_library(LLVMAnalysis ObjCARCAnalysisUtils.cpp ObjCARCInstKind.cpp OptimizationRemarkEmitter.cpp + OverflowInstAnalysis.cpp PHITransAddr.cpp PhiValues.cpp PostDominators.cpp diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp index 73b760f..a916c77 100644 --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -26,6 +26,7 @@ #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Analysis/OverflowInstAnalysis.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/ConstantRange.h" @@ -1947,77 +1948,6 @@ static Value *simplifyAndOrOfCmps(const SimplifyQuery &Q, return nullptr; } -/// Check that the Op1 is in expected form, i.e.: -/// %Agg = tail call { i4, i1 } @llvm.[us]mul.with.overflow.i4(i4 %X, i4 %???) -/// %Op1 = extractvalue { i4, i1 } %Agg, 1 -static bool omitCheckForZeroBeforeMulWithOverflowInternal(Value *Op1, - Value *X) { - auto *Extract = dyn_cast(Op1); - // We should only be extracting the overflow bit. - if (!Extract || !Extract->getIndices().equals(1)) - return false; - Value *Agg = Extract->getAggregateOperand(); - // This should be a multiplication-with-overflow intrinsic. - if (!match(Agg, m_CombineOr(m_Intrinsic(), - m_Intrinsic()))) - return false; - // One of its multipliers should be the value we checked for zero before. - if (!match(Agg, m_CombineOr(m_Argument<0>(m_Specific(X)), - m_Argument<1>(m_Specific(X))))) - return false; - return true; -} - -/// The @llvm.[us]mul.with.overflow intrinsic could have been folded from some -/// other form of check, e.g. one that was using division; it may have been -/// guarded against division-by-zero. We can drop that check now. -/// Look for: -/// %Op0 = icmp ne i4 %X, 0 -/// %Agg = tail call { i4, i1 } @llvm.[us]mul.with.overflow.i4(i4 %X, i4 %???) -/// %Op1 = extractvalue { i4, i1 } %Agg, 1 -/// %??? = and i1 %Op0, %Op1 -/// We can just return %Op1 -static Value *omitCheckForZeroBeforeMulWithOverflow(Value *Op0, Value *Op1) { - ICmpInst::Predicate Pred; - Value *X; - if (!match(Op0, m_ICmp(Pred, m_Value(X), m_Zero())) || - Pred != ICmpInst::Predicate::ICMP_NE) - return nullptr; - // Is Op1 in expected form? - if (!omitCheckForZeroBeforeMulWithOverflowInternal(Op1, X)) - return nullptr; - // Can omit 'and', and just return the overflow bit. - return Op1; -} - -/// The @llvm.[us]mul.with.overflow intrinsic could have been folded from some -/// other form of check, e.g. one that was using division; it may have been -/// guarded against division-by-zero. We can drop that check now. -/// Look for: -/// %Op0 = icmp eq i4 %X, 0 -/// %Agg = tail call { i4, i1 } @llvm.[us]mul.with.overflow.i4(i4 %X, i4 %???) -/// %Op1 = extractvalue { i4, i1 } %Agg, 1 -/// %NotOp1 = xor i1 %Op1, true -/// %or = or i1 %Op0, %NotOp1 -/// We can just return %NotOp1 -static Value *omitCheckForZeroBeforeInvertedMulWithOverflow(Value *Op0, - Value *NotOp1) { - ICmpInst::Predicate Pred; - Value *X; - if (!match(Op0, m_ICmp(Pred, m_Value(X), m_Zero())) || - Pred != ICmpInst::Predicate::ICMP_EQ) - return nullptr; - // We expect the other hand of an 'or' to be a 'not'. - Value *Op1; - if (!match(NotOp1, m_Not(m_Value(Op1)))) - return nullptr; - // Is Op1 in expected form? - if (!omitCheckForZeroBeforeMulWithOverflowInternal(Op1, X)) - return nullptr; - // Can omit 'and', and just return the inverted overflow bit. - return NotOp1; -} - /// Given a bitwise logic op, check if the operands are add/sub with a common /// source value and inverted constant (identity: C - X -> ~(X + ~C)). static Value *simplifyLogicOfAddSub(Value *Op0, Value *Op1, @@ -2102,10 +2032,10 @@ static Value *SimplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, // If we have a multiplication overflow check that is being 'and'ed with a // check that one of the multipliers is not zero, we can omit the 'and', and // only keep the overflow check. - if (Value *V = omitCheckForZeroBeforeMulWithOverflow(Op0, Op1)) - return V; - if (Value *V = omitCheckForZeroBeforeMulWithOverflow(Op1, Op0)) - return V; + if (isCheckForZeroAndMulWithOverflow(Op0, Op1, true)) + return Op1; + if (isCheckForZeroAndMulWithOverflow(Op1, Op0, true)) + return Op0; // A & (-A) = A if A is a power of two or zero. if (match(Op0, m_Neg(m_Specific(Op1))) || @@ -2316,10 +2246,10 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, // If we have a multiplication overflow check that is being 'and'ed with a // check that one of the multipliers is not zero, we can omit the 'and', and // only keep the overflow check. - if (Value *V = omitCheckForZeroBeforeInvertedMulWithOverflow(Op0, Op1)) - return V; - if (Value *V = omitCheckForZeroBeforeInvertedMulWithOverflow(Op1, Op0)) - return V; + if (isCheckForZeroAndMulWithOverflow(Op0, Op1, false)) + return Op1; + if (isCheckForZeroAndMulWithOverflow(Op1, Op0, false)) + return Op0; // Try some generic simplifications for associative operations. if (Value *V = SimplifyAssociativeBinOp(Instruction::Or, Op0, Op1, Q, diff --git a/llvm/lib/Analysis/OverflowInstAnalysis.cpp b/llvm/lib/Analysis/OverflowInstAnalysis.cpp new file mode 100644 index 0000000..2456ae9 --- /dev/null +++ b/llvm/lib/Analysis/OverflowInstAnalysis.cpp @@ -0,0 +1,71 @@ +//==-- OverflowInstAnalysis.cpp - Utils to fold overflow insts ----*- C++ -*-=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file holds routines to help analyse overflow instructions +// and fold them into constants or other overflow instructions +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/OverflowInstAnalysis.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/PatternMatch.h" + +using namespace llvm; +using namespace llvm::PatternMatch; + +bool llvm::isCheckForZeroAndMulWithOverflow(Value *Op0, Value *Op1, bool IsAnd, + Use *&Y) { + ICmpInst::Predicate Pred; + Value *X, *NotOp1; + int XIdx; + IntrinsicInst *II; + + if (!match(Op0, m_ICmp(Pred, m_Value(X), m_Zero()))) + return false; + + /// %Agg = call { i4, i1 } @llvm.[us]mul.with.overflow.i4(i4 %X, i4 %???) + /// %V = extractvalue { i4, i1 } %Agg, 1 + auto matchMulOverflowCheck = [X, &II, &XIdx](Value *V) { + auto *Extract = dyn_cast(V); + // We should only be extracting the overflow bit. + if (!Extract || !Extract->getIndices().equals(1)) + return false; + + II = dyn_cast(Extract->getAggregateOperand()); + if (!match(II, m_CombineOr(m_Intrinsic(), + m_Intrinsic()))) + return false; + + if (II->getArgOperand(0) == X) + XIdx = 0; + else if (II->getArgOperand(1) == X) + XIdx = 1; + else + return false; + return true; + }; + + bool Matched = + (IsAnd && Pred == ICmpInst::Predicate::ICMP_NE && + matchMulOverflowCheck(Op1)) || + (!IsAnd && Pred == ICmpInst::Predicate::ICMP_EQ && + match(Op1, m_Not(m_Value(NotOp1))) && matchMulOverflowCheck(NotOp1)); + + if (!Matched) + return false; + + Y = &II->getArgOperandUse(!XIdx); + return true; +} + +bool llvm::isCheckForZeroAndMulWithOverflow(Value *Op0, Value *Op1, + bool IsAnd) { + Use *Y; + return isCheckForZeroAndMulWithOverflow(Op0, Op1, IsAnd, Y); +} \ No newline at end of file diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 74d92aa..cb791e5 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -18,6 +18,7 @@ #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CmpInstAnalysis.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/OverflowInstAnalysis.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" @@ -2697,6 +2698,18 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { if (Value *S = SimplifyWithOpReplaced(FalseVal, CondVal, Zero, SQ, /* AllowRefinement */ true)) return replaceOperand(SI, 2, S); + + if (match(FalseVal, m_Zero()) || match(TrueVal, m_One())) { + Use *Y = nullptr; + bool IsAnd = match(FalseVal, m_Zero()) ? true : false; + Value *Op1 = IsAnd ? TrueVal : FalseVal; + if (isCheckForZeroAndMulWithOverflow(CondVal, Op1, IsAnd, Y)) { + auto *FI = new FreezeInst(*Y, (*Y)->getName() + ".fr"); + InsertNewInstBefore(FI, *cast(Y->getUser())); + replaceUse(*Y, FI); + return replaceInstUsesWith(SI, Op1); + } + } } // Selecting between two integer or vector splat integer constants? diff --git a/llvm/test/Transforms/InstCombine/div-by-0-guard-before-smul_ov-not.ll b/llvm/test/Transforms/InstCombine/div-by-0-guard-before-smul_ov-not.ll index f5f93b0..170b072 100644 --- a/llvm/test/Transforms/InstCombine/div-by-0-guard-before-smul_ov-not.ll +++ b/llvm/test/Transforms/InstCombine/div-by-0-guard-before-smul_ov-not.ll @@ -5,12 +5,11 @@ declare { i4, i1 } @llvm.smul.with.overflow.i4(i4, i4) #1 define i1 @t0_umul(i4 %size, i4 %nmemb) { ; CHECK-LABEL: @t0_umul( -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i4 [[SIZE:%.*]], 0 -; CHECK-NEXT: [[SMUL:%.*]] = tail call { i4, i1 } @llvm.smul.with.overflow.i4(i4 [[SIZE]], i4 [[NMEMB:%.*]]) +; CHECK-NEXT: [[NMEMB_FR:%.*]] = freeze i4 [[NMEMB:%.*]] +; CHECK-NEXT: [[SMUL:%.*]] = tail call { i4, i1 } @llvm.smul.with.overflow.i4(i4 [[SIZE:%.*]], i4 [[NMEMB_FR]]) ; CHECK-NEXT: [[SMUL_OV:%.*]] = extractvalue { i4, i1 } [[SMUL]], 1 ; CHECK-NEXT: [[PHITMP:%.*]] = xor i1 [[SMUL_OV]], true -; CHECK-NEXT: [[OR:%.*]] = select i1 [[CMP]], i1 true, i1 [[PHITMP]] -; CHECK-NEXT: ret i1 [[OR]] +; CHECK-NEXT: ret i1 [[PHITMP]] ; %cmp = icmp eq i4 %size, 0 %smul = tail call { i4, i1 } @llvm.smul.with.overflow.i4(i4 %size, i4 %nmemb) diff --git a/llvm/test/Transforms/InstCombine/div-by-0-guard-before-smul_ov.ll b/llvm/test/Transforms/InstCombine/div-by-0-guard-before-smul_ov.ll index c2fb2b7..7d4744c 100644 --- a/llvm/test/Transforms/InstCombine/div-by-0-guard-before-smul_ov.ll +++ b/llvm/test/Transforms/InstCombine/div-by-0-guard-before-smul_ov.ll @@ -20,11 +20,10 @@ define i1 @t0_smul(i4 %size, i4 %nmemb) { define i1 @t1_commutative(i4 %size, i4 %nmemb) { ; CHECK-LABEL: @t1_commutative( -; CHECK-NEXT: [[CMP:%.*]] = icmp ne i4 [[SIZE:%.*]], 0 -; CHECK-NEXT: [[SMUL:%.*]] = tail call { i4, i1 } @llvm.smul.with.overflow.i4(i4 [[SIZE]], i4 [[NMEMB:%.*]]) +; CHECK-NEXT: [[NMEMB_FR:%.*]] = freeze i4 [[NMEMB:%.*]] +; CHECK-NEXT: [[SMUL:%.*]] = tail call { i4, i1 } @llvm.smul.with.overflow.i4(i4 [[SIZE:%.*]], i4 [[NMEMB_FR]]) ; CHECK-NEXT: [[SMUL_OV:%.*]] = extractvalue { i4, i1 } [[SMUL]], 1 -; CHECK-NEXT: [[AND:%.*]] = select i1 [[CMP]], i1 [[SMUL_OV]], i1 false -; CHECK-NEXT: ret i1 [[AND]] +; CHECK-NEXT: ret i1 [[SMUL_OV]] ; %cmp = icmp ne i4 %size, 0 %smul = tail call { i4, i1 } @llvm.smul.with.overflow.i4(i4 %size, i4 %nmemb) diff --git a/llvm/test/Transforms/InstCombine/div-by-0-guard-before-umul_ov-not.ll b/llvm/test/Transforms/InstCombine/div-by-0-guard-before-umul_ov-not.ll index 5da6e9b..8468aa7 100644 --- a/llvm/test/Transforms/InstCombine/div-by-0-guard-before-umul_ov-not.ll +++ b/llvm/test/Transforms/InstCombine/div-by-0-guard-before-umul_ov-not.ll @@ -5,12 +5,11 @@ declare { i4, i1 } @llvm.umul.with.overflow.i4(i4, i4) #1 define i1 @t0_umul(i4 %size, i4 %nmemb) { ; CHECK-LABEL: @t0_umul( -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i4 [[SIZE:%.*]], 0 -; CHECK-NEXT: [[UMUL:%.*]] = tail call { i4, i1 } @llvm.umul.with.overflow.i4(i4 [[SIZE]], i4 [[NMEMB:%.*]]) +; CHECK-NEXT: [[NMEMB_FR:%.*]] = freeze i4 [[NMEMB:%.*]] +; CHECK-NEXT: [[UMUL:%.*]] = tail call { i4, i1 } @llvm.umul.with.overflow.i4(i4 [[SIZE:%.*]], i4 [[NMEMB_FR]]) ; CHECK-NEXT: [[UMUL_OV:%.*]] = extractvalue { i4, i1 } [[UMUL]], 1 ; CHECK-NEXT: [[PHITMP:%.*]] = xor i1 [[UMUL_OV]], true -; CHECK-NEXT: [[OR:%.*]] = select i1 [[CMP]], i1 true, i1 [[PHITMP]] -; CHECK-NEXT: ret i1 [[OR]] +; CHECK-NEXT: ret i1 [[PHITMP]] ; %cmp = icmp eq i4 %size, 0 %umul = tail call { i4, i1 } @llvm.umul.with.overflow.i4(i4 %size, i4 %nmemb) diff --git a/llvm/test/Transforms/InstCombine/div-by-0-guard-before-umul_ov.ll b/llvm/test/Transforms/InstCombine/div-by-0-guard-before-umul_ov.ll index 55c5014..19d33f3 100644 --- a/llvm/test/Transforms/InstCombine/div-by-0-guard-before-umul_ov.ll +++ b/llvm/test/Transforms/InstCombine/div-by-0-guard-before-umul_ov.ll @@ -20,11 +20,10 @@ define i1 @t0_umul(i4 %size, i4 %nmemb) { define i1 @t1_commutative(i4 %size, i4 %nmemb) { ; CHECK-LABEL: @t1_commutative( -; CHECK-NEXT: [[CMP:%.*]] = icmp ne i4 [[SIZE:%.*]], 0 -; CHECK-NEXT: [[UMUL:%.*]] = tail call { i4, i1 } @llvm.umul.with.overflow.i4(i4 [[SIZE]], i4 [[NMEMB:%.*]]) +; CHECK-NEXT: [[NMEMB_FR:%.*]] = freeze i4 [[NMEMB:%.*]] +; CHECK-NEXT: [[UMUL:%.*]] = tail call { i4, i1 } @llvm.umul.with.overflow.i4(i4 [[SIZE:%.*]], i4 [[NMEMB_FR]]) ; CHECK-NEXT: [[UMUL_OV:%.*]] = extractvalue { i4, i1 } [[UMUL]], 1 -; CHECK-NEXT: [[AND:%.*]] = select i1 [[CMP]], i1 [[UMUL_OV]], i1 false -; CHECK-NEXT: ret i1 [[AND]] +; CHECK-NEXT: ret i1 [[UMUL_OV]] ; %cmp = icmp ne i4 %size, 0 %umul = tail call { i4, i1 } @llvm.umul.with.overflow.i4(i4 %size, i4 %nmemb) -- 2.7.4