[GlobalISel] Implement fewerElements legalization for vector reductions.
authorAmara Emerson <amara@apple.com>
Sun, 21 Feb 2021 22:17:03 +0000 (14:17 -0800)
committerAmara Emerson <amara@apple.com>
Tue, 30 Mar 2021 18:19:21 +0000 (11:19 -0700)
This patch adds 3 methods, one for power-of-2 vectors which use tree
reductions using vector ops, before a final reduction op. For non-pow-2
types it generates multiple narrow reductions and combines the values with
scalar ops.

Differential Revision: https://reviews.llvm.org/D97163

llvm/include/llvm/CodeGen/GlobalISel/LegalizerHelper.h
llvm/include/llvm/CodeGen/GlobalISel/Utils.h
llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
llvm/test/CodeGen/AArch64/GlobalISel/legalize-reduce-add.mir
llvm/test/CodeGen/AArch64/GlobalISel/legalize-reduce-fadd.mir
llvm/test/CodeGen/AArch64/arm64-vabs.ll

index 200d660..d276fab 100644 (file)
@@ -249,6 +249,10 @@ private:
 
   void changeOpcode(MachineInstr &MI, unsigned NewOpcode);
 
+  LegalizeResult tryNarrowPow2Reduction(MachineInstr &MI, Register SrcReg,
+                                        LLT SrcTy, LLT NarrowTy,
+                                        unsigned ScalarOpc);
+
 public:
   /// Return the alignment to use for a stack temporary object with the given
   /// type.
@@ -319,6 +323,9 @@ public:
   LegalizeResult narrowScalarShiftByConstant(MachineInstr &MI, const APInt &Amt,
                                              LLT HalfTy, LLT ShiftAmtTy);
 
+  LegalizeResult fewerElementsVectorReductions(MachineInstr &MI,
+                                               unsigned TypeIdx, LLT NarrowTy);
+
   LegalizeResult narrowScalarShift(MachineInstr &MI, unsigned TypeIdx, LLT Ty);
   LegalizeResult narrowScalarAddSub(MachineInstr &MI, unsigned TypeIdx,
                                     LLT NarrowTy);
index ddf7835..19a5589 100644 (file)
@@ -44,6 +44,39 @@ class TargetRegisterClass;
 class ConstantFP;
 class APFloat;
 
+// Convenience macros for dealing with vector reduction opcodes.
+#define GISEL_VECREDUCE_CASES_ALL                                              \
+  case TargetOpcode::G_VECREDUCE_SEQ_FADD:                                     \
+  case TargetOpcode::G_VECREDUCE_SEQ_FMUL:                                     \
+  case TargetOpcode::G_VECREDUCE_FADD:                                         \
+  case TargetOpcode::G_VECREDUCE_FMUL:                                         \
+  case TargetOpcode::G_VECREDUCE_FMAX:                                         \
+  case TargetOpcode::G_VECREDUCE_FMIN:                                         \
+  case TargetOpcode::G_VECREDUCE_ADD:                                          \
+  case TargetOpcode::G_VECREDUCE_MUL:                                          \
+  case TargetOpcode::G_VECREDUCE_AND:                                          \
+  case TargetOpcode::G_VECREDUCE_OR:                                           \
+  case TargetOpcode::G_VECREDUCE_XOR:                                          \
+  case TargetOpcode::G_VECREDUCE_SMAX:                                         \
+  case TargetOpcode::G_VECREDUCE_SMIN:                                         \
+  case TargetOpcode::G_VECREDUCE_UMAX:                                         \
+  case TargetOpcode::G_VECREDUCE_UMIN:
+
+#define GISEL_VECREDUCE_CASES_NONSEQ                                           \
+  case TargetOpcode::G_VECREDUCE_FADD:                                         \
+  case TargetOpcode::G_VECREDUCE_FMUL:                                         \
+  case TargetOpcode::G_VECREDUCE_FMAX:                                         \
+  case TargetOpcode::G_VECREDUCE_FMIN:                                         \
+  case TargetOpcode::G_VECREDUCE_ADD:                                          \
+  case TargetOpcode::G_VECREDUCE_MUL:                                          \
+  case TargetOpcode::G_VECREDUCE_AND:                                          \
+  case TargetOpcode::G_VECREDUCE_OR:                                           \
+  case TargetOpcode::G_VECREDUCE_XOR:                                          \
+  case TargetOpcode::G_VECREDUCE_SMAX:                                         \
+  case TargetOpcode::G_VECREDUCE_SMIN:                                         \
+  case TargetOpcode::G_VECREDUCE_UMAX:                                         \
+  case TargetOpcode::G_VECREDUCE_UMIN:
+
 /// Try to constrain Reg to the specified register class. If this fails,
 /// create a new virtual register in the correct class.
 ///
index 7680f61..9eb4c80 100644 (file)
@@ -17,6 +17,7 @@
 #include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h"
 #include "llvm/CodeGen/GlobalISel/LegalizerInfo.h"
 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
+#include "llvm/CodeGen/GlobalISel/Utils.h"
 #include "llvm/CodeGen/MachineRegisterInfo.h"
 #include "llvm/CodeGen/TargetFrameLowering.h"
 #include "llvm/CodeGen/TargetInstrInfo.h"
@@ -4207,11 +4208,139 @@ LegalizerHelper::fewerElementsVector(MachineInstr &MI, unsigned TypeIdx,
     return reduceLoadStoreWidth(MI, TypeIdx, NarrowTy);
   case G_SEXT_INREG:
     return fewerElementsVectorSextInReg(MI, TypeIdx, NarrowTy);
+  GISEL_VECREDUCE_CASES_NONSEQ
+    return fewerElementsVectorReductions(MI, TypeIdx, NarrowTy);
   default:
     return UnableToLegalize;
   }
 }
 
+LegalizerHelper::LegalizeResult LegalizerHelper::fewerElementsVectorReductions(
+    MachineInstr &MI, unsigned int TypeIdx, LLT NarrowTy) {
+  unsigned Opc = MI.getOpcode();
+  assert(Opc != TargetOpcode::G_VECREDUCE_SEQ_FADD &&
+         Opc != TargetOpcode::G_VECREDUCE_SEQ_FMUL &&
+         "Sequential reductions not expected");
+
+  if (TypeIdx != 1)
+    return UnableToLegalize;
+
+  // The semantics of the normal non-sequential reductions allow us to freely
+  // re-associate the operation.
+  Register SrcReg = MI.getOperand(1).getReg();
+  LLT SrcTy = MRI.getType(SrcReg);
+  Register DstReg = MI.getOperand(0).getReg();
+  LLT DstTy = MRI.getType(DstReg);
+
+  if (SrcTy.getNumElements() % NarrowTy.getNumElements() != 0)
+    return UnableToLegalize;
+
+  SmallVector<Register> SplitSrcs;
+  const unsigned NumParts = SrcTy.getNumElements() / NarrowTy.getNumElements();
+  extractParts(SrcReg, NarrowTy, NumParts, SplitSrcs);
+  SmallVector<Register> PartialReductions;
+  for (unsigned Part = 0; Part < NumParts; ++Part) {
+    PartialReductions.push_back(
+        MIRBuilder.buildInstr(Opc, {DstTy}, {SplitSrcs[Part]}).getReg(0));
+  }
+
+  unsigned ScalarOpc;
+  switch (Opc) {
+  case TargetOpcode::G_VECREDUCE_FADD:
+    ScalarOpc = TargetOpcode::G_FADD;
+    break;
+  case TargetOpcode::G_VECREDUCE_FMUL:
+    ScalarOpc = TargetOpcode::G_FMUL;
+    break;
+  case TargetOpcode::G_VECREDUCE_FMAX:
+    ScalarOpc = TargetOpcode::G_FMAXNUM;
+    break;
+  case TargetOpcode::G_VECREDUCE_FMIN:
+    ScalarOpc = TargetOpcode::G_FMINNUM;
+    break;
+  case TargetOpcode::G_VECREDUCE_ADD:
+    ScalarOpc = TargetOpcode::G_ADD;
+    break;
+  case TargetOpcode::G_VECREDUCE_MUL:
+    ScalarOpc = TargetOpcode::G_MUL;
+    break;
+  case TargetOpcode::G_VECREDUCE_AND:
+    ScalarOpc = TargetOpcode::G_AND;
+    break;
+  case TargetOpcode::G_VECREDUCE_OR:
+    ScalarOpc = TargetOpcode::G_OR;
+    break;
+  case TargetOpcode::G_VECREDUCE_XOR:
+    ScalarOpc = TargetOpcode::G_XOR;
+    break;
+  case TargetOpcode::G_VECREDUCE_SMAX:
+    ScalarOpc = TargetOpcode::G_SMAX;
+    break;
+  case TargetOpcode::G_VECREDUCE_SMIN:
+    ScalarOpc = TargetOpcode::G_SMIN;
+    break;
+  case TargetOpcode::G_VECREDUCE_UMAX:
+    ScalarOpc = TargetOpcode::G_UMAX;
+    break;
+  case TargetOpcode::G_VECREDUCE_UMIN:
+    ScalarOpc = TargetOpcode::G_UMIN;
+    break;
+  default:
+    LLVM_DEBUG(dbgs() << "Can't legalize: unknown reduction kind.\n");
+    return UnableToLegalize;
+  }
+
+  // If the types involved are powers of 2, we can generate intermediate vector
+  // ops, before generating a final reduction operation.
+  if (isPowerOf2_32(SrcTy.getNumElements()) &&
+      isPowerOf2_32(NarrowTy.getNumElements())) {
+    return tryNarrowPow2Reduction(MI, SrcReg, SrcTy, NarrowTy, ScalarOpc);
+  }
+
+  Register Acc = PartialReductions[0];
+  for (unsigned Part = 1; Part < NumParts; ++Part) {
+    if (Part == NumParts - 1) {
+      MIRBuilder.buildInstr(ScalarOpc, {DstReg},
+                            {Acc, PartialReductions[Part]});
+    } else {
+      Acc = MIRBuilder
+                .buildInstr(ScalarOpc, {DstTy}, {Acc, PartialReductions[Part]})
+                .getReg(0);
+    }
+  }
+  MI.eraseFromParent();
+  return Legalized;
+}
+
+LegalizerHelper::LegalizeResult
+LegalizerHelper::tryNarrowPow2Reduction(MachineInstr &MI, Register SrcReg,
+                                        LLT SrcTy, LLT NarrowTy,
+                                        unsigned ScalarOpc) {
+  SmallVector<Register> SplitSrcs;
+  // Split the sources into NarrowTy size pieces.
+  extractParts(SrcReg, NarrowTy,
+               SrcTy.getNumElements() / NarrowTy.getNumElements(), SplitSrcs);
+  // We're going to do a tree reduction using vector operations until we have
+  // one NarrowTy size value left.
+  while (SplitSrcs.size() > 1) {
+    SmallVector<Register> PartialRdxs;
+    for (unsigned Idx = 0; Idx < SplitSrcs.size()-1; Idx += 2) {
+      Register LHS = SplitSrcs[Idx];
+      Register RHS = SplitSrcs[Idx + 1];
+      // Create the intermediate vector op.
+      Register Res =
+          MIRBuilder.buildInstr(ScalarOpc, {NarrowTy}, {LHS, RHS}).getReg(0);
+      PartialRdxs.push_back(Res);
+    }
+    SplitSrcs = std::move(PartialRdxs);
+  }
+  // Finally generate the requested NarrowTy based reduction.
+  Observer.changingInstr(MI);
+  MI.getOperand(1).setReg(SplitSrcs[0]);
+  Observer.changedInstr(MI);
+  return Legalized;
+}
+
 LegalizerHelper::LegalizeResult
 LegalizerHelper::narrowScalarShiftByConstant(MachineInstr &MI, const APInt &Amt,
                                              const LLT HalfTy, const LLT AmtTy) {
index 9f39160..07067c3 100644 (file)
@@ -691,11 +691,15 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
   getActionDefinitionsBuilder(G_VECREDUCE_FADD)
       // We only have FADDP to do reduction-like operations. Lower the rest.
       .legalFor({{s32, v2s32}, {s64, v2s64}})
+      .clampMaxNumElements(1, s64, 2)
+      .clampMaxNumElements(1, s32, 2)
       .lower();
 
   getActionDefinitionsBuilder(G_VECREDUCE_ADD)
       .legalFor(
           {{s8, v16s8}, {s16, v8s16}, {s32, v4s32}, {s32, v2s32}, {s64, v2s64}})
+      .clampMaxNumElements(1, s64, 2)
+      .clampMaxNumElements(1, s32, 4)
       .lower();
 
   getActionDefinitionsBuilder({G_UADDSAT, G_USUBSAT})
index 2d83db4..eba3a38 100644 (file)
@@ -109,3 +109,65 @@ body:             |
     RET_ReallyLR implicit $w0
 
 ...
+---
+name:            test_v8i64
+alignment:       4
+tracksRegLiveness: true
+body:             |
+  bb.1:
+    liveins: $q0, $q1, $q2, $q3
+    ; This is a power-of-2 legalization, so use a tree reduction.
+    ; CHECK-LABEL: name: test_v8i64
+    ; CHECK: liveins: $q0, $q1, $q2, $q3
+    ; CHECK: [[COPY:%[0-9]+]]:_(<2 x s64>) = COPY $q0
+    ; CHECK: [[COPY1:%[0-9]+]]:_(<2 x s64>) = COPY $q1
+    ; CHECK: [[COPY2:%[0-9]+]]:_(<2 x s64>) = COPY $q2
+    ; CHECK: [[COPY3:%[0-9]+]]:_(<2 x s64>) = COPY $q3
+    ; CHECK: [[ADD:%[0-9]+]]:_(<2 x s64>) = G_ADD [[COPY]], [[COPY1]]
+    ; CHECK: [[ADD1:%[0-9]+]]:_(<2 x s64>) = G_ADD [[COPY2]], [[COPY3]]
+    ; CHECK: [[ADD2:%[0-9]+]]:_(<2 x s64>) = G_ADD [[ADD]], [[ADD1]]
+    ; CHECK: [[VECREDUCE_ADD:%[0-9]+]]:_(s64) = G_VECREDUCE_ADD [[ADD2]](<2 x s64>)
+    ; CHECK: $x0 = COPY [[VECREDUCE_ADD]](s64)
+    ; CHECK: RET_ReallyLR implicit $x0
+    %0:_(<2 x s64>) = COPY $q0
+    %1:_(<2 x s64>) = COPY $q1
+    %2:_(<2 x s64>) = COPY $q2
+    %3:_(<2 x s64>) = COPY $q3
+    %4:_(<4 x s64>) = G_CONCAT_VECTORS %0(<2 x s64>), %1(<2 x s64>)
+    %5:_(<4 x s64>) = G_CONCAT_VECTORS %2(<2 x s64>), %3(<2 x s64>)
+    %6:_(<8 x s64>) = G_CONCAT_VECTORS %4(<4 x s64>), %5(<4 x s64>)
+    %7:_(s64) = G_VECREDUCE_ADD %6(<8 x s64>)
+    $x0 = COPY %7(s64)
+    RET_ReallyLR implicit $x0
+
+...
+---
+name:            test_v6i64
+alignment:       4
+tracksRegLiveness: true
+body:             |
+  bb.1:
+    liveins: $q0, $q1, $q2, $q3
+    ; This is a non-power-of-2 legalization, generate multiple vector reductions
+    ; and combine them with scalar ops.
+    ; CHECK-LABEL: name: test_v6i64
+    ; CHECK: liveins: $q0, $q1, $q2, $q3
+    ; CHECK: [[COPY:%[0-9]+]]:_(<2 x s64>) = COPY $q0
+    ; CHECK: [[COPY1:%[0-9]+]]:_(<2 x s64>) = COPY $q1
+    ; CHECK: [[COPY2:%[0-9]+]]:_(<2 x s64>) = COPY $q2
+    ; CHECK: [[VECREDUCE_ADD:%[0-9]+]]:_(s64) = G_VECREDUCE_ADD [[COPY]](<2 x s64>)
+    ; CHECK: [[VECREDUCE_ADD1:%[0-9]+]]:_(s64) = G_VECREDUCE_ADD [[COPY1]](<2 x s64>)
+    ; CHECK: [[VECREDUCE_ADD2:%[0-9]+]]:_(s64) = G_VECREDUCE_ADD [[COPY2]](<2 x s64>)
+    ; CHECK: [[ADD:%[0-9]+]]:_(s64) = G_ADD [[VECREDUCE_ADD]], [[VECREDUCE_ADD1]]
+    ; CHECK: [[ADD1:%[0-9]+]]:_(s64) = G_ADD [[ADD]], [[VECREDUCE_ADD2]]
+    ; CHECK: $x0 = COPY [[ADD1]](s64)
+    ; CHECK: RET_ReallyLR implicit $x0
+    %0:_(<2 x s64>) = COPY $q0
+    %1:_(<2 x s64>) = COPY $q1
+    %2:_(<2 x s64>) = COPY $q2
+    %3:_(<6 x s64>) = G_CONCAT_VECTORS %0(<2 x s64>), %1(<2 x s64>), %2(<2 x s64>)
+    %4:_(s64) = G_VECREDUCE_ADD %3(<6 x s64>)
+    $x0 = COPY %4(s64)
+    RET_ReallyLR implicit $x0
+
+...
index 9750ac8..091f0e2 100644 (file)
@@ -39,3 +39,35 @@ body:             |
     RET_ReallyLR implicit $x0
 
 ...
+---
+name:            fadd_v8s64
+alignment:       4
+tracksRegLiveness: true
+body:             |
+  bb.1:
+    liveins: $q0, $q1, $q2, $q3
+    ; This is a power-of-2 legalization, so use a tree reduction.
+    ; CHECK-LABEL: name: fadd_v8s64
+    ; CHECK: liveins: $q0, $q1, $q2, $q3
+    ; CHECK: [[COPY:%[0-9]+]]:_(<2 x s64>) = COPY $q0
+    ; CHECK: [[COPY1:%[0-9]+]]:_(<2 x s64>) = COPY $q1
+    ; CHECK: [[COPY2:%[0-9]+]]:_(<2 x s64>) = COPY $q2
+    ; CHECK: [[COPY3:%[0-9]+]]:_(<2 x s64>) = COPY $q3
+    ; CHECK: [[FADD:%[0-9]+]]:_(<2 x s64>) = G_FADD [[COPY]], [[COPY1]]
+    ; CHECK: [[FADD1:%[0-9]+]]:_(<2 x s64>) = G_FADD [[COPY2]], [[COPY3]]
+    ; CHECK: [[FADD2:%[0-9]+]]:_(<2 x s64>) = G_FADD [[FADD]], [[FADD1]]
+    ; CHECK: [[VECREDUCE_FADD:%[0-9]+]]:_(s64) = G_VECREDUCE_FADD [[FADD2]](<2 x s64>)
+    ; CHECK: $x0 = COPY [[VECREDUCE_FADD]](s64)
+    ; CHECK: RET_ReallyLR implicit $x0
+    %0:_(<2 x s64>) = COPY $q0
+    %1:_(<2 x s64>) = COPY $q1
+    %2:_(<2 x s64>) = COPY $q2
+    %3:_(<2 x s64>) = COPY $q3
+    %4:_(<4 x s64>) = G_CONCAT_VECTORS %0(<2 x s64>), %1(<2 x s64>)
+    %5:_(<4 x s64>) = G_CONCAT_VECTORS %2(<2 x s64>), %3(<2 x s64>)
+    %6:_(<8 x s64>) = G_CONCAT_VECTORS %4(<4 x s64>), %5(<4 x s64>)
+    %7:_(s64) = G_VECREDUCE_FADD %6(<8 x s64>)
+    $x0 = COPY %7(s64)
+    RET_ReallyLR implicit $x0
+
+...
index 954e724..f2ba768 100644 (file)
@@ -1,7 +1,6 @@
 ; RUN: llc < %s -mtriple=arm64-eabi -aarch64-neon-syntax=apple | FileCheck -check-prefixes=CHECK,DAG %s
 ; RUN: llc < %s -global-isel -global-isel-abort=2 -pass-remarks-missed=gisel* -mtriple=arm64-eabi -aarch64-neon-syntax=apple 2>&1 | FileCheck %s --check-prefixes=FALLBACK,CHECK,GISEL
 
-; FALLBACK-NOT: remark:{{.*}} G_ZEXT
 ; FALLBACK-NOT: remark:{{.*}} sabdl8h
 define <8 x i16> @sabdl8h(<8 x i8>* %A, <8 x i8>* %B) nounwind {
 ;CHECK-LABEL: sabdl8h: