[LV] Update generateInstruction to return produced value (NFC).
authorFlorian Hahn <flo@fhahn.com>
Wed, 5 Jul 2023 18:15:55 +0000 (19:15 +0100)
committerFlorian Hahn <flo@fhahn.com>
Wed, 5 Jul 2023 18:53:59 +0000 (19:53 +0100)
Update generateInstruction to return the produced value instead of
setting it for each opcode. This reduces the amount of duplicated code
and is a preparation for D153696.

Reviewed By: Ayal

Differential Revision: https://reviews.llvm.org/D154240

llvm/lib/Transforms/Vectorize/VPlan.h
llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

index c132f0c..7331346 100644 (file)
@@ -847,8 +847,10 @@ private:
   const std::string Name;
 
   /// Utility method serving execute(): generates a single instance of the
-  /// modeled instruction.
-  void generateInstruction(VPTransformState &State, unsigned Part);
+  /// modeled instruction. \returns the generated value for \p Part.
+  /// In some cases an existing value is returned rather than a generated
+  /// one.
+  Value *generateInstruction(VPTransformState &State, unsigned Part);
 
 protected:
   void setUnderlyingInstr(Instruction *I) { setUnderlyingValue(I); }
index 5a4e8cc..26c309e 100644 (file)
@@ -216,41 +216,32 @@ void VPRecipeBase::moveBefore(VPBasicBlock &BB,
   insertBefore(BB, I);
 }
 
-void VPInstruction::generateInstruction(VPTransformState &State,
-                                        unsigned Part) {
+Value *VPInstruction::generateInstruction(VPTransformState &State,
+                                          unsigned Part) {
   IRBuilderBase &Builder = State.Builder;
   Builder.SetCurrentDebugLocation(DL);
 
   if (Instruction::isBinaryOp(getOpcode())) {
     Value *A = State.get(getOperand(0), Part);
     Value *B = State.get(getOperand(1), Part);
-    Value *V =
-        Builder.CreateBinOp((Instruction::BinaryOps)getOpcode(), A, B, Name);
-    State.set(this, V, Part);
-    return;
+    return Builder.CreateBinOp((Instruction::BinaryOps)getOpcode(), A, B, Name);
   }
 
   switch (getOpcode()) {
   case VPInstruction::Not: {
     Value *A = State.get(getOperand(0), Part);
-    Value *V = Builder.CreateNot(A, Name);
-    State.set(this, V, Part);
-    break;
+    return Builder.CreateNot(A, Name);
   }
   case VPInstruction::ICmpULE: {
     Value *IV = State.get(getOperand(0), Part);
     Value *TC = State.get(getOperand(1), Part);
-    Value *V = Builder.CreateICmpULE(IV, TC, Name);
-    State.set(this, V, Part);
-    break;
+    return Builder.CreateICmpULE(IV, TC, Name);
   }
   case Instruction::Select: {
     Value *Cond = State.get(getOperand(0), Part);
     Value *Op1 = State.get(getOperand(1), Part);
     Value *Op2 = State.get(getOperand(2), Part);
-    Value *V = Builder.CreateSelect(Cond, Op1, Op2, Name);
-    State.set(this, V, Part);
-    break;
+    return Builder.CreateSelect(Cond, Op1, Op2, Name);
   }
   case VPInstruction::ActiveLaneMask: {
     // Get first lane of vector induction variable.
@@ -260,11 +251,9 @@ void VPInstruction::generateInstruction(VPTransformState &State,
 
     auto *Int1Ty = Type::getInt1Ty(Builder.getContext());
     auto *PredTy = VectorType::get(Int1Ty, State.VF);
-    Instruction *Call = Builder.CreateIntrinsic(
-        Intrinsic::get_active_lane_mask, {PredTy, ScalarTC->getType()},
-        {VIVElem0, ScalarTC}, nullptr, Name);
-    State.set(this, Call, Part);
-    break;
+    return Builder.CreateIntrinsic(Intrinsic::get_active_lane_mask,
+                                   {PredTy, ScalarTC->getType()},
+                                   {VIVElem0, ScalarTC}, nullptr, Name);
   }
   case VPInstruction::FirstOrderRecurrenceSplice: {
     // Generate code to combine the previous and current values in vector v3.
@@ -282,14 +271,10 @@ void VPInstruction::generateInstruction(VPTransformState &State,
     // For the first part, use the recurrence phi (v1), otherwise v2.
     auto *V1 = State.get(getOperand(0), 0);
     Value *PartMinus1 = Part == 0 ? V1 : State.get(getOperand(1), Part - 1);
-    if (!PartMinus1->getType()->isVectorTy()) {
-      State.set(this, PartMinus1, Part);
-    } else {
-      Value *V2 = State.get(getOperand(1), Part);
-      State.set(this, Builder.CreateVectorSplice(PartMinus1, V2, -1, Name),
-                Part);
-    }
-    break;
+    if (!PartMinus1->getType()->isVectorTy())
+      return PartMinus1;
+    Value *V2 = State.get(getOperand(1), Part);
+    return Builder.CreateVectorSplice(PartMinus1, V2, -1, Name);
   }
   case VPInstruction::CalculateTripCountMinusVF: {
     Value *ScalarTC = State.get(getOperand(0), {0, 0});
@@ -298,13 +283,10 @@ void VPInstruction::generateInstruction(VPTransformState &State,
     Value *Sub = Builder.CreateSub(ScalarTC, Step);
     Value *Cmp = Builder.CreateICmp(CmpInst::Predicate::ICMP_UGT, ScalarTC, Step);
     Value *Zero = ConstantInt::get(ScalarTC->getType(), 0);
-    Value *Sel = Builder.CreateSelect(Cmp, Sub, Zero);
-    State.set(this, Sel, Part);
-    break;
+    return Builder.CreateSelect(Cmp, Sub, Zero);
   }
   case VPInstruction::CanonicalIVIncrement:
   case VPInstruction::CanonicalIVIncrementNUW: {
-    Value *Next = nullptr;
     if (Part == 0) {
       bool IsNUW = getOpcode() == VPInstruction::CanonicalIVIncrementNUW;
       auto *Phi = State.get(getOperand(0), 0);
@@ -312,34 +294,26 @@ void VPInstruction::generateInstruction(VPTransformState &State,
       // elements) times the unroll factor (num of SIMD instructions).
       Value *Step =
           createStepForVF(Builder, Phi->getType(), State.VF, State.UF);
-      Next = Builder.CreateAdd(Phi, Step, Name, IsNUW, false);
-    } else {
-      Next = State.get(this, 0);
+      return Builder.CreateAdd(Phi, Step, Name, IsNUW, false);
     }
-
-    State.set(this, Next, Part);
-    break;
+    return State.get(this, 0);
   }
 
   case VPInstruction::CanonicalIVIncrementForPart:
   case VPInstruction::CanonicalIVIncrementForPartNUW: {
     bool IsNUW = getOpcode() == VPInstruction::CanonicalIVIncrementForPartNUW;
     auto *IV = State.get(getOperand(0), VPIteration(0, 0));
-    if (Part == 0) {
-      State.set(this, IV, Part);
-      break;
-    }
+    if (Part == 0)
+      return IV;
 
     // The canonical IV is incremented by the vectorization factor (num of SIMD
     // elements) times the unroll part.
     Value *Step = createStepForVF(Builder, IV->getType(), State.VF, Part);
-    Value *Next = Builder.CreateAdd(IV, Step, Name, IsNUW, false);
-    State.set(this, Next, Part);
-    break;
+    return Builder.CreateAdd(IV, Step, Name, IsNUW, false);
   }
   case VPInstruction::BranchOnCond: {
     if (Part != 0)
-      break;
+      return nullptr;
 
     Value *Cond = State.get(getOperand(0), VPIteration(Part, 0));
     VPRegionBlock *ParentRegion = getParent()->getParent();
@@ -356,11 +330,11 @@ void VPInstruction::generateInstruction(VPTransformState &State,
 
     CondBr->setSuccessor(0, nullptr);
     Builder.GetInsertBlock()->getTerminator()->eraseFromParent();
-    break;
+    return CondBr;
   }
   case VPInstruction::BranchOnCount: {
     if (Part != 0)
-      break;
+      return nullptr;
     // First create the compare.
     Value *IV = State.get(getOperand(0), Part);
     Value *TC = State.get(getOperand(1), Part);
@@ -380,7 +354,7 @@ void VPInstruction::generateInstruction(VPTransformState &State,
                                               State.CFG.VPBB2IRBB[Header]);
     CondBr->setSuccessor(0, nullptr);
     Builder.GetInsertBlock()->getTerminator()->eraseFromParent();
-    break;
+    return CondBr;
   }
   default:
     llvm_unreachable("Unsupported opcode for instruction");
@@ -391,8 +365,13 @@ void VPInstruction::execute(VPTransformState &State) {
   assert(!State.Instance && "VPInstruction executing an Instance");
   IRBuilderBase::FastMathFlagGuard FMFGuard(State.Builder);
   State.Builder.setFastMathFlags(FMF);
-  for (unsigned Part = 0; Part < State.UF; ++Part)
-    generateInstruction(State, Part);
+  for (unsigned Part = 0; Part < State.UF; ++Part) {
+    Value *GeneratedValue = generateInstruction(State, Part);
+    if (!hasResult())
+      continue;
+    assert(GeneratedValue && "generateInstruction must produce a value");
+    State.set(this, GeneratedValue, Part);
+  }
 }
 
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)