[AggressiveInstCombine] convert sqrt libcalls with "nnan" to sqrt intrinsics
authorSanjay Patel <spatel@rotateright.com>
Tue, 26 Jul 2022 19:31:12 +0000 (15:31 -0400)
committerSanjay Patel <spatel@rotateright.com>
Tue, 26 Jul 2022 19:50:14 +0000 (15:50 -0400)
This is an alternate to D129155 that uses TTI.haveFastSqrt() to avoid a
potential miscompile for programs with reads of errno. Moving the transform
to AggressiveInstCombine provides access to TTI.

If a sqrt call has "nnan", that implies that the input argument is never
negative because sqrt of {negative number} --> NAN.
If the argument is never negative and the call can be lowered without a
libcall, then we can assume that errno accesses are unchanged after lowering,
so the call can be translated to the LLVM intrinsic (which is expected to
become inline code).

This affects codegen for targets like x86 that have sqrt instructions, but
still have to conservatively assume that a libcall may be needed to set
errno as shown in issue #52620 and issue #56383.

This patch won't solve those examples - we will need to extend this to use
CannotBeOrderedLessThanZero or similar, enhance that analysis for new
operators, and/or deal with llvm.assume too.

Differential Revision: https://reviews.llvm.org/D129167

llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp
llvm/test/Transforms/AggressiveInstCombine/X86/sqrt.ll [new file with mode: 0644]

index 1fd8b88..35adaa3 100644 (file)
@@ -31,6 +31,7 @@
 #include "llvm/IR/PatternMatch.h"
 #include "llvm/InitializePasses.h"
 #include "llvm/Pass.h"
+#include "llvm/Transforms/Utils/BuildLibCalls.h"
 #include "llvm/Transforms/Utils/Local.h"
 
 using namespace llvm;
@@ -427,27 +428,73 @@ static bool tryToFPToSat(Instruction &I, TargetTransformInfo &TTI) {
   return true;
 }
 
+/// Try to replace a mathlib call to sqrt with the LLVM intrinsic. This avoids
+/// pessimistic codegen that has to account for setting errno and can enable
+/// vectorization.
+static bool
+foldSqrt(Instruction &I, TargetTransformInfo &TTI, TargetLibraryInfo &TLI) {
+  // Match a call to sqrt mathlib function.
+  auto *Call = dyn_cast<CallInst>(&I);
+  if (!Call)
+    return false;
+
+  Module *M = Call->getModule();
+  LibFunc Func;
+  if (!TLI.getLibFunc(*Call, Func) || !isLibFuncEmittable(M, &TLI, Func))
+    return false;
+
+  if (Func != LibFunc_sqrt && Func != LibFunc_sqrtf && Func != LibFunc_sqrtl)
+    return false;
+
+  // If (1) this is a sqrt libcall, (2) we can assume that NAN is not created,
+  // and (3) we would not end up lowering to a libcall anyway (which could
+  // change the value of errno), then:
+  // (1) the operand arg must not be less than -0.0.
+  // (2) errno won't be set.
+  // (3) it is safe to convert this to an intrinsic call.
+  // TODO: Check if the arg is known non-negative.
+  Type *Ty = Call->getType();
+  if (TTI.haveFastSqrt(Ty) && Call->hasNoNaNs()) {
+    IRBuilder<> Builder(&I);
+    IRBuilderBase::FastMathFlagGuard Guard(Builder);
+    Builder.setFastMathFlags(Call->getFastMathFlags());
+
+    Function *Sqrt = Intrinsic::getDeclaration(M, Intrinsic::sqrt, Ty);
+    Value *NewSqrt = Builder.CreateCall(Sqrt, Call->getArgOperand(0), "sqrt");
+    I.replaceAllUsesWith(NewSqrt);
+
+    // Explicitly erase the old call because a call with side effects is not
+    // trivially dead.
+    I.eraseFromParent();
+    return true;
+  }
+
+  return false;
+}
+
 /// This is the entry point for folds that could be implemented in regular
 /// InstCombine, but they are separated because they are not expected to
 /// occur frequently and/or have more than a constant-length pattern match.
 static bool foldUnusualPatterns(Function &F, DominatorTree &DT,
-                                TargetTransformInfo &TTI) {
+                                TargetTransformInfo &TTI,
+                                TargetLibraryInfo &TLI) {
   bool MadeChange = false;
   for (BasicBlock &BB : F) {
     // Ignore unreachable basic blocks.
     if (!DT.isReachableFromEntry(&BB))
       continue;
-    // Do not delete instructions under here and invalidate the iterator.
+
     // Walk the block backwards for efficiency. We're matching a chain of
     // use->defs, so we're more likely to succeed by starting from the bottom.
     // Also, we want to avoid matching partial patterns.
     // TODO: It would be more efficient if we removed dead instructions
     // iteratively in this loop rather than waiting until the end.
-    for (Instruction &I : llvm::reverse(BB)) {
+    for (Instruction &I : make_early_inc_range(llvm::reverse(BB))) {
       MadeChange |= foldAnyOrAllBitsSet(I);
       MadeChange |= foldGuardedFunnelShift(I, DT);
       MadeChange |= tryToRecognizePopCount(I);
       MadeChange |= tryToFPToSat(I, TTI);
+      MadeChange |= foldSqrt(I, TTI, TLI);
     }
   }
 
@@ -467,7 +514,7 @@ static bool runImpl(Function &F, AssumptionCache &AC, TargetTransformInfo &TTI,
   const DataLayout &DL = F.getParent()->getDataLayout();
   TruncInstCombine TIC(AC, TLI, DL, DT);
   MadeChange |= TIC.run(F);
-  MadeChange |= foldUnusualPatterns(F, DT, TTI);
+  MadeChange |= foldUnusualPatterns(F, DT, TTI, TLI);
   return MadeChange;
 }
 
diff --git a/llvm/test/Transforms/AggressiveInstCombine/X86/sqrt.ll b/llvm/test/Transforms/AggressiveInstCombine/X86/sqrt.ll
new file mode 100644 (file)
index 0000000..75dfe0b
--- /dev/null
@@ -0,0 +1,53 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -passes=aggressive-instcombine -mtriple x86_64-- -S | FileCheck %s
+
+declare float @sqrtf(float)
+declare double @sqrt(double)
+declare fp128 @sqrtl(fp128)
+
+; "nnan" implies no setting of errno and the target can lower this to an
+; instruction, so transform to an intrinsic.
+
+define float @sqrt_call_nnan_f32(float %x) {
+; CHECK-LABEL: @sqrt_call_nnan_f32(
+; CHECK-NEXT:    [[SQRT1:%.*]] = call nnan float @llvm.sqrt.f32(float [[X:%.*]])
+; CHECK-NEXT:    ret float [[SQRT1]]
+;
+  %sqrt = call nnan float @sqrtf(float %x)
+  ret float %sqrt
+}
+
+; Verify that other FMF are propagated to the intrinsic call.
+; We don't care about propagating 'tail' because this is not going to be a lowered as a call.
+
+define double @sqrt_call_nnan_f64(double %x) {
+; CHECK-LABEL: @sqrt_call_nnan_f64(
+; CHECK-NEXT:    [[SQRT1:%.*]] = call nnan ninf double @llvm.sqrt.f64(double [[X:%.*]])
+; CHECK-NEXT:    ret double [[SQRT1]]
+;
+  %sqrt = tail call nnan ninf double @sqrt(double %x)
+  ret double %sqrt
+}
+
+; We don't change this because it will be lowered to a call that could
+; theoretically still change errno and affect other accessors of errno.
+
+define fp128 @sqrt_call_nnan_f128(fp128 %x) {
+; CHECK-LABEL: @sqrt_call_nnan_f128(
+; CHECK-NEXT:    [[SQRT:%.*]] = call nnan fp128 @sqrtl(fp128 [[X:%.*]])
+; CHECK-NEXT:    ret fp128 [[SQRT]]
+;
+  %sqrt = call nnan fp128 @sqrtl(fp128 %x)
+  ret fp128 %sqrt
+}
+
+; Don't alter a no-builtin libcall.
+
+define float @sqrt_call_nnan_f32_nobuiltin(float %x) {
+; CHECK-LABEL: @sqrt_call_nnan_f32_nobuiltin(
+; CHECK-NEXT:    [[SQRT:%.*]] = call nnan float @sqrtf(float [[X:%.*]]) #[[ATTR1:[0-9]+]]
+; CHECK-NEXT:    ret float [[SQRT]]
+;
+  %sqrt = call nnan float @sqrtf(float %x) nobuiltin
+  ret float %sqrt
+}