// Check integral scalar types.
const bool HasExtMOrZmmul =
Subtarget.hasStdExtM() || Subtarget.hasStdExtZmmul();
- if (VT.isScalarInteger()) {
- // Omit the optimization if the sub target has the M extension and the data
- // size exceeds XLen.
- if (HasExtMOrZmmul && VT.getSizeInBits() > Subtarget.getXLen())
- return false;
- if (auto *ConstNode = dyn_cast<ConstantSDNode>(C.getNode())) {
- // Break the MUL to a SLLI and an ADD/SUB.
- const APInt &Imm = ConstNode->getAPIntValue();
- if ((Imm + 1).isPowerOf2() || (Imm - 1).isPowerOf2() ||
- (1 - Imm).isPowerOf2() || (-1 - Imm).isPowerOf2())
- return true;
- // Optimize the MUL to (SH*ADD x, (SLLI x, bits)) if Imm is not simm12.
- if (Subtarget.hasStdExtZba() && !Imm.isSignedIntN(12) &&
- ((Imm - 2).isPowerOf2() || (Imm - 4).isPowerOf2() ||
- (Imm - 8).isPowerOf2()))
+ if (!VT.isScalarInteger())
+ return false;
+
+ // Omit the optimization if the sub target has the M extension and the data
+ // size exceeds XLen.
+ if (HasExtMOrZmmul && VT.getSizeInBits() > Subtarget.getXLen())
+ return false;
+
+ if (auto *ConstNode = dyn_cast<ConstantSDNode>(C.getNode())) {
+ // Break the MUL to a SLLI and an ADD/SUB.
+ const APInt &Imm = ConstNode->getAPIntValue();
+ if ((Imm + 1).isPowerOf2() || (Imm - 1).isPowerOf2() ||
+ (1 - Imm).isPowerOf2() || (-1 - Imm).isPowerOf2())
+ return true;
+
+ // Optimize the MUL to (SH*ADD x, (SLLI x, bits)) if Imm is not simm12.
+ if (Subtarget.hasStdExtZba() && !Imm.isSignedIntN(12) &&
+ ((Imm - 2).isPowerOf2() || (Imm - 4).isPowerOf2() ||
+ (Imm - 8).isPowerOf2()))
+ return true;
+
+ // Break the MUL to two SLLI instructions and an ADD/SUB, if Imm needs
+ // a pair of LUI/ADDI.
+ if (!Imm.isSignedIntN(12) && Imm.countr_zero() < 12 &&
+ ConstNode->hasOneUse()) {
+ APInt ImmS = Imm.ashr(Imm.countr_zero());
+ if ((ImmS + 1).isPowerOf2() || (ImmS - 1).isPowerOf2() ||
+ (1 - ImmS).isPowerOf2())
return true;
- // Break the MUL to two SLLI instructions and an ADD/SUB, if Imm needs
- // a pair of LUI/ADDI.
- if (!Imm.isSignedIntN(12) && Imm.countr_zero() < 12 &&
- ConstNode->hasOneUse()) {
- APInt ImmS = Imm.ashr(Imm.countr_zero());
- if ((ImmS + 1).isPowerOf2() || (ImmS - 1).isPowerOf2() ||
- (1 - ImmS).isPowerOf2())
- return true;
- }
}
}