From bfe2f5d38bb14bf7ce4f44d3de558fbc076bdc1a Mon Sep 17 00:00:00 2001 From: Noah Goldstein Date: Tue, 18 Apr 2023 16:34:02 -0500 Subject: [PATCH] [InstCombine] Fix buggy `(mul X, Y)` -> `(shl X, Log2(Y))` transform PR62175 Bug was because we recognized patterns like `(shl 4, Z)` as a power of 2 we could take Log2 of (`2 + Z`), but doing `(shl X, (2 + Z))` can cause a poison shift. https://alive2.llvm.org/ce/z/yuJm_k The fix is to verify that `Log2(Y)` will be a non-poisonous shift amount. We can do this with: `nsw` flag: - https://alive2.llvm.org/ce/z/yyyJBr - https://alive2.llvm.org/ce/z/YgubD_ `nuw` flag: - https://alive2.llvm.org/ce/z/-4mpyV - https://alive2.llvm.org/ce/z/a6ik6r Prove `Y != 0`: - https://alive2.llvm.org/ce/z/ced4su - https://alive2.llvm.org/ce/z/X-JJHb Reviewed By: nikic Differential Revision: https://reviews.llvm.org/D148609 --- .../InstCombine/InstCombineMulDivRem.cpp | 55 ++++++++++++++-------- llvm/test/Transforms/InstCombine/div-shift.ll | 25 ++++++++++ llvm/test/Transforms/InstCombine/mul-pow2.ll | 34 +++++++++++++ 3 files changed, 95 insertions(+), 19 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index 9d0e171..19f2d2f 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -186,7 +186,7 @@ static Value *foldMulShl1(BinaryOperator &Mul, bool CommuteOperands, } static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth, - bool DoFold); + bool AssumeNonZero, bool DoFold); Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); @@ -486,15 +486,19 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { // (shl Op1, Log2(Op0)) // if Log2(Op1) folds away -> // (shl Op0, Log2(Op1)) - if (takeLog2(Builder, Op0, /*Depth*/ 0, /*DoFold*/ false)) { - Value *Res = takeLog2(Builder, Op0, /*Depth*/ 0, /*DoFold*/ true); + if (takeLog2(Builder, Op0, /*Depth*/ 0, /*AssumeNonZero*/ false, + /*DoFold*/ false)) { + Value *Res = takeLog2(Builder, Op0, /*Depth*/ 0, /*AssumeNonZero*/ false, + /*DoFold*/ true); BinaryOperator *Shl = BinaryOperator::CreateShl(Op1, Res); // We can only propegate nuw flag. Shl->setHasNoUnsignedWrap(HasNUW); return Shl; } - if (takeLog2(Builder, Op1, /*Depth*/ 0, /*DoFold*/ false)) { - Value *Res = takeLog2(Builder, Op1, /*Depth*/ 0, /*DoFold*/ true); + if (takeLog2(Builder, Op1, /*Depth*/ 0, /*AssumeNonZero*/ false, + /*DoFold*/ false)) { + Value *Res = takeLog2(Builder, Op1, /*Depth*/ 0, /*AssumeNonZero*/ false, + /*DoFold*/ true); BinaryOperator *Shl = BinaryOperator::CreateShl(Op0, Res); // We can only propegate nuw flag. Shl->setHasNoUnsignedWrap(HasNUW); @@ -1181,7 +1185,7 @@ static const unsigned MaxDepth = 6; // actual instructions, otherwise return a non-null dummy value. Return nullptr // on failure. static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth, - bool DoFold) { + bool AssumeNonZero, bool DoFold) { auto IfFold = [DoFold](function_ref Fn) { if (!DoFold) return reinterpret_cast(-1); @@ -1207,14 +1211,18 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth, // FIXME: Require one use? Value *X, *Y; if (match(Op, m_ZExt(m_Value(X)))) - if (Value *LogX = takeLog2(Builder, X, Depth, DoFold)) + if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold)) return IfFold([&]() { return Builder.CreateZExt(LogX, Op->getType()); }); // log2(X << Y) -> log2(X) + Y // FIXME: Require one use unless X is 1? - if (match(Op, m_Shl(m_Value(X), m_Value(Y)))) - if (Value *LogX = takeLog2(Builder, X, Depth, DoFold)) - return IfFold([&]() { return Builder.CreateAdd(LogX, Y); }); + if (match(Op, m_Shl(m_Value(X), m_Value(Y)))) { + auto *BO = cast(Op); + // nuw will be set if the `shl` is trivially non-zero. + if (AssumeNonZero || BO->hasNoUnsignedWrap() || BO->hasNoSignedWrap()) + if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold)) + return IfFold([&]() { return Builder.CreateAdd(LogX, Y); }); + } // log2(Cond ? X : Y) -> Cond ? log2(X) : log2(Y) // FIXME: missed optimization: if one of the hands of select is/contains @@ -1222,8 +1230,10 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth, // FIXME: can both hands contain undef? // FIXME: Require one use? if (SelectInst *SI = dyn_cast(Op)) - if (Value *LogX = takeLog2(Builder, SI->getOperand(1), Depth, DoFold)) - if (Value *LogY = takeLog2(Builder, SI->getOperand(2), Depth, DoFold)) + if (Value *LogX = takeLog2(Builder, SI->getOperand(1), Depth, + AssumeNonZero, DoFold)) + if (Value *LogY = takeLog2(Builder, SI->getOperand(2), Depth, + AssumeNonZero, DoFold)) return IfFold([&]() { return Builder.CreateSelect(SI->getOperand(0), LogX, LogY); }); @@ -1231,13 +1241,18 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth, // log2(umin(X, Y)) -> umin(log2(X), log2(Y)) // log2(umax(X, Y)) -> umax(log2(X), log2(Y)) auto *MinMax = dyn_cast(Op); - if (MinMax && MinMax->hasOneUse() && !MinMax->isSigned()) - if (Value *LogX = takeLog2(Builder, MinMax->getLHS(), Depth, DoFold)) - if (Value *LogY = takeLog2(Builder, MinMax->getRHS(), Depth, DoFold)) + if (MinMax && MinMax->hasOneUse() && !MinMax->isSigned()) { + // Use AssumeNonZero as false here. Otherwise we can hit case where + // log2(umax(X, Y)) != umax(log2(X), log2(Y)) (because overflow). + if (Value *LogX = takeLog2(Builder, MinMax->getLHS(), Depth, + /*AssumeNonZero*/ false, DoFold)) + if (Value *LogY = takeLog2(Builder, MinMax->getRHS(), Depth, + /*AssumeNonZero*/ false, DoFold)) return IfFold([&]() { - return Builder.CreateBinaryIntrinsic( - MinMax->getIntrinsicID(), LogX, LogY); + return Builder.CreateBinaryIntrinsic(MinMax->getIntrinsicID(), LogX, + LogY); }); + } return nullptr; } @@ -1357,8 +1372,10 @@ Instruction *InstCombinerImpl::visitUDiv(BinaryOperator &I) { } // Op1 udiv Op2 -> Op1 lshr log2(Op2), if log2() folds away. - if (takeLog2(Builder, Op1, /*Depth*/0, /*DoFold*/false)) { - Value *Res = takeLog2(Builder, Op1, /*Depth*/0, /*DoFold*/true); + if (takeLog2(Builder, Op1, /*Depth*/ 0, /*AssumeNonZero*/ true, + /*DoFold*/ false)) { + Value *Res = takeLog2(Builder, Op1, /*Depth*/ 0, + /*AssumeNonZero*/ true, /*DoFold*/ true); return replaceInstUsesWith( I, Builder.CreateLShr(Op0, Res, I.getName(), I.isExact())); } diff --git a/llvm/test/Transforms/InstCombine/div-shift.ll b/llvm/test/Transforms/InstCombine/div-shift.ll index efdaa0c..76c5328 100644 --- a/llvm/test/Transforms/InstCombine/div-shift.ll +++ b/llvm/test/Transforms/InstCombine/div-shift.ll @@ -1000,3 +1000,28 @@ define i8 @udiv_shl_nuw_divisor(i8 %x, i8 %y, i8 %z) { %d = udiv i8 %x, %s ret i8 %d } + +define i8 @udiv_fail_shl_overflow(i8 %x, i8 %y) { +; CHECK-LABEL: @udiv_fail_shl_overflow( +; CHECK-NEXT: [[SHL:%.*]] = shl i8 2, [[Y:%.*]] +; CHECK-NEXT: [[MIN:%.*]] = call i8 @llvm.umax.i8(i8 [[SHL]], i8 1) +; CHECK-NEXT: [[MUL:%.*]] = udiv i8 [[X:%.*]], [[MIN]] +; CHECK-NEXT: ret i8 [[MUL]] +; + %shl = shl i8 2, %y + %min = call i8 @llvm.umax.i8(i8 %shl, i8 1) + %mul = udiv i8 %x, %min + ret i8 %mul +} + +define i8 @udiv_shl_no_overflow(i8 %x, i8 %y) { +; CHECK-LABEL: @udiv_shl_no_overflow( +; CHECK-NEXT: [[TMP1:%.*]] = add i8 [[Y:%.*]], 1 +; CHECK-NEXT: [[MUL1:%.*]] = lshr i8 [[X:%.*]], [[TMP1]] +; CHECK-NEXT: ret i8 [[MUL1]] +; + %shl = shl nuw i8 2, %y + %min = call i8 @llvm.umax.i8(i8 %shl, i8 1) + %mul = udiv i8 %x, %min + ret i8 %mul +} diff --git a/llvm/test/Transforms/InstCombine/mul-pow2.ll b/llvm/test/Transforms/InstCombine/mul-pow2.ll index 5617c74..c16fd71 100644 --- a/llvm/test/Transforms/InstCombine/mul-pow2.ll +++ b/llvm/test/Transforms/InstCombine/mul-pow2.ll @@ -102,3 +102,37 @@ define <2 x i8> @mul_x_selectp2_vec(<2 x i8> %xx, i1 %c) { %r = mul <2 x i8> %x, %s ret <2 x i8> %r } + + +define i8 @shl_add_log_may_cause_poison_pr62175_fail(i8 %x, i8 %y) { +; CHECK-LABEL: @shl_add_log_may_cause_poison_pr62175_fail( +; CHECK-NEXT: [[SHL:%.*]] = shl i8 4, [[X:%.*]] +; CHECK-NEXT: [[MUL:%.*]] = mul i8 [[SHL]], [[Y:%.*]] +; CHECK-NEXT: ret i8 [[MUL]] +; + %shl = shl i8 4, %x + %mul = mul i8 %y, %shl + ret i8 %mul +} + +define i8 @shl_add_log_may_cause_poison_pr62175_with_nuw(i8 %x, i8 %y) { +; CHECK-LABEL: @shl_add_log_may_cause_poison_pr62175_with_nuw( +; CHECK-NEXT: [[TMP1:%.*]] = add i8 [[X:%.*]], 2 +; CHECK-NEXT: [[MUL:%.*]] = shl i8 [[Y:%.*]], [[TMP1]] +; CHECK-NEXT: ret i8 [[MUL]] +; + %shl = shl nuw i8 4, %x + %mul = mul i8 %y, %shl + ret i8 %mul +} + +define i8 @shl_add_log_may_cause_poison_pr62175_with_nsw(i8 %x, i8 %y) { +; CHECK-LABEL: @shl_add_log_may_cause_poison_pr62175_with_nsw( +; CHECK-NEXT: [[TMP1:%.*]] = add i8 [[X:%.*]], 2 +; CHECK-NEXT: [[MUL:%.*]] = shl i8 [[Y:%.*]], [[TMP1]] +; CHECK-NEXT: ret i8 [[MUL]] +; + %shl = shl nsw i8 4, %x + %mul = mul i8 %y, %shl + ret i8 %mul +} -- 2.7.4