From 016092d786f226f403fce5b5d0888dfa939b3f21 Mon Sep 17 00:00:00 2001 From: "Wang, Pengfei" Date: Tue, 27 Apr 2021 09:51:46 +0800 Subject: [PATCH] Reapply "[X86][AMX] Try to hoist AMX shapes' def" We request no intersections between AMX instructions and their shapes' def when we insert ldtilecfg. However, this is not always ture resulting from not only users don't follow AMX API model, but also optimizations. This patch adds a mechanism that tries to hoist AMX shapes' def as well. It only hoists shapes inside a BB, we can improve it for cases across BBs in future. Currently, it only hoists shapes of which all sources' def above the first AMX instruction. We can improve for the case that only source that moves an immediate value to a register below AMX instruction. Reviewed By: xiangzhangllvm Differential Revision: https://reviews.llvm.org/D101067 --- llvm/lib/Target/X86/X86PreTileConfig.cpp | 69 ++++++++++++++++++++++++-------- llvm/test/CodeGen/X86/AMX/amx-sched.ll | 15 +++++++ 2 files changed, 67 insertions(+), 17 deletions(-) diff --git a/llvm/lib/Target/X86/X86PreTileConfig.cpp b/llvm/lib/Target/X86/X86PreTileConfig.cpp index 4ec9e79..9164dfd 100644 --- a/llvm/lib/Target/X86/X86PreTileConfig.cpp +++ b/llvm/lib/Target/X86/X86PreTileConfig.cpp @@ -57,6 +57,9 @@ struct MIRef { ++I, ++Pos) MI = &*I; } + MIRef(MachineInstr *MI) + : MI(MI), MBB(MI->getParent()), + Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {} MIRef(MachineInstr *MI, MachineBasicBlock *MBB) : MI(MI), MBB(MBB), Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {} @@ -66,6 +69,7 @@ struct MIRef { bool operator==(const MIRef &RHS) const { return MI == RHS.MI && MBB == RHS.MBB; } + bool operator!=(const MIRef &RHS) const { return !(*this == RHS); } bool operator<(const MIRef &RHS) const { return MBB < RHS.MBB || (MBB == RHS.MBB && Pos < RHS.Pos); } @@ -77,7 +81,7 @@ struct MIRef { struct BBInfo { MIRef FirstAMX; MIRef LastCall; - MIRef LastShape; + bool HasAMXRegLiveIn = false; bool TileCfgForbidden = false; bool NeedTileCfgLiveIn = false; }; @@ -86,8 +90,8 @@ class X86PreTileConfig : public MachineFunctionPass { MachineRegisterInfo *MRI; const MachineLoopInfo *MLI; SmallSet DefVisited; - SmallSet ShapeBBs; DenseMap BBVisitedInfo; + DenseMap> ShapeBBs; /// Check if the callee will clobber AMX registers. bool isDestructiveCall(MachineInstr &MI, BitVector UsableRegs) { @@ -124,6 +128,32 @@ class X86PreTileConfig : public MachineFunctionPass { /// Collect the shape def information for later use. void collectShapeInfo(MachineInstr &MI); + /// Try to hoist shapes definded below AMX instructions. + bool hoistShapesInBB(MachineBasicBlock *MBB, SmallVectorImpl &Shapes) { + MIRef &FirstAMX = BBVisitedInfo[MBB].FirstAMX; + auto FirstShapeBelowAMX = llvm::lower_bound(Shapes, FirstAMX); + auto InsertPoint = FirstAMX.MI->getIterator(); + for (auto I = FirstShapeBelowAMX, E = Shapes.end(); I != E; ++I) { + // Do not hoist instructions that access memory. + if (I->MI->mayLoadOrStore()) + return false; + for (auto &MO : I->MI->operands()) { + if (MO.isDef()) + continue; + // Do not hoist instructions if the sources' def under AMX instruction. + // TODO: We can handle isMoveImmediate MI here. + if (MO.isReg() && MIRef(MRI->getVRegDef(MO.getReg())) > FirstAMX) + return false; + // TODO: Maybe need more checks here. + } + MBB->insert(InsertPoint, I->MI->removeFromParent()); + } + // We only need to mark the last shape in the BB now. + Shapes.clear(); + Shapes.push_back(MIRef(&*--InsertPoint, MBB)); + return true; + } + public: X86PreTileConfig() : MachineFunctionPass(ID) {} @@ -165,9 +195,9 @@ INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig", void X86PreTileConfig::collectShapeInfo(MachineInstr &MI) { auto RecordShape = [&](MachineInstr *MI, MachineBasicBlock *MBB) { MIRef MIR(MI, MBB); - if (BBVisitedInfo[MBB].LastShape < MIR) - BBVisitedInfo[MBB].LastShape = MIR; - ShapeBBs.insert(MBB); + auto I = llvm::lower_bound(ShapeBBs[MBB], MIR); + if (I == ShapeBBs[MBB].end() || *I != MIR) + ShapeBBs[MBB].insert(I, MIR); }; SmallVector WorkList( @@ -229,6 +259,10 @@ bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) { else CfgLiveInBBs.push_back(&MBB); } + if (BBVisitedInfo[&MBB].FirstAMX || BBVisitedInfo[&MBB].HasAMXRegLiveIn) + for (auto *Succ : MBB.successors()) + if (!isLoopBackEdge(Succ, &MBB)) + BBVisitedInfo[Succ].HasAMXRegLiveIn = true; } // Update NeedTileCfgLiveIn for predecessors. @@ -252,8 +286,17 @@ bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) { return false; // Avoid to insert ldtilecfg before any shape defs. - SmallVector WorkList( - make_range(ShapeBBs.begin(), ShapeBBs.end())); + SmallVector WorkList; + for (auto &I : ShapeBBs) { + // TODO: We can hoist shapes across BBs here. + if (BBVisitedInfo[I.first].HasAMXRegLiveIn) + REPORT_CONFIG_FAIL + if (BBVisitedInfo[I.first].FirstAMX && + BBVisitedInfo[I.first].FirstAMX < I.second.back() && + !hoistShapesInBB(I.first, I.second)) + REPORT_CONFIG_FAIL + WorkList.push_back(I.first); + } while (!WorkList.empty()) { MachineBasicBlock *MBB = WorkList.pop_back_val(); for (auto *Pred : MBB->predecessors()) { @@ -282,9 +325,6 @@ bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) { } else { // Avoid the BB to be multi visited. VisitedOrInserted.insert(I); - // We cannot sink it across any AMX instruction. - if (BBVisitedInfo[I.MBB].FirstAMX) - REPORT_CONFIG_FAIL; // Sink the inserting point along the chain with NeedTileCfgLiveIn = // true when MBB isn't all shapes reachable. for (auto *Succ : I.MBB->successors()) @@ -296,14 +336,9 @@ bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) { // A given point might be forked due to shape conditions are not met. for (MIRef I : InsertPoints) { - // Even MBB is all shapes reachable, we still need to check if there's - // AMX that intersects with shapes in the same MBB. - if (BBVisitedInfo[I.MBB].FirstAMX && - BBVisitedInfo[I.MBB].FirstAMX < BBVisitedInfo[I.MBB].LastShape) - REPORT_CONFIG_FAIL; // Make sure we insert ldtilecfg after the last shape def in MBB. - if (I < BBVisitedInfo[I.MBB].LastShape) - I = BBVisitedInfo[I.MBB].LastShape; + if (ShapeBBs.count(I.MBB) && I < ShapeBBs[I.MBB].back()) + I = ShapeBBs[I.MBB].back(); // There're chances the MBB is sunk more than once. Record it to avoid // multi insert. if (VisitedOrInserted.insert(I).second) { diff --git a/llvm/test/CodeGen/X86/AMX/amx-sched.ll b/llvm/test/CodeGen/X86/AMX/amx-sched.ll index 7e704cf..790c6c9 100644 --- a/llvm/test/CodeGen/X86/AMX/amx-sched.ll +++ b/llvm/test/CodeGen/X86/AMX/amx-sched.ll @@ -2,6 +2,7 @@ define <256 x i32> @test_shape_sched(i16 %m, i16 %n, i16 %k, <256 x i32> %c, <256 x i32> %a, <256 x i32> %b) nounwind { ; Just to make sure shape def is not scheduled across ldtilecfg. +; CHECK-LABEL: test_shape_sched: ; CHECK: ldtilecfg ; CHECK-NOT: movw %c1 = bitcast <256 x i32> %c to x86_amx @@ -12,5 +13,19 @@ define <256 x i32> @test_shape_sched(i16 %m, i16 %n, i16 %k, <256 x i32> %c, <25 ret <256 x i32> %res } +define <256 x i32> @test_shape_sched2(i16 %m, i16 %n, i16 %k, i8* %c, i8* %a, i8* %b) nounwind { +; Just to make sure shape def is not scheduled across ldtilecfg. +; CHECK-LABEL: test_shape_sched2: +; CHECK: ldtilecfg +; CHECK-NOT: movw + %aa = lshr i16 %k, 2 + %c1 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %m, i16 %n, i8* %c, i64 64) + %a1 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %m, i16 %k, i8* %a, i64 64) + %b1 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %aa, i16 %n, i8* %b, i64 64) + %t = call x86_amx @llvm.x86.tdpbssd.internal(i16 %m, i16 %n, i16 %k, x86_amx %c1, x86_amx %a1, x86_amx %b1) + %res = bitcast x86_amx %t to <256 x i32> + ret <256 x i32> %res +} +declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) -- 2.7.4