From 274ac9d40e79f25ac8c928732875708b5bac8f09 Mon Sep 17 00:00:00 2001 From: Jun Ma Date: Thu, 1 Apr 2021 19:44:59 +0800 Subject: [PATCH] [AArch64][SVE] Lowering sve.dot to DOT node Differential Revision: https://reviews.llvm.org/D99699 --- llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 3 ++ llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 30 +++++++++++++--- llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td | 4 +-- .../CodeGen/AArch64/sve-intrinsics-int-arith.ll | 40 ++++++++++++++++++++++ 4 files changed, 71 insertions(+), 6 deletions(-) diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index 943e06e..064f8cb 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -145,6 +145,9 @@ bool ISD::isConstantSplatVector(const SDNode *N, APInt &SplatVal) { if (auto *Op0 = dyn_cast(N->getOperand(0))) { SplatVal = Op0->getAPIntValue().truncOrSelf(EltSize); return true; + } else if (auto *Op0 = dyn_cast(N->getOperand(0))) { + SplatVal = Op0->getValueAPF().bitcastToAPInt().truncOrSelf(EltSize); + return true; } } diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 9364bfb..b40fb7e 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -2153,6 +2153,24 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter( // Lowering Code //===----------------------------------------------------------------------===// +/// isZerosVector - Check whether SDNode N is a zero-filled vector. +static bool isZerosVector(const SDNode *N) { + // Look through a bit convert. + while (N->getOpcode() == ISD::BITCAST) + N = N->getOperand(0).getNode(); + + if (ISD::isConstantSplatVectorAllZeros(N)) + return true; + + if (N->getOpcode() != AArch64ISD::DUP) + return false; + + auto Opnd0 = N->getOperand(0); + auto *CINT = dyn_cast(Opnd0); + auto *CFP = dyn_cast(Opnd0); + return (CINT && CINT->isNullValue()) || (CFP && CFP->isZero()); +} + /// changeIntCCToAArch64CC - Convert a DAG integer condition code to an AArch64 /// CC static AArch64CC::CondCode changeIntCCToAArch64CC(ISD::CondCode CC) { @@ -3924,9 +3942,13 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, Op.getOperand(2)); } case Intrinsic::aarch64_neon_sdot: - case Intrinsic::aarch64_neon_udot: { - unsigned Opcode = IntNo == Intrinsic::aarch64_neon_udot ? AArch64ISD::UDOT - : AArch64ISD::SDOT; + case Intrinsic::aarch64_neon_udot: + case Intrinsic::aarch64_sve_sdot: + case Intrinsic::aarch64_sve_udot: { + unsigned Opcode = (IntNo == Intrinsic::aarch64_neon_udot || + IntNo == Intrinsic::aarch64_sve_udot) + ? AArch64ISD::UDOT + : AArch64ISD::SDOT; return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1), Op.getOperand(2), Op.getOperand(3)); } @@ -13340,7 +13362,7 @@ static SDValue performAddDotCombine(SDNode *N, SelectionDAG &DAG) { auto isZeroDot = [](SDValue Dot) { return (Dot.getOpcode() == AArch64ISD::UDOT || Dot.getOpcode() == AArch64ISD::SDOT) && - ISD::isBuildVectorAllZeros(Dot.getOperand(0).getNode()); + isZerosVector(Dot.getOperand(0).getNode()); }; if (!isZeroDot(Dot)) std::swap(Dot, A); diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td index d3a607d..df4e2cd 100644 --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -353,8 +353,8 @@ let Predicates = [HasSVE] in { defm SDIV_ZPZZ : sve_int_bin_pred_sd; defm UDIV_ZPZZ : sve_int_bin_pred_sd; - defm SDOT_ZZZ : sve_intx_dot<0b0, "sdot", int_aarch64_sve_sdot>; - defm UDOT_ZZZ : sve_intx_dot<0b1, "udot", int_aarch64_sve_udot>; + defm SDOT_ZZZ : sve_intx_dot<0b0, "sdot", AArch64sdot>; + defm UDOT_ZZZ : sve_intx_dot<0b1, "udot", AArch64udot>; defm SDOT_ZZZI : sve_intx_dot_by_indexed_elem<0b0, "sdot", int_aarch64_sve_sdot_lane>; defm UDOT_ZZZI : sve_intx_dot_by_indexed_elem<0b1, "udot", int_aarch64_sve_udot_lane>; diff --git a/llvm/test/CodeGen/AArch64/sve-intrinsics-int-arith.ll b/llvm/test/CodeGen/AArch64/sve-intrinsics-int-arith.ll index fa67d92..0c8c7c2 100644 --- a/llvm/test/CodeGen/AArch64/sve-intrinsics-int-arith.ll +++ b/llvm/test/CodeGen/AArch64/sve-intrinsics-int-arith.ll @@ -114,6 +114,26 @@ define @sdot_i64( %a, %b ret %out } +define @test_sdot_i64_zero( %a, %b, %c) { +; CHECK-LABEL: test_sdot_i64_zero: +; CHECK: sdot z0.d, z1.h, z2.h +; CHECK-NEXT: ret +entry: + %vdot1.i = call @llvm.aarch64.sve.sdot.nxv2i64( zeroinitializer, %b, %c) + %ret = add %vdot1.i, %a + ret %ret +} + +define @test_sdot_i32_zero( %a, %b, %c) { +; CHECK-LABEL: test_sdot_i32_zero: +; CHECK: sdot z0.s, z1.b, z2.b +; CHECK-NEXT: ret +entry: + %vdot1.i = call @llvm.aarch64.sve.sdot.nxv4i32( zeroinitializer, %b, %c) + %ret = add %vdot1.i, %a + ret %ret +} + ; SDOT (Indexed) define @sdot_lane_i32( %a, %b, %c) { @@ -236,6 +256,26 @@ define @udot_i64( %a, %b ret %out } +define @test_udot_i64_zero( %a, %b, %c) { +; CHECK-LABEL: test_udot_i64_zero: +; CHECK: udot z0.d, z1.h, z2.h +; CHECK-NEXT: ret +entry: + %vdot1.i = call @llvm.aarch64.sve.udot.nxv2i64( zeroinitializer, %b, %c) + %ret = add %vdot1.i, %a + ret %ret +} + +define @test_udot_i32_zero( %a, %b, %c) { +; CHECK-LABEL: test_udot_i32_zero: +; CHECK: udot z0.s, z1.b, z2.b +; CHECK-NEXT: ret +entry: + %vdot1.i = call @llvm.aarch64.sve.udot.nxv4i32( zeroinitializer, %b, %c) + %ret = add %vdot1.i, %a + ret %ret +} + ; UDOT (Indexed) define @udot_lane_i32( %a, %b, %c) { -- 2.7.4