[RISCVGatherScatterLowering] Minor code cleanup [NFC]
authorPhilip Reames <preames@rivosinc.com>
Fri, 12 May 2023 17:30:05 +0000 (10:30 -0700)
committerPhilip Reames <listmail@philipreames.com>
Fri, 12 May 2023 17:33:29 +0000 (10:33 -0700)
Use a switch to avoid repeat tests on the opcode, and factor out some common code out of another switch.  (Not all branches had both common bits, but the one which didn't left the values unchanged and their starting value is the respective operand - so storing it back is a nop.)

llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp

index e844e1b11271b97c5da6991b94763d6ec420fbdb..0174aa547a1b85a3651ee65197e1a10804403d77 100644 (file)
@@ -225,20 +225,24 @@ bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
   if (!BO)
     return false;
 
-  if (BO->getOpcode() != Instruction::Add &&
-      BO->getOpcode() != Instruction::Or &&
-      BO->getOpcode() != Instruction::Mul &&
-      BO->getOpcode() != Instruction::Shl)
-    return false;
-
-  // Only support shift by constant.
-  if (BO->getOpcode() == Instruction::Shl && !isa<Constant>(BO->getOperand(1)))
-    return false;
-
-  // We need to be able to treat Or as Add.
-  if (BO->getOpcode() == Instruction::Or &&
-      !haveNoCommonBitsSet(BO->getOperand(0), BO->getOperand(1), *DL))
+  switch (BO->getOpcode()) {
+  default:
     return false;
+  case Instruction::Or:
+    // We need to be able to treat Or as Add.
+    if (!haveNoCommonBitsSet(BO->getOperand(0), BO->getOperand(1), *DL))
+      return false;
+    break;
+  case Instruction::Add:
+    break;
+  case Instruction::Shl:
+    // Only support shift by constant.
+    if (!isa<Constant>(BO->getOperand(1)))
+      return false;
+    break;
+  case Instruction::Mul:
+    break;
+  }
 
   // We should have one operand in the loop and one splat.
   Value *OtherOp;
@@ -291,7 +295,6 @@ bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
       Start = SplatOp;
     else
       Start = Builder.CreateAdd(Start, SplatOp, "start");
-    BasePtr->setIncomingValue(StartBlock, Start);
     break;
   }
   case Instruction::Mul: {
@@ -306,8 +309,6 @@ bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
       Stride = SplatOp;
     else
       Stride = Builder.CreateMul(Stride, SplatOp, "stride");
-    Inc->setOperand(StepIndex, Step);
-    BasePtr->setIncomingValue(StartBlock, Start);
     break;
   }
   case Instruction::Shl: {
@@ -316,12 +317,12 @@ bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
       Start = Builder.CreateShl(Start, SplatOp, "start");
     Step = Builder.CreateShl(Step, SplatOp, "step");
     Stride = Builder.CreateShl(Stride, SplatOp, "stride");
-    Inc->setOperand(StepIndex, Step);
-    BasePtr->setIncomingValue(StartBlock, Start);
     break;
   }
   }
 
+  Inc->setOperand(StepIndex, Step);
+  BasePtr->setIncomingValue(StartBlock, Start);
   return true;
 }