[IRBuilder] Migrate all binops to folding API
authorNikita Popov <npopov@redhat.com>
Thu, 30 Jun 2022 10:52:31 +0000 (12:52 +0200)
committerNikita Popov <npopov@redhat.com>
Thu, 30 Jun 2022 14:41:17 +0000 (16:41 +0200)
Migrate all binops to use FoldXYZ rather than CreateXYZ APIs,
which are compatible with InstSimplifyFolder and fallible constant
folding.

Rather than continuing to add one method for every single operator,
add a generic FoldBinOp (plus variants for nowrap, exact and fmf
operators), which we would need anyway for CreateBinaryOp.

This change is not NFC because IRBuilder with InstSimplifyFolder
may perform more folding. However, this patch changes SCEVExpander
to not use the folder in InsertBinOp to minimize practical impact
and keep this change as close to NFC as possible.

llvm/include/llvm/Analysis/InstSimplifyFolder.h
llvm/include/llvm/Analysis/TargetFolder.h
llvm/include/llvm/IR/ConstantFolder.h
llvm/include/llvm/IR/IRBuilder.h
llvm/include/llvm/IR/IRBuilderFolder.h
llvm/include/llvm/IR/NoFolder.h
llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp

index a67424f..33b6f12 100644 (file)
@@ -46,33 +46,25 @@ public:
   // Return an existing value or a constant if the operation can be simplified.
   // Otherwise return nullptr.
   //===--------------------------------------------------------------------===//
-  Value *FoldAdd(Value *LHS, Value *RHS, bool HasNUW = false,
-                 bool HasNSW = false) const override {
-    return simplifyAddInst(LHS, RHS, HasNUW, HasNSW, SQ);
-  }
-
-  Value *FoldAnd(Value *LHS, Value *RHS) const override {
-    return simplifyAndInst(LHS, RHS, SQ);
-  }
 
-  Value *FoldOr(Value *LHS, Value *RHS) const override {
-    return simplifyOrInst(LHS, RHS, SQ);
+  Value *FoldBinOp(Instruction::BinaryOps Opc, Value *LHS,
+                   Value *RHS) const override {
+    return simplifyBinOp(Opc, LHS, RHS, SQ);
   }
 
-  Value *FoldUDiv(Value *LHS, Value *RHS, bool IsExact) const override {
-    return simplifyUDivInst(LHS, RHS, SQ);
+  Value *FoldExactBinOp(Instruction::BinaryOps Opc, Value *LHS, Value *RHS,
+                        bool IsExact) const override {
+    return simplifyBinOp(Opc, LHS, RHS, SQ);
   }
 
-  Value *FoldSDiv(Value *LHS, Value *RHS, bool IsExact) const override {
-    return simplifySDivInst(LHS, RHS, SQ);
+  Value *FoldNoWrapBinOp(Instruction::BinaryOps Opc, Value *LHS, Value *RHS,
+                         bool HasNUW, bool HasNSW) const override {
+    return simplifyBinOp(Opc, LHS, RHS, SQ);
   }
 
-  Value *FoldURem(Value *LHS, Value *RHS) const override {
-    return simplifyURemInst(LHS, RHS, SQ);
-  }
-
-  Value *FoldSRem(Value *LHS, Value *RHS) const override {
-    return simplifySRemInst(LHS, RHS, SQ);
+  Value *FoldBinOpFMF(Instruction::BinaryOps Opc, Value *LHS, Value *RHS,
+                      FastMathFlags FMF) const override {
+    return simplifyBinOp(Opc, LHS, RHS, FMF, SQ);
   }
 
   Value *FoldICmp(CmpInst::Predicate P, Value *LHS, Value *RHS) const override {
@@ -116,54 +108,6 @@ public:
   }
 
   //===--------------------------------------------------------------------===//
-  // Binary Operators
-  //===--------------------------------------------------------------------===//
-
-  Value *CreateFAdd(Constant *LHS, Constant *RHS) const override {
-    return ConstFolder.CreateFAdd(LHS, RHS);
-  }
-  Value *CreateSub(Constant *LHS, Constant *RHS, bool HasNUW = false,
-                   bool HasNSW = false) const override {
-    return ConstFolder.CreateSub(LHS, RHS, HasNUW, HasNSW);
-  }
-  Value *CreateFSub(Constant *LHS, Constant *RHS) const override {
-    return ConstFolder.CreateFSub(LHS, RHS);
-  }
-  Value *CreateMul(Constant *LHS, Constant *RHS, bool HasNUW = false,
-                   bool HasNSW = false) const override {
-    return ConstFolder.CreateMul(LHS, RHS, HasNUW, HasNSW);
-  }
-  Value *CreateFMul(Constant *LHS, Constant *RHS) const override {
-    return ConstFolder.CreateFMul(LHS, RHS);
-  }
-  Value *CreateFDiv(Constant *LHS, Constant *RHS) const override {
-    return ConstFolder.CreateFDiv(LHS, RHS);
-  }
-  Value *CreateFRem(Constant *LHS, Constant *RHS) const override {
-    return ConstFolder.CreateFRem(LHS, RHS);
-  }
-  Value *CreateShl(Constant *LHS, Constant *RHS, bool HasNUW = false,
-                   bool HasNSW = false) const override {
-    return ConstFolder.CreateShl(LHS, RHS, HasNUW, HasNSW);
-  }
-  Value *CreateLShr(Constant *LHS, Constant *RHS,
-                    bool isExact = false) const override {
-    return ConstFolder.CreateLShr(LHS, RHS, isExact);
-  }
-  Value *CreateAShr(Constant *LHS, Constant *RHS,
-                    bool isExact = false) const override {
-    return ConstFolder.CreateAShr(LHS, RHS, isExact);
-  }
-  Value *CreateXor(Constant *LHS, Constant *RHS) const override {
-    return ConstFolder.CreateXor(LHS, RHS);
-  }
-
-  Value *CreateBinOp(Instruction::BinaryOps Opc, Constant *LHS,
-                     Constant *RHS) const override {
-    return ConstFolder.CreateBinOp(Opc, LHS, RHS);
-  }
-
-  //===--------------------------------------------------------------------===//
   // Unary Operators
   //===--------------------------------------------------------------------===//
 
index 1187e9c..a360be5 100644 (file)
@@ -22,6 +22,7 @@
 #include "llvm/Analysis/ConstantFolding.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/IRBuilderFolder.h"
+#include "llvm/IR/Operator.h"
 
 namespace llvm {
 
@@ -49,63 +50,45 @@ public:
   // Return an existing value or a constant if the operation can be simplified.
   // Otherwise return nullptr.
   //===--------------------------------------------------------------------===//
-  Value *FoldAdd(Value *LHS, Value *RHS, bool HasNUW = false,
-                 bool HasNSW = false) const override {
-    auto *LC = dyn_cast<Constant>(LHS);
-    auto *RC = dyn_cast<Constant>(RHS);
-    if (LC && RC)
-      return Fold(ConstantExpr::getAdd(LC, RC, HasNUW, HasNSW));
-    return nullptr;
-  }
-
-  Value *FoldAnd(Value *LHS, Value *RHS) const override {
-    auto *LC = dyn_cast<Constant>(LHS);
-    auto *RC = dyn_cast<Constant>(RHS);
-    if (LC && RC)
-      return Fold(ConstantExpr::getAnd(LC, RC));
-    return nullptr;
-  }
 
-  Value *FoldOr(Value *LHS, Value *RHS) const override {
+  Value *FoldBinOp(Instruction::BinaryOps Opc, Value *LHS,
+                   Value *RHS) const override {
     auto *LC = dyn_cast<Constant>(LHS);
     auto *RC = dyn_cast<Constant>(RHS);
     if (LC && RC)
-      return Fold(ConstantExpr::getOr(LC, RC));
+      return Fold(ConstantExpr::get(Opc, LC, RC));
     return nullptr;
   }
 
-  Value *FoldUDiv(Value *LHS, Value *RHS, bool IsExact) const override {
+  Value *FoldExactBinOp(Instruction::BinaryOps Opc, Value *LHS, Value *RHS,
+                        bool IsExact) const override {
     auto *LC = dyn_cast<Constant>(LHS);
     auto *RC = dyn_cast<Constant>(RHS);
     if (LC && RC)
-      return Fold(ConstantExpr::getUDiv(LC, RC, IsExact));
+      return Fold(ConstantExpr::get(
+          Opc, LC, RC, IsExact ? PossiblyExactOperator::IsExact : 0));
     return nullptr;
   }
 
-  Value *FoldSDiv(Value *LHS, Value *RHS, bool IsExact) const override {
+  Value *FoldNoWrapBinOp(Instruction::BinaryOps Opc, Value *LHS, Value *RHS,
+                         bool HasNUW, bool HasNSW) const override {
     auto *LC = dyn_cast<Constant>(LHS);
     auto *RC = dyn_cast<Constant>(RHS);
-    if (LC && RC)
-      return Fold(ConstantExpr::getSDiv(LC, RC, IsExact));
-    return nullptr;
-  }
-
-  Value *FoldURem(Value *LHS, Value *RHS) const override {
-    auto *LC = dyn_cast<Constant>(LHS);
-    auto *RC = dyn_cast<Constant>(RHS);
-    if (LC && RC)
-      return Fold(ConstantExpr::getURem(LC, RC));
+    if (LC && RC) {
+      unsigned Flags = 0;
+      if (HasNUW)
+        Flags |= OverflowingBinaryOperator::NoUnsignedWrap;
+      if (HasNSW)
+        Flags |= OverflowingBinaryOperator::NoSignedWrap;
+      return Fold(ConstantExpr::get(Opc, LC, RC, Flags));
+    }
     return nullptr;
   }
 
-  Value *FoldSRem(Value *LHS, Value *RHS) const override {
-    auto *LC = dyn_cast<Constant>(LHS);
-    auto *RC = dyn_cast<Constant>(RHS);
-    if (LC && RC)
-      return Fold(ConstantExpr::getSRem(LC, RC));
-    return nullptr;
+  Value *FoldBinOpFMF(Instruction::BinaryOps Opc, Value *LHS, Value *RHS,
+                      FastMathFlags FMF) const override {
+    return FoldBinOp(Opc, LHS, RHS);
   }
-
   Value *FoldICmp(CmpInst::Predicate P, Value *LHS, Value *RHS) const override {
     auto *LC = dyn_cast<Constant>(LHS);
     auto *RC = dyn_cast<Constant>(RHS);
@@ -182,54 +165,6 @@ public:
   }
 
   //===--------------------------------------------------------------------===//
-  // Binary Operators
-  //===--------------------------------------------------------------------===//
-
-  Constant *CreateFAdd(Constant *LHS, Constant *RHS) const override {
-    return Fold(ConstantExpr::getFAdd(LHS, RHS));
-  }
-  Constant *CreateSub(Constant *LHS, Constant *RHS,
-                      bool HasNUW = false, bool HasNSW = false) const override {
-    return Fold(ConstantExpr::getSub(LHS, RHS, HasNUW, HasNSW));
-  }
-  Constant *CreateFSub(Constant *LHS, Constant *RHS) const override {
-    return Fold(ConstantExpr::getFSub(LHS, RHS));
-  }
-  Constant *CreateMul(Constant *LHS, Constant *RHS,
-                      bool HasNUW = false, bool HasNSW = false) const override {
-    return Fold(ConstantExpr::getMul(LHS, RHS, HasNUW, HasNSW));
-  }
-  Constant *CreateFMul(Constant *LHS, Constant *RHS) const override {
-    return Fold(ConstantExpr::getFMul(LHS, RHS));
-  }
-  Constant *CreateFDiv(Constant *LHS, Constant *RHS) const override {
-    return Fold(ConstantExpr::getFDiv(LHS, RHS));
-  }
-  Constant *CreateFRem(Constant *LHS, Constant *RHS) const override {
-    return Fold(ConstantExpr::getFRem(LHS, RHS));
-  }
-  Constant *CreateShl(Constant *LHS, Constant *RHS,
-                      bool HasNUW = false, bool HasNSW = false) const override {
-    return Fold(ConstantExpr::getShl(LHS, RHS, HasNUW, HasNSW));
-  }
-  Constant *CreateLShr(Constant *LHS, Constant *RHS,
-                       bool isExact = false) const override {
-    return Fold(ConstantExpr::getLShr(LHS, RHS, isExact));
-  }
-  Constant *CreateAShr(Constant *LHS, Constant *RHS,
-                       bool isExact = false) const override {
-    return Fold(ConstantExpr::getAShr(LHS, RHS, isExact));
-  }
-  Constant *CreateXor(Constant *LHS, Constant *RHS) const override {
-    return Fold(ConstantExpr::getXor(LHS, RHS));
-  }
-
-  Constant *CreateBinOp(Instruction::BinaryOps Opc,
-                        Constant *LHS, Constant *RHS) const override {
-    return Fold(ConstantExpr::get(Opc, LHS, RHS));
-  }
-
-  //===--------------------------------------------------------------------===//
   // Unary Operators
   //===--------------------------------------------------------------------===//
 
index 6f5661d..9cf68dc 100644 (file)
@@ -22,6 +22,7 @@
 #include "llvm/IR/ConstantFold.h"
 #include "llvm/IR/IRBuilderFolder.h"
 #include "llvm/IR/Instruction.h"
+#include "llvm/IR/Operator.h"
 
 namespace llvm {
 
@@ -38,61 +39,44 @@ public:
   // Return an existing value or a constant if the operation can be simplified.
   // Otherwise return nullptr.
   //===--------------------------------------------------------------------===//
-  Value *FoldAdd(Value *LHS, Value *RHS, bool HasNUW = false,
-                 bool HasNSW = false) const override {
-    auto *LC = dyn_cast<Constant>(LHS);
-    auto *RC = dyn_cast<Constant>(RHS);
-    if (LC && RC)
-      return ConstantExpr::getAdd(LC, RC, HasNUW, HasNSW);
-    return nullptr;
-  }
-
-  Value *FoldAnd(Value *LHS, Value *RHS) const override {
-    auto *LC = dyn_cast<Constant>(LHS);
-    auto *RC = dyn_cast<Constant>(RHS);
-    if (LC && RC)
-      return ConstantExpr::getAnd(LC, RC);
-    return nullptr;
-  }
-
-  Value *FoldOr(Value *LHS, Value *RHS) const override {
-    auto *LC = dyn_cast<Constant>(LHS);
-    auto *RC = dyn_cast<Constant>(RHS);
-    if (LC && RC)
-      return ConstantExpr::getOr(LC, RC);
-    return nullptr;
-  }
 
-  Value *FoldUDiv(Value *LHS, Value *RHS, bool IsExact) const override {
+  Value *FoldBinOp(Instruction::BinaryOps Opc, Value *LHS,
+                   Value *RHS) const override {
     auto *LC = dyn_cast<Constant>(LHS);
     auto *RC = dyn_cast<Constant>(RHS);
     if (LC && RC)
-      return ConstantExpr::getUDiv(LC, RC, IsExact);
+      return ConstantExpr::get(Opc, LC, RC);
     return nullptr;
   }
 
-  Value *FoldSDiv(Value *LHS, Value *RHS, bool IsExact) const override {
+  Value *FoldExactBinOp(Instruction::BinaryOps Opc, Value *LHS, Value *RHS,
+                        bool IsExact) const override {
     auto *LC = dyn_cast<Constant>(LHS);
     auto *RC = dyn_cast<Constant>(RHS);
     if (LC && RC)
-      return ConstantExpr::getSDiv(LC, RC, IsExact);
+      return ConstantExpr::get(Opc, LC, RC,
+                               IsExact ? PossiblyExactOperator::IsExact : 0);
     return nullptr;
   }
 
-  Value *FoldURem(Value *LHS, Value *RHS) const override {
+  Value *FoldNoWrapBinOp(Instruction::BinaryOps Opc, Value *LHS, Value *RHS,
+                         bool HasNUW, bool HasNSW) const override {
     auto *LC = dyn_cast<Constant>(LHS);
     auto *RC = dyn_cast<Constant>(RHS);
-    if (LC && RC)
-      return ConstantExpr::getURem(LC, RC);
+    if (LC && RC) {
+      unsigned Flags = 0;
+      if (HasNUW)
+        Flags |= OverflowingBinaryOperator::NoUnsignedWrap;
+      if (HasNSW)
+        Flags |= OverflowingBinaryOperator::NoSignedWrap;
+      return ConstantExpr::get(Opc, LC, RC, Flags);
+    }
     return nullptr;
   }
 
-  Value *FoldSRem(Value *LHS, Value *RHS) const override {
-    auto *LC = dyn_cast<Constant>(LHS);
-    auto *RC = dyn_cast<Constant>(RHS);
-    if (LC && RC)
-      return ConstantExpr::getSRem(LC, RC);
-    return nullptr;
+  Value *FoldBinOpFMF(Instruction::BinaryOps Opc, Value *LHS, Value *RHS,
+                      FastMathFlags FMF) const override {
+    return FoldBinOp(Opc, LHS, RHS);
   }
 
   Value *FoldICmp(CmpInst::Predicate P, Value *LHS, Value *RHS) const override {
@@ -171,68 +155,6 @@ public:
   }
 
   //===--------------------------------------------------------------------===//
-  // Binary Operators
-  //===--------------------------------------------------------------------===//
-
-  Constant *CreateFAdd(Constant *LHS, Constant *RHS) const override {
-    return ConstantExpr::getFAdd(LHS, RHS);
-  }
-
-  Constant *CreateSub(Constant *LHS, Constant *RHS,
-                      bool HasNUW = false, bool HasNSW = false) const override {
-    return ConstantExpr::getSub(LHS, RHS, HasNUW, HasNSW);
-  }
-
-  Constant *CreateFSub(Constant *LHS, Constant *RHS) const override {
-    return ConstantExpr::getFSub(LHS, RHS);
-  }
-
-  Constant *CreateMul(Constant *LHS, Constant *RHS,
-                      bool HasNUW = false, bool HasNSW = false) const override {
-    return ConstantExpr::getMul(LHS, RHS, HasNUW, HasNSW);
-  }
-
-  Constant *CreateFMul(Constant *LHS, Constant *RHS) const override {
-    return ConstantExpr::getFMul(LHS, RHS);
-  }
-
-  Constant *CreateFDiv(Constant *LHS, Constant *RHS) const override {
-    return ConstantExpr::getFDiv(LHS, RHS);
-  }
-
-  Constant *CreateFRem(Constant *LHS, Constant *RHS) const override {
-    return ConstantExpr::getFRem(LHS, RHS);
-  }
-
-  Constant *CreateShl(Constant *LHS, Constant *RHS,
-                      bool HasNUW = false, bool HasNSW = false) const override {
-    return ConstantExpr::getShl(LHS, RHS, HasNUW, HasNSW);
-  }
-
-  Constant *CreateLShr(Constant *LHS, Constant *RHS,
-                       bool isExact = false) const override {
-    return ConstantExpr::getLShr(LHS, RHS, isExact);
-  }
-
-  Constant *CreateAShr(Constant *LHS, Constant *RHS,
-                       bool isExact = false) const override {
-    return ConstantExpr::getAShr(LHS, RHS, isExact);
-  }
-
-  Constant *CreateOr(Constant *LHS, Constant *RHS) const {
-    return ConstantExpr::getOr(LHS, RHS);
-  }
-
-  Constant *CreateXor(Constant *LHS, Constant *RHS) const override {
-    return ConstantExpr::getXor(LHS, RHS);
-  }
-
-  Constant *CreateBinOp(Instruction::BinaryOps Opc,
-                        Constant *LHS, Constant *RHS) const override {
-    return ConstantExpr::get(Opc, LHS, RHS);
-  }
-
-  //===--------------------------------------------------------------------===//
   // Unary Operators
   //===--------------------------------------------------------------------===//
 
index 3267d99..902d945 100644 (file)
@@ -1158,13 +1158,6 @@ private:
     return I;
   }
 
-  Value *foldConstant(Instruction::BinaryOps Opc, Value *L,
-                      Value *R, const Twine &Name) const {
-    auto *LC = dyn_cast<Constant>(L);
-    auto *RC = dyn_cast<Constant>(R);
-    return (LC && RC) ? Insert(Folder.CreateBinOp(Opc, LC, RC), Name) : nullptr;
-  }
-
   Value *getConstrainedFPRounding(Optional<RoundingMode> Rounding) {
     RoundingMode UseRounding = DefaultConstrainedRounding;
 
@@ -1206,10 +1199,11 @@ private:
 public:
   Value *CreateAdd(Value *LHS, Value *RHS, const Twine &Name = "",
                    bool HasNUW = false, bool HasNSW = false) {
-    if (auto *V = Folder.FoldAdd(LHS, RHS, HasNUW, HasNSW))
+    if (Value *V =
+            Folder.FoldNoWrapBinOp(Instruction::Add, LHS, RHS, HasNUW, HasNSW))
       return V;
-    return CreateInsertNUWNSWBinOp(Instruction::Add, LHS, RHS, Name,
-                                   HasNUW, HasNSW);
+    return CreateInsertNUWNSWBinOp(Instruction::Add, LHS, RHS, Name, HasNUW,
+                                   HasNSW);
   }
 
   Value *CreateNSWAdd(Value *LHS, Value *RHS, const Twine &Name = "") {
@@ -1222,11 +1216,11 @@ public:
 
   Value *CreateSub(Value *LHS, Value *RHS, const Twine &Name = "",
                    bool HasNUW = false, bool HasNSW = false) {
-    if (auto *LC = dyn_cast<Constant>(LHS))
-      if (auto *RC = dyn_cast<Constant>(RHS))
-        return Insert(Folder.CreateSub(LC, RC, HasNUW, HasNSW), Name);
-    return CreateInsertNUWNSWBinOp(Instruction::Sub, LHS, RHS, Name,
-                                   HasNUW, HasNSW);
+    if (Value *V =
+            Folder.FoldNoWrapBinOp(Instruction::Sub, LHS, RHS, HasNUW, HasNSW))
+      return V;
+    return CreateInsertNUWNSWBinOp(Instruction::Sub, LHS, RHS, Name, HasNUW,
+                                   HasNSW);
   }
 
   Value *CreateNSWSub(Value *LHS, Value *RHS, const Twine &Name = "") {
@@ -1239,11 +1233,11 @@ public:
 
   Value *CreateMul(Value *LHS, Value *RHS, const Twine &Name = "",
                    bool HasNUW = false, bool HasNSW = false) {
-    if (auto *LC = dyn_cast<Constant>(LHS))
-      if (auto *RC = dyn_cast<Constant>(RHS))
-        return Insert(Folder.CreateMul(LC, RC, HasNUW, HasNSW), Name);
-    return CreateInsertNUWNSWBinOp(Instruction::Mul, LHS, RHS, Name,
-                                   HasNUW, HasNSW);
+    if (Value *V =
+            Folder.FoldNoWrapBinOp(Instruction::Mul, LHS, RHS, HasNUW, HasNSW))
+      return V;
+    return CreateInsertNUWNSWBinOp(Instruction::Mul, LHS, RHS, Name, HasNUW,
+                                   HasNSW);
   }
 
   Value *CreateNSWMul(Value *LHS, Value *RHS, const Twine &Name = "") {
@@ -1256,7 +1250,7 @@ public:
 
   Value *CreateUDiv(Value *LHS, Value *RHS, const Twine &Name = "",
                     bool isExact = false) {
-    if (Value *V = Folder.FoldUDiv(LHS, RHS, isExact))
+    if (Value *V = Folder.FoldExactBinOp(Instruction::UDiv, LHS, RHS, isExact))
       return V;
     if (!isExact)
       return Insert(BinaryOperator::CreateUDiv(LHS, RHS), Name);
@@ -1269,7 +1263,7 @@ public:
 
   Value *CreateSDiv(Value *LHS, Value *RHS, const Twine &Name = "",
                     bool isExact = false) {
-    if (Value *V = Folder.FoldSDiv(LHS, RHS, isExact))
+    if (Value *V = Folder.FoldExactBinOp(Instruction::SDiv, LHS, RHS, isExact))
       return V;
     if (!isExact)
       return Insert(BinaryOperator::CreateSDiv(LHS, RHS), Name);
@@ -1281,22 +1275,22 @@ public:
   }
 
   Value *CreateURem(Value *LHS, Value *RHS, const Twine &Name = "") {
-    if (Value *V = Folder.FoldURem(LHS, RHS))
+    if (Value *V = Folder.FoldBinOp(Instruction::URem, LHS, RHS))
       return V;
     return Insert(BinaryOperator::CreateURem(LHS, RHS), Name);
   }
 
   Value *CreateSRem(Value *LHS, Value *RHS, const Twine &Name = "") {
-    if (Value *V = Folder.FoldSRem(LHS, RHS))
+    if (Value *V = Folder.FoldBinOp(Instruction::SRem, LHS, RHS))
       return V;
     return Insert(BinaryOperator::CreateSRem(LHS, RHS), Name);
   }
 
   Value *CreateShl(Value *LHS, Value *RHS, const Twine &Name = "",
                    bool HasNUW = false, bool HasNSW = false) {
-    if (auto *LC = dyn_cast<Constant>(LHS))
-      if (auto *RC = dyn_cast<Constant>(RHS))
-        return Insert(Folder.CreateShl(LC, RC, HasNUW, HasNSW), Name);
+    if (Value *V =
+            Folder.FoldNoWrapBinOp(Instruction::Shl, LHS, RHS, HasNUW, HasNSW))
+      return V;
     return CreateInsertNUWNSWBinOp(Instruction::Shl, LHS, RHS, Name,
                                    HasNUW, HasNSW);
   }
@@ -1315,9 +1309,8 @@ public:
 
   Value *CreateLShr(Value *LHS, Value *RHS, const Twine &Name = "",
                     bool isExact = false) {
-    if (auto *LC = dyn_cast<Constant>(LHS))
-      if (auto *RC = dyn_cast<Constant>(RHS))
-        return Insert(Folder.CreateLShr(LC, RC, isExact), Name);
+    if (Value *V = Folder.FoldExactBinOp(Instruction::LShr, LHS, RHS, isExact))
+      return V;
     if (!isExact)
       return Insert(BinaryOperator::CreateLShr(LHS, RHS), Name);
     return Insert(BinaryOperator::CreateExactLShr(LHS, RHS), Name);
@@ -1335,9 +1328,8 @@ public:
 
   Value *CreateAShr(Value *LHS, Value *RHS, const Twine &Name = "",
                     bool isExact = false) {
-    if (auto *LC = dyn_cast<Constant>(LHS))
-      if (auto *RC = dyn_cast<Constant>(RHS))
-        return Insert(Folder.CreateAShr(LC, RC, isExact), Name);
+    if (Value *V = Folder.FoldExactBinOp(Instruction::AShr, LHS, RHS, isExact))
+      return V;
     if (!isExact)
       return Insert(BinaryOperator::CreateAShr(LHS, RHS), Name);
     return Insert(BinaryOperator::CreateExactAShr(LHS, RHS), Name);
@@ -1354,7 +1346,7 @@ public:
   }
 
   Value *CreateAnd(Value *LHS, Value *RHS, const Twine &Name = "") {
-    if (auto *V = Folder.FoldAnd(LHS, RHS))
+    if (auto *V = Folder.FoldBinOp(Instruction::And, LHS, RHS))
       return V;
     return Insert(BinaryOperator::CreateAnd(LHS, RHS), Name);
   }
@@ -1376,7 +1368,7 @@ public:
   }
 
   Value *CreateOr(Value *LHS, Value *RHS, const Twine &Name = "") {
-    if (auto *V = Folder.FoldOr(LHS, RHS))
+    if (auto *V = Folder.FoldBinOp(Instruction::Or, LHS, RHS))
       return V;
     return Insert(BinaryOperator::CreateOr(LHS, RHS), Name);
   }
@@ -1398,7 +1390,8 @@ public:
   }
 
   Value *CreateXor(Value *LHS, Value *RHS, const Twine &Name = "") {
-    if (Value *V = foldConstant(Instruction::Xor, LHS, RHS, Name)) return V;
+    if (Value *V = Folder.FoldBinOp(Instruction::Xor, LHS, RHS))
+      return V;
     return Insert(BinaryOperator::CreateXor(LHS, RHS), Name);
   }
 
@@ -1416,7 +1409,8 @@ public:
       return CreateConstrainedFPBinOp(Intrinsic::experimental_constrained_fadd,
                                       L, R, nullptr, Name, FPMD);
 
-    if (Value *V = foldConstant(Instruction::FAdd, L, R, Name)) return V;
+    if (Value *V = Folder.FoldBinOpFMF(Instruction::FAdd, L, R, FMF))
+      return V;
     Instruction *I = setFPAttrs(BinaryOperator::CreateFAdd(L, R), FPMD, FMF);
     return Insert(I, Name);
   }
@@ -1429,9 +1423,10 @@ public:
       return CreateConstrainedFPBinOp(Intrinsic::experimental_constrained_fadd,
                                       L, R, FMFSource, Name);
 
-    if (Value *V = foldConstant(Instruction::FAdd, L, R, Name)) return V;
-    Instruction *I = setFPAttrs(BinaryOperator::CreateFAdd(L, R), nullptr,
-                                FMFSource->getFastMathFlags());
+    FastMathFlags FMF = FMFSource->getFastMathFlags();
+    if (Value *V = Folder.FoldBinOpFMF(Instruction::FAdd, L, R, FMF))
+      return V;
+    Instruction *I = setFPAttrs(BinaryOperator::CreateFAdd(L, R), nullptr, FMF);
     return Insert(I, Name);
   }
 
@@ -1441,7 +1436,8 @@ public:
       return CreateConstrainedFPBinOp(Intrinsic::experimental_constrained_fsub,
                                       L, R, nullptr, Name, FPMD);
 
-    if (Value *V = foldConstant(Instruction::FSub, L, R, Name)) return V;
+    if (Value *V = Folder.FoldBinOpFMF(Instruction::FSub, L, R, FMF))
+      return V;
     Instruction *I = setFPAttrs(BinaryOperator::CreateFSub(L, R), FPMD, FMF);
     return Insert(I, Name);
   }
@@ -1454,9 +1450,10 @@ public:
       return CreateConstrainedFPBinOp(Intrinsic::experimental_constrained_fsub,
                                       L, R, FMFSource, Name);
 
-    if (Value *V = foldConstant(Instruction::FSub, L, R, Name)) return V;
-    Instruction *I = setFPAttrs(BinaryOperator::CreateFSub(L, R), nullptr,
-                                FMFSource->getFastMathFlags());
+    FastMathFlags FMF = FMFSource->getFastMathFlags();
+    if (Value *V = Folder.FoldBinOpFMF(Instruction::FSub, L, R, FMF))
+      return V;
+    Instruction *I = setFPAttrs(BinaryOperator::CreateFSub(L, R), nullptr, FMF);
     return Insert(I, Name);
   }
 
@@ -1466,7 +1463,8 @@ public:
       return CreateConstrainedFPBinOp(Intrinsic::experimental_constrained_fmul,
                                       L, R, nullptr, Name, FPMD);
 
-    if (Value *V = foldConstant(Instruction::FMul, L, R, Name)) return V;
+    if (Value *V = Folder.FoldBinOpFMF(Instruction::FMul, L, R, FMF))
+      return V;
     Instruction *I = setFPAttrs(BinaryOperator::CreateFMul(L, R), FPMD, FMF);
     return Insert(I, Name);
   }
@@ -1479,9 +1477,10 @@ public:
       return CreateConstrainedFPBinOp(Intrinsic::experimental_constrained_fmul,
                                       L, R, FMFSource, Name);
 
-    if (Value *V = foldConstant(Instruction::FMul, L, R, Name)) return V;
-    Instruction *I = setFPAttrs(BinaryOperator::CreateFMul(L, R), nullptr,
-                                FMFSource->getFastMathFlags());
+    FastMathFlags FMF = FMFSource->getFastMathFlags();
+    if (Value *V = Folder.FoldBinOpFMF(Instruction::FMul, L, R, FMF))
+      return V;
+    Instruction *I = setFPAttrs(BinaryOperator::CreateFMul(L, R), nullptr, FMF);
     return Insert(I, Name);
   }
 
@@ -1491,7 +1490,8 @@ public:
       return CreateConstrainedFPBinOp(Intrinsic::experimental_constrained_fdiv,
                                       L, R, nullptr, Name, FPMD);
 
-    if (Value *V = foldConstant(Instruction::FDiv, L, R, Name)) return V;
+    if (Value *V = Folder.FoldBinOpFMF(Instruction::FDiv, L, R, FMF))
+      return V;
     Instruction *I = setFPAttrs(BinaryOperator::CreateFDiv(L, R), FPMD, FMF);
     return Insert(I, Name);
   }
@@ -1504,9 +1504,9 @@ public:
       return CreateConstrainedFPBinOp(Intrinsic::experimental_constrained_fdiv,
                                       L, R, FMFSource, Name);
 
-    if (Value *V = foldConstant(Instruction::FDiv, L, R, Name)) return V;
-    Instruction *I = setFPAttrs(BinaryOperator::CreateFDiv(L, R), nullptr,
-                                FMFSource->getFastMathFlags());
+    if (Value *V = Folder.FoldBinOpFMF(Instruction::FDiv, L, R, FMF))
+      return V;
+    Instruction *I = setFPAttrs(BinaryOperator::CreateFDiv(L, R), nullptr, FMF);
     return Insert(I, Name);
   }
 
@@ -1516,7 +1516,7 @@ public:
       return CreateConstrainedFPBinOp(Intrinsic::experimental_constrained_frem,
                                       L, R, nullptr, Name, FPMD);
 
-    if (Value *V = foldConstant(Instruction::FRem, L, R, Name)) return V;
+    if (Value *V = Folder.FoldBinOpFMF(Instruction::FRem, L, R, FMF)) return V;
     Instruction *I = setFPAttrs(BinaryOperator::CreateFRem(L, R), FPMD, FMF);
     return Insert(I, Name);
   }
@@ -1529,16 +1529,16 @@ public:
       return CreateConstrainedFPBinOp(Intrinsic::experimental_constrained_frem,
                                       L, R, FMFSource, Name);
 
-    if (Value *V = foldConstant(Instruction::FRem, L, R, Name)) return V;
-    Instruction *I = setFPAttrs(BinaryOperator::CreateFRem(L, R), nullptr,
-                                FMFSource->getFastMathFlags());
+    FastMathFlags FMF = FMFSource->getFastMathFlags();
+    if (Value *V = Folder.FoldBinOpFMF(Instruction::FRem, L, R, FMF)) return V;
+    Instruction *I = setFPAttrs(BinaryOperator::CreateFRem(L, R), nullptr, FMF);
     return Insert(I, Name);
   }
 
   Value *CreateBinOp(Instruction::BinaryOps Opc,
                      Value *LHS, Value *RHS, const Twine &Name = "",
                      MDNode *FPMathTag = nullptr) {
-    if (Value *V = foldConstant(Opc, LHS, RHS, Name)) return V;
+    if (Value *V = Folder.FoldBinOp(Opc, LHS, RHS)) return V;
     Instruction *BinOp = BinaryOperator::Create(Opc, LHS, RHS);
     if (isa<FPMathOperator>(BinOp))
       setFPAttrs(BinOp, FPMathTag, FMF);
index 38e150e..1cc5993 100644 (file)
@@ -31,20 +31,19 @@ public:
   // Return an existing value or a constant if the operation can be simplified.
   // Otherwise return nullptr.
   //===--------------------------------------------------------------------===//
-  virtual Value *FoldAdd(Value *LHS, Value *RHS, bool HasNUW = false,
-                         bool HasNSW = false) const = 0;
 
-  virtual Value *FoldAnd(Value *LHS, Value *RHS) const = 0;
+  virtual Value *FoldBinOp(Instruction::BinaryOps Opc, Value *LHS,
+                           Value *RHS) const = 0;
 
-  virtual Value *FoldOr(Value *LHS, Value *RHS) const = 0;
+  virtual Value *FoldExactBinOp(Instruction::BinaryOps Opc, Value *LHS,
+                                Value *RHS, bool IsExact) const = 0;
 
-  virtual Value *FoldUDiv(Value *LHS, Value *RHS, bool IsExact) const = 0;
+  virtual Value *FoldNoWrapBinOp(Instruction::BinaryOps Opc, Value *LHS,
+                                 Value *RHS, bool HasNUW,
+                                 bool HasNSW) const = 0;
 
-  virtual Value *FoldSDiv(Value *LHS, Value *RHS, bool IsExact) const = 0;
-
-  virtual Value *FoldURem(Value *LHS, Value *RHS) const = 0;
-
-  virtual Value *FoldSRem(Value *LHS, Value *RHS) const = 0;
+  virtual Value *FoldBinOpFMF(Instruction::BinaryOps Opc, Value *LHS,
+                              Value *RHS, FastMathFlags FMF) const = 0;
 
   virtual Value *FoldICmp(CmpInst::Predicate P, Value *LHS,
                           Value *RHS) const = 0;
@@ -69,29 +68,6 @@ public:
                                    ArrayRef<int> Mask) const = 0;
 
   //===--------------------------------------------------------------------===//
-  // Binary Operators
-  //===--------------------------------------------------------------------===//
-
-  virtual Value *CreateFAdd(Constant *LHS, Constant *RHS) const = 0;
-  virtual Value *CreateSub(Constant *LHS, Constant *RHS,
-                           bool HasNUW = false, bool HasNSW = false) const = 0;
-  virtual Value *CreateFSub(Constant *LHS, Constant *RHS) const = 0;
-  virtual Value *CreateMul(Constant *LHS, Constant *RHS,
-                           bool HasNUW = false, bool HasNSW = false) const = 0;
-  virtual Value *CreateFMul(Constant *LHS, Constant *RHS) const = 0;
-  virtual Value *CreateFDiv(Constant *LHS, Constant *RHS) const = 0;
-  virtual Value *CreateFRem(Constant *LHS, Constant *RHS) const = 0;
-  virtual Value *CreateShl(Constant *LHS, Constant *RHS,
-                           bool HasNUW = false, bool HasNSW = false) const = 0;
-  virtual Value *CreateLShr(Constant *LHS, Constant *RHS,
-                            bool isExact = false) const = 0;
-  virtual Value *CreateAShr(Constant *LHS, Constant *RHS,
-                            bool isExact = false) const = 0;
-  virtual Value *CreateXor(Constant *LHS, Constant *RHS) const = 0;
-  virtual Value *CreateBinOp(Instruction::BinaryOps Opc,
-                             Constant *LHS, Constant *RHS) const = 0;
-
-  //===--------------------------------------------------------------------===//
   // Unary Operators
   //===--------------------------------------------------------------------===//
 
index 4b9e183..183b5a4 100644 (file)
@@ -43,26 +43,26 @@ public:
   // Return an existing value or a constant if the operation can be simplified.
   // Otherwise return nullptr.
   //===--------------------------------------------------------------------===//
-  Value *FoldAdd(Value *LHS, Value *RHS, bool HasNUW = false,
-                 bool HasNSW = false) const override {
+
+  Value *FoldBinOp(Instruction::BinaryOps Opc, Value *LHS,
+                   Value *RHS) const override {
     return nullptr;
   }
 
-  Value *FoldAnd(Value *LHS, Value *RHS) const override { return nullptr; }
-
-  Value *FoldOr(Value *LHS, Value *RHS) const override { return nullptr; }
-
-  Value *FoldUDiv(Value *LHS, Value *RHS, bool IsExact) const override {
+  Value *FoldExactBinOp(Instruction::BinaryOps Opc, Value *LHS, Value *RHS,
+                        bool IsExact) const override {
     return nullptr;
   }
 
-  Value *FoldSDiv(Value *LHS, Value *RHS, bool IsExact) const override {
+  Value *FoldNoWrapBinOp(Instruction::BinaryOps Opc, Value *LHS, Value *RHS,
+                         bool HasNUW, bool HasNSW) const override {
     return nullptr;
   }
 
-  Value *FoldURem(Value *LHS, Value *RHS) const override { return nullptr; }
-
-  Value *FoldSRem(Value *LHS, Value *RHS) const override { return nullptr; }
+  Value *FoldBinOpFMF(Instruction::BinaryOps Opc, Value *LHS, Value *RHS,
+                      FastMathFlags FMF) const override {
+    return nullptr;
+  }
 
   Value *FoldICmp(CmpInst::Predicate P, Value *LHS, Value *RHS) const override {
     return nullptr;
@@ -102,79 +102,6 @@ public:
   }
 
   //===--------------------------------------------------------------------===//
-  // Binary Operators
-  //===--------------------------------------------------------------------===//
-
-  Instruction *CreateFAdd(Constant *LHS, Constant *RHS) const override {
-    return BinaryOperator::CreateFAdd(LHS, RHS);
-  }
-
-  Instruction *CreateSub(Constant *LHS, Constant *RHS,
-                         bool HasNUW = false,
-                         bool HasNSW = false) const override {
-    BinaryOperator *BO = BinaryOperator::CreateSub(LHS, RHS);
-    if (HasNUW) BO->setHasNoUnsignedWrap();
-    if (HasNSW) BO->setHasNoSignedWrap();
-    return BO;
-  }
-
-  Instruction *CreateFSub(Constant *LHS, Constant *RHS) const override {
-    return BinaryOperator::CreateFSub(LHS, RHS);
-  }
-
-  Instruction *CreateMul(Constant *LHS, Constant *RHS,
-                         bool HasNUW = false,
-                         bool HasNSW = false) const override {
-    BinaryOperator *BO = BinaryOperator::CreateMul(LHS, RHS);
-    if (HasNUW) BO->setHasNoUnsignedWrap();
-    if (HasNSW) BO->setHasNoSignedWrap();
-    return BO;
-  }
-
-  Instruction *CreateFMul(Constant *LHS, Constant *RHS) const override {
-    return BinaryOperator::CreateFMul(LHS, RHS);
-  }
-
-  Instruction *CreateFDiv(Constant *LHS, Constant *RHS) const override {
-    return BinaryOperator::CreateFDiv(LHS, RHS);
-  }
-
-  Instruction *CreateFRem(Constant *LHS, Constant *RHS) const override {
-    return BinaryOperator::CreateFRem(LHS, RHS);
-  }
-
-  Instruction *CreateShl(Constant *LHS, Constant *RHS, bool HasNUW = false,
-                         bool HasNSW = false) const override {
-    BinaryOperator *BO = BinaryOperator::CreateShl(LHS, RHS);
-    if (HasNUW) BO->setHasNoUnsignedWrap();
-    if (HasNSW) BO->setHasNoSignedWrap();
-    return BO;
-  }
-
-  Instruction *CreateLShr(Constant *LHS, Constant *RHS,
-                          bool isExact = false) const override {
-    if (!isExact)
-      return BinaryOperator::CreateLShr(LHS, RHS);
-    return BinaryOperator::CreateExactLShr(LHS, RHS);
-  }
-
-  Instruction *CreateAShr(Constant *LHS, Constant *RHS,
-                          bool isExact = false) const override {
-    if (!isExact)
-      return BinaryOperator::CreateAShr(LHS, RHS);
-    return BinaryOperator::CreateExactAShr(LHS, RHS);
-  }
-
-  Instruction *CreateXor(Constant *LHS, Constant *RHS) const override {
-    return BinaryOperator::CreateXor(LHS, RHS);
-  }
-
-  Instruction *CreateBinOp(Instruction::BinaryOps Opc,
-                           Constant *LHS, Constant *RHS) const override {
-    return BinaryOperator::Create(Opc, LHS, RHS);
-  }
-
-  //===--------------------------------------------------------------------===//
   // Unary Operators
   //===--------------------------------------------------------------------===//
 
index a02eec1..401f1ee 100644 (file)
@@ -273,7 +273,9 @@ Value *SCEVExpander::InsertBinop(Instruction::BinaryOps Opcode,
   }
 
   // If we haven't found this binop, insert it.
-  Instruction *BO = cast<Instruction>(Builder.CreateBinOp(Opcode, LHS, RHS));
+  // TODO: Use the Builder, which will make CreateBinOp below fold with
+  // InstSimplifyFolder.
+  Instruction *BO = Builder.Insert(BinaryOperator::Create(Opcode, LHS, RHS));
   BO->setDebugLoc(Loc);
   if (Flags & SCEV::FlagNUW)
     BO->setHasNoUnsignedWrap();