[DAGCombiner][SVE] Ensure MGATHER/MSCATTER addressing mode combines preserve index...
authorPaul Walker <paul.walker@arm.com>
Tue, 5 Apr 2022 16:49:01 +0000 (17:49 +0100)
committerPaul Walker <paul.walker@arm.com>
Fri, 29 Apr 2022 11:35:16 +0000 (12:35 +0100)
refineUniformBase and selectGatherScatterAddrMode both attempt the
transformation:

  base(0) + index(A+splat(B)) => base(B) + index(A)

However, this is only safe when index is not implicitly scaled.

Differential Revision: https://reviews.llvm.org/D123222

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll

index 71618eb2bd7c9f9366ced7af254516567fe7fd03..181ff00184b36baa76f11abbd8792bf1d560c387 100644 (file)
@@ -10426,14 +10426,19 @@ static SDValue ConvertSelectToConcatVector(SDNode *N, SelectionDAG &DAG) {
       TopHalf->isZero() ? RHS->getOperand(1) : LHS->getOperand(1));
 }
 
-bool refineUniformBase(SDValue &BasePtr, SDValue &Index, SelectionDAG &DAG) {
+bool refineUniformBase(SDValue &BasePtr, SDValue &Index, bool IndexIsScaled,
+                       SelectionDAG &DAG) {
   if (!isNullConstant(BasePtr) || Index.getOpcode() != ISD::ADD)
     return false;
 
+  // Only perform the transformation when existing operands can be reused.
+  if (IndexIsScaled)
+    return false;
+
   // For now we check only the LHS of the add.
   SDValue LHS = Index.getOperand(0);
   SDValue SplatVal = DAG.getSplatValue(LHS);
-  if (!SplatVal)
+  if (!SplatVal || SplatVal.getValueType() != BasePtr.getValueType())
     return false;
 
   BasePtr = SplatVal;
@@ -10481,7 +10486,7 @@ SDValue DAGCombiner::visitMSCATTER(SDNode *N) {
   if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
     return Chain;
 
-  if (refineUniformBase(BasePtr, Index, DAG)) {
+  if (refineUniformBase(BasePtr, Index, MSC->isIndexScaled(), DAG)) {
     SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
     return DAG.getMaskedScatter(
         DAG.getVTList(MVT::Other), MSC->getMemoryVT(), DL, Ops,
@@ -10576,7 +10581,7 @@ SDValue DAGCombiner::visitMGATHER(SDNode *N) {
   if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
     return CombineTo(N, PassThru, MGT->getChain());
 
-  if (refineUniformBase(BasePtr, Index, DAG)) {
+  if (refineUniformBase(BasePtr, Index, MGT->isIndexScaled(), DAG)) {
     SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
     return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other),
                                MGT->getMemoryVT(), DL, Ops,
index a5dadd112c9eaab9c7c27d653fde3c0015886d59..3f6a36fe49f4b2330fe63694118b1b2d78097cb6 100644 (file)
@@ -4656,10 +4656,10 @@ bool getGatherScatterIndexIsExtended(SDValue Index) {
 // VECTOR + IMMEDIATE:
 //    getelementptr nullptr, <vscale x N x T> (splat(#x)) + %indices)
 // -> getelementptr #x, <vscale x N x T> %indices
-void selectGatherScatterAddrMode(SDValue &BasePtr, SDValue &Index, EVT MemVT,
-                                 unsigned &Opcode, bool IsGather,
-                                 SelectionDAG &DAG) {
-  if (!isNullConstant(BasePtr))
+void selectGatherScatterAddrMode(SDValue &BasePtr, SDValue &Index,
+                                 bool IsScaled, EVT MemVT, unsigned &Opcode,
+                                 bool IsGather, SelectionDAG &DAG) {
+  if (!isNullConstant(BasePtr) || IsScaled)
     return;
 
   // FIXME: This will not match for fixed vector type codegen as the nodes in
@@ -4789,7 +4789,7 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
     Index = Index.getOperand(0);
 
   unsigned Opcode = getGatherVecOpcode(IsScaled, IsSigned, IdxNeedsExtend);
-  selectGatherScatterAddrMode(BasePtr, Index, MemVT, Opcode,
+  selectGatherScatterAddrMode(BasePtr, Index, IsScaled, MemVT, Opcode,
                               /*isGather=*/true, DAG);
 
   if (ExtType == ISD::SEXTLOAD)
@@ -4898,7 +4898,7 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
     Index = Index.getOperand(0);
 
   unsigned Opcode = getScatterVecOpcode(IsScaled, IsSigned, NeedsExtend);
-  selectGatherScatterAddrMode(BasePtr, Index, MemVT, Opcode,
+  selectGatherScatterAddrMode(BasePtr, Index, IsScaled, MemVT, Opcode,
                               /*isGather=*/false, DAG);
 
   if (IsFixedLength) {
index d06cc313ba539843d752a471f57c0d1d5f77a54a..4fdf4a106dbc0fd0874921d232e5a192cc1ff278 100644 (file)
@@ -343,12 +343,13 @@ define <vscale x 2 x i64> @masked_gather_nxv2i64_const_with_vec_offsets(<vscale
   ret <vscale x 2 x i64> %data
 }
 
-; TODO: The generated code is wrong because we've lost the scaling applied to
-; %scalar_offset when it's used to calculate %ptrs.
 define <vscale x 2 x i64> @masked_gather_nxv2i64_null_with_vec_plus_scalar_offsets(<vscale x 2 x i64> %vector_offsets, i64 %scalar_offset, <vscale x 2 x i1> %pg) #0 {
 ; CHECK-LABEL: masked_gather_nxv2i64_null_with_vec_plus_scalar_offsets:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ld1d { z0.d }, p0/z, [x0, z0.d, lsl #3]
+; CHECK-NEXT:    mov x8, xzr
+; CHECK-NEXT:    mov z1.d, x0
+; CHECK-NEXT:    add z0.d, z0.d, z1.d
+; CHECK-NEXT:    ld1d { z0.d }, p0/z, [x8, z0.d, lsl #3]
 ; CHECK-NEXT:    ret
   %scalar_offset.ins = insertelement <vscale x 2 x i64> undef, i64 %scalar_offset, i64 0
   %scalar_offset.splat = shufflevector <vscale x 2 x i64> %scalar_offset.ins, <vscale x 2 x i64> undef, <vscale x 2 x i32> zeroinitializer
@@ -358,12 +359,11 @@ define <vscale x 2 x i64> @masked_gather_nxv2i64_null_with_vec_plus_scalar_offse
   ret <vscale x 2 x i64> %data
 }
 
-; TODO: The generated code is wrong because we've lost the scaling applied to
-; constant scalar offset (i.e. i64 1)  when it's used to calculate %ptrs.
 define <vscale x 2 x i64> @masked_gather_nxv2i64_null_with__vec_plus_imm_offsets(<vscale x 2 x i64> %vector_offsets, <vscale x 2 x i1> %pg) #0 {
 ; CHECK-LABEL: masked_gather_nxv2i64_null_with__vec_plus_imm_offsets:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov w8, #1
+; CHECK-NEXT:    mov x8, xzr
+; CHECK-NEXT:    add z0.d, z0.d, #1 // =0x1
 ; CHECK-NEXT:    ld1d { z0.d }, p0/z, [x8, z0.d, lsl #3]
 ; CHECK-NEXT:    ret
   %scalar_offset.ins = insertelement <vscale x 2 x i64> undef, i64 1, i64 0
@@ -425,12 +425,13 @@ define void @masked_scatter_nxv2i64_const_with_vec_offsets(<vscale x 2 x i64> %v
   ret void
 }
 
-; TODO: The generated code is wrong because we've lost the scaling applied to
-; %scalar_offset when it's used to calculate %ptrs.
 define void @masked_scatter_nxv2i64_null_with_vec_plus_scalar_offsets(<vscale x 2 x i64> %vector_offsets, i64 %scalar_offset, <vscale x 2 x i1> %pg, <vscale x 2 x i64> %data) #0 {
 ; CHECK-LABEL: masked_scatter_nxv2i64_null_with_vec_plus_scalar_offsets:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    st1d { z1.d }, p0, [x0, z0.d, lsl #3]
+; CHECK-NEXT:    mov x8, xzr
+; CHECK-NEXT:    mov z2.d, x0
+; CHECK-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NEXT:    st1d { z1.d }, p0, [x8, z0.d, lsl #3]
 ; CHECK-NEXT:    ret
   %scalar_offset.ins = insertelement <vscale x 2 x i64> undef, i64 %scalar_offset, i64 0
   %scalar_offset.splat = shufflevector <vscale x 2 x i64> %scalar_offset.ins, <vscale x 2 x i64> undef, <vscale x 2 x i32> zeroinitializer
@@ -440,12 +441,11 @@ define void @masked_scatter_nxv2i64_null_with_vec_plus_scalar_offsets(<vscale x
   ret void
 }
 
-; TODO: The generated code is wrong because we've lost the scaling applied to
-; constant scalar offset (i.e. i64 1)  when it's used to calculate %ptrs.
 define void @masked_scatter_nxv2i64_null_with__vec_plus_imm_offsets(<vscale x 2 x i64> %vector_offsets, <vscale x 2 x i1> %pg, <vscale x 2 x i64> %data) #0 {
 ; CHECK-LABEL: masked_scatter_nxv2i64_null_with__vec_plus_imm_offsets:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov w8, #1
+; CHECK-NEXT:    mov x8, xzr
+; CHECK-NEXT:    add z0.d, z0.d, #1 // =0x1
 ; CHECK-NEXT:    st1d { z1.d }, p0, [x8, z0.d, lsl #3]
 ; CHECK-NEXT:    ret
   %scalar_offset.ins = insertelement <vscale x 2 x i64> undef, i64 1, i64 0