Value *createTileLoadStoreLoops(BasicBlock *Start, BasicBlock *End,
IRBuilderBase &B, Value *Row, Value *Col,
Value *Ptr, Value *Stride, Value *Tile);
- Value *createTileDPBSSDLoops(BasicBlock *Start, BasicBlock *End,
- IRBuilderBase &B, Value *Row, Value *Col,
- Value *K, Value *Acc, Value *LHS, Value *RHS);
+ template <Intrinsic::ID IntrID>
+ typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
+ IntrID == Intrinsic::x86_tdpbf16ps_internal,
+ Value *>::type
+ createTileDPLoops(BasicBlock *Start, BasicBlock *End, IRBuilderBase &B,
+ Value *Row, Value *Col, Value *K, Value *Acc, Value *LHS,
+ Value *RHS);
template <bool IsTileLoad>
bool lowerTileLoadStore(Instruction *TileLoadStore);
- bool lowerTileDPBSSD(Instruction *TileDPBSSD);
+ template <Intrinsic::ID IntrID>
+ typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
+ IntrID == Intrinsic::x86_tdpbf16ps_internal,
+ bool>::type
+ lowerTileDP(Instruction *TileDP);
bool lowerTileZero(Instruction *TileZero);
};
}
}
-Value *X86LowerAMXIntrinsics::createTileDPBSSDLoops(
- BasicBlock *Start, BasicBlock *End, IRBuilderBase &B, Value *Row,
- Value *Col, Value *K, Value *Acc, Value *LHS, Value *RHS) {
+template <Intrinsic::ID IntrID>
+typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
+ IntrID == Intrinsic::x86_tdpbf16ps_internal,
+ Value *>::type
+X86LowerAMXIntrinsics::createTileDPLoops(BasicBlock *Start, BasicBlock *End,
+ IRBuilderBase &B, Value *Row,
+ Value *Col, Value *K, Value *Acc,
+ Value *LHS, Value *RHS) {
+ std::string IntrinName =
+ IntrID == Intrinsic::x86_tdpbssd_internal ? "tiledpbssd" : "tdpbf16ps";
Loop *RowLoop = nullptr;
Loop *ColLoop = nullptr;
Loop *InnerLoop = nullptr;
}
BasicBlock *RowBody = createLoop(Start, End, Row, B.getInt16(1),
- "tiledpbssd.scalarize.rows", B, RowLoop);
+ IntrinName + ".scalarize.rows", B, RowLoop);
BasicBlock *RowLatch = RowBody->getSingleSuccessor();
BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col, B.getInt16(1),
- "tiledpbssd.scalarize.cols", B, ColLoop);
+ IntrinName + ".scalarize.cols", B, ColLoop);
+
BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor();
B.SetInsertPoint(ColBody->getTerminator());
BasicBlock *InnerBody =
createLoop(ColBody, ColLoopLatch, K, B.getInt16(1),
- "tiledpbssd.scalarize.inner", B, InnerLoop);
+ IntrinName + ".scalarize.inner", B, InnerLoop);
BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor();
BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor();
PHINode *VecCPhi = B.CreatePHI(V256I32Ty, 2, "vec.c.inner.phi");
VecCPhi->addIncoming(VecCPhiColLoop, ColBody);
- // tiledpbssd.scalarize.inner.body:
- // calculate idxa, idxb
- // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc
- // %elta = extractelement <256 x i32> %veca, i16 %idxa
- // %eltav4i8 = bitcast i32 %elta to <4 x i8>
- // %eltb = extractelement <256 x i32> %vecb, i16 %idxb
- // %eltbv4i8 = bitcast i32 %eltb to <4 x i8>
- // %eltav4i32 = sext <4 x i8> %eltav4i8 to <4 x i32>
- // %eltbv4i32 = sext <4 x i8> %eltbv4i8 to <4 x i32>
- // %mulab = mul <4 x i32> %eltbv4i32, %eltav4i32
- // %acc = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %131)
- // %neweltc = add i32 %elt, %acc
- // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc,
- // i16 %idxc
-
B.SetInsertPoint(InnerBody->getTerminator());
Value *IdxA =
B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentInner);
Value *IdxB =
B.CreateAdd(B.CreateMul(CurrentInner, B.getInt16(16)), CurrentCol);
-
- FixedVectorType *V4I8Ty = FixedVectorType::get(B.getInt8Ty(), 4);
- FixedVectorType *V4I32Ty = FixedVectorType::get(B.getInt32Ty(), 4);
- Value *EltC = B.CreateExtractElement(VecCPhi, IdxC);
- Value *EltA = B.CreateExtractElement(VecA, IdxA);
- Value *SubVecA = B.CreateBitCast(EltA, V4I8Ty);
- Value *EltB = B.CreateExtractElement(VecB, IdxB);
- Value *SubVecB = B.CreateBitCast(EltB, V4I8Ty);
- Value *SEXTSubVecB = B.CreateSExt(SubVecB, V4I32Ty);
- Value *SEXTSubVecA = B.CreateSExt(SubVecA, V4I32Ty);
- Value *SubVecR = B.CreateAddReduce(B.CreateMul(SEXTSubVecA, SEXTSubVecB));
- Value *ResElt = B.CreateAdd(EltC, SubVecR);
- Value *NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC);
+ Value *NewVecC = nullptr;
+
+ if (IntrID == Intrinsic::x86_tdpbssd_internal) {
+ // tiledpbssd.scalarize.inner.body:
+ // calculate idxa, idxb
+ // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc
+ // %elta = extractelement <256 x i32> %veca, i16 %idxa
+ // %eltav4i8 = bitcast i32 %elta to <4 x i8>
+ // %eltb = extractelement <256 x i32> %vecb, i16 %idxb
+ // %eltbv4i8 = bitcast i32 %eltb to <4 x i8>
+ // %eltav4i32 = sext <4 x i8> %eltav4i8 to <4 x i32>
+ // %eltbv4i32 = sext <4 x i8> %eltbv4i8 to <4 x i32>
+ // %mulab = mul <4 x i32> %eltbv4i32, %eltav4i32
+ // %acc = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %131)
+ // %neweltc = add i32 %elt, %acc
+ // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc,
+ // i16 %idxc
+ FixedVectorType *V4I8Ty = FixedVectorType::get(B.getInt8Ty(), 4);
+ FixedVectorType *V4I32Ty = FixedVectorType::get(B.getInt32Ty(), 4);
+ Value *EltC = B.CreateExtractElement(VecCPhi, IdxC);
+ Value *EltA = B.CreateExtractElement(VecA, IdxA);
+ Value *SubVecA = B.CreateBitCast(EltA, V4I8Ty);
+ Value *EltB = B.CreateExtractElement(VecB, IdxB);
+ Value *SubVecB = B.CreateBitCast(EltB, V4I8Ty);
+ Value *SEXTSubVecB = B.CreateSExt(SubVecB, V4I32Ty);
+ Value *SEXTSubVecA = B.CreateSExt(SubVecA, V4I32Ty);
+ Value *SubVecR = B.CreateAddReduce(B.CreateMul(SEXTSubVecA, SEXTSubVecB));
+ Value *ResElt = B.CreateAdd(EltC, SubVecR);
+ NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC);
+ } else if (IntrID == Intrinsic::x86_tdpbf16ps_internal) {
+ // tiledpbf16ps.scalarize.inner.body:
+ // calculate idxa, idxb, idxc
+ // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc
+ // %eltcf32 = bitcast i32 %eltc to float
+ // %elta = extractelement <256 x i32> %veca, i16 %idxa
+ // %eltav2i16 = bitcast i32 %elta to <2 x i16>
+ // %eltb = extractelement <256 x i32> %vecb, i16 %idxb
+ // %eltbv2i16 = bitcast i32 %eltb to <2 x i16>
+ // %shufflea = shufflevector <2 x i16> %elta, <2 x i16> zeroinitializer, <4
+ // x i32> <i32 2, i32 0, i32 3, i32 1>
+ // %eltav2f32 = bitcast <4 x i16> %shufflea to <2 x float>
+ // %shuffleb = shufflevector <2 x i16> %eltb, <2 xi16> zeroinitializer, <4 x
+ // i32> <i32 2, i32 0, i32 3, i32 1>
+ // %eltbv2f32 = bitcast <4 x i16> %shuffleb to <2 x float>
+ // %mulab = fmul <2 x float> %eltav2f32, %eltbv2f32
+ // %acc = call float
+ // @llvm.vector.reduce.fadd.v2f32(float %eltcf32, <2 x float> %mulab)
+ // %neweltc = bitcast float %acc to i32
+ // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc,
+ // i16 %idxc
+ // %NewVecD = insertelement <256 x i32> %vec.d.inner.phi, i32 %neweltc,
+ // i16 %idxc
+ FixedVectorType *V2I16Ty = FixedVectorType::get(B.getInt16Ty(), 2);
+ FixedVectorType *V2F32Ty = FixedVectorType::get(B.getFloatTy(), 2);
+ Value *EltC = B.CreateExtractElement(VecCPhi, IdxC);
+ Value *EltCF32 = B.CreateBitCast(EltC, B.getFloatTy());
+ Value *EltA = B.CreateExtractElement(VecA, IdxA);
+ Value *SubVecA = B.CreateBitCast(EltA, V2I16Ty);
+ Value *EltB = B.CreateExtractElement(VecB, IdxB);
+ Value *SubVecB = B.CreateBitCast(EltB, V2I16Ty);
+ Value *ZeroV2I16 = Constant::getNullValue(V2I16Ty);
+ int ShuffleMask[4] = {2, 0, 3, 1};
+ auto ShuffleArray = makeArrayRef(ShuffleMask);
+ Value *AV2F32 = B.CreateBitCast(
+ B.CreateShuffleVector(SubVecA, ZeroV2I16, ShuffleArray), V2F32Ty);
+ Value *BV2F32 = B.CreateBitCast(
+ B.CreateShuffleVector(SubVecB, ZeroV2I16, ShuffleArray), V2F32Ty);
+ Value *SubVecR = B.CreateFAddReduce(EltCF32, B.CreateFMul(AV2F32, BV2F32));
+ Value *ResElt = B.CreateBitCast(SubVecR, B.getInt32Ty());
+ NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC);
+ }
// tiledpbssd.scalarize.cols.latch:
// %NewEltC = extractelement <256 x i32> %vec.c.phi.col, i16 %idxc
return NewVecD;
}
-bool X86LowerAMXIntrinsics::lowerTileDPBSSD(Instruction *TileDPBSSD) {
+template <Intrinsic::ID IntrID>
+typename std::enable_if<IntrID == Intrinsic::x86_tdpbssd_internal ||
+ IntrID == Intrinsic::x86_tdpbf16ps_internal,
+ bool>::type
+X86LowerAMXIntrinsics::lowerTileDP(Instruction *TileDP) {
Value *M, *N, *K, *C, *A, *B;
- match(TileDPBSSD, m_Intrinsic<Intrinsic::x86_tdpbssd_internal>(
- m_Value(M), m_Value(N), m_Value(K), m_Value(C),
- m_Value(A), m_Value(B)));
- Instruction *InsertI = TileDPBSSD;
- IRBuilder<> PreBuilder(TileDPBSSD);
- PreBuilder.SetInsertPoint(TileDPBSSD);
+ match(TileDP, m_Intrinsic<IntrID>(m_Value(M), m_Value(N), m_Value(K),
+ m_Value(C), m_Value(A), m_Value(B)));
+ Instruction *InsertI = TileDP;
+ IRBuilder<> PreBuilder(TileDP);
+ PreBuilder.SetInsertPoint(TileDP);
// We visit the loop with (m, n/4, k/4):
// %n_dword = lshr i16 %n, 2
// %k_dword = lshr i16 %k, 2
BasicBlock *Start = InsertI->getParent();
BasicBlock *End =
SplitBlock(InsertI->getParent(), InsertI, &DTU, LI, nullptr, "continue");
- IRBuilder<> Builder(TileDPBSSD);
- Value *ResVec =
- createTileDPBSSDLoops(Start, End, Builder, M, NDWord, KDWord, C, A, B);
+ IRBuilder<> Builder(TileDP);
+ Value *ResVec = createTileDPLoops<IntrID>(Start, End, Builder, M, NDWord,
+ KDWord, C, A, B);
// we cannot assume there always be bitcast after tiledpbssd. So we need to
// insert one bitcast as required
Builder.SetInsertPoint(End->getFirstNonPHI());
Value *ResAMX =
Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext()));
- // Delete tiledpbssd intrinsic and do some clean-up.
- for (auto UI = TileDPBSSD->use_begin(), UE = TileDPBSSD->use_end();
- UI != UE;) {
+ // Delete TileDP intrinsic and do some clean-up.
+ for (auto UI = TileDP->use_begin(), UE = TileDP->use_end(); UI != UE;) {
Instruction *I = cast<Instruction>((UI++)->getUser());
Value *Vec;
if (match(I, m_BitCast(m_Value(Vec)))) {
I->eraseFromParent();
}
}
- TileDPBSSD->replaceAllUsesWith(ResAMX);
- TileDPBSSD->eraseFromParent();
+ TileDP->replaceAllUsesWith(ResAMX);
+ TileDP->eraseFromParent();
return true;
}
case Intrinsic::x86_tileloadd64_internal:
case Intrinsic::x86_tilestored64_internal:
case Intrinsic::x86_tilezero_internal:
+ case Intrinsic::x86_tdpbf16ps_internal:
WorkList.push_back(Inst);
break;
default:
for (auto *Inst : WorkList) {
switch (Inst->getIntrinsicID()) {
case Intrinsic::x86_tdpbssd_internal:
- C = lowerTileDPBSSD(Inst) || C;
+ C = lowerTileDP<Intrinsic::x86_tdpbssd_internal>(Inst) || C;
+ break;
+ case Intrinsic::x86_tdpbf16ps_internal:
+ C = lowerTileDP<Intrinsic::x86_tdpbf16ps_internal>(Inst) || C;
break;
case Intrinsic::x86_tileloadd64_internal:
C = lowerTileLoadStore<true>(Inst) || C;
ret void
}
-define dso_local void @test_amx_dp(i16 signext %row, i16 signext %col, i16 signext %k, <256 x i32> %c, <256 x i32> %a, <256 x i32> %b, <256 x i32>* %vptr) #0 {
-; CHECK-LABEL: @test_amx_dp(
+define dso_local void @test_amx_dpbssd(i16 signext %row, i16 signext %col, i16 signext %k, <256 x i32> %c, <256 x i32> %a, <256 x i32> %b, <256 x i32>* %vptr) #0 {
+; CHECK-LABEL: @test_amx_dpbssd(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[A_AMX:%.*]] = bitcast <256 x i32> [[A:%.*]] to x86_amx
; CHECK-NEXT: [[B_AMX:%.*]] = bitcast <256 x i32> [[B:%.*]] to x86_amx
ret void
}
+define dso_local void @test_amx_dpbf16ps(i16 signext %row, i16 signext %col, i16 signext %k, <256 x i32> %c, <256 x i32> %a, <256 x i32> %b, <256 x i32>* %vptr) #0 {
+; CHECK-LABEL: @test_amx_dpbf16ps(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[A_AMX:%.*]] = bitcast <256 x i32> [[A:%.*]] to x86_amx
+; CHECK-NEXT: [[B_AMX:%.*]] = bitcast <256 x i32> [[B:%.*]] to x86_amx
+; CHECK-NEXT: [[C_AMX:%.*]] = bitcast <256 x i32> [[C:%.*]] to x86_amx
+; CHECK-NEXT: [[TMP0:%.*]] = lshr i16 [[COL:%.*]], 2
+; CHECK-NEXT: [[TMP1:%.*]] = lshr i16 [[K:%.*]], 2
+; CHECK-NEXT: br label [[TDPBF16PS_SCALARIZE_ROWS_HEADER:%.*]]
+; CHECK: tdpbf16ps.scalarize.rows.header:
+; CHECK-NEXT: [[TDPBF16PS_SCALARIZE_ROWS_IV:%.*]] = phi i16 [ 0, [[ENTRY:%.*]] ], [ [[TDPBF16PS_SCALARIZE_ROWS_STEP:%.*]], [[TDPBF16PS_SCALARIZE_ROWS_LATCH:%.*]] ]
+; CHECK-NEXT: [[VEC_C_PHI_ROW:%.*]] = phi <256 x i32> [ [[C]], [[ENTRY]] ], [ [[TMP21:%.*]], [[TDPBF16PS_SCALARIZE_ROWS_LATCH]] ]
+; CHECK-NEXT: [[VEC_D_PHI_ROW:%.*]] = phi <256 x i32> [ zeroinitializer, [[ENTRY]] ], [ [[TMP23:%.*]], [[TDPBF16PS_SCALARIZE_ROWS_LATCH]] ]
+; CHECK-NEXT: br label [[TDPBF16PS_SCALARIZE_ROWS_BODY:%.*]]
+; CHECK: tdpbf16ps.scalarize.rows.body:
+; CHECK-NEXT: br label [[TDPBF16PS_SCALARIZE_COLS_HEADER:%.*]]
+; CHECK: tdpbf16ps.scalarize.cols.header:
+; CHECK-NEXT: [[TDPBF16PS_SCALARIZE_COLS_IV:%.*]] = phi i16 [ 0, [[TDPBF16PS_SCALARIZE_ROWS_BODY]] ], [ [[TDPBF16PS_SCALARIZE_COLS_STEP:%.*]], [[TDPBF16PS_SCALARIZE_COLS_LATCH:%.*]] ]
+; CHECK-NEXT: [[VEC_C_PHI_COL:%.*]] = phi <256 x i32> [ [[VEC_C_PHI_ROW]], [[TDPBF16PS_SCALARIZE_ROWS_BODY]] ], [ [[TMP21]], [[TDPBF16PS_SCALARIZE_COLS_LATCH]] ]
+; CHECK-NEXT: [[VEC_D_PHI_COL:%.*]] = phi <256 x i32> [ [[VEC_D_PHI_ROW]], [[TDPBF16PS_SCALARIZE_ROWS_BODY]] ], [ [[TMP23]], [[TDPBF16PS_SCALARIZE_COLS_LATCH]] ]
+; CHECK-NEXT: [[TMP2:%.*]] = mul i16 [[TDPBF16PS_SCALARIZE_ROWS_IV]], 16
+; CHECK-NEXT: [[TMP3:%.*]] = add i16 [[TMP2]], [[TDPBF16PS_SCALARIZE_COLS_IV]]
+; CHECK-NEXT: br label [[TDPBF16PS_SCALARIZE_COLS_BODY:%.*]]
+; CHECK: tdpbf16ps.scalarize.cols.body:
+; CHECK-NEXT: br label [[TDPBF16PS_SCALARIZE_INNER_HEADER:%.*]]
+; CHECK: tdpbf16ps.scalarize.inner.header:
+; CHECK-NEXT: [[TDPBF16PS_SCALARIZE_INNER_IV:%.*]] = phi i16 [ 0, [[TDPBF16PS_SCALARIZE_COLS_BODY]] ], [ [[TDPBF16PS_SCALARIZE_INNER_STEP:%.*]], [[TDPBF16PS_SCALARIZE_INNER_LATCH:%.*]] ]
+; CHECK-NEXT: [[VEC_C_INNER_PHI:%.*]] = phi <256 x i32> [ [[VEC_C_PHI_COL]], [[TDPBF16PS_SCALARIZE_COLS_BODY]] ], [ [[TMP21]], [[TDPBF16PS_SCALARIZE_INNER_LATCH]] ]
+; CHECK-NEXT: br label [[TDPBF16PS_SCALARIZE_INNER_BODY:%.*]]
+; CHECK: tdpbf16ps.scalarize.inner.body:
+; CHECK-NEXT: [[TMP4:%.*]] = mul i16 [[TDPBF16PS_SCALARIZE_ROWS_IV]], 16
+; CHECK-NEXT: [[TMP5:%.*]] = add i16 [[TMP4]], [[TDPBF16PS_SCALARIZE_INNER_IV]]
+; CHECK-NEXT: [[TMP6:%.*]] = mul i16 [[TDPBF16PS_SCALARIZE_INNER_IV]], 16
+; CHECK-NEXT: [[TMP7:%.*]] = add i16 [[TMP6]], [[TDPBF16PS_SCALARIZE_COLS_IV]]
+; CHECK-NEXT: [[TMP8:%.*]] = extractelement <256 x i32> [[VEC_C_INNER_PHI]], i16 [[TMP3]]
+; CHECK-NEXT: [[TMP9:%.*]] = bitcast i32 [[TMP8]] to float
+; CHECK-NEXT: [[TMP10:%.*]] = extractelement <256 x i32> [[A]], i16 [[TMP5]]
+; CHECK-NEXT: [[TMP11:%.*]] = bitcast i32 [[TMP10]] to <2 x i16>
+; CHECK-NEXT: [[TMP12:%.*]] = extractelement <256 x i32> [[B]], i16 [[TMP7]]
+; CHECK-NEXT: [[TMP13:%.*]] = bitcast i32 [[TMP12]] to <2 x i16>
+; CHECK-NEXT: [[TMP14:%.*]] = shufflevector <2 x i16> [[TMP11]], <2 x i16> zeroinitializer, <4 x i32> <i32 2, i32 0, i32 3, i32 1>
+; CHECK-NEXT: [[TMP15:%.*]] = bitcast <4 x i16> [[TMP14]] to <2 x float>
+; CHECK-NEXT: [[TMP16:%.*]] = shufflevector <2 x i16> [[TMP13]], <2 x i16> zeroinitializer, <4 x i32> <i32 2, i32 0, i32 3, i32 1>
+; CHECK-NEXT: [[TMP17:%.*]] = bitcast <4 x i16> [[TMP16]] to <2 x float>
+; CHECK-NEXT: [[TMP18:%.*]] = fmul <2 x float> [[TMP15]], [[TMP17]]
+; CHECK-NEXT: [[TMP19:%.*]] = call float @llvm.vector.reduce.fadd.v2f32(float [[TMP9]], <2 x float> [[TMP18]])
+; CHECK-NEXT: [[TMP20:%.*]] = bitcast float [[TMP19]] to i32
+; CHECK-NEXT: [[TMP21]] = insertelement <256 x i32> [[VEC_C_INNER_PHI]], i32 [[TMP20]], i16 [[TMP3]]
+; CHECK-NEXT: br label [[TDPBF16PS_SCALARIZE_INNER_LATCH]]
+; CHECK: tdpbf16ps.scalarize.inner.latch:
+; CHECK-NEXT: [[TDPBF16PS_SCALARIZE_INNER_STEP]] = add i16 [[TDPBF16PS_SCALARIZE_INNER_IV]], 1
+; CHECK-NEXT: [[TDPBF16PS_SCALARIZE_INNER_COND:%.*]] = icmp ne i16 [[TDPBF16PS_SCALARIZE_INNER_STEP]], [[TMP1]]
+; CHECK-NEXT: br i1 [[TDPBF16PS_SCALARIZE_INNER_COND]], label [[TDPBF16PS_SCALARIZE_INNER_HEADER]], label [[TDPBF16PS_SCALARIZE_COLS_LATCH]]
+; CHECK: tdpbf16ps.scalarize.cols.latch:
+; CHECK-NEXT: [[TDPBF16PS_SCALARIZE_COLS_STEP]] = add i16 [[TDPBF16PS_SCALARIZE_COLS_IV]], 1
+; CHECK-NEXT: [[TDPBF16PS_SCALARIZE_COLS_COND:%.*]] = icmp ne i16 [[TDPBF16PS_SCALARIZE_COLS_STEP]], [[TMP0]]
+; CHECK-NEXT: [[TMP22:%.*]] = extractelement <256 x i32> [[TMP21]], i16 [[TMP3]]
+; CHECK-NEXT: [[TMP23]] = insertelement <256 x i32> [[VEC_D_PHI_COL]], i32 [[TMP22]], i16 [[TMP3]]
+; CHECK-NEXT: br i1 [[TDPBF16PS_SCALARIZE_COLS_COND]], label [[TDPBF16PS_SCALARIZE_COLS_HEADER]], label [[TDPBF16PS_SCALARIZE_ROWS_LATCH]]
+; CHECK: tdpbf16ps.scalarize.rows.latch:
+; CHECK-NEXT: [[TDPBF16PS_SCALARIZE_ROWS_STEP]] = add i16 [[TDPBF16PS_SCALARIZE_ROWS_IV]], 1
+; CHECK-NEXT: [[TDPBF16PS_SCALARIZE_ROWS_COND:%.*]] = icmp ne i16 [[TDPBF16PS_SCALARIZE_ROWS_STEP]], [[ROW:%.*]]
+; CHECK-NEXT: br i1 [[TDPBF16PS_SCALARIZE_ROWS_COND]], label [[TDPBF16PS_SCALARIZE_ROWS_HEADER]], label [[CONTINUE:%.*]]
+; CHECK: continue:
+; CHECK-NEXT: [[TMP24:%.*]] = bitcast <256 x i32> [[TMP23]] to x86_amx
+; CHECK-NEXT: store <256 x i32> [[TMP23]], <256 x i32>* [[VPTR:%.*]], align 64
+; CHECK-NEXT: ret void
+;
+entry:
+ %a.amx = bitcast <256 x i32> %a to x86_amx
+ %b.amx = bitcast <256 x i32> %b to x86_amx
+ %c.amx = bitcast <256 x i32> %c to x86_amx
+ %acc = call x86_amx @llvm.x86.tdpbf16ps.internal(i16 %row, i16 %col, i16 %k, x86_amx %c.amx, x86_amx %a.amx, x86_amx %b.amx)
+ %vec = bitcast x86_amx %acc to <256 x i32>
+ store <256 x i32> %vec, <256 x i32>* %vptr, align 64
+ ret void
+}
+
define dso_local void @test_amx_store(i16 signext %row, i16 signext %col, i8 *%ptr, i64 %stride, <256 x i32>* %vptr, <256 x i32> %vec) #0 {
; CHECK-LABEL: @test_amx_store(
; CHECK-NEXT: entry:
declare x86_amx @llvm.x86.tilezero.internal(i16, i16)
declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64)
declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)
+declare x86_amx @llvm.x86.tdpbf16ps.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)
declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx)
attributes #0 = { noinline nounwind optnone }