[CodeGen] Enable processing of interconnected complex number operations
authorIgor Kirillov <igor.kirillov@arm.com>
Mon, 27 Mar 2023 16:32:40 +0000 (16:32 +0000)
committerIgor Kirillov <igor.kirillov@arm.com>
Tue, 18 Apr 2023 13:05:49 +0000 (13:05 +0000)
With this patch, ComplexDeinterleavingPass now has the ability to handle
any number of interconnected operations involving complex numbers.
For example, the patch enables the processing of code like the following:

for (int i = 0; i < 1000; ++i) {
    a[i] =  w[i] * v[i];
    b[i] =  w[i] * u[i];
}

This code has multiple arrays containing complex numbers and a common
subexpression `w` that appears in two expressions.

Differential Revision: https://reviews.llvm.org/D146988

llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
llvm/test/CodeGen/AArch64/complex-deinterleaving-multiuses.ll

index ff0c5d5..3cfe935 100644 (file)
@@ -137,19 +137,12 @@ public:
   Instruction *Real;
   Instruction *Imag;
 
-  // Instructions that should only exist within this node, there should be no
-  // users of these instructions outside the node. An example of these would be
-  // the multiply instructions of a partial multiply operation.
-  SmallVector<Instruction *> InternalInstructions;
   ComplexDeinterleavingRotation Rotation;
   SmallVector<RawNodePtr> Operands;
   Value *ReplacementNode = nullptr;
 
-  void addInstruction(Instruction *I) { InternalInstructions.push_back(I); }
   void addOperand(NodePtr Node) { Operands.push_back(Node.get()); }
 
-  bool hasAllInternalUses(SmallPtrSet<Instruction *, 16> &AllInstructions);
-
   void dump() { dump(dbgs()); }
   void dump(raw_ostream &OS) {
     auto PrintValue = [&](Value *V) {
@@ -181,12 +174,6 @@ public:
       OS << "    - ";
       PrintNodeRef(Op);
     }
-    OS << "  InternalInstructions:\n";
-    for (const auto &I : InternalInstructions) {
-      OS << "    - \"";
-      I->print(OS, true);
-      OS << "\"\n";
-    }
   }
 };
 
@@ -194,14 +181,22 @@ class ComplexDeinterleavingGraph {
 public:
   using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;
   using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;
-  explicit ComplexDeinterleavingGraph(const TargetLowering *tl) : TL(tl) {}
+  explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
+                                      const TargetLibraryInfo *TLI)
+      : TL(TL), TLI(TLI) {}
 
 private:
   const TargetLowering *TL = nullptr;
-  Instruction *RootValue = nullptr;
-  NodePtr RootNode;
+  const TargetLibraryInfo *TLI = nullptr;
   SmallVector<NodePtr> CompositeNodes;
-  SmallPtrSet<Instruction *, 16> AllInstructions;
+
+  SmallPtrSet<Instruction *, 16> FinalInstructions;
+
+  /// Root instructions are instructions from which complex computation starts
+  std::map<Instruction *, NodePtr> RootToNode;
+
+  /// Topologically sorted root instructions
+  SmallVector<Instruction *, 1> OrderedRoots;
 
   NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
                                Instruction *R, Instruction *I) {
@@ -211,10 +206,6 @@ private:
 
   NodePtr submitCompositeNode(NodePtr Node) {
     CompositeNodes.push_back(Node);
-    AllInstructions.insert(Node->Real);
-    AllInstructions.insert(Node->Imag);
-    for (auto *I : Node->InternalInstructions)
-      AllInstructions.insert(I);
     return Node;
   }
 
@@ -271,6 +262,10 @@ public:
   /// current graph.
   bool identifyNodes(Instruction *RootI);
 
+  /// Check that every instruction, from the roots to the leaves, has internal
+  /// uses.
+  bool checkNodes();
+
   /// Perform the actual replacement of the underlying instruction graph.
   void replaceNodes();
 };
@@ -368,9 +363,7 @@ static bool isDeinterleavingMask(ArrayRef<int> Mask) {
 }
 
 bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
-  bool Changed = false;
-
-  SmallVector<Instruction *> DeadInstrRoots;
+  ComplexDeinterleavingGraph Graph(TL, TLI);
 
   for (auto &I : *B) {
     auto *SVI = dyn_cast<ShuffleVectorInst>(&I);
@@ -382,22 +375,15 @@ bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
     if (!isInterleavingMask(SVI->getShuffleMask()))
       continue;
 
-    ComplexDeinterleavingGraph Graph(TL);
-    if (!Graph.identifyNodes(SVI))
-      continue;
-
-    Graph.replaceNodes();
-    DeadInstrRoots.push_back(SVI);
-    Changed = true;
+    Graph.identifyNodes(SVI);
   }
 
-  for (const auto &I : DeadInstrRoots) {
-    if (!I || I->getParent() == nullptr)
-      continue;
-    llvm::RecursivelyDeleteTriviallyDeadInstructions(I, TLI);
+  if (Graph.checkNodes()) {
+    Graph.replaceNodes();
+    return true;
   }
 
-  return Changed;
+  return false;
 }
 
 ComplexDeinterleavingGraph::NodePtr
@@ -511,7 +497,6 @@ ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
   Node->Rotation = Rotation;
   Node->addOperand(CommonNode);
   Node->addOperand(UncommonNode);
-  Node->InternalInstructions.append(FNegs);
   return submitCompositeNode(Node);
 }
 
@@ -627,8 +612,6 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
 
   NodePtr Node = prepareCompositeNode(
       ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
-  Node->addInstruction(RealMulI);
-  Node->addInstruction(ImagMulI);
   Node->Rotation = Rotation;
   Node->addOperand(CommonRes);
   Node->addOperand(UncommonRes);
@@ -846,6 +829,8 @@ ComplexDeinterleavingGraph::identifyNode(Instruction *Real, Instruction *Imag) {
         prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Shuffle,
                              RealShuffle, ImagShuffle);
     PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
+    FinalInstructions.insert(RealShuffle);
+    FinalInstructions.insert(ImagShuffle);
     return submitCompositeNode(PlaceholderNode);
   }
   if (RealShuffle || ImagShuffle) {
@@ -881,9 +866,7 @@ bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
   if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
     return false;
 
-  RootValue = RootI;
-  AllInstructions.insert(RootI);
-  RootNode = identifyNode(Real, Imag);
+  auto RootNode = identifyNode(Real, Imag);
 
   LLVM_DEBUG({
     Function *F = RootI->getFunction();
@@ -894,14 +877,86 @@ bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
     dbgs() << "\n";
   });
 
-  // Check all instructions have internal uses
-  for (const auto &Node : CompositeNodes) {
-    if (!Node->hasAllInternalUses(AllInstructions)) {
-      LLVM_DEBUG(dbgs() << "  - Invalid internal uses\n");
-      return false;
+  if (RootNode) {
+    RootToNode[RootI] = RootNode;
+    OrderedRoots.push_back(RootI);
+    return true;
+  }
+
+  return false;
+}
+
+bool ComplexDeinterleavingGraph::checkNodes() {
+  // Collect all instructions from roots to leaves
+  SmallPtrSet<Instruction *, 16> AllInstructions;
+  SmallVector<Instruction *, 8> Worklist;
+  for (auto *I : OrderedRoots)
+    Worklist.push_back(I);
+
+  // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
+  // chains
+  while (!Worklist.empty()) {
+    auto *I = Worklist.back();
+    Worklist.pop_back();
+
+    if (!AllInstructions.insert(I).second)
+      continue;
+
+    for (Value *Op : I->operands()) {
+      if (auto *OpI = dyn_cast<Instruction>(Op)) {
+        if (!FinalInstructions.count(I))
+          Worklist.emplace_back(OpI);
+      }
     }
   }
-  return RootNode != nullptr;
+
+  // Find instructions that have users outside of chain
+  SmallVector<Instruction *, 2> OuterInstructions;
+  for (auto *I : AllInstructions) {
+    // Skip root nodes
+    if (RootToNode.count(I))
+      continue;
+
+    for (User *U : I->users()) {
+      if (AllInstructions.count(cast<Instruction>(U)))
+        continue;
+
+      // Found an instruction that is not used by XCMLA/XCADD chain
+      Worklist.emplace_back(I);
+      break;
+    }
+  }
+
+  // If any instructions are found to be used outside, find and remove roots
+  // that somehow connect to those instructions.
+  SmallPtrSet<Instruction *, 16> Visited;
+  while (!Worklist.empty()) {
+    auto *I = Worklist.back();
+    Worklist.pop_back();
+    if (!Visited.insert(I).second)
+      continue;
+
+    // Found an impacted root node. Removing it from the nodes to be
+    // deinterleaved
+    if (RootToNode.count(I)) {
+      LLVM_DEBUG(dbgs() << "Instruction " << *I
+                        << " could be deinterleaved but its chain of complex "
+                           "operations have an outside user\n");
+      RootToNode.erase(I);
+    }
+
+    if (!AllInstructions.count(I) || FinalInstructions.count(I))
+      continue;
+
+    for (User *U : I->users())
+      Worklist.emplace_back(cast<Instruction>(U));
+
+    for (Value *Op : I->operands()) {
+      if (auto *OpI = dyn_cast<Instruction>(Op))
+        Worklist.emplace_back(OpI);
+    }
+  }
+  return !RootToNode.empty();
 }
 
 static Value *replaceSymmetricNode(ComplexDeinterleavingGraph::RawNodePtr Node,
@@ -958,29 +1013,21 @@ Value *ComplexDeinterleavingGraph::replaceNode(
 }
 
 void ComplexDeinterleavingGraph::replaceNodes() {
-  Value *R = replaceNode(RootNode.get());
-  assert(R && "Unable to find replacement for RootValue");
-  RootValue->replaceAllUsesWith(R);
-}
-
-bool ComplexDeinterleavingCompositeNode::hasAllInternalUses(
-    SmallPtrSet<Instruction *, 16> &AllInstructions) {
-  if (Operation == ComplexDeinterleavingOperation::Shuffle)
-    return true;
+  SmallVector<Instruction *, 16> DeadInstrRoots;
+  for (auto *RootInstruction : OrderedRoots) {
+    // Check if this potential root went through check process and we can
+    // deinterleave it
+    if (!RootToNode.count(RootInstruction))
+      continue;
 
-  for (auto *User : Real->users()) {
-    if (!AllInstructions.contains(cast<Instruction>(User)))
-      return false;
+    IRBuilder<> Builder(RootInstruction);
+    auto RootNode = RootToNode[RootInstruction];
+    Value *R = replaceNode(RootNode.get());
+    assert(R && "Unable to find replacement for RootInstruction");
+    DeadInstrRoots.push_back(RootInstruction);
+    RootInstruction->replaceAllUsesWith(R);
   }
-  for (auto *User : Imag->users()) {
-    if (!AllInstructions.contains(cast<Instruction>(User)))
-      return false;
-  }
-  for (auto *I : InternalInstructions) {
-    for (auto *User : I->users()) {
-      if (!AllInstructions.contains(cast<Instruction>(User)))
-        return false;
-    }
-  }
-  return true;
+
+  for (auto *I : DeadInstrRoots)
+    RecursivelyDeleteTriviallyDeadInstructions(I, TLI);
 }
index fe3d306..4d84636 100644 (file)
@@ -2,30 +2,20 @@
 ; RUN: llc < %s --mattr=+complxnum,+neon -o - | FileCheck %s
 
 target triple = "aarch64-arm-none-eabi"
-; Expected to not transform
+; Expected to transform
 ;   *p = (a * b);
 ;   return (a * b) * a;
 define <4 x float> @mul_triangle(<4 x float> %a, <4 x float> %b, ptr %p) {
 ; CHECK-LABEL: mul_triangle:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    ext v2.16b, v0.16b, v0.16b, #8
-; CHECK-NEXT:    ext v3.16b, v1.16b, v1.16b, #8
-; CHECK-NEXT:    zip2 v4.2s, v0.2s, v2.2s
-; CHECK-NEXT:    zip1 v0.2s, v0.2s, v2.2s
-; CHECK-NEXT:    zip2 v5.2s, v1.2s, v3.2s
-; CHECK-NEXT:    zip1 v1.2s, v1.2s, v3.2s
-; CHECK-NEXT:    fmul v6.2s, v5.2s, v4.2s
-; CHECK-NEXT:    fneg v2.2s, v6.2s
-; CHECK-NEXT:    fmla v2.2s, v0.2s, v1.2s
-; CHECK-NEXT:    fmul v3.2s, v4.2s, v1.2s
-; CHECK-NEXT:    fmla v3.2s, v0.2s, v5.2s
-; CHECK-NEXT:    fmul v1.2s, v3.2s, v4.2s
-; CHECK-NEXT:    fmul v5.2s, v3.2s, v0.2s
-; CHECK-NEXT:    st2 { v2.2s, v3.2s }, [x0]
-; CHECK-NEXT:    fneg v1.2s, v1.2s
-; CHECK-NEXT:    fmla v5.2s, v4.2s, v2.2s
-; CHECK-NEXT:    fmla v1.2s, v0.2s, v2.2s
-; CHECK-NEXT:    zip1 v0.4s, v1.4s, v5.4s
+; CHECK-NEXT:    movi v3.2d, #0000000000000000
+; CHECK-NEXT:    movi v2.2d, #0000000000000000
+; CHECK-NEXT:    fcmla v3.4s, v1.4s, v0.4s, #0
+; CHECK-NEXT:    fcmla v3.4s, v1.4s, v0.4s, #90
+; CHECK-NEXT:    fcmla v2.4s, v0.4s, v3.4s, #0
+; CHECK-NEXT:    str q3, [x0]
+; CHECK-NEXT:    fcmla v2.4s, v0.4s, v3.4s, #90
+; CHECK-NEXT:    mov v0.16b, v2.16b
 ; CHECK-NEXT:    ret
 entry:
   %strided.vec = shufflevector <4 x float> %a, <4 x float> poison, <2 x i32> <i32 0, i32 2>