[RISCV][ISel] Fold extensions when all the users can consume them
authorQuentin Colombet <quentin.colombet@gmail.com>
Mon, 3 Oct 2022 23:28:34 +0000 (23:28 +0000)
committerQuentin Colombet <quentin.colombet@gmail.com>
Wed, 5 Oct 2022 20:49:21 +0000 (20:49 +0000)
This patch allows the combines that fold extensions in binary operations
to have more than one use.
The approach here is pretty conservative: if all the users of an
extension can fold the extension, then the folding is done, otherwise we
don't fold.
This is the first step towards avoiding the one-use limitation.

As a result, we make a decision to fold/don't fold for a web of
instructions. An instruction is part of the web of instructions as soon
as it consumes an extension that needs to be folded for all its users.

Because of how SDISel works a web of instructions can be visited over
and over. More precisely, if the folding happens, it happens for the
whole web and that's the end of it, but if the folding fails, the whole
web may be revisited when another member of the web is visited.

To avoid a compile time explosion in pathological cases, we bail out
earlier for webs that are bigger than a given threshold (arbitrarily set
at 18 for now.) This size can be changed using
`--riscv-lower-ext-max-web-size=<maxWebSize>`.

At the current time, I didn't see a better scheme for that. Assuming we
want to stick with doing that in SDISel.

Differential Revision: https://reviews.llvm.org/D133739

llvm/lib/Target/RISCV/RISCVISelLowering.cpp
llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vw-web-simplification.ll [new file with mode: 0644]
llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwmul.ll

index f4f74eb..4a8bc31 100644 (file)
@@ -46,6 +46,12 @@ using namespace llvm;
 
 STATISTIC(NumTailCalls, "Number of tail calls");
 
+static cl::opt<unsigned> ExtensionMaxWebSize(
+    DEBUG_TYPE "-ext-max-web-size", cl::Hidden,
+    cl::desc("Give the maximum size (in number of nodes) of the web of "
+             "instructions that we will consider for VW expansion"),
+    cl::init(18));
+
 static cl::opt<bool>
     AllowSplatInVW_W(DEBUG_TYPE "-form-vw-w-with-splat", cl::Hidden,
                      cl::desc("Allow the formation of VW_W operations (e.g., "
@@ -8547,9 +8553,9 @@ struct CombineResult {
   /// Root of the combine.
   SDNode *Root;
   /// LHS of the TargetOpcode.
-  const NodeExtensionHelper &LHS;
+  NodeExtensionHelper LHS;
   /// RHS of the TargetOpcode.
-  const NodeExtensionHelper &RHS;
+  NodeExtensionHelper RHS;
 
   CombineResult(unsigned TargetOpcode, SDNode *Root,
                 const NodeExtensionHelper &LHS, Optional<bool> SExtLHS,
@@ -8728,31 +8734,83 @@ combineBinOp_VLToVWBinOp_VL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
 
   assert(NodeExtensionHelper::isSupportedRoot(N) &&
          "Shouldn't have called this method");
+  SmallVector<SDNode *> Worklist;
+  SmallSet<SDNode *, 8> Inserted;
+  Worklist.push_back(N);
+  Inserted.insert(N);
+  SmallVector<CombineResult> CombinesToApply;
+
+  while (!Worklist.empty()) {
+    SDNode *Root = Worklist.pop_back_val();
+    if (!NodeExtensionHelper::isSupportedRoot(Root))
+      return SDValue();
 
-  NodeExtensionHelper LHS(N, 0, DAG);
-  NodeExtensionHelper RHS(N, 1, DAG);
-
-  if (LHS.needToPromoteOtherUsers() && !LHS.OrigOperand.hasOneUse())
-    return SDValue();
-
-  if (RHS.needToPromoteOtherUsers() && !RHS.OrigOperand.hasOneUse())
-    return SDValue();
+    NodeExtensionHelper LHS(N, 0, DAG);
+    NodeExtensionHelper RHS(N, 1, DAG);
+    auto AppendUsersIfNeeded = [&Worklist,
+                                &Inserted](const NodeExtensionHelper &Op) {
+      if (Op.needToPromoteOtherUsers()) {
+        for (SDNode *TheUse : Op.OrigOperand->uses()) {
+          if (Inserted.insert(TheUse).second)
+            Worklist.push_back(TheUse);
+        }
+      }
+    };
+    AppendUsersIfNeeded(LHS);
+    AppendUsersIfNeeded(RHS);
 
-  SmallVector<NodeExtensionHelper::CombineToTry> FoldingStrategies =
-      NodeExtensionHelper::getSupportedFoldings(N);
+    // Control the compile time by limiting the number of node we look at in
+    // total.
+    if (Inserted.size() > ExtensionMaxWebSize)
+      return SDValue();
 
-  assert(!FoldingStrategies.empty() && "Nothing to be folded");
-  for (int Attempt = 0; Attempt != 1 + NodeExtensionHelper::isCommutative(N);
-       ++Attempt) {
-    for (NodeExtensionHelper::CombineToTry FoldingStrategy :
-         FoldingStrategies) {
-      Optional<CombineResult> Res = FoldingStrategy(N, LHS, RHS);
-      if (Res)
-        return Res->materialize(DAG);
+    SmallVector<NodeExtensionHelper::CombineToTry> FoldingStrategies =
+        NodeExtensionHelper::getSupportedFoldings(N);
+
+    assert(!FoldingStrategies.empty() && "Nothing to be folded");
+    bool Matched = false;
+    for (int Attempt = 0;
+         (Attempt != 1 + NodeExtensionHelper::isCommutative(N)) && !Matched;
+         ++Attempt) {
+
+      for (NodeExtensionHelper::CombineToTry FoldingStrategy :
+           FoldingStrategies) {
+        Optional<CombineResult> Res = FoldingStrategy(N, LHS, RHS);
+        if (Res) {
+          Matched = true;
+          CombinesToApply.push_back(*Res);
+          break;
+        }
+      }
+      std::swap(LHS, RHS);
     }
-    std::swap(LHS, RHS);
+    // Right now we do an all or nothing approach.
+    if (!Matched)
+      return SDValue();
   }
-  return SDValue();
+  // Store the value for the replacement of the input node separately.
+  SDValue InputRootReplacement;
+  // We do the RAUW after we materialize all the combines, because some replaced
+  // nodes may be feeding some of the yet-to-be-replaced nodes. Put differently,
+  // some of these nodes may appear in the NodeExtensionHelpers of some of the
+  // yet-to-be-visited CombinesToApply roots.
+  SmallVector<std::pair<SDValue, SDValue>> ValuesToReplace;
+  ValuesToReplace.reserve(CombinesToApply.size());
+  for (CombineResult Res : CombinesToApply) {
+    SDValue NewValue = Res.materialize(DAG);
+    if (!InputRootReplacement) {
+      assert(Res.Root == N &&
+             "First element is expected to be the current node");
+      InputRootReplacement = NewValue;
+    } else {
+      ValuesToReplace.emplace_back(SDValue(Res.Root, 0), NewValue);
+    }
+  }
+  for (std::pair<SDValue, SDValue> OldNewValues : ValuesToReplace) {
+    DAG.ReplaceAllUsesOfValueWith(OldNewValues.first, OldNewValues.second);
+    DCI.AddToWorklist(OldNewValues.second.getNode());
+  }
+  return InputRootReplacement;
 }
 
 // Fold
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vw-web-simplification.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vw-web-simplification.ll
new file mode 100644 (file)
index 0000000..4fdf737
--- /dev/null
@@ -0,0 +1,60 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=riscv32 -mattr=+v -riscv-v-vector-bits-min=128 -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=1 | FileCheck %s --check-prefixes=NO_FOLDING
+; RUN: llc -mtriple=riscv64 -mattr=+v -riscv-v-vector-bits-min=128 -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=1 | FileCheck %s --check-prefixes=NO_FOLDING
+; RUN: llc -mtriple=riscv32 -mattr=+v -riscv-v-vector-bits-min=128 -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=2 | FileCheck %s --check-prefixes=NO_FOLDING
+; RUN: llc -mtriple=riscv64 -mattr=+v -riscv-v-vector-bits-min=128 -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=2 | FileCheck %s --check-prefixes=NO_FOLDING
+; RUN: llc -mtriple=riscv32 -mattr=+v -riscv-v-vector-bits-min=128 -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=3 | FileCheck %s --check-prefixes=FOLDING
+; RUN: llc -mtriple=riscv64 -mattr=+v -riscv-v-vector-bits-min=128 -verify-machineinstrs %s -o - --riscv-lower-ext-max-web-size=3 | FileCheck %s --check-prefixes=FOLDING
+; Check that the default value enables the web folding and
+; that it is bigger than 3.
+; RUN: llc -mtriple=riscv32 -mattr=+v -riscv-v-vector-bits-min=128 -verify-machineinstrs %s -o - | FileCheck %s --check-prefixes=FOLDING
+; RUN: llc -mtriple=riscv64 -mattr=+v -riscv-v-vector-bits-min=128 -verify-machineinstrs %s -o - | FileCheck %s --check-prefixes=FOLDING
+
+
+; Check that the add/sub/mul operations are all promoted into their
+; vw counterpart when the folding of the web size is increased to 3.
+; We need the web size to be at least 3 for the folding to happen, because
+; %c has 3 uses.
+define <2 x i16> @vwmul_v2i16_multiple_users(<2 x i8>* %x, <2 x i8>* %y, <2 x i8> *%z) {
+; NO_FOLDING-LABEL: vwmul_v2i16_multiple_users:
+; NO_FOLDING:       # %bb.0:
+; NO_FOLDING-NEXT:    vsetivli zero, 2, e16, mf4, ta, mu
+; NO_FOLDING-NEXT:    vle8.v v8, (a0)
+; NO_FOLDING-NEXT:    vle8.v v9, (a1)
+; NO_FOLDING-NEXT:    vle8.v v10, (a2)
+; NO_FOLDING-NEXT:    vsext.vf2 v11, v8
+; NO_FOLDING-NEXT:    vsext.vf2 v8, v9
+; NO_FOLDING-NEXT:    vsext.vf2 v9, v10
+; NO_FOLDING-NEXT:    vmul.vv v8, v11, v8
+; NO_FOLDING-NEXT:    vadd.vv v10, v11, v9
+; NO_FOLDING-NEXT:    vsub.vv v9, v11, v9
+; NO_FOLDING-NEXT:    vor.vv v8, v8, v10
+; NO_FOLDING-NEXT:    vor.vv v8, v8, v9
+; NO_FOLDING-NEXT:    ret
+;
+; FOLDING-LABEL: vwmul_v2i16_multiple_users:
+; FOLDING:       # %bb.0:
+; FOLDING-NEXT:    vsetivli zero, 2, e8, mf8, ta, mu
+; FOLDING-NEXT:    vle8.v v8, (a0)
+; FOLDING-NEXT:    vle8.v v9, (a1)
+; FOLDING-NEXT:    vle8.v v10, (a2)
+; FOLDING-NEXT:    vwmul.vv v11, v8, v9
+; FOLDING-NEXT:    vwadd.vv v9, v8, v10
+; FOLDING-NEXT:    vwsub.vv v12, v8, v10
+; FOLDING-NEXT:    vsetvli zero, zero, e16, mf4, ta, mu
+; FOLDING-NEXT:    vor.vv v8, v11, v9
+; FOLDING-NEXT:    vor.vv v8, v8, v12
+; FOLDING-NEXT:    ret
+  %a = load <2 x i8>, <2 x i8>* %x
+  %b = load <2 x i8>, <2 x i8>* %y
+  %b2 = load <2 x i8>, <2 x i8>* %z
+  %c = sext <2 x i8> %a to <2 x i16>
+  %d = sext <2 x i8> %b to <2 x i16>
+  %d2 = sext <2 x i8> %b2 to <2 x i16>
+  %e = mul <2 x i16> %c, %d
+  %f = add <2 x i16> %c, %d2
+  %g = sub <2 x i16> %c, %d2
+  %h = or <2 x i16> %e, %f
+  %i = or <2 x i16> %h, %g
+  ret <2 x i16> %i
+}
index 8862e33..0026164 100644 (file)
@@ -21,16 +21,14 @@ define <2 x i16> @vwmul_v2i16(<2 x i8>* %x, <2 x i8>* %y) {
 define <2 x i16> @vwmul_v2i16_multiple_users(<2 x i8>* %x, <2 x i8>* %y, <2 x i8> *%z) {
 ; CHECK-LABEL: vwmul_v2i16_multiple_users:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    vsetivli zero, 2, e16, mf4, ta, mu
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, mu
 ; CHECK-NEXT:    vle8.v v8, (a0)
 ; CHECK-NEXT:    vle8.v v9, (a1)
 ; CHECK-NEXT:    vle8.v v10, (a2)
-; CHECK-NEXT:    vsext.vf2 v11, v8
-; CHECK-NEXT:    vsext.vf2 v8, v9
-; CHECK-NEXT:    vsext.vf2 v9, v10
-; CHECK-NEXT:    vmul.vv v8, v11, v8
-; CHECK-NEXT:    vmul.vv v9, v11, v9
-; CHECK-NEXT:    vor.vv v8, v8, v9
+; CHECK-NEXT:    vwmul.vv v11, v8, v9
+; CHECK-NEXT:    vwmul.vv v9, v8, v10
+; CHECK-NEXT:    vsetvli zero, zero, e16, mf4, ta, mu
+; CHECK-NEXT:    vor.vv v8, v11, v9
 ; CHECK-NEXT:    ret
   %a = load <2 x i8>, <2 x i8>* %x
   %b = load <2 x i8>, <2 x i8>* %y