From c5c2de287e5ff1803a10d94b0ab17b579442726d Mon Sep 17 00:00:00 2001 From: Quentin Colombet Date: Mon, 3 Oct 2022 23:28:34 +0000 Subject: [PATCH] [RISCV][ISel] Fold extensions when all the users can consume them This patch allows the combines that fold extensions in binary operations to have more than one use. The approach here is pretty conservative: if all the users of an extension can fold the extension, then the folding is done, otherwise we don't fold. This is the first step towards avoiding the one-use limitation. As a result, we make a decision to fold/don't fold for a web of instructions. An instruction is part of the web of instructions as soon as it consumes an extension that needs to be folded for all its users. Because of how SDISel works a web of instructions can be visited over and over. More precisely, if the folding happens, it happens for the whole web and that's the end of it, but if the folding fails, the whole web may be revisited when another member of the web is visited. To avoid a compile time explosion in pathological cases, we bail out earlier for webs that are bigger than a given threshold (arbitrarily set at 18 for now.) This size can be changed using `--riscv-lower-ext-max-web-size=`. At the current time, I didn't see a better scheme for that. Assuming we want to stick with doing that in SDISel. Differential Revision: https://reviews.llvm.org/D133739 --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 102 ++++++++++++++++----- .../rvv/fixed-vectors-vw-web-simplification.ll | 60 ++++++++++++ llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwmul.ll | 12 +-- 3 files changed, 145 insertions(+), 29 deletions(-) create mode 100644 llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vw-web-simplification.ll diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index f4f74eb..4a8bc31 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -46,6 +46,12 @@ using namespace llvm; STATISTIC(NumTailCalls, "Number of tail calls"); +static cl::opt ExtensionMaxWebSize( + DEBUG_TYPE "-ext-max-web-size", cl::Hidden, + cl::desc("Give the maximum size (in number of nodes) of the web of " + "instructions that we will consider for VW expansion"), + cl::init(18)); + static cl::opt AllowSplatInVW_W(DEBUG_TYPE "-form-vw-w-with-splat", cl::Hidden, cl::desc("Allow the formation of VW_W operations (e.g., " @@ -8547,9 +8553,9 @@ struct CombineResult { /// Root of the combine. SDNode *Root; /// LHS of the TargetOpcode. - const NodeExtensionHelper &LHS; + NodeExtensionHelper LHS; /// RHS of the TargetOpcode. - const NodeExtensionHelper &RHS; + NodeExtensionHelper RHS; CombineResult(unsigned TargetOpcode, SDNode *Root, const NodeExtensionHelper &LHS, Optional SExtLHS, @@ -8728,31 +8734,83 @@ combineBinOp_VLToVWBinOp_VL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { assert(NodeExtensionHelper::isSupportedRoot(N) && "Shouldn't have called this method"); + SmallVector Worklist; + SmallSet Inserted; + Worklist.push_back(N); + Inserted.insert(N); + SmallVector CombinesToApply; + + while (!Worklist.empty()) { + SDNode *Root = Worklist.pop_back_val(); + if (!NodeExtensionHelper::isSupportedRoot(Root)) + return SDValue(); - NodeExtensionHelper LHS(N, 0, DAG); - NodeExtensionHelper RHS(N, 1, DAG); - - if (LHS.needToPromoteOtherUsers() && !LHS.OrigOperand.hasOneUse()) - return SDValue(); - - if (RHS.needToPromoteOtherUsers() && !RHS.OrigOperand.hasOneUse()) - return SDValue(); + NodeExtensionHelper LHS(N, 0, DAG); + NodeExtensionHelper RHS(N, 1, DAG); + auto AppendUsersIfNeeded = [&Worklist, + &Inserted](const NodeExtensionHelper &Op) { + if (Op.needToPromoteOtherUsers()) { + for (SDNode *TheUse : Op.OrigOperand->uses()) { + if (Inserted.insert(TheUse).second) + Worklist.push_back(TheUse); + } + } + }; + AppendUsersIfNeeded(LHS); + AppendUsersIfNeeded(RHS); - SmallVector FoldingStrategies = - NodeExtensionHelper::getSupportedFoldings(N); + // Control the compile time by limiting the number of node we look at in + // total. + if (Inserted.size() > ExtensionMaxWebSize) + return SDValue(); - assert(!FoldingStrategies.empty() && "Nothing to be folded"); - for (int Attempt = 0; Attempt != 1 + NodeExtensionHelper::isCommutative(N); - ++Attempt) { - for (NodeExtensionHelper::CombineToTry FoldingStrategy : - FoldingStrategies) { - Optional Res = FoldingStrategy(N, LHS, RHS); - if (Res) - return Res->materialize(DAG); + SmallVector FoldingStrategies = + NodeExtensionHelper::getSupportedFoldings(N); + + assert(!FoldingStrategies.empty() && "Nothing to be folded"); + bool Matched = false; + for (int Attempt = 0; + (Attempt != 1 + NodeExtensionHelper::isCommutative(N)) && !Matched; + ++Attempt) { + + for (NodeExtensionHelper::CombineToTry FoldingStrategy : + FoldingStrategies) { + Optional Res = FoldingStrategy(N, LHS, RHS); + if (Res) { + Matched = true; + CombinesToApply.push_back(*Res); + break; + } + } + std::swap(LHS, RHS); } - std::swap(LHS, RHS); + // Right now we do an all or nothing approach. + if (!Matched) + return SDValue(); } - return SDValue(); + // Store the value for the replacement of the input node separately. + SDValue InputRootReplacement; + // We do the RAUW after we materialize all the combines, because some replaced + // nodes may be feeding some of the yet-to-be-replaced nodes. Put differently, + // some of these nodes may appear in the NodeExtensionHelpers of some of the + // yet-to-be-visited CombinesToApply roots. + SmallVector> ValuesToReplace; + ValuesToReplace.reserve(CombinesToApply.size()); + for (CombineResult Res : CombinesToApply) { + SDValue NewValue = Res.materialize(DAG); + if (!InputRootReplacement) { + assert(Res.Root == N && + "First element is expected to be the current node"); + InputRootReplacement = NewValue; + } else { + ValuesToReplace.emplace_back(SDValue(Res.Root, 0), NewValue); + } + } + for (std::pair OldNewValues : ValuesToReplace) { + DAG.ReplaceAllUsesOfValueWith(OldNewValues.first, OldNewValues.second); + DCI.AddToWorklist(OldNewValues.second.getNode()); + } + return InputRootReplacement; } // Fold diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vw-web-simplification.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vw-web-simplification.ll new file mode 100644 index 0000000..4fdf737 --- /dev/null +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vw-web-simplification.ll @@ -0,0 +1,60 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=riscv32 -mattr=+v -riscv-v-vector-bits-min=128 -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=1 | FileCheck %s --check-prefixes=NO_FOLDING +; RUN: llc -mtriple=riscv64 -mattr=+v -riscv-v-vector-bits-min=128 -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=1 | FileCheck %s --check-prefixes=NO_FOLDING +; RUN: llc -mtriple=riscv32 -mattr=+v -riscv-v-vector-bits-min=128 -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=2 | FileCheck %s --check-prefixes=NO_FOLDING +; RUN: llc -mtriple=riscv64 -mattr=+v -riscv-v-vector-bits-min=128 -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=2 | FileCheck %s --check-prefixes=NO_FOLDING +; RUN: llc -mtriple=riscv32 -mattr=+v -riscv-v-vector-bits-min=128 -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=3 | FileCheck %s --check-prefixes=FOLDING +; RUN: llc -mtriple=riscv64 -mattr=+v -riscv-v-vector-bits-min=128 -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=3 | FileCheck %s --check-prefixes=FOLDING +; Check that the default value enables the web folding and +; that it is bigger than 3. +; RUN: llc -mtriple=riscv32 -mattr=+v -riscv-v-vector-bits-min=128 -verify-machineinstrs %s -o - | FileCheck %s --check-prefixes=FOLDING +; RUN: llc -mtriple=riscv64 -mattr=+v -riscv-v-vector-bits-min=128 -verify-machineinstrs %s -o - | FileCheck %s --check-prefixes=FOLDING + + +; Check that the add/sub/mul operations are all promoted into their +; vw counterpart when the folding of the web size is increased to 3. +; We need the web size to be at least 3 for the folding to happen, because +; %c has 3 uses. +define <2 x i16> @vwmul_v2i16_multiple_users(<2 x i8>* %x, <2 x i8>* %y, <2 x i8> *%z) { +; NO_FOLDING-LABEL: vwmul_v2i16_multiple_users: +; NO_FOLDING: # %bb.0: +; NO_FOLDING-NEXT: vsetivli zero, 2, e16, mf4, ta, mu +; NO_FOLDING-NEXT: vle8.v v8, (a0) +; NO_FOLDING-NEXT: vle8.v v9, (a1) +; NO_FOLDING-NEXT: vle8.v v10, (a2) +; NO_FOLDING-NEXT: vsext.vf2 v11, v8 +; NO_FOLDING-NEXT: vsext.vf2 v8, v9 +; NO_FOLDING-NEXT: vsext.vf2 v9, v10 +; NO_FOLDING-NEXT: vmul.vv v8, v11, v8 +; NO_FOLDING-NEXT: vadd.vv v10, v11, v9 +; NO_FOLDING-NEXT: vsub.vv v9, v11, v9 +; NO_FOLDING-NEXT: vor.vv v8, v8, v10 +; NO_FOLDING-NEXT: vor.vv v8, v8, v9 +; NO_FOLDING-NEXT: ret +; +; FOLDING-LABEL: vwmul_v2i16_multiple_users: +; FOLDING: # %bb.0: +; FOLDING-NEXT: vsetivli zero, 2, e8, mf8, ta, mu +; FOLDING-NEXT: vle8.v v8, (a0) +; FOLDING-NEXT: vle8.v v9, (a1) +; FOLDING-NEXT: vle8.v v10, (a2) +; FOLDING-NEXT: vwmul.vv v11, v8, v9 +; FOLDING-NEXT: vwadd.vv v9, v8, v10 +; FOLDING-NEXT: vwsub.vv v12, v8, v10 +; FOLDING-NEXT: vsetvli zero, zero, e16, mf4, ta, mu +; FOLDING-NEXT: vor.vv v8, v11, v9 +; FOLDING-NEXT: vor.vv v8, v8, v12 +; FOLDING-NEXT: ret + %a = load <2 x i8>, <2 x i8>* %x + %b = load <2 x i8>, <2 x i8>* %y + %b2 = load <2 x i8>, <2 x i8>* %z + %c = sext <2 x i8> %a to <2 x i16> + %d = sext <2 x i8> %b to <2 x i16> + %d2 = sext <2 x i8> %b2 to <2 x i16> + %e = mul <2 x i16> %c, %d + %f = add <2 x i16> %c, %d2 + %g = sub <2 x i16> %c, %d2 + %h = or <2 x i16> %e, %f + %i = or <2 x i16> %h, %g + ret <2 x i16> %i +} diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwmul.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwmul.ll index 8862e33..0026164 100644 --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwmul.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwmul.ll @@ -21,16 +21,14 @@ define <2 x i16> @vwmul_v2i16(<2 x i8>* %x, <2 x i8>* %y) { define <2 x i16> @vwmul_v2i16_multiple_users(<2 x i8>* %x, <2 x i8>* %y, <2 x i8> *%z) { ; CHECK-LABEL: vwmul_v2i16_multiple_users: ; CHECK: # %bb.0: -; CHECK-NEXT: vsetivli zero, 2, e16, mf4, ta, mu +; CHECK-NEXT: vsetivli zero, 2, e8, mf8, ta, mu ; CHECK-NEXT: vle8.v v8, (a0) ; CHECK-NEXT: vle8.v v9, (a1) ; CHECK-NEXT: vle8.v v10, (a2) -; CHECK-NEXT: vsext.vf2 v11, v8 -; CHECK-NEXT: vsext.vf2 v8, v9 -; CHECK-NEXT: vsext.vf2 v9, v10 -; CHECK-NEXT: vmul.vv v8, v11, v8 -; CHECK-NEXT: vmul.vv v9, v11, v9 -; CHECK-NEXT: vor.vv v8, v8, v9 +; CHECK-NEXT: vwmul.vv v11, v8, v9 +; CHECK-NEXT: vwmul.vv v9, v8, v10 +; CHECK-NEXT: vsetvli zero, zero, e16, mf4, ta, mu +; CHECK-NEXT: vor.vv v8, v11, v9 ; CHECK-NEXT: ret %a = load <2 x i8>, <2 x i8>* %x %b = load <2 x i8>, <2 x i8>* %y -- 2.7.4