Instruction *Real;
Instruction *Imag;
- // Instructions that should only exist within this node, there should be no
- // users of these instructions outside the node. An example of these would be
- // the multiply instructions of a partial multiply operation.
- SmallVector<Instruction *> InternalInstructions;
ComplexDeinterleavingRotation Rotation;
SmallVector<RawNodePtr> Operands;
Value *ReplacementNode = nullptr;
- void addInstruction(Instruction *I) { InternalInstructions.push_back(I); }
void addOperand(NodePtr Node) { Operands.push_back(Node.get()); }
- bool hasAllInternalUses(SmallPtrSet<Instruction *, 16> &AllInstructions);
-
void dump() { dump(dbgs()); }
void dump(raw_ostream &OS) {
auto PrintValue = [&](Value *V) {
OS << " - ";
PrintNodeRef(Op);
}
- OS << " InternalInstructions:\n";
- for (const auto &I : InternalInstructions) {
- OS << " - \"";
- I->print(OS, true);
- OS << "\"\n";
- }
}
};
public:
using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;
using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;
- explicit ComplexDeinterleavingGraph(const TargetLowering *tl) : TL(tl) {}
+ explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
+ const TargetLibraryInfo *TLI)
+ : TL(TL), TLI(TLI) {}
private:
const TargetLowering *TL = nullptr;
- Instruction *RootValue = nullptr;
- NodePtr RootNode;
+ const TargetLibraryInfo *TLI = nullptr;
SmallVector<NodePtr> CompositeNodes;
- SmallPtrSet<Instruction *, 16> AllInstructions;
+
+ SmallPtrSet<Instruction *, 16> FinalInstructions;
+
+ /// Root instructions are instructions from which complex computation starts
+ std::map<Instruction *, NodePtr> RootToNode;
+
+ /// Topologically sorted root instructions
+ SmallVector<Instruction *, 1> OrderedRoots;
NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
Instruction *R, Instruction *I) {
NodePtr submitCompositeNode(NodePtr Node) {
CompositeNodes.push_back(Node);
- AllInstructions.insert(Node->Real);
- AllInstructions.insert(Node->Imag);
- for (auto *I : Node->InternalInstructions)
- AllInstructions.insert(I);
return Node;
}
/// current graph.
bool identifyNodes(Instruction *RootI);
+ /// Check that every instruction, from the roots to the leaves, has internal
+ /// uses.
+ bool checkNodes();
+
/// Perform the actual replacement of the underlying instruction graph.
void replaceNodes();
};
}
bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
- bool Changed = false;
-
- SmallVector<Instruction *> DeadInstrRoots;
+ ComplexDeinterleavingGraph Graph(TL, TLI);
for (auto &I : *B) {
auto *SVI = dyn_cast<ShuffleVectorInst>(&I);
if (!isInterleavingMask(SVI->getShuffleMask()))
continue;
- ComplexDeinterleavingGraph Graph(TL);
- if (!Graph.identifyNodes(SVI))
- continue;
-
- Graph.replaceNodes();
- DeadInstrRoots.push_back(SVI);
- Changed = true;
+ Graph.identifyNodes(SVI);
}
- for (const auto &I : DeadInstrRoots) {
- if (!I || I->getParent() == nullptr)
- continue;
- llvm::RecursivelyDeleteTriviallyDeadInstructions(I, TLI);
+ if (Graph.checkNodes()) {
+ Graph.replaceNodes();
+ return true;
}
- return Changed;
+ return false;
}
ComplexDeinterleavingGraph::NodePtr
Node->Rotation = Rotation;
Node->addOperand(CommonNode);
Node->addOperand(UncommonNode);
- Node->InternalInstructions.append(FNegs);
return submitCompositeNode(Node);
}
NodePtr Node = prepareCompositeNode(
ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
- Node->addInstruction(RealMulI);
- Node->addInstruction(ImagMulI);
Node->Rotation = Rotation;
Node->addOperand(CommonRes);
Node->addOperand(UncommonRes);
prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Shuffle,
RealShuffle, ImagShuffle);
PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
+ FinalInstructions.insert(RealShuffle);
+ FinalInstructions.insert(ImagShuffle);
return submitCompositeNode(PlaceholderNode);
}
if (RealShuffle || ImagShuffle) {
if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
return false;
- RootValue = RootI;
- AllInstructions.insert(RootI);
- RootNode = identifyNode(Real, Imag);
+ auto RootNode = identifyNode(Real, Imag);
LLVM_DEBUG({
Function *F = RootI->getFunction();
dbgs() << "\n";
});
- // Check all instructions have internal uses
- for (const auto &Node : CompositeNodes) {
- if (!Node->hasAllInternalUses(AllInstructions)) {
- LLVM_DEBUG(dbgs() << " - Invalid internal uses\n");
- return false;
+ if (RootNode) {
+ RootToNode[RootI] = RootNode;
+ OrderedRoots.push_back(RootI);
+ return true;
+ }
+
+ return false;
+}
+
+bool ComplexDeinterleavingGraph::checkNodes() {
+ // Collect all instructions from roots to leaves
+ SmallPtrSet<Instruction *, 16> AllInstructions;
+ SmallVector<Instruction *, 8> Worklist;
+ for (auto *I : OrderedRoots)
+ Worklist.push_back(I);
+
+ // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
+ // chains
+ while (!Worklist.empty()) {
+ auto *I = Worklist.back();
+ Worklist.pop_back();
+
+ if (!AllInstructions.insert(I).second)
+ continue;
+
+ for (Value *Op : I->operands()) {
+ if (auto *OpI = dyn_cast<Instruction>(Op)) {
+ if (!FinalInstructions.count(I))
+ Worklist.emplace_back(OpI);
+ }
}
}
- return RootNode != nullptr;
+
+ // Find instructions that have users outside of chain
+ SmallVector<Instruction *, 2> OuterInstructions;
+ for (auto *I : AllInstructions) {
+ // Skip root nodes
+ if (RootToNode.count(I))
+ continue;
+
+ for (User *U : I->users()) {
+ if (AllInstructions.count(cast<Instruction>(U)))
+ continue;
+
+ // Found an instruction that is not used by XCMLA/XCADD chain
+ Worklist.emplace_back(I);
+ break;
+ }
+ }
+
+ // If any instructions are found to be used outside, find and remove roots
+ // that somehow connect to those instructions.
+ SmallPtrSet<Instruction *, 16> Visited;
+ while (!Worklist.empty()) {
+ auto *I = Worklist.back();
+ Worklist.pop_back();
+ if (!Visited.insert(I).second)
+ continue;
+
+ // Found an impacted root node. Removing it from the nodes to be
+ // deinterleaved
+ if (RootToNode.count(I)) {
+ LLVM_DEBUG(dbgs() << "Instruction " << *I
+ << " could be deinterleaved but its chain of complex "
+ "operations have an outside user\n");
+ RootToNode.erase(I);
+ }
+
+ if (!AllInstructions.count(I) || FinalInstructions.count(I))
+ continue;
+
+ for (User *U : I->users())
+ Worklist.emplace_back(cast<Instruction>(U));
+
+ for (Value *Op : I->operands()) {
+ if (auto *OpI = dyn_cast<Instruction>(Op))
+ Worklist.emplace_back(OpI);
+ }
+ }
+ return !RootToNode.empty();
}
static Value *replaceSymmetricNode(ComplexDeinterleavingGraph::RawNodePtr Node,
}
void ComplexDeinterleavingGraph::replaceNodes() {
- Value *R = replaceNode(RootNode.get());
- assert(R && "Unable to find replacement for RootValue");
- RootValue->replaceAllUsesWith(R);
-}
-
-bool ComplexDeinterleavingCompositeNode::hasAllInternalUses(
- SmallPtrSet<Instruction *, 16> &AllInstructions) {
- if (Operation == ComplexDeinterleavingOperation::Shuffle)
- return true;
+ SmallVector<Instruction *, 16> DeadInstrRoots;
+ for (auto *RootInstruction : OrderedRoots) {
+ // Check if this potential root went through check process and we can
+ // deinterleave it
+ if (!RootToNode.count(RootInstruction))
+ continue;
- for (auto *User : Real->users()) {
- if (!AllInstructions.contains(cast<Instruction>(User)))
- return false;
+ IRBuilder<> Builder(RootInstruction);
+ auto RootNode = RootToNode[RootInstruction];
+ Value *R = replaceNode(RootNode.get());
+ assert(R && "Unable to find replacement for RootInstruction");
+ DeadInstrRoots.push_back(RootInstruction);
+ RootInstruction->replaceAllUsesWith(R);
}
- for (auto *User : Imag->users()) {
- if (!AllInstructions.contains(cast<Instruction>(User)))
- return false;
- }
- for (auto *I : InternalInstructions) {
- for (auto *User : I->users()) {
- if (!AllInstructions.contains(cast<Instruction>(User)))
- return false;
- }
- }
- return true;
+
+ for (auto *I : DeadInstrRoots)
+ RecursivelyDeleteTriviallyDeadInstructions(I, TLI);
}
; RUN: llc < %s --mattr=+complxnum,+neon -o - | FileCheck %s
target triple = "aarch64-arm-none-eabi"
-; Expected to not transform
+; Expected to transform
; *p = (a * b);
; return (a * b) * a;
define <4 x float> @mul_triangle(<4 x float> %a, <4 x float> %b, ptr %p) {
; CHECK-LABEL: mul_triangle:
; CHECK: // %bb.0: // %entry
-; CHECK-NEXT: ext v2.16b, v0.16b, v0.16b, #8
-; CHECK-NEXT: ext v3.16b, v1.16b, v1.16b, #8
-; CHECK-NEXT: zip2 v4.2s, v0.2s, v2.2s
-; CHECK-NEXT: zip1 v0.2s, v0.2s, v2.2s
-; CHECK-NEXT: zip2 v5.2s, v1.2s, v3.2s
-; CHECK-NEXT: zip1 v1.2s, v1.2s, v3.2s
-; CHECK-NEXT: fmul v6.2s, v5.2s, v4.2s
-; CHECK-NEXT: fneg v2.2s, v6.2s
-; CHECK-NEXT: fmla v2.2s, v0.2s, v1.2s
-; CHECK-NEXT: fmul v3.2s, v4.2s, v1.2s
-; CHECK-NEXT: fmla v3.2s, v0.2s, v5.2s
-; CHECK-NEXT: fmul v1.2s, v3.2s, v4.2s
-; CHECK-NEXT: fmul v5.2s, v3.2s, v0.2s
-; CHECK-NEXT: st2 { v2.2s, v3.2s }, [x0]
-; CHECK-NEXT: fneg v1.2s, v1.2s
-; CHECK-NEXT: fmla v5.2s, v4.2s, v2.2s
-; CHECK-NEXT: fmla v1.2s, v0.2s, v2.2s
-; CHECK-NEXT: zip1 v0.4s, v1.4s, v5.4s
+; CHECK-NEXT: movi v3.2d, #0000000000000000
+; CHECK-NEXT: movi v2.2d, #0000000000000000
+; CHECK-NEXT: fcmla v3.4s, v1.4s, v0.4s, #0
+; CHECK-NEXT: fcmla v3.4s, v1.4s, v0.4s, #90
+; CHECK-NEXT: fcmla v2.4s, v0.4s, v3.4s, #0
+; CHECK-NEXT: str q3, [x0]
+; CHECK-NEXT: fcmla v2.4s, v0.4s, v3.4s, #90
+; CHECK-NEXT: mov v0.16b, v2.16b
; CHECK-NEXT: ret
entry:
%strided.vec = shufflevector <4 x float> %a, <4 x float> poison, <2 x i32> <i32 0, i32 2>