++I, ++Pos)
MI = &*I;
}
+ MIRef(MachineInstr *MI)
+ : MI(MI), MBB(MI->getParent()),
+ Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {}
MIRef(MachineInstr *MI, MachineBasicBlock *MBB)
: MI(MI), MBB(MBB),
Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {}
bool operator==(const MIRef &RHS) const {
return MI == RHS.MI && MBB == RHS.MBB;
}
+ bool operator!=(const MIRef &RHS) const { return !(*this == RHS); }
bool operator<(const MIRef &RHS) const {
return MBB < RHS.MBB || (MBB == RHS.MBB && Pos < RHS.Pos);
}
struct BBInfo {
MIRef FirstAMX;
MIRef LastCall;
- MIRef LastShape;
+ bool HasAMXRegLiveIn = false;
bool TileCfgForbidden = false;
bool NeedTileCfgLiveIn = false;
};
MachineRegisterInfo *MRI;
const MachineLoopInfo *MLI;
SmallSet<MachineInstr *, 8> DefVisited;
- SmallSet<MachineBasicBlock *, 8> ShapeBBs;
DenseMap<MachineBasicBlock *, BBInfo> BBVisitedInfo;
+ DenseMap<MachineBasicBlock *, SmallVector<MIRef, 8>> ShapeBBs;
/// Check if the callee will clobber AMX registers.
bool isDestructiveCall(MachineInstr &MI, BitVector UsableRegs) {
/// Collect the shape def information for later use.
void collectShapeInfo(MachineInstr &MI);
+ /// Try to hoist shapes definded below AMX instructions.
+ bool hoistShapesInBB(MachineBasicBlock *MBB, SmallVectorImpl<MIRef> &Shapes) {
+ MIRef &FirstAMX = BBVisitedInfo[MBB].FirstAMX;
+ auto FirstShapeBelowAMX = llvm::lower_bound(Shapes, FirstAMX);
+ auto InsertPoint = FirstAMX.MI->getIterator();
+ for (auto I = FirstShapeBelowAMX, E = Shapes.end(); I != E; ++I) {
+ // Do not hoist instructions that access memory.
+ if (I->MI->mayLoadOrStore())
+ return false;
+ for (auto &MO : I->MI->operands()) {
+ if (MO.isDef())
+ continue;
+ // Do not hoist instructions if the sources' def under AMX instruction.
+ // TODO: We can handle isMoveImmediate MI here.
+ if (MO.isReg() && MIRef(MRI->getVRegDef(MO.getReg())) > FirstAMX)
+ return false;
+ // TODO: Maybe need more checks here.
+ }
+ MBB->insert(InsertPoint, I->MI->removeFromParent());
+ }
+ // We only need to mark the last shape in the BB now.
+ Shapes.clear();
+ Shapes.push_back(MIRef(&*--InsertPoint, MBB));
+ return true;
+ }
+
public:
X86PreTileConfig() : MachineFunctionPass(ID) {}
void X86PreTileConfig::collectShapeInfo(MachineInstr &MI) {
auto RecordShape = [&](MachineInstr *MI, MachineBasicBlock *MBB) {
MIRef MIR(MI, MBB);
- if (BBVisitedInfo[MBB].LastShape < MIR)
- BBVisitedInfo[MBB].LastShape = MIR;
- ShapeBBs.insert(MBB);
+ auto I = llvm::lower_bound(ShapeBBs[MBB], MIR);
+ if (I == ShapeBBs[MBB].end() || *I != MIR)
+ ShapeBBs[MBB].insert(I, MIR);
};
SmallVector<Register, 8> WorkList(
else
CfgLiveInBBs.push_back(&MBB);
}
+ if (BBVisitedInfo[&MBB].FirstAMX || BBVisitedInfo[&MBB].HasAMXRegLiveIn)
+ for (auto *Succ : MBB.successors())
+ if (!isLoopBackEdge(Succ, &MBB))
+ BBVisitedInfo[Succ].HasAMXRegLiveIn = true;
}
// Update NeedTileCfgLiveIn for predecessors.
return false;
// Avoid to insert ldtilecfg before any shape defs.
- SmallVector<MachineBasicBlock *, 8> WorkList(
- make_range(ShapeBBs.begin(), ShapeBBs.end()));
+ SmallVector<MachineBasicBlock *, 8> WorkList;
+ for (auto &I : ShapeBBs) {
+ // TODO: We can hoist shapes across BBs here.
+ if (BBVisitedInfo[I.first].HasAMXRegLiveIn)
+ REPORT_CONFIG_FAIL
+ if (BBVisitedInfo[I.first].FirstAMX &&
+ BBVisitedInfo[I.first].FirstAMX < I.second.back() &&
+ !hoistShapesInBB(I.first, I.second))
+ REPORT_CONFIG_FAIL
+ WorkList.push_back(I.first);
+ }
while (!WorkList.empty()) {
MachineBasicBlock *MBB = WorkList.pop_back_val();
for (auto *Pred : MBB->predecessors()) {
} else {
// Avoid the BB to be multi visited.
VisitedOrInserted.insert(I);
- // We cannot sink it across any AMX instruction.
- if (BBVisitedInfo[I.MBB].FirstAMX)
- REPORT_CONFIG_FAIL;
// Sink the inserting point along the chain with NeedTileCfgLiveIn =
// true when MBB isn't all shapes reachable.
for (auto *Succ : I.MBB->successors())
// A given point might be forked due to shape conditions are not met.
for (MIRef I : InsertPoints) {
- // Even MBB is all shapes reachable, we still need to check if there's
- // AMX that intersects with shapes in the same MBB.
- if (BBVisitedInfo[I.MBB].FirstAMX &&
- BBVisitedInfo[I.MBB].FirstAMX < BBVisitedInfo[I.MBB].LastShape)
- REPORT_CONFIG_FAIL;
// Make sure we insert ldtilecfg after the last shape def in MBB.
- if (I < BBVisitedInfo[I.MBB].LastShape)
- I = BBVisitedInfo[I.MBB].LastShape;
+ if (ShapeBBs.count(I.MBB) && I < ShapeBBs[I.MBB].back())
+ I = ShapeBBs[I.MBB].back();
// There're chances the MBB is sunk more than once. Record it to avoid
// multi insert.
if (VisitedOrInserted.insert(I).second) {
define <256 x i32> @test_shape_sched(i16 %m, i16 %n, i16 %k, <256 x i32> %c, <256 x i32> %a, <256 x i32> %b) nounwind {
; Just to make sure shape def is not scheduled across ldtilecfg.
+; CHECK-LABEL: test_shape_sched:
; CHECK: ldtilecfg
; CHECK-NOT: movw
%c1 = bitcast <256 x i32> %c to x86_amx
ret <256 x i32> %res
}
+define <256 x i32> @test_shape_sched2(i16 %m, i16 %n, i16 %k, i8* %c, i8* %a, i8* %b) nounwind {
+; Just to make sure shape def is not scheduled across ldtilecfg.
+; CHECK-LABEL: test_shape_sched2:
+; CHECK: ldtilecfg
+; CHECK-NOT: movw
+ %aa = lshr i16 %k, 2
+ %c1 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %m, i16 %n, i8* %c, i64 64)
+ %a1 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %m, i16 %k, i8* %a, i64 64)
+ %b1 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %aa, i16 %n, i8* %b, i64 64)
+ %t = call x86_amx @llvm.x86.tdpbssd.internal(i16 %m, i16 %n, i16 %k, x86_amx %c1, x86_amx %a1, x86_amx %b1)
+ %res = bitcast x86_amx %t to <256 x i32>
+ ret <256 x i32> %res
+}
+declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64)
declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)