From b9c473307954e62f5f756aa7af315d0ffe707634 Mon Sep 17 00:00:00 2001 From: Philip Reames Date: Thu, 22 Sep 2022 18:35:00 -0700 Subject: [PATCH] [DAG] Move one-use add of splat to base of scatter/gather This extends the uniform base transform used with scatter/gather to support one-use vector adds-of-splats with a non-zero base. This has the effect of essentially reassociating an add from vector to scalar domain. The motivation is to improve the lowering of scatter/gather operations fed by complex geps. Differential Revision: https://reviews.llvm.org/D134472 --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 30 +++++++++----- .../AArch64/sve-gather-scatter-addr-opts.ll | 48 ++++++++++------------ llvm/test/CodeGen/RISCV/rvv/mscatter-combine.ll | 26 ++++++------ 3 files changed, 53 insertions(+), 51 deletions(-) diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 2aad144..dbaf954 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -10668,23 +10668,33 @@ static SDValue ConvertSelectToConcatVector(SDNode *N, SelectionDAG &DAG) { } bool refineUniformBase(SDValue &BasePtr, SDValue &Index, bool IndexIsScaled, - SelectionDAG &DAG) { - if (!isNullConstant(BasePtr) || Index.getOpcode() != ISD::ADD) + SelectionDAG &DAG, const SDLoc &DL) { + if (Index.getOpcode() != ISD::ADD) return false; // Only perform the transformation when existing operands can be reused. if (IndexIsScaled) return false; + if (!isNullConstant(BasePtr) && !Index.hasOneUse()) + return false; + + EVT VT = BasePtr.getValueType(); if (SDValue SplatVal = DAG.getSplatValue(Index.getOperand(0)); - SplatVal && SplatVal.getValueType() == BasePtr.getValueType()) { - BasePtr = SplatVal; + SplatVal && SplatVal.getValueType() == VT) { + if (isNullConstant(BasePtr)) + BasePtr = SplatVal; + else + BasePtr = DAG.getNode(ISD::ADD, DL, VT, BasePtr, SplatVal); Index = Index.getOperand(1); return true; } if (SDValue SplatVal = DAG.getSplatValue(Index.getOperand(1)); - SplatVal && SplatVal.getValueType() == BasePtr.getValueType()) { - BasePtr = SplatVal; + SplatVal && SplatVal.getValueType() == VT) { + if (isNullConstant(BasePtr)) + BasePtr = SplatVal; + else + BasePtr = DAG.getNode(ISD::ADD, DL, VT, BasePtr, SplatVal); Index = Index.getOperand(0); return true; } @@ -10739,7 +10749,7 @@ SDValue DAGCombiner::visitVPSCATTER(SDNode *N) { if (ISD::isConstantSplatVectorAllZeros(Mask.getNode())) return Chain; - if (refineUniformBase(BasePtr, Index, MSC->isIndexScaled(), DAG)) { + if (refineUniformBase(BasePtr, Index, MSC->isIndexScaled(), DAG, DL)) { SDValue Ops[] = {Chain, StoreVal, BasePtr, Index, Scale, Mask, VL}; return DAG.getScatterVP(DAG.getVTList(MVT::Other), MSC->getMemoryVT(), DL, Ops, MSC->getMemOperand(), IndexType); @@ -10769,7 +10779,7 @@ SDValue DAGCombiner::visitMSCATTER(SDNode *N) { if (ISD::isConstantSplatVectorAllZeros(Mask.getNode())) return Chain; - if (refineUniformBase(BasePtr, Index, MSC->isIndexScaled(), DAG)) { + if (refineUniformBase(BasePtr, Index, MSC->isIndexScaled(), DAG, DL)) { SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale}; return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(), DL, Ops, MSC->getMemOperand(), IndexType, @@ -10861,7 +10871,7 @@ SDValue DAGCombiner::visitVPGATHER(SDNode *N) { ISD::MemIndexType IndexType = MGT->getIndexType(); SDLoc DL(N); - if (refineUniformBase(BasePtr, Index, MGT->isIndexScaled(), DAG)) { + if (refineUniformBase(BasePtr, Index, MGT->isIndexScaled(), DAG, DL)) { SDValue Ops[] = {Chain, BasePtr, Index, Scale, Mask, VL}; return DAG.getGatherVP( DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL, @@ -10893,7 +10903,7 @@ SDValue DAGCombiner::visitMGATHER(SDNode *N) { if (ISD::isConstantSplatVectorAllZeros(Mask.getNode())) return CombineTo(N, PassThru, MGT->getChain()); - if (refineUniformBase(BasePtr, Index, MGT->isIndexScaled(), DAG)) { + if (refineUniformBase(BasePtr, Index, MGT->isIndexScaled(), DAG, DL)) { SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale}; return DAG.getMaskedGather( DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL, diff --git a/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll b/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll index 9257a6a5..bdede03 100644 --- a/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll +++ b/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll @@ -105,18 +105,16 @@ define void @scatter_i8_index_offset_maximum_plus_one(i8* %base, i64 %offset, undef, i64 %offset, i32 0 %t1 = shufflevector %t0, undef, zeroinitializer @@ -138,18 +136,16 @@ define void @scatter_i8_index_offset_minimum_minus_one(i8* %base, i64 %offset, < ; CHECK-NEXT: mov x9, #-2 ; CHECK-NEXT: lsr x8, x8, #4 ; CHECK-NEXT: movk x9, #64511, lsl #16 -; CHECK-NEXT: add x10, x0, x1 +; CHECK-NEXT: add x11, x0, x1 +; CHECK-NEXT: mov x10, #-33554433 +; CHECK-NEXT: madd x8, x8, x9, x11 ; CHECK-NEXT: punpklo p1.h, p0.b -; CHECK-NEXT: mul x8, x8, x9 -; CHECK-NEXT: mov x9, #-33554433 -; CHECK-NEXT: uunpklo z3.d, z0.s +; CHECK-NEXT: uunpklo z2.d, z0.s ; CHECK-NEXT: punpkhi p0.h, p0.b +; CHECK-NEXT: index z1.d, #0, x10 ; CHECK-NEXT: uunpkhi z0.d, z0.s -; CHECK-NEXT: index z1.d, #0, x9 -; CHECK-NEXT: mov z2.d, x8 -; CHECK-NEXT: st1b { z3.d }, p1, [x10, z1.d] -; CHECK-NEXT: add z2.d, z1.d, z2.d -; CHECK-NEXT: st1b { z0.d }, p0, [x10, z2.d] +; CHECK-NEXT: st1b { z2.d }, p1, [x11, z1.d] +; CHECK-NEXT: st1b { z0.d }, p0, [x8, z1.d] ; CHECK-NEXT: ret %t0 = insertelement undef, i64 %offset, i32 0 %t1 = shufflevector %t0, undef, zeroinitializer @@ -170,18 +166,16 @@ define void @scatter_i8_index_stride_too_big(i8* %base, i64 %offset, undef, i64 %offset, i32 0 %t1 = shufflevector %t0, undef, zeroinitializer diff --git a/llvm/test/CodeGen/RISCV/rvv/mscatter-combine.ll b/llvm/test/CodeGen/RISCV/rvv/mscatter-combine.ll index cc7b281..7991eb9 100644 --- a/llvm/test/CodeGen/RISCV/rvv/mscatter-combine.ll +++ b/llvm/test/CodeGen/RISCV/rvv/mscatter-combine.ll @@ -10,25 +10,23 @@ define void @complex_gep(ptr %p, %vec.ind, ; RV32-LABEL: complex_gep: ; RV32: # %bb.0: ; RV32-NEXT: vsetvli a1, zero, e32, m1, ta, mu -; RV32-NEXT: vmv.v.x v10, a0 -; RV32-NEXT: vnsrl.wi v11, v8, 0 -; RV32-NEXT: li a0, 48 -; RV32-NEXT: vmadd.vx v11, a0, v10 -; RV32-NEXT: vmv.v.i v8, 0 -; RV32-NEXT: li a0, 28 -; RV32-NEXT: vsoxei32.v v8, (a0), v11, v0.t +; RV32-NEXT: vnsrl.wi v10, v8, 0 +; RV32-NEXT: li a1, 48 +; RV32-NEXT: vmul.vx v8, v10, a1 +; RV32-NEXT: addi a0, a0, 28 +; RV32-NEXT: vmv.v.i v9, 0 +; RV32-NEXT: vsoxei32.v v9, (a0), v8, v0.t ; RV32-NEXT: ret ; ; RV64-LABEL: complex_gep: ; RV64: # %bb.0: -; RV64-NEXT: vsetvli a1, zero, e64, m2, ta, mu -; RV64-NEXT: vmv.v.x v10, a0 -; RV64-NEXT: li a0, 56 -; RV64-NEXT: vmacc.vx v10, a0, v8 +; RV64-NEXT: li a1, 56 +; RV64-NEXT: vsetvli a2, zero, e64, m2, ta, mu +; RV64-NEXT: vmul.vx v8, v8, a1 +; RV64-NEXT: addi a0, a0, 32 ; RV64-NEXT: vsetvli zero, zero, e32, m1, ta, mu -; RV64-NEXT: vmv.v.i v8, 0 -; RV64-NEXT: li a0, 32 -; RV64-NEXT: vsoxei64.v v8, (a0), v10, v0.t +; RV64-NEXT: vmv.v.i v10, 0 +; RV64-NEXT: vsoxei64.v v10, (a0), v8, v0.t ; RV64-NEXT: ret %gep = getelementptr inbounds %struct, ptr %p, %vec.ind, i32 5 call void @llvm.masked.scatter.nxv2i32.nxv2p0( zeroinitializer, %gep, i32 8, %m) -- 2.7.4