[AArch64] Combine vector shift instructions in SelectionDAG
authorAndrew Savonichev <andrew.savonichev@gmail.com>
Tue, 18 May 2021 12:54:43 +0000 (15:54 +0300)
committerAndrew Savonichev <andrew.savonichev@gmail.com>
Thu, 20 May 2021 07:50:13 +0000 (10:50 +0300)
bswap.v2i16 + sitofp in LLVM IR generate a sequence of:

  - REV32 + USHR for bswap.v2i16
  - SHL + SSHR + SCVTF for sext to v2i32 and scvt

The shift instructions are excessive as noted in PR24820, and they can
be optimized to just SSHR.

Differential Revision: https://reviews.llvm.org/D102333

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/lib/Target/AArch64/AArch64ISelLowering.h
llvm/test/CodeGen/AArch64/aarch64-bswap-ext.ll [new file with mode: 0644]

index 218ab13..61bfc3c 100644 (file)
@@ -14639,6 +14639,29 @@ static SDValue performGLD1Combine(SDNode *N, SelectionDAG &DAG) {
   return SDValue();
 }
 
+/// Optimize a vector shift instruction and its operand if shifted out
+/// bits are not used.
+static SDValue performVectorShiftCombine(SDNode *N,
+                                         const AArch64TargetLowering &TLI,
+                                         TargetLowering::DAGCombinerInfo &DCI) {
+  assert(N->getOpcode() == AArch64ISD::VASHR ||
+         N->getOpcode() == AArch64ISD::VLSHR);
+
+  SDValue Op = N->getOperand(0);
+  unsigned OpScalarSize = Op.getScalarValueSizeInBits();
+
+  unsigned ShiftImm = N->getConstantOperandVal(1);
+  assert(OpScalarSize > ShiftImm && "Invalid shift imm");
+
+  APInt ShiftedOutBits = APInt::getLowBitsSet(OpScalarSize, ShiftImm);
+  APInt DemandedMask = ~ShiftedOutBits;
+
+  if (TLI.SimplifyDemandedBits(Op, DemandedMask, DCI))
+    return SDValue(N, 0);
+
+  return SDValue();
+}
+
 /// Target-specific DAG combine function for post-increment LD1 (lane) and
 /// post-increment LD1R.
 static SDValue performPostLD1Combine(SDNode *N,
@@ -16115,6 +16138,9 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
   case AArch64ISD::GLD1S_SXTW_SCALED_MERGE_ZERO:
   case AArch64ISD::GLD1S_IMM_MERGE_ZERO:
     return performGLD1Combine(N, DAG);
+  case AArch64ISD::VASHR:
+  case AArch64ISD::VLSHR:
+    return performVectorShiftCombine(N, *this, DCI);
   case ISD::INSERT_VECTOR_ELT:
     return performInsertVectorEltCombine(N, DCI);
   case ISD::EXTRACT_VECTOR_ELT:
@@ -17805,3 +17831,47 @@ SDValue AArch64TargetLowering::getSVESafeBitCast(EVT VT, SDValue Op,
 bool AArch64TargetLowering::isAllActivePredicate(SDValue N) const {
   return ::isAllActivePredicate(N);
 }
+
+bool AArch64TargetLowering::SimplifyDemandedBitsForTargetNode(
+    SDValue Op, const APInt &OriginalDemandedBits,
+    const APInt &OriginalDemandedElts, KnownBits &Known, TargetLoweringOpt &TLO,
+    unsigned Depth) const {
+
+  unsigned Opc = Op.getOpcode();
+  switch (Opc) {
+  case AArch64ISD::VSHL: {
+    // Match (VSHL (VLSHR Val X) X)
+    SDValue ShiftL = Op;
+    SDValue ShiftR = Op->getOperand(0);
+    if (ShiftR->getOpcode() != AArch64ISD::VLSHR)
+      return false;
+
+    if (!ShiftL.hasOneUse() || !ShiftR.hasOneUse())
+      return false;
+
+    unsigned ShiftLBits = ShiftL->getConstantOperandVal(1);
+    unsigned ShiftRBits = ShiftR->getConstantOperandVal(1);
+
+    // Other cases can be handled as well, but this is not
+    // implemented.
+    if (ShiftRBits != ShiftLBits)
+      return false;
+
+    unsigned ScalarSize = Op.getScalarValueSizeInBits();
+    assert(ScalarSize > ShiftLBits && "Invalid shift imm");
+
+    APInt ZeroBits = APInt::getLowBitsSet(ScalarSize, ShiftLBits);
+    APInt UnusedBits = ~OriginalDemandedBits;
+
+    if ((ZeroBits & UnusedBits) != ZeroBits)
+      return false;
+
+    // All bits that are zeroed by (VSHL (VLSHR Val X) X) are not
+    // used - simplify to just Val.
+    return TLO.CombineTo(Op, ShiftR->getOperand(0));
+  }
+  }
+
+  return TargetLowering::SimplifyDemandedBitsForTargetNode(
+      Op, OriginalDemandedBits, OriginalDemandedElts, Known, TLO, Depth);
+}
index 22bfec3..2539c14 100644 (file)
@@ -1068,6 +1068,13 @@ private:
   bool shouldLocalize(const MachineInstr &MI,
                       const TargetTransformInfo *TTI) const override;
 
+  bool SimplifyDemandedBitsForTargetNode(SDValue Op,
+                                         const APInt &OriginalDemandedBits,
+                                         const APInt &OriginalDemandedElts,
+                                         KnownBits &Known,
+                                         TargetLoweringOpt &TLO,
+                                         unsigned Depth) const override;
+
   // Normally SVE is only used for byte size vectors that do not fit within a
   // NEON vector. This changes when OverrideNEON is true, allowing SVE to be
   // used for 64bit and 128bit vectors as well.
diff --git a/llvm/test/CodeGen/AArch64/aarch64-bswap-ext.ll b/llvm/test/CodeGen/AArch64/aarch64-bswap-ext.ll
new file mode 100644 (file)
index 0000000..fdd5ce6
--- /dev/null
@@ -0,0 +1,27 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=aarch64-unknown-linux-gnu < %s | FileCheck %s
+
+define <2 x i32> @test1(<2 x i16> %v2i16) {
+; CHECK-LABEL: test1:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    rev32 v0.8b, v0.8b
+; CHECK-NEXT:    sshr v0.2s, v0.2s, #16
+; CHECK-NEXT:    ret
+  %v2i16_rev = call <2 x i16> @llvm.bswap.v2i16(<2 x i16> %v2i16)
+  %v2i32 = sext <2 x i16> %v2i16_rev to <2 x i32>
+  ret <2 x i32> %v2i32
+}
+
+define <2 x float> @test2(<2 x i16> %v2i16) {
+; CHECK-LABEL: test2:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    rev32 v0.8b, v0.8b
+; CHECK-NEXT:    sshr v0.2s, v0.2s, #16
+; CHECK-NEXT:    scvtf v0.2s, v0.2s
+; CHECK-NEXT:    ret
+  %v2i16_rev = call <2 x i16> @llvm.bswap.v2i16(<2 x i16> %v2i16)
+  %v2f32 = sitofp <2 x i16> %v2i16_rev to <2 x float>
+  ret <2 x float> %v2f32
+}
+
+declare <2 x i16> @llvm.bswap.v2i16(<2 x i16>) nounwind readnone