[X86] Pass to transform amx intrinsics to scalar operation.
authorLuo, Yuanke <yuanke.luo@intel.com>
Thu, 4 Mar 2021 01:42:06 +0000 (09:42 +0800)
committerBing1 Yu <bing1.yu@intel.com>
Fri, 5 Mar 2021 08:02:02 +0000 (16:02 +0800)
This pass runs in any situations but we skip it when it is not O0 and the
function doesn't have optnone attribute. With -O0, the def of shape to amx
intrinsics is near the amx intrinsics code. We are not able to find a
point which post-dominate all the shape and dominate all amx intrinsics.
To decouple the dependency of the shape, we transform amx intrinsics
to scalar operation, so that compiling doesn't fail. In long term, we
 should improve fast register allocation to allocate amx register.

Reviewed By: pengfei

Differential Revision: https://reviews.llvm.org/D93594

12 files changed:
llvm/include/llvm/CodeGen/Passes.h
llvm/lib/Target/X86/CMakeLists.txt
llvm/lib/Target/X86/X86.h
llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp [new file with mode: 0644]
llvm/lib/Target/X86/X86LowerAMXType.cpp
llvm/lib/Target/X86/X86TargetMachine.cpp
llvm/test/CodeGen/X86/AMX/amx-low-intrinsics-no-amx-bitcast.ll [new file with mode: 0644]
llvm/test/CodeGen/X86/AMX/amx-low-intrinsics.ll [new file with mode: 0644]
llvm/test/CodeGen/X86/AMX/amx-type.ll
llvm/test/CodeGen/X86/O0-pipeline.ll
llvm/test/CodeGen/X86/opt-pipeline.ll
llvm/tools/opt/opt.cpp

index 17ec9a6..bca0c76 100644 (file)
@@ -497,9 +497,13 @@ namespace llvm {
   /// caller saved registers with stack slots.
   extern char &FixupStatepointCallerSavedID;
 
-  /// The pass transform load/store <256 x i32> to AMX load/store intrinsics
+  /// The pass transforms load/store <256 x i32> to AMX load/store intrinsics
   /// or split the data to two <128 x i32>.
   FunctionPass *createX86LowerAMXTypePass();
+
+  /// The pass transforms amx intrinsics to scalar operation if the function has
+  /// optnone attribute or it is O0.
+  FunctionPass *createX86LowerAMXIntrinsicsPass();
 } // End llvm namespace
 
 #endif
index c2796aa..a02612c 100644 (file)
@@ -34,6 +34,7 @@ set(sources
   X86DiscriminateMemOps.cpp
   X86LowerTileCopy.cpp
   X86LowerAMXType.cpp
+  X86LowerAMXIntrinsics.cpp
   X86TileConfig.cpp
   X86PreTileConfig.cpp
   X86ExpandPseudo.cpp
index 3f38e25..0240dc7 100644 (file)
@@ -175,6 +175,7 @@ void initializeX86PreTileConfigPass(PassRegistry &);
 void initializeX86TileConfigPass(PassRegistry &);
 void initializeX86LowerAMXTypeLegacyPassPass(PassRegistry &);
 void initializeX86LowerTileCopyPass(PassRegistry &);
+void initializeX86LowerAMXIntrinsicsLegacyPassPass(PassRegistry &);
 
 namespace X86AS {
 enum : unsigned {
diff --git a/llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp b/llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp
new file mode 100644 (file)
index 0000000..5b0fd95
--- /dev/null
@@ -0,0 +1,538 @@
+//===-- X86LowerAMXIntrinsics.cpp -X86 Scalarize AMX Intrinsics------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+/// \file Pass to transform amx intrinsics to scalar operations.
+/// This pass is always enabled and it skips when it is not -O0 and has no
+/// optnone attributes. With -O0 or optnone attribute, the def of shape to amx
+/// intrinsics is near the amx intrinsics code. We are not able to find a
+/// point which post-dominate all the shape and dominate all amx intrinsics.
+/// To decouple the dependency of the shape, we transform amx intrinsics
+/// to scalar operation, so that compiling doesn't fail. In long term, we
+/// should improve fast register allocation to allocate amx register.
+//===----------------------------------------------------------------------===//
+//
+#include "X86.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/Analysis/DomTreeUpdater.h"
+#include "llvm/Analysis/OptimizationRemarkEmitter.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/CodeGen/Passes.h"
+#include "llvm/CodeGen/TargetPassConfig.h"
+#include "llvm/CodeGen/ValueTypes.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/IntrinsicsX86.h"
+#include "llvm/IR/PatternMatch.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Pass.h"
+#include "llvm/Target/TargetMachine.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include "llvm/Transforms/Utils/LoopUtils.h"
+
+using namespace llvm;
+using namespace PatternMatch;
+
+#define DEBUG_TYPE "lower-amx-intrinsics"
+
+static bool isV256I32Ty(Type *Ty) {
+  if (auto *FVT = dyn_cast<FixedVectorType>(Ty))
+    return FVT->getNumElements() == 256 &&
+           FVT->getElementType()->isIntegerTy(32);
+  return false;
+}
+
+static BasicBlock *createLoop(BasicBlock *Preheader, BasicBlock *Exit,
+                              Value *Bound, Value *Step, StringRef Name,
+                              IRBuilderBase &B, DomTreeUpdater &DTU, Loop *L,
+                              LoopInfo &LI) {
+  LLVMContext &Ctx = Preheader->getContext();
+  BasicBlock *Header =
+      BasicBlock::Create(Ctx, Name + ".header", Preheader->getParent(), Exit);
+  BasicBlock *Body =
+      BasicBlock::Create(Ctx, Name + ".body", Header->getParent(), Exit);
+  BasicBlock *Latch =
+      BasicBlock::Create(Ctx, Name + ".latch", Header->getParent(), Exit);
+
+  Type *I16Ty = Type::getInt16Ty(Ctx);
+  BranchInst::Create(Body, Header);
+  BranchInst::Create(Latch, Body);
+  PHINode *IV =
+      PHINode::Create(I16Ty, 2, Name + ".iv", Header->getTerminator());
+  IV->addIncoming(ConstantInt::get(I16Ty, 0), Preheader);
+
+  B.SetInsertPoint(Latch);
+  Value *Inc = B.CreateAdd(IV, Step, Name + ".step");
+  Value *Cond = B.CreateICmpNE(Inc, Bound, Name + ".cond");
+  BranchInst::Create(Header, Exit, Cond, Latch);
+  IV->addIncoming(Inc, Latch);
+
+  BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator());
+  BasicBlock *Tmp = PreheaderBr->getSuccessor(0);
+  PreheaderBr->setSuccessor(0, Header);
+  DTU.applyUpdatesPermissive({
+      {DominatorTree::Delete, Preheader, Tmp},
+      {DominatorTree::Insert, Header, Body},
+      {DominatorTree::Insert, Body, Latch},
+      {DominatorTree::Insert, Latch, Header},
+      {DominatorTree::Insert, Latch, Exit},
+      {DominatorTree::Insert, Preheader, Header},
+  });
+
+  L->addBasicBlockToLoop(Header, LI);
+  L->addBasicBlockToLoop(Body, LI);
+  L->addBasicBlockToLoop(Latch, LI);
+  return Body;
+}
+
+template <bool IsTileLoad>
+static Value *createTileLoadStoreLoops(BasicBlock *Start, BasicBlock *End,
+                                       IRBuilderBase &B, DomTreeUpdater &DTU,
+                                       LoopInfo &LI, Value *Row, Value *Col,
+                                       Value *Ptr, Value *Stride, Value *Tile) {
+  std::string IntrinName = IsTileLoad ? "tileload" : "tilestore";
+  Loop *RowLoop = LI.AllocateLoop();
+  Loop *ColLoop = LI.AllocateLoop();
+  RowLoop->addChildLoop(ColLoop);
+  if (Loop *ParentL = LI.getLoopFor(Start))
+    ParentL->addChildLoop(RowLoop);
+  else
+    LI.addTopLevelLoop(RowLoop);
+
+  BasicBlock *RowBody =
+      createLoop(Start, End, Row, B.getInt16(1), IntrinName + ".scalarize.rows",
+                 B, DTU, RowLoop, LI);
+  BasicBlock *RowLatch = RowBody->getSingleSuccessor();
+
+  BasicBlock *ColBody =
+      createLoop(RowBody, RowLatch, Col, B.getInt16(1),
+                 IntrinName + ".scalarize.cols", B, DTU, ColLoop, LI);
+
+  BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor();
+  BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor();
+  BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor();
+  Value *CurrentRow = &*RowLoopHeader->begin();
+  Value *CurrentCol = &*ColLoopHeader->begin();
+  Type *EltTy = B.getInt32Ty();
+  FixedVectorType *V256I32Ty = FixedVectorType::get(EltTy, 256);
+
+  // Common part for tileload and tilestore
+  // *.scalarize.cols.body:
+  // Calculate %idxmem and %idxvec
+  B.SetInsertPoint(ColBody->getTerminator());
+  Value *CurrentRowZExt = B.CreateZExt(CurrentRow, Stride->getType());
+  Value *CurrentColZExt = B.CreateZExt(CurrentCol, Stride->getType());
+  Value *Offset =
+      B.CreateAdd(B.CreateMul(CurrentRowZExt, Stride), CurrentColZExt);
+  unsigned AS = cast<PointerType>(Ptr->getType())->getAddressSpace();
+  Value *EltBasePtr = B.CreatePointerCast(Ptr, PointerType::get(EltTy, AS));
+  Value *EltPtr = B.CreateGEP(EltTy, EltBasePtr, Offset);
+  Value *Idx = B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol);
+  if (IsTileLoad) {
+    // tileload.scalarize.rows.header:
+    // %vec.phi.row = phi <256 x i32> [ zeroinitializer, %entry ], [ %ResVec,
+    // %tileload.scalarize.rows.latch ]
+    B.SetInsertPoint(RowLoopHeader->getTerminator());
+    Value *VecZero = Constant::getNullValue(V256I32Ty);
+    PHINode *VecCPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.phi.row");
+    VecCPhiRowLoop->addIncoming(VecZero, Start);
+
+    // tileload.scalarize.cols.header:
+    // %vec.phi = phi <256 x i32> [ %vec.phi.row, %tileload.scalarize.rows.body
+    // ], [ %ResVec, %tileload.scalarize.cols.latch ]
+    B.SetInsertPoint(ColLoopHeader->getTerminator());
+    PHINode *VecPhi = B.CreatePHI(V256I32Ty, 2, "vec.phi");
+    VecPhi->addIncoming(VecCPhiRowLoop, RowBody);
+
+    // tileload.scalarize.cols.body:
+    // Calculate %idxmem and %idxvec
+    // %eltptr = getelementptr i32, i32* %base, i64 %idxmem
+    // %elt = load i32, i32* %ptr
+    // %ResVec = insertelement <256 x i32> %vec.phi, i32 %elt, i16 %idxvec
+    B.SetInsertPoint(ColBody->getTerminator());
+    Value *Elt = B.CreateLoad(EltTy, EltPtr);
+    Value *ResVec = B.CreateInsertElement(VecPhi, Elt, Idx);
+    VecPhi->addIncoming(ResVec, ColLoopLatch);
+    VecCPhiRowLoop->addIncoming(ResVec, RowLatch);
+
+    return ResVec;
+  } else {
+    auto *BitCast = cast<BitCastInst>(Tile);
+    Value *Vec = BitCast->getOperand(0);
+    assert(isV256I32Ty(Vec->getType()) && "bitcast from non-v256i32 to x86amx");
+    // tilestore.scalarize.cols.body:
+    // %mul = mul i16 %row.iv, i16 16
+    // %idx = add i16 %mul, i16 %col.iv
+    // %vec = extractelement <16 x i32> %vec, i16 %idx
+    // store i32 %vec, i32* %ptr
+    B.SetInsertPoint(ColBody->getTerminator());
+    Value *Elt = B.CreateExtractElement(Vec, Idx);
+
+    B.CreateStore(Elt, EltPtr);
+    return nullptr;
+  }
+}
+
+static Value *createTileDPBSSDLoops(BasicBlock *Start, BasicBlock *End,
+                                    IRBuilderBase &B, DomTreeUpdater &DTU,
+                                    LoopInfo &LI, Value *Row, Value *Col,
+                                    Value *K, Value *Acc, Value *LHS,
+                                    Value *RHS) {
+  Loop *RowLoop = LI.AllocateLoop();
+  Loop *ColLoop = LI.AllocateLoop();
+  Loop *InnerLoop = LI.AllocateLoop();
+  ColLoop->addChildLoop(InnerLoop);
+  RowLoop->addChildLoop(ColLoop);
+  if (Loop *ParentL = LI.getLoopFor(Start))
+    ParentL->addChildLoop(RowLoop);
+  else
+    LI.addTopLevelLoop(RowLoop);
+
+  BasicBlock *RowBody =
+      createLoop(Start, End, Row, B.getInt16(1), "tiledpbssd.scalarize.rows", B,
+                 DTU, RowLoop, LI);
+  BasicBlock *RowLatch = RowBody->getSingleSuccessor();
+
+  BasicBlock *ColBody =
+      createLoop(RowBody, RowLatch, Col, B.getInt16(1),
+                 "tiledpbssd.scalarize.cols", B, DTU, ColLoop, LI);
+  BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor();
+
+  B.SetInsertPoint(ColBody->getTerminator());
+  BasicBlock *InnerBody =
+      createLoop(ColBody, ColLoopLatch, K, B.getInt16(1),
+                 "tiledpbssd.scalarize.inner", B, DTU, InnerLoop, LI);
+
+  BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor();
+  BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor();
+  BasicBlock *InnerLoopHeader = InnerBody->getSinglePredecessor();
+  BasicBlock *InnerLoopLatch = InnerBody->getSingleSuccessor();
+  Value *CurrentRow = &*RowLoopHeader->begin();
+  Value *CurrentCol = &*ColLoopHeader->begin();
+  Value *CurrentInner = &*InnerLoopHeader->begin();
+
+  FixedVectorType *V256I32Ty = FixedVectorType::get(B.getInt32Ty(), 256);
+  auto *BitCastAcc = cast<BitCastInst>(Acc);
+  Value *VecC = BitCastAcc->getOperand(0);
+  assert(isV256I32Ty(VecC->getType()) && "bitcast from non-v256i32 to x86amx");
+  // TODO else create BitCast from x86amx to v256i32.
+  // Store x86amx to memory, and reload from memory
+  // to vector. However with -O0, it doesn't happen.
+  auto *BitCastLHS = cast<BitCastInst>(LHS);
+  Value *VecA = BitCastLHS->getOperand(0);
+  assert(isV256I32Ty(VecA->getType()) && "bitcast from non-v256i32 to x86amx");
+  auto *BitCastRHS = cast<BitCastInst>(RHS);
+  Value *VecB = BitCastRHS->getOperand(0);
+  assert(isV256I32Ty(VecB->getType()) && "bitcast from non-v256i32 to x86amx");
+
+  // tiledpbssd.scalarize.rows.header:
+  // %vec.c.phi.row = phi <256 x i32> [ %VecC, %continue ], [ %NewVecC,
+  // %tiledpbssd.scalarize.rows.latch ]
+
+  // %vec.d.phi.row = phi <256 x i32> [ zeroinitializer, %continue ], [
+  // %NewVecD, %tiledpbssd.scalarize.rows.latch ]
+  B.SetInsertPoint(RowLoopHeader->getTerminator());
+  PHINode *VecCPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.c.phi.row");
+  VecCPhiRowLoop->addIncoming(VecC, Start);
+  Value *VecZero = Constant::getNullValue(V256I32Ty);
+  PHINode *VecDPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.d.phi.row");
+  VecDPhiRowLoop->addIncoming(VecZero, Start);
+
+  // tiledpbssd.scalarize.cols.header:
+  // %vec.c.phi.col = phi <256 x i32> [ %vec.c.phi.row,
+  // %tiledpbssd.scalarize.rows.body ], [ %NewVecC,
+  // %tiledpbssd.scalarize.cols.latch ]
+
+  // %vec.d.phi.col = phi <256 x i32> [
+  // %vec.d.phi.row, %tiledpbssd.scalarize.rows.body ], [ %NewVecD,
+  // %tiledpbssd.scalarize.cols.latch ]
+
+  // calculate idxc.
+  B.SetInsertPoint(ColLoopHeader->getTerminator());
+  PHINode *VecCPhiColLoop = B.CreatePHI(V256I32Ty, 2, "vec.c.phi.col");
+  VecCPhiColLoop->addIncoming(VecCPhiRowLoop, RowBody);
+  PHINode *VecDPhiColLoop = B.CreatePHI(V256I32Ty, 2, "vec.d.phi.col");
+  VecDPhiColLoop->addIncoming(VecDPhiRowLoop, RowBody);
+  Value *IdxC =
+      B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol);
+
+  // tiledpbssd.scalarize.inner.header:
+  // %vec.c.inner.phi = phi <256 x i32> [ %vec.c.phi.col,
+  // %tiledpbssd.scalarize.cols.body ], [ %NewVecC,
+  // %tiledpbssd.scalarize.inner.latch ]
+
+  B.SetInsertPoint(InnerLoopHeader->getTerminator());
+  PHINode *VecCPhi = B.CreatePHI(V256I32Ty, 2, "vec.c.inner.phi");
+  VecCPhi->addIncoming(VecCPhiColLoop, ColBody);
+
+  // tiledpbssd.scalarize.inner.body:
+  // calculate idxa, idxb
+  // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc
+  // %elta = extractelement <256 x i32> %veca, i16 %idxa
+  // %eltav4i8 = bitcast i32 %elta to <4 x i8>
+  // %eltb = extractelement <256 x i32> %vecb, i16 %idxb
+  // %eltbv4i8 = bitcast i32 %eltb to <4 x i8>
+  // %eltav4i32 = sext <4 x i8> %eltav4i8 to <4 x i32>
+  // %eltbv4i32 = sext <4 x i8> %eltbv4i8 to <4 x i32>
+  // %mulab = mul <4 x i32> %eltbv4i32, %eltav4i32
+  // %acc = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %131)
+  // %neweltc = add i32 %elt, %acc
+  // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc,
+  // i16 %idxc
+
+  B.SetInsertPoint(InnerBody->getTerminator());
+  Value *IdxA =
+      B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentInner);
+  Value *IdxB =
+      B.CreateAdd(B.CreateMul(CurrentInner, B.getInt16(16)), CurrentCol);
+
+  FixedVectorType *V4I8Ty = FixedVectorType::get(B.getInt8Ty(), 4);
+  FixedVectorType *V4I32Ty = FixedVectorType::get(B.getInt32Ty(), 4);
+  Value *EltC = B.CreateExtractElement(VecCPhi, IdxC);
+  Value *EltA = B.CreateExtractElement(VecA, IdxA);
+  Value *SubVecA = B.CreateBitCast(EltA, V4I8Ty);
+  Value *EltB = B.CreateExtractElement(VecB, IdxB);
+  Value *SubVecB = B.CreateBitCast(EltB, V4I8Ty);
+  Value *SubVecR = B.CreateAddReduce(B.CreateMul(
+      B.CreateSExt(SubVecA, V4I32Ty), B.CreateSExt(SubVecB, V4I32Ty)));
+  Value *ResElt = B.CreateAdd(EltC, SubVecR);
+  Value *NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC);
+
+  // tiledpbssd.scalarize.cols.latch:
+  // %NewEltC = extractelement <256 x i32> %vec.c.phi.col, i16 %idxc
+  // %NewVecD = insertelement <256 x i32> %vec.d.phi.col, i32 %NewEltC,
+  // i16 %idxc
+  B.SetInsertPoint(ColLoopLatch->getTerminator());
+  Value *NewEltC = B.CreateExtractElement(NewVecC, IdxC);
+  Value *NewVecD = B.CreateInsertElement(VecDPhiColLoop, NewEltC, IdxC);
+
+  VecCPhi->addIncoming(NewVecC, InnerLoopLatch);
+  VecCPhiRowLoop->addIncoming(NewVecC, RowLatch);
+  VecCPhiColLoop->addIncoming(NewVecC, ColLoopLatch);
+  VecDPhiRowLoop->addIncoming(NewVecD, RowLatch);
+  VecDPhiColLoop->addIncoming(NewVecD, ColLoopLatch);
+
+  return NewVecD;
+}
+
+namespace {
+class X86LowerAMXIntrinsics {
+  Function &Func;
+
+public:
+  X86LowerAMXIntrinsics(Function &F, DominatorTree *DT, LoopInfo *LI)
+      : Func(F), DT(DT), LI(LI) {}
+  bool visit();
+
+private:
+  DominatorTree *DT;
+  LoopInfo *LI;
+  template <bool IsTileLoad>
+  bool lowerTileLoadStore(Instruction *TileLoadStore);
+  bool lowerTileDPBSSD(Instruction *TileDPBSSD);
+  bool lowerTileZero(Instruction *TileZero);
+};
+
+bool X86LowerAMXIntrinsics::lowerTileDPBSSD(Instruction *TileDPBSSD) {
+  Value *M, *N, *K, *C, *A, *B;
+  match(TileDPBSSD, m_Intrinsic<Intrinsic::x86_tdpbssd_internal>(
+                        m_Value(M), m_Value(N), m_Value(K), m_Value(C),
+                        m_Value(A), m_Value(B)));
+  DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
+  Instruction *InsertI = TileDPBSSD;
+  IRBuilder<> PreBuilder(TileDPBSSD);
+  PreBuilder.SetInsertPoint(TileDPBSSD);
+  // We visit the loop with (m, n/4, k/4):
+  // %n_dword = lshr i16 %n, 2
+  // %k_dword = lshr i16 %k, 2
+  Value *NDWord = PreBuilder.CreateLShr(N, PreBuilder.getInt16(2));
+  Value *KDWord = PreBuilder.CreateLShr(K, PreBuilder.getInt16(2));
+  BasicBlock *Start = InsertI->getParent();
+  BasicBlock *End =
+      SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue");
+  IRBuilder<> Builder(TileDPBSSD);
+  Value *ResVec = createTileDPBSSDLoops(Start, End, Builder, DTU, *LI, M,
+                                        NDWord, KDWord, C, A, B);
+  // we cannot assume there always be bitcast after tiledpbssd. So we need to
+  // insert one bitcast as required
+  Builder.SetInsertPoint(End->getFirstNonPHI());
+  Value *ResAMX =
+      Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext()));
+  // Delete tiledpbssd intrinsic and do some clean-up.
+  for (auto UI = TileDPBSSD->use_begin(), UE = TileDPBSSD->use_end();
+       UI != UE;) {
+    Instruction *I = cast<Instruction>((UI++)->getUser());
+    Value *Vec;
+    if (match(I, m_BitCast(m_Value(Vec)))) {
+      I->replaceAllUsesWith(ResVec);
+      I->eraseFromParent();
+    }
+  }
+  TileDPBSSD->replaceAllUsesWith(ResAMX);
+  TileDPBSSD->eraseFromParent();
+  return true;
+}
+
+template <bool IsTileLoad>
+bool X86LowerAMXIntrinsics::lowerTileLoadStore(Instruction *TileLoadStore) {
+  Value *M, *N, *Ptr, *Stride, *Tile;
+  if (IsTileLoad)
+    match(TileLoadStore,
+          m_Intrinsic<Intrinsic::x86_tileloadd64_internal>(
+              m_Value(M), m_Value(N), m_Value(Ptr), m_Value(Stride)));
+  else
+    match(TileLoadStore, m_Intrinsic<Intrinsic::x86_tilestored64_internal>(
+                             m_Value(M), m_Value(N), m_Value(Ptr),
+                             m_Value(Stride), m_Value(Tile)));
+
+  DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
+  Instruction *InsertI = TileLoadStore;
+  IRBuilder<> PreBuilder(TileLoadStore);
+  PreBuilder.SetInsertPoint(TileLoadStore);
+  Value *NDWord = PreBuilder.CreateLShr(N, PreBuilder.getInt16(2));
+  Value *StrideDWord = PreBuilder.CreateLShr(Stride, PreBuilder.getInt64(2));
+  BasicBlock *Start = InsertI->getParent();
+  BasicBlock *End =
+      SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue");
+  IRBuilder<> Builder(TileLoadStore);
+  Value *ResVec = createTileLoadStoreLoops<IsTileLoad>(
+      Start, End, Builder, DTU, *LI, M, NDWord, Ptr, StrideDWord,
+      IsTileLoad ? nullptr : Tile);
+  if (IsTileLoad) {
+    // we cannot assume there always be bitcast after tileload. So we need to
+    // insert one bitcast as required
+    Builder.SetInsertPoint(End->getFirstNonPHI());
+    Value *ResAMX =
+        Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext()));
+    // Delete tileloadd6 intrinsic and do some clean-up
+    for (auto UI = TileLoadStore->use_begin(), UE = TileLoadStore->use_end();
+         UI != UE;) {
+      Instruction *I = cast<Instruction>((UI++)->getUser());
+      Value *Vec;
+      if (match(I, m_BitCast(m_Value(Vec)))) {
+        I->replaceAllUsesWith(ResVec);
+        I->eraseFromParent();
+      }
+    }
+    TileLoadStore->replaceAllUsesWith(ResAMX);
+  }
+  TileLoadStore->eraseFromParent();
+  return true;
+}
+
+bool X86LowerAMXIntrinsics::lowerTileZero(Instruction *TileZero) {
+  IRBuilder<> Builder(TileZero);
+  FixedVectorType *V256I32Ty = FixedVectorType::get(Builder.getInt32Ty(), 256);
+  Value *VecZero = Constant::getNullValue(V256I32Ty);
+  for (auto UI = TileZero->use_begin(), UE = TileZero->use_end(); UI != UE;) {
+    Instruction *I = cast<Instruction>((UI++)->getUser());
+    Value *Vec;
+    if (match(I, m_BitCast(m_Value(Vec)))) {
+      I->replaceAllUsesWith(VecZero);
+      I->eraseFromParent();
+    }
+  }
+  TileZero->eraseFromParent();
+  return true;
+}
+
+bool X86LowerAMXIntrinsics::visit() {
+  bool C = false;
+  SmallVector<IntrinsicInst *, 8> WorkList;
+  for (BasicBlock *BB : depth_first(&Func)) {
+    for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) {
+      if (auto *Inst = dyn_cast<IntrinsicInst>(&*II++)) {
+        switch (Inst->getIntrinsicID()) {
+        case Intrinsic::x86_tdpbssd_internal:
+        case Intrinsic::x86_tileloadd64_internal:
+        case Intrinsic::x86_tilestored64_internal:
+        case Intrinsic::x86_tilezero_internal:
+          WorkList.push_back(Inst);
+          break;
+        default:
+          break;
+        }
+      }
+    }
+  }
+
+  for (auto *Inst : WorkList) {
+    switch (Inst->getIntrinsicID()) {
+    case Intrinsic::x86_tdpbssd_internal:
+      C = lowerTileDPBSSD(Inst) || C;
+      break;
+    case Intrinsic::x86_tileloadd64_internal:
+      C = lowerTileLoadStore<true>(Inst) || C;
+      break;
+    case Intrinsic::x86_tilestored64_internal:
+      C = lowerTileLoadStore<false>(Inst) || C;
+      break;
+    case Intrinsic::x86_tilezero_internal:
+      C = lowerTileZero(Inst) || C;
+      break;
+    default:
+      llvm_unreachable("invalid amx intrinsics!");
+    }
+  }
+
+  return C;
+}
+} // anonymous namespace
+
+namespace {
+
+class X86LowerAMXIntrinsicsLegacyPass : public FunctionPass {
+public:
+  static char ID;
+
+  X86LowerAMXIntrinsicsLegacyPass() : FunctionPass(ID) {
+    initializeX86LowerAMXIntrinsicsLegacyPassPass(
+        *PassRegistry::getPassRegistry());
+  }
+
+  bool runOnFunction(Function &F) override {
+    TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
+    if (!F.hasFnAttribute(Attribute::OptimizeNone) &&
+        TM->getOptLevel() != CodeGenOpt::None)
+      return false;
+
+    auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+    auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
+
+    X86LowerAMXIntrinsics LAT(F, &DT, &LI);
+    return LAT.visit();
+  }
+  StringRef getPassName() const override { return "Lower AMX intrinsics"; }
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.addRequired<DominatorTreeWrapperPass>();
+    AU.addPreserved<DominatorTreeWrapperPass>();
+    AU.addRequired<LoopInfoWrapperPass>();
+    AU.addPreserved<LoopInfoWrapperPass>();
+    AU.addRequired<TargetPassConfig>();
+  }
+};
+
+} // anonymous namespace
+
+static const char PassName[] = "Lower AMX intrinsics";
+char X86LowerAMXIntrinsicsLegacyPass::ID = 0;
+INITIALIZE_PASS_BEGIN(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName,
+                      false, false)
+INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
+INITIALIZE_PASS_END(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName,
+                    false, false)
+
+FunctionPass *llvm::createX86LowerAMXIntrinsicsPass() {
+  return new X86LowerAMXIntrinsicsLegacyPass();
+}
index 5e844a0..2150a9d 100644 (file)
@@ -23,6 +23,7 @@
 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/CodeGen/Passes.h"
+#include "llvm/CodeGen/TargetPassConfig.h"
 #include "llvm/CodeGen/ValueTypes.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/Function.h"
@@ -33,6 +34,7 @@
 #include "llvm/IR/PatternMatch.h"
 #include "llvm/InitializePasses.h"
 #include "llvm/Pass.h"
+#include "llvm/Target/TargetMachine.h"
 
 using namespace llvm;
 using namespace PatternMatch;
@@ -331,6 +333,10 @@ public:
   }
 
   bool runOnFunction(Function &F) override {
+    TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
+    if (F.hasFnAttribute(Attribute::OptimizeNone) ||
+        TM->getOptLevel() == CodeGenOpt::None)
+      return false;
     X86LowerAMXType LAT(F);
     bool C = LAT.visit();
     return C;
@@ -338,6 +344,7 @@ public:
 
   void getAnalysisUsage(AnalysisUsage &AU) const override {
     AU.setPreservesCFG();
+    AU.addRequired<TargetPassConfig>();
   }
 };
 
@@ -347,6 +354,7 @@ static const char PassName[] = "Lower AMX type for load/store";
 char X86LowerAMXTypeLegacyPass::ID = 0;
 INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
                       false)
+INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
 INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
                     false)
 
index ddbfc6b..0f19acc 100644 (file)
@@ -62,6 +62,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeX86Target() {
   RegisterTargetMachine<X86TargetMachine> Y(getTheX86_64Target());
 
   PassRegistry &PR = *PassRegistry::getPassRegistry();
+  initializeX86LowerAMXIntrinsicsLegacyPassPass(PR);
   initializeX86LowerAMXTypeLegacyPassPass(PR);
   initializeGlobalISel(PR);
   initializeWinEHStatePassPass(PR);
@@ -411,6 +412,10 @@ TargetPassConfig *X86TargetMachine::createPassConfig(PassManagerBase &PM) {
 
 void X86PassConfig::addIRPasses() {
   addPass(createAtomicExpandPass());
+
+  // We add both pass anyway and when these two passes run, we skip the pass
+  // based on the option level and option attribute.
+  addPass(createX86LowerAMXIntrinsicsPass());
   addPass(createX86LowerAMXTypePass());
 
   TargetPassConfig::addIRPasses();
diff --git a/llvm/test/CodeGen/X86/AMX/amx-low-intrinsics-no-amx-bitcast.ll b/llvm/test/CodeGen/X86/AMX/amx-low-intrinsics-no-amx-bitcast.ll
new file mode 100644 (file)
index 0000000..1145ff7
--- /dev/null
@@ -0,0 +1,211 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -mtriple=x86_64 -lower-amx-intrinsics %s -S | FileCheck %s
+
+define dso_local void @test_no_bitcast(i32* %A_mem, i32* %B_mem, i32* %C_mem) local_unnamed_addr #0 {
+; CHECK-LABEL: @test_no_bitcast(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = bitcast i32* [[C_MEM:%.*]] to i8*
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_ROWS_HEADER:%.*]]
+; CHECK:       tileload.scalarize.rows.header:
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_ROWS_IV:%.*]] = phi i16 [ 0, [[ENTRY:%.*]] ], [ [[TILELOAD_SCALARIZE_ROWS_STEP:%.*]], [[TILELOAD_SCALARIZE_ROWS_LATCH:%.*]] ]
+; CHECK-NEXT:    [[VEC_PHI_ROW:%.*]] = phi <256 x i32> [ zeroinitializer, [[ENTRY]] ], [ [[TMP10:%.*]], [[TILELOAD_SCALARIZE_ROWS_LATCH]] ]
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_ROWS_BODY:%.*]]
+; CHECK:       tileload.scalarize.rows.body:
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_COLS_HEADER:%.*]]
+; CHECK:       tileload.scalarize.cols.header:
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_COLS_IV:%.*]] = phi i16 [ 0, [[TILELOAD_SCALARIZE_ROWS_BODY]] ], [ [[TILELOAD_SCALARIZE_COLS_STEP:%.*]], [[TILELOAD_SCALARIZE_COLS_LATCH:%.*]] ]
+; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <256 x i32> [ [[VEC_PHI_ROW]], [[TILELOAD_SCALARIZE_ROWS_BODY]] ], [ [[TMP10]], [[TILELOAD_SCALARIZE_COLS_LATCH]] ]
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_COLS_BODY:%.*]]
+; CHECK:       tileload.scalarize.cols.body:
+; CHECK-NEXT:    [[TMP1:%.*]] = zext i16 [[TILELOAD_SCALARIZE_ROWS_IV]] to i64
+; CHECK-NEXT:    [[TMP2:%.*]] = zext i16 [[TILELOAD_SCALARIZE_COLS_IV]] to i64
+; CHECK-NEXT:    [[TMP3:%.*]] = mul i64 [[TMP1]], 4
+; CHECK-NEXT:    [[TMP4:%.*]] = add i64 [[TMP3]], [[TMP2]]
+; CHECK-NEXT:    [[TMP5:%.*]] = bitcast i8* [[TMP0]] to i32*
+; CHECK-NEXT:    [[TMP6:%.*]] = getelementptr i32, i32* [[TMP5]], i64 [[TMP4]]
+; CHECK-NEXT:    [[TMP7:%.*]] = mul i16 [[TILELOAD_SCALARIZE_ROWS_IV]], 16
+; CHECK-NEXT:    [[TMP8:%.*]] = add i16 [[TMP7]], [[TILELOAD_SCALARIZE_COLS_IV]]
+; CHECK-NEXT:    [[TMP9:%.*]] = load i32, i32* [[TMP6]], align 4
+; CHECK-NEXT:    [[TMP10]] = insertelement <256 x i32> [[VEC_PHI]], i32 [[TMP9]], i16 [[TMP8]]
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_COLS_LATCH]]
+; CHECK:       tileload.scalarize.cols.latch:
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_COLS_STEP]] = add i16 [[TILELOAD_SCALARIZE_COLS_IV]], 1
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_COLS_COND:%.*]] = icmp ne i16 [[TILELOAD_SCALARIZE_COLS_STEP]], 4
+; CHECK-NEXT:    br i1 [[TILELOAD_SCALARIZE_COLS_COND]], label [[TILELOAD_SCALARIZE_COLS_HEADER]], label [[TILELOAD_SCALARIZE_ROWS_LATCH]]
+; CHECK:       tileload.scalarize.rows.latch:
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_ROWS_STEP]] = add i16 [[TILELOAD_SCALARIZE_ROWS_IV]], 1
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_ROWS_COND:%.*]] = icmp ne i16 [[TILELOAD_SCALARIZE_ROWS_STEP]], 4
+; CHECK-NEXT:    br i1 [[TILELOAD_SCALARIZE_ROWS_COND]], label [[TILELOAD_SCALARIZE_ROWS_HEADER]], label [[CONTINUE:%.*]]
+; CHECK:       continue:
+; CHECK-NEXT:    [[TMP11:%.*]] = bitcast <256 x i32> [[TMP10]] to x86_amx
+; CHECK-NEXT:    [[TMP12:%.*]] = bitcast i32* [[A_MEM:%.*]] to i8*
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_ROWS_HEADER2:%.*]]
+; CHECK:       tileload.scalarize.rows.header2:
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_ROWS_IV5:%.*]] = phi i16 [ 0, [[CONTINUE]] ], [ [[TILELOAD_SCALARIZE_ROWS_STEP6:%.*]], [[TILELOAD_SCALARIZE_ROWS_LATCH4:%.*]] ]
+; CHECK-NEXT:    [[VEC_PHI_ROW14:%.*]] = phi <256 x i32> [ zeroinitializer, [[CONTINUE]] ], [ [[TMP22:%.*]], [[TILELOAD_SCALARIZE_ROWS_LATCH4]] ]
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_ROWS_BODY3:%.*]]
+; CHECK:       tileload.scalarize.rows.body3:
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_COLS_HEADER8:%.*]]
+; CHECK:       tileload.scalarize.cols.header8:
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_COLS_IV11:%.*]] = phi i16 [ 0, [[TILELOAD_SCALARIZE_ROWS_BODY3]] ], [ [[TILELOAD_SCALARIZE_COLS_STEP12:%.*]], [[TILELOAD_SCALARIZE_COLS_LATCH10:%.*]] ]
+; CHECK-NEXT:    [[VEC_PHI15:%.*]] = phi <256 x i32> [ [[VEC_PHI_ROW14]], [[TILELOAD_SCALARIZE_ROWS_BODY3]] ], [ [[TMP22]], [[TILELOAD_SCALARIZE_COLS_LATCH10]] ]
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_COLS_BODY9:%.*]]
+; CHECK:       tileload.scalarize.cols.body9:
+; CHECK-NEXT:    [[TMP13:%.*]] = zext i16 [[TILELOAD_SCALARIZE_ROWS_IV5]] to i64
+; CHECK-NEXT:    [[TMP14:%.*]] = zext i16 [[TILELOAD_SCALARIZE_COLS_IV11]] to i64
+; CHECK-NEXT:    [[TMP15:%.*]] = mul i64 [[TMP13]], 4
+; CHECK-NEXT:    [[TMP16:%.*]] = add i64 [[TMP15]], [[TMP14]]
+; CHECK-NEXT:    [[TMP17:%.*]] = bitcast i8* [[TMP12]] to i32*
+; CHECK-NEXT:    [[TMP18:%.*]] = getelementptr i32, i32* [[TMP17]], i64 [[TMP16]]
+; CHECK-NEXT:    [[TMP19:%.*]] = mul i16 [[TILELOAD_SCALARIZE_ROWS_IV5]], 16
+; CHECK-NEXT:    [[TMP20:%.*]] = add i16 [[TMP19]], [[TILELOAD_SCALARIZE_COLS_IV11]]
+; CHECK-NEXT:    [[TMP21:%.*]] = load i32, i32* [[TMP18]], align 4
+; CHECK-NEXT:    [[TMP22]] = insertelement <256 x i32> [[VEC_PHI15]], i32 [[TMP21]], i16 [[TMP20]]
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_COLS_LATCH10]]
+; CHECK:       tileload.scalarize.cols.latch10:
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_COLS_STEP12]] = add i16 [[TILELOAD_SCALARIZE_COLS_IV11]], 1
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_COLS_COND13:%.*]] = icmp ne i16 [[TILELOAD_SCALARIZE_COLS_STEP12]], 4
+; CHECK-NEXT:    br i1 [[TILELOAD_SCALARIZE_COLS_COND13]], label [[TILELOAD_SCALARIZE_COLS_HEADER8]], label [[TILELOAD_SCALARIZE_ROWS_LATCH4]]
+; CHECK:       tileload.scalarize.rows.latch4:
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_ROWS_STEP6]] = add i16 [[TILELOAD_SCALARIZE_ROWS_IV5]], 1
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_ROWS_COND7:%.*]] = icmp ne i16 [[TILELOAD_SCALARIZE_ROWS_STEP6]], 4
+; CHECK-NEXT:    br i1 [[TILELOAD_SCALARIZE_ROWS_COND7]], label [[TILELOAD_SCALARIZE_ROWS_HEADER2]], label [[CONTINUE1:%.*]]
+; CHECK:       continue1:
+; CHECK-NEXT:    [[TMP23:%.*]] = bitcast <256 x i32> [[TMP22]] to x86_amx
+; CHECK-NEXT:    [[TMP24:%.*]] = bitcast i32* [[B_MEM:%.*]] to i8*
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_ROWS_HEADER17:%.*]]
+; CHECK:       tileload.scalarize.rows.header17:
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_ROWS_IV20:%.*]] = phi i16 [ 0, [[CONTINUE1]] ], [ [[TILELOAD_SCALARIZE_ROWS_STEP21:%.*]], [[TILELOAD_SCALARIZE_ROWS_LATCH19:%.*]] ]
+; CHECK-NEXT:    [[VEC_PHI_ROW29:%.*]] = phi <256 x i32> [ zeroinitializer, [[CONTINUE1]] ], [ [[TMP34:%.*]], [[TILELOAD_SCALARIZE_ROWS_LATCH19]] ]
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_ROWS_BODY18:%.*]]
+; CHECK:       tileload.scalarize.rows.body18:
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_COLS_HEADER23:%.*]]
+; CHECK:       tileload.scalarize.cols.header23:
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_COLS_IV26:%.*]] = phi i16 [ 0, [[TILELOAD_SCALARIZE_ROWS_BODY18]] ], [ [[TILELOAD_SCALARIZE_COLS_STEP27:%.*]], [[TILELOAD_SCALARIZE_COLS_LATCH25:%.*]] ]
+; CHECK-NEXT:    [[VEC_PHI30:%.*]] = phi <256 x i32> [ [[VEC_PHI_ROW29]], [[TILELOAD_SCALARIZE_ROWS_BODY18]] ], [ [[TMP34]], [[TILELOAD_SCALARIZE_COLS_LATCH25]] ]
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_COLS_BODY24:%.*]]
+; CHECK:       tileload.scalarize.cols.body24:
+; CHECK-NEXT:    [[TMP25:%.*]] = zext i16 [[TILELOAD_SCALARIZE_ROWS_IV20]] to i64
+; CHECK-NEXT:    [[TMP26:%.*]] = zext i16 [[TILELOAD_SCALARIZE_COLS_IV26]] to i64
+; CHECK-NEXT:    [[TMP27:%.*]] = mul i64 [[TMP25]], 4
+; CHECK-NEXT:    [[TMP28:%.*]] = add i64 [[TMP27]], [[TMP26]]
+; CHECK-NEXT:    [[TMP29:%.*]] = bitcast i8* [[TMP24]] to i32*
+; CHECK-NEXT:    [[TMP30:%.*]] = getelementptr i32, i32* [[TMP29]], i64 [[TMP28]]
+; CHECK-NEXT:    [[TMP31:%.*]] = mul i16 [[TILELOAD_SCALARIZE_ROWS_IV20]], 16
+; CHECK-NEXT:    [[TMP32:%.*]] = add i16 [[TMP31]], [[TILELOAD_SCALARIZE_COLS_IV26]]
+; CHECK-NEXT:    [[TMP33:%.*]] = load i32, i32* [[TMP30]], align 4
+; CHECK-NEXT:    [[TMP34]] = insertelement <256 x i32> [[VEC_PHI30]], i32 [[TMP33]], i16 [[TMP32]]
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_COLS_LATCH25]]
+; CHECK:       tileload.scalarize.cols.latch25:
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_COLS_STEP27]] = add i16 [[TILELOAD_SCALARIZE_COLS_IV26]], 1
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_COLS_COND28:%.*]] = icmp ne i16 [[TILELOAD_SCALARIZE_COLS_STEP27]], 4
+; CHECK-NEXT:    br i1 [[TILELOAD_SCALARIZE_COLS_COND28]], label [[TILELOAD_SCALARIZE_COLS_HEADER23]], label [[TILELOAD_SCALARIZE_ROWS_LATCH19]]
+; CHECK:       tileload.scalarize.rows.latch19:
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_ROWS_STEP21]] = add i16 [[TILELOAD_SCALARIZE_ROWS_IV20]], 1
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_ROWS_COND22:%.*]] = icmp ne i16 [[TILELOAD_SCALARIZE_ROWS_STEP21]], 4
+; CHECK-NEXT:    br i1 [[TILELOAD_SCALARIZE_ROWS_COND22]], label [[TILELOAD_SCALARIZE_ROWS_HEADER17]], label [[CONTINUE16:%.*]]
+; CHECK:       continue16:
+; CHECK-NEXT:    [[TMP35:%.*]] = bitcast <256 x i32> [[TMP34]] to x86_amx
+; CHECK-NEXT:    br label [[TILEDPBSSD_SCALARIZE_ROWS_HEADER:%.*]]
+; CHECK:       tiledpbssd.scalarize.rows.header:
+; CHECK-NEXT:    [[TILEDPBSSD_SCALARIZE_ROWS_IV:%.*]] = phi i16 [ 0, [[CONTINUE16]] ], [ [[TILEDPBSSD_SCALARIZE_ROWS_STEP:%.*]], [[TILEDPBSSD_SCALARIZE_ROWS_LATCH:%.*]] ]
+; CHECK-NEXT:    [[VEC_C_PHI_ROW:%.*]] = phi <256 x i32> [ [[TMP10]], [[CONTINUE16]] ], [ [[TMP52:%.*]], [[TILEDPBSSD_SCALARIZE_ROWS_LATCH]] ]
+; CHECK-NEXT:    [[VEC_D_PHI_ROW:%.*]] = phi <256 x i32> [ zeroinitializer, [[CONTINUE16]] ], [ [[TMP54:%.*]], [[TILEDPBSSD_SCALARIZE_ROWS_LATCH]] ]
+; CHECK-NEXT:    br label [[TILEDPBSSD_SCALARIZE_ROWS_BODY:%.*]]
+; CHECK:       tiledpbssd.scalarize.rows.body:
+; CHECK-NEXT:    br label [[TILEDPBSSD_SCALARIZE_COLS_HEADER:%.*]]
+; CHECK:       tiledpbssd.scalarize.cols.header:
+; CHECK-NEXT:    [[TILEDPBSSD_SCALARIZE_COLS_IV:%.*]] = phi i16 [ 0, [[TILEDPBSSD_SCALARIZE_ROWS_BODY]] ], [ [[TILEDPBSSD_SCALARIZE_COLS_STEP:%.*]], [[TILEDPBSSD_SCALARIZE_COLS_LATCH:%.*]] ]
+; CHECK-NEXT:    [[VEC_C_PHI_COL:%.*]] = phi <256 x i32> [ [[VEC_C_PHI_ROW]], [[TILEDPBSSD_SCALARIZE_ROWS_BODY]] ], [ [[TMP52]], [[TILEDPBSSD_SCALARIZE_COLS_LATCH]] ]
+; CHECK-NEXT:    [[VEC_D_PHI_COL:%.*]] = phi <256 x i32> [ [[VEC_D_PHI_ROW]], [[TILEDPBSSD_SCALARIZE_ROWS_BODY]] ], [ [[TMP54]], [[TILEDPBSSD_SCALARIZE_COLS_LATCH]] ]
+; CHECK-NEXT:    [[TMP36:%.*]] = mul i16 [[TILEDPBSSD_SCALARIZE_ROWS_IV]], 16
+; CHECK-NEXT:    [[TMP37:%.*]] = add i16 [[TMP36]], [[TILEDPBSSD_SCALARIZE_COLS_IV]]
+; CHECK-NEXT:    br label [[TILEDPBSSD_SCALARIZE_COLS_BODY:%.*]]
+; CHECK:       tiledpbssd.scalarize.cols.body:
+; CHECK-NEXT:    br label [[TILEDPBSSD_SCALARIZE_INNER_HEADER:%.*]]
+; CHECK:       tiledpbssd.scalarize.inner.header:
+; CHECK-NEXT:    [[TILEDPBSSD_SCALARIZE_INNER_IV:%.*]] = phi i16 [ 0, [[TILEDPBSSD_SCALARIZE_COLS_BODY]] ], [ [[TILEDPBSSD_SCALARIZE_INNER_STEP:%.*]], [[TILEDPBSSD_SCALARIZE_INNER_LATCH:%.*]] ]
+; CHECK-NEXT:    [[VEC_C_INNER_PHI:%.*]] = phi <256 x i32> [ [[VEC_C_PHI_COL]], [[TILEDPBSSD_SCALARIZE_COLS_BODY]] ], [ [[TMP52]], [[TILEDPBSSD_SCALARIZE_INNER_LATCH]] ]
+; CHECK-NEXT:    br label [[TILEDPBSSD_SCALARIZE_INNER_BODY:%.*]]
+; CHECK:       tiledpbssd.scalarize.inner.body:
+; CHECK-NEXT:    [[TMP38:%.*]] = mul i16 [[TILEDPBSSD_SCALARIZE_ROWS_IV]], 16
+; CHECK-NEXT:    [[TMP39:%.*]] = add i16 [[TMP38]], [[TILEDPBSSD_SCALARIZE_INNER_IV]]
+; CHECK-NEXT:    [[TMP40:%.*]] = mul i16 [[TILEDPBSSD_SCALARIZE_INNER_IV]], 16
+; CHECK-NEXT:    [[TMP41:%.*]] = add i16 [[TMP40]], [[TILEDPBSSD_SCALARIZE_COLS_IV]]
+; CHECK-NEXT:    [[TMP42:%.*]] = extractelement <256 x i32> [[VEC_C_INNER_PHI]], i16 [[TMP37]]
+; CHECK-NEXT:    [[TMP43:%.*]] = extractelement <256 x i32> [[TMP22]], i16 [[TMP39]]
+; CHECK-NEXT:    [[TMP44:%.*]] = bitcast i32 [[TMP43]] to <4 x i8>
+; CHECK-NEXT:    [[TMP45:%.*]] = extractelement <256 x i32> [[TMP34]], i16 [[TMP41]]
+; CHECK-NEXT:    [[TMP46:%.*]] = bitcast i32 [[TMP45]] to <4 x i8>
+; CHECK-NEXT:    [[TMP47:%.*]] = sext <4 x i8> [[TMP46]] to <4 x i32>
+; CHECK-NEXT:    [[TMP48:%.*]] = sext <4 x i8> [[TMP44]] to <4 x i32>
+; CHECK-NEXT:    [[TMP49:%.*]] = mul <4 x i32> [[TMP48]], [[TMP47]]
+; CHECK-NEXT:    [[TMP50:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP49]])
+; CHECK-NEXT:    [[TMP51:%.*]] = add i32 [[TMP42]], [[TMP50]]
+; CHECK-NEXT:    [[TMP52]] = insertelement <256 x i32> [[VEC_C_INNER_PHI]], i32 [[TMP51]], i16 [[TMP37]]
+; CHECK-NEXT:    br label [[TILEDPBSSD_SCALARIZE_INNER_LATCH]]
+; CHECK:       tiledpbssd.scalarize.inner.latch:
+; CHECK-NEXT:    [[TILEDPBSSD_SCALARIZE_INNER_STEP]] = add i16 [[TILEDPBSSD_SCALARIZE_INNER_IV]], 1
+; CHECK-NEXT:    [[TILEDPBSSD_SCALARIZE_INNER_COND:%.*]] = icmp ne i16 [[TILEDPBSSD_SCALARIZE_INNER_STEP]], 4
+; CHECK-NEXT:    br i1 [[TILEDPBSSD_SCALARIZE_INNER_COND]], label [[TILEDPBSSD_SCALARIZE_INNER_HEADER]], label [[TILEDPBSSD_SCALARIZE_COLS_LATCH]]
+; CHECK:       tiledpbssd.scalarize.cols.latch:
+; CHECK-NEXT:    [[TILEDPBSSD_SCALARIZE_COLS_STEP]] = add i16 [[TILEDPBSSD_SCALARIZE_COLS_IV]], 1
+; CHECK-NEXT:    [[TILEDPBSSD_SCALARIZE_COLS_COND:%.*]] = icmp ne i16 [[TILEDPBSSD_SCALARIZE_COLS_STEP]], 4
+; CHECK-NEXT:    [[TMP53:%.*]] = extractelement <256 x i32> [[TMP52]], i16 [[TMP37]]
+; CHECK-NEXT:    [[TMP54]] = insertelement <256 x i32> [[VEC_D_PHI_COL]], i32 [[TMP53]], i16 [[TMP37]]
+; CHECK-NEXT:    br i1 [[TILEDPBSSD_SCALARIZE_COLS_COND]], label [[TILEDPBSSD_SCALARIZE_COLS_HEADER]], label [[TILEDPBSSD_SCALARIZE_ROWS_LATCH]]
+; CHECK:       tiledpbssd.scalarize.rows.latch:
+; CHECK-NEXT:    [[TILEDPBSSD_SCALARIZE_ROWS_STEP]] = add i16 [[TILEDPBSSD_SCALARIZE_ROWS_IV]], 1
+; CHECK-NEXT:    [[TILEDPBSSD_SCALARIZE_ROWS_COND:%.*]] = icmp ne i16 [[TILEDPBSSD_SCALARIZE_ROWS_STEP]], 4
+; CHECK-NEXT:    br i1 [[TILEDPBSSD_SCALARIZE_ROWS_COND]], label [[TILEDPBSSD_SCALARIZE_ROWS_HEADER]], label [[CONTINUE31:%.*]]
+; CHECK:       continue31:
+; CHECK-NEXT:    [[TMP55:%.*]] = bitcast <256 x i32> [[TMP54]] to x86_amx
+; CHECK-NEXT:    br label [[TILESTORE_SCALARIZE_ROWS_HEADER:%.*]]
+; CHECK:       tilestore.scalarize.rows.header:
+; CHECK-NEXT:    [[TILESTORE_SCALARIZE_ROWS_IV:%.*]] = phi i16 [ 0, [[CONTINUE31]] ], [ [[TILESTORE_SCALARIZE_ROWS_STEP:%.*]], [[TILESTORE_SCALARIZE_ROWS_LATCH:%.*]] ]
+; CHECK-NEXT:    br label [[TILESTORE_SCALARIZE_ROWS_BODY:%.*]]
+; CHECK:       tilestore.scalarize.rows.body:
+; CHECK-NEXT:    br label [[TILESTORE_SCALARIZE_COLS_HEADER:%.*]]
+; CHECK:       tilestore.scalarize.cols.header:
+; CHECK-NEXT:    [[TILESTORE_SCALARIZE_COLS_IV:%.*]] = phi i16 [ 0, [[TILESTORE_SCALARIZE_ROWS_BODY]] ], [ [[TILESTORE_SCALARIZE_COLS_STEP:%.*]], [[TILESTORE_SCALARIZE_COLS_LATCH:%.*]] ]
+; CHECK-NEXT:    br label [[TILESTORE_SCALARIZE_COLS_BODY:%.*]]
+; CHECK:       tilestore.scalarize.cols.body:
+; CHECK-NEXT:    [[TMP56:%.*]] = zext i16 [[TILESTORE_SCALARIZE_ROWS_IV]] to i64
+; CHECK-NEXT:    [[TMP57:%.*]] = zext i16 [[TILESTORE_SCALARIZE_COLS_IV]] to i64
+; CHECK-NEXT:    [[TMP58:%.*]] = mul i64 [[TMP56]], 4
+; CHECK-NEXT:    [[TMP59:%.*]] = add i64 [[TMP58]], [[TMP57]]
+; CHECK-NEXT:    [[TMP60:%.*]] = bitcast i8* [[TMP0]] to i32*
+; CHECK-NEXT:    [[TMP61:%.*]] = getelementptr i32, i32* [[TMP60]], i64 [[TMP59]]
+; CHECK-NEXT:    [[TMP62:%.*]] = mul i16 [[TILESTORE_SCALARIZE_ROWS_IV]], 16
+; CHECK-NEXT:    [[TMP63:%.*]] = add i16 [[TMP62]], [[TILESTORE_SCALARIZE_COLS_IV]]
+; CHECK-NEXT:    [[TMP64:%.*]] = extractelement <256 x i32> [[TMP54]], i16 [[TMP63]]
+; CHECK-NEXT:    store i32 [[TMP64]], i32* [[TMP61]], align 4
+; CHECK-NEXT:    br label [[TILESTORE_SCALARIZE_COLS_LATCH]]
+; CHECK:       tilestore.scalarize.cols.latch:
+; CHECK-NEXT:    [[TILESTORE_SCALARIZE_COLS_STEP]] = add i16 [[TILESTORE_SCALARIZE_COLS_IV]], 1
+; CHECK-NEXT:    [[TILESTORE_SCALARIZE_COLS_COND:%.*]] = icmp ne i16 [[TILESTORE_SCALARIZE_COLS_STEP]], 4
+; CHECK-NEXT:    br i1 [[TILESTORE_SCALARIZE_COLS_COND]], label [[TILESTORE_SCALARIZE_COLS_HEADER]], label [[TILESTORE_SCALARIZE_ROWS_LATCH]]
+; CHECK:       tilestore.scalarize.rows.latch:
+; CHECK-NEXT:    [[TILESTORE_SCALARIZE_ROWS_STEP]] = add i16 [[TILESTORE_SCALARIZE_ROWS_IV]], 1
+; CHECK-NEXT:    [[TILESTORE_SCALARIZE_ROWS_COND:%.*]] = icmp ne i16 [[TILESTORE_SCALARIZE_ROWS_STEP]], 4
+; CHECK-NEXT:    br i1 [[TILESTORE_SCALARIZE_ROWS_COND]], label [[TILESTORE_SCALARIZE_ROWS_HEADER]], label [[CONTINUE32:%.*]]
+; CHECK:       continue32:
+; CHECK-NEXT:    ret void
+;
+entry:
+  %0 = bitcast i32* %C_mem to i8*
+  %1 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 4, i16 16, i8* %0, i64 16)
+  %2 = bitcast i32* %A_mem to i8*
+  %3 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 4, i16 16, i8* %2, i64 16)
+  %4 = bitcast i32* %B_mem to i8*
+  %5 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 4, i16 16, i8* %4, i64 16)
+  %6 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 4, i16 16, i16 16, x86_amx %1, x86_amx %3, x86_amx %5)
+  tail call void @llvm.x86.tilestored64.internal(i16 4, i16 16, i8* %0, i64 16, x86_amx %6)
+  ret void
+}
+
+declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64)
+declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)
+declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx)
+
+attributes #0 = { noinline nounwind optnone }
diff --git a/llvm/test/CodeGen/X86/AMX/amx-low-intrinsics.ll b/llvm/test/CodeGen/X86/AMX/amx-low-intrinsics.ll
new file mode 100644 (file)
index 0000000..5420123
--- /dev/null
@@ -0,0 +1,237 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -mtriple=x86_64 -lower-amx-intrinsics %s -S | FileCheck %s
+
+define dso_local void @test_amx_load_non_O0(i16 signext %row, i16 signext %col, i8 *%ptr, i64 %stride, <256 x i32>* %vptr) {
+; CHECK-LABEL: @test_amx_load_non_O0(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = lshr i16 [[COL:%.*]], 2
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr i64 [[STRIDE:%.*]], 2
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_ROWS_HEADER:%.*]]
+; CHECK:       tileload.scalarize.rows.header:
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_ROWS_IV:%.*]] = phi i16 [ 0, [[ENTRY:%.*]] ], [ [[TILELOAD_SCALARIZE_ROWS_STEP:%.*]], [[TILELOAD_SCALARIZE_ROWS_LATCH:%.*]] ]
+; CHECK-NEXT:    [[VEC_PHI_ROW:%.*]] = phi <256 x i32> [ zeroinitializer, [[ENTRY]] ], [ [[TMP11:%.*]], [[TILELOAD_SCALARIZE_ROWS_LATCH]] ]
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_ROWS_BODY:%.*]]
+; CHECK:       tileload.scalarize.rows.body:
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_COLS_HEADER:%.*]]
+; CHECK:       tileload.scalarize.cols.header:
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_COLS_IV:%.*]] = phi i16 [ 0, [[TILELOAD_SCALARIZE_ROWS_BODY]] ], [ [[TILELOAD_SCALARIZE_COLS_STEP:%.*]], [[TILELOAD_SCALARIZE_COLS_LATCH:%.*]] ]
+; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <256 x i32> [ [[VEC_PHI_ROW]], [[TILELOAD_SCALARIZE_ROWS_BODY]] ], [ [[TMP11]], [[TILELOAD_SCALARIZE_COLS_LATCH]] ]
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_COLS_BODY:%.*]]
+; CHECK:       tileload.scalarize.cols.body:
+; CHECK-NEXT:    [[TMP2:%.*]] = zext i16 [[TILELOAD_SCALARIZE_ROWS_IV]] to i64
+; CHECK-NEXT:    [[TMP3:%.*]] = zext i16 [[TILELOAD_SCALARIZE_COLS_IV]] to i64
+; CHECK-NEXT:    [[TMP4:%.*]] = mul i64 [[TMP2]], [[TMP1]]
+; CHECK-NEXT:    [[TMP5:%.*]] = add i64 [[TMP4]], [[TMP3]]
+; CHECK-NEXT:    [[TMP6:%.*]] = bitcast i8* [[PTR:%.*]] to i32*
+; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr i32, i32* [[TMP6]], i64 [[TMP5]]
+; CHECK-NEXT:    [[TMP8:%.*]] = mul i16 [[TILELOAD_SCALARIZE_ROWS_IV]], 16
+; CHECK-NEXT:    [[TMP9:%.*]] = add i16 [[TMP8]], [[TILELOAD_SCALARIZE_COLS_IV]]
+; CHECK-NEXT:    [[TMP10:%.*]] = load i32, i32* [[TMP7]], align 4
+; CHECK-NEXT:    [[TMP11]] = insertelement <256 x i32> [[VEC_PHI]], i32 [[TMP10]], i16 [[TMP9]]
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_COLS_LATCH]]
+; CHECK:       tileload.scalarize.cols.latch:
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_COLS_STEP]] = add i16 [[TILELOAD_SCALARIZE_COLS_IV]], 1
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_COLS_COND:%.*]] = icmp ne i16 [[TILELOAD_SCALARIZE_COLS_STEP]], [[TMP0]]
+; CHECK-NEXT:    br i1 [[TILELOAD_SCALARIZE_COLS_COND]], label [[TILELOAD_SCALARIZE_COLS_HEADER]], label [[TILELOAD_SCALARIZE_ROWS_LATCH]]
+; CHECK:       tileload.scalarize.rows.latch:
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_ROWS_STEP]] = add i16 [[TILELOAD_SCALARIZE_ROWS_IV]], 1
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_ROWS_COND:%.*]] = icmp ne i16 [[TILELOAD_SCALARIZE_ROWS_STEP]], [[ROW:%.*]]
+; CHECK-NEXT:    br i1 [[TILELOAD_SCALARIZE_ROWS_COND]], label [[TILELOAD_SCALARIZE_ROWS_HEADER]], label [[CONTINUE:%.*]]
+; CHECK:       continue:
+; CHECK-NEXT:    [[TMP12:%.*]] = bitcast <256 x i32> [[TMP11]] to x86_amx
+; CHECK-NEXT:    store <256 x i32> [[TMP11]], <256 x i32>* [[VPTR:%.*]], align 64
+; CHECK-NEXT:    ret void
+;
+entry:
+  %amx = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, i8* %ptr, i64 %stride)
+  %vec = bitcast x86_amx %amx to <256 x i32>
+  store <256 x i32> %vec, <256 x i32>* %vptr, align 64
+  ret void
+}
+
+define dso_local void @test_amx_load(i16 signext %row, i16 signext %col, i8 *%ptr, i64 %stride, <256 x i32>* %vptr) #0 {
+; CHECK-LABEL: @test_amx_load(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = lshr i16 [[COL:%.*]], 2
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr i64 [[STRIDE:%.*]], 2
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_ROWS_HEADER:%.*]]
+; CHECK:       tileload.scalarize.rows.header:
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_ROWS_IV:%.*]] = phi i16 [ 0, [[ENTRY:%.*]] ], [ [[TILELOAD_SCALARIZE_ROWS_STEP:%.*]], [[TILELOAD_SCALARIZE_ROWS_LATCH:%.*]] ]
+; CHECK-NEXT:    [[VEC_PHI_ROW:%.*]] = phi <256 x i32> [ zeroinitializer, [[ENTRY]] ], [ [[TMP11:%.*]], [[TILELOAD_SCALARIZE_ROWS_LATCH]] ]
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_ROWS_BODY:%.*]]
+; CHECK:       tileload.scalarize.rows.body:
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_COLS_HEADER:%.*]]
+; CHECK:       tileload.scalarize.cols.header:
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_COLS_IV:%.*]] = phi i16 [ 0, [[TILELOAD_SCALARIZE_ROWS_BODY]] ], [ [[TILELOAD_SCALARIZE_COLS_STEP:%.*]], [[TILELOAD_SCALARIZE_COLS_LATCH:%.*]] ]
+; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <256 x i32> [ [[VEC_PHI_ROW]], [[TILELOAD_SCALARIZE_ROWS_BODY]] ], [ [[TMP11]], [[TILELOAD_SCALARIZE_COLS_LATCH]] ]
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_COLS_BODY:%.*]]
+; CHECK:       tileload.scalarize.cols.body:
+; CHECK-NEXT:    [[TMP2:%.*]] = zext i16 [[TILELOAD_SCALARIZE_ROWS_IV]] to i64
+; CHECK-NEXT:    [[TMP3:%.*]] = zext i16 [[TILELOAD_SCALARIZE_COLS_IV]] to i64
+; CHECK-NEXT:    [[TMP4:%.*]] = mul i64 [[TMP2]], [[TMP1]]
+; CHECK-NEXT:    [[TMP5:%.*]] = add i64 [[TMP4]], [[TMP3]]
+; CHECK-NEXT:    [[TMP6:%.*]] = bitcast i8* [[PTR:%.*]] to i32*
+; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr i32, i32* [[TMP6]], i64 [[TMP5]]
+; CHECK-NEXT:    [[TMP8:%.*]] = mul i16 [[TILELOAD_SCALARIZE_ROWS_IV]], 16
+; CHECK-NEXT:    [[TMP9:%.*]] = add i16 [[TMP8]], [[TILELOAD_SCALARIZE_COLS_IV]]
+; CHECK-NEXT:    [[TMP10:%.*]] = load i32, i32* [[TMP7]], align 4
+; CHECK-NEXT:    [[TMP11]] = insertelement <256 x i32> [[VEC_PHI]], i32 [[TMP10]], i16 [[TMP9]]
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_COLS_LATCH]]
+; CHECK:       tileload.scalarize.cols.latch:
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_COLS_STEP]] = add i16 [[TILELOAD_SCALARIZE_COLS_IV]], 1
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_COLS_COND:%.*]] = icmp ne i16 [[TILELOAD_SCALARIZE_COLS_STEP]], [[TMP0]]
+; CHECK-NEXT:    br i1 [[TILELOAD_SCALARIZE_COLS_COND]], label [[TILELOAD_SCALARIZE_COLS_HEADER]], label [[TILELOAD_SCALARIZE_ROWS_LATCH]]
+; CHECK:       tileload.scalarize.rows.latch:
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_ROWS_STEP]] = add i16 [[TILELOAD_SCALARIZE_ROWS_IV]], 1
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_ROWS_COND:%.*]] = icmp ne i16 [[TILELOAD_SCALARIZE_ROWS_STEP]], [[ROW:%.*]]
+; CHECK-NEXT:    br i1 [[TILELOAD_SCALARIZE_ROWS_COND]], label [[TILELOAD_SCALARIZE_ROWS_HEADER]], label [[CONTINUE:%.*]]
+; CHECK:       continue:
+; CHECK-NEXT:    [[TMP12:%.*]] = bitcast <256 x i32> [[TMP11]] to x86_amx
+; CHECK-NEXT:    store <256 x i32> [[TMP11]], <256 x i32>* [[VPTR:%.*]], align 64
+; CHECK-NEXT:    ret void
+;
+entry:
+  %amx = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, i8* %ptr, i64 %stride)
+  %vec = bitcast x86_amx %amx to <256 x i32>
+  store <256 x i32> %vec, <256 x i32>* %vptr, align 64
+  ret void
+}
+
+define dso_local void @test_amx_dp(i16 signext %row, i16 signext %col, i16 signext %k, <256 x i32> %c, <256 x i32> %a, <256 x i32> %b, <256 x i32>* %vptr) #0 {
+; CHECK-LABEL: @test_amx_dp(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[A_AMX:%.*]] = bitcast <256 x i32> [[A:%.*]] to x86_amx
+; CHECK-NEXT:    [[B_AMX:%.*]] = bitcast <256 x i32> [[B:%.*]] to x86_amx
+; CHECK-NEXT:    [[C_AMX:%.*]] = bitcast <256 x i32> [[C:%.*]] to x86_amx
+; CHECK-NEXT:    [[TMP0:%.*]] = lshr i16 [[COL:%.*]], 2
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr i16 [[K:%.*]], 2
+; CHECK-NEXT:    br label [[TILEDPBSSD_SCALARIZE_ROWS_HEADER:%.*]]
+; CHECK:       tiledpbssd.scalarize.rows.header:
+; CHECK-NEXT:    [[TILEDPBSSD_SCALARIZE_ROWS_IV:%.*]] = phi i16 [ 0, [[ENTRY:%.*]] ], [ [[TILEDPBSSD_SCALARIZE_ROWS_STEP:%.*]], [[TILEDPBSSD_SCALARIZE_ROWS_LATCH:%.*]] ]
+; CHECK-NEXT:    [[VEC_C_PHI_ROW:%.*]] = phi <256 x i32> [ [[C]], [[ENTRY]] ], [ [[TMP18:%.*]], [[TILEDPBSSD_SCALARIZE_ROWS_LATCH]] ]
+; CHECK-NEXT:    [[VEC_D_PHI_ROW:%.*]] = phi <256 x i32> [ zeroinitializer, [[ENTRY]] ], [ [[TMP20:%.*]], [[TILEDPBSSD_SCALARIZE_ROWS_LATCH]] ]
+; CHECK-NEXT:    br label [[TILEDPBSSD_SCALARIZE_ROWS_BODY:%.*]]
+; CHECK:       tiledpbssd.scalarize.rows.body:
+; CHECK-NEXT:    br label [[TILEDPBSSD_SCALARIZE_COLS_HEADER:%.*]]
+; CHECK:       tiledpbssd.scalarize.cols.header:
+; CHECK-NEXT:    [[TILEDPBSSD_SCALARIZE_COLS_IV:%.*]] = phi i16 [ 0, [[TILEDPBSSD_SCALARIZE_ROWS_BODY]] ], [ [[TILEDPBSSD_SCALARIZE_COLS_STEP:%.*]], [[TILEDPBSSD_SCALARIZE_COLS_LATCH:%.*]] ]
+; CHECK-NEXT:    [[VEC_C_PHI_COL:%.*]] = phi <256 x i32> [ [[VEC_C_PHI_ROW]], [[TILEDPBSSD_SCALARIZE_ROWS_BODY]] ], [ [[TMP18]], [[TILEDPBSSD_SCALARIZE_COLS_LATCH]] ]
+; CHECK-NEXT:    [[VEC_D_PHI_COL:%.*]] = phi <256 x i32> [ [[VEC_D_PHI_ROW]], [[TILEDPBSSD_SCALARIZE_ROWS_BODY]] ], [ [[TMP20]], [[TILEDPBSSD_SCALARIZE_COLS_LATCH]] ]
+; CHECK-NEXT:    [[TMP2:%.*]] = mul i16 [[TILEDPBSSD_SCALARIZE_ROWS_IV]], 16
+; CHECK-NEXT:    [[TMP3:%.*]] = add i16 [[TMP2]], [[TILEDPBSSD_SCALARIZE_COLS_IV]]
+; CHECK-NEXT:    br label [[TILEDPBSSD_SCALARIZE_COLS_BODY:%.*]]
+; CHECK:       tiledpbssd.scalarize.cols.body:
+; CHECK-NEXT:    br label [[TILEDPBSSD_SCALARIZE_INNER_HEADER:%.*]]
+; CHECK:       tiledpbssd.scalarize.inner.header:
+; CHECK-NEXT:    [[TILEDPBSSD_SCALARIZE_INNER_IV:%.*]] = phi i16 [ 0, [[TILEDPBSSD_SCALARIZE_COLS_BODY]] ], [ [[TILEDPBSSD_SCALARIZE_INNER_STEP:%.*]], [[TILEDPBSSD_SCALARIZE_INNER_LATCH:%.*]] ]
+; CHECK-NEXT:    [[VEC_C_INNER_PHI:%.*]] = phi <256 x i32> [ [[VEC_C_PHI_COL]], [[TILEDPBSSD_SCALARIZE_COLS_BODY]] ], [ [[TMP18]], [[TILEDPBSSD_SCALARIZE_INNER_LATCH]] ]
+; CHECK-NEXT:    br label [[TILEDPBSSD_SCALARIZE_INNER_BODY:%.*]]
+; CHECK:       tiledpbssd.scalarize.inner.body:
+; CHECK-NEXT:    [[TMP4:%.*]] = mul i16 [[TILEDPBSSD_SCALARIZE_ROWS_IV]], 16
+; CHECK-NEXT:    [[TMP5:%.*]] = add i16 [[TMP4]], [[TILEDPBSSD_SCALARIZE_INNER_IV]]
+; CHECK-NEXT:    [[TMP6:%.*]] = mul i16 [[TILEDPBSSD_SCALARIZE_INNER_IV]], 16
+; CHECK-NEXT:    [[TMP7:%.*]] = add i16 [[TMP6]], [[TILEDPBSSD_SCALARIZE_COLS_IV]]
+; CHECK-NEXT:    [[TMP8:%.*]] = extractelement <256 x i32> [[VEC_C_INNER_PHI]], i16 [[TMP3]]
+; CHECK-NEXT:    [[TMP9:%.*]] = extractelement <256 x i32> [[A]], i16 [[TMP5]]
+; CHECK-NEXT:    [[TMP10:%.*]] = bitcast i32 [[TMP9]] to <4 x i8>
+; CHECK-NEXT:    [[TMP11:%.*]] = extractelement <256 x i32> [[B]], i16 [[TMP7]]
+; CHECK-NEXT:    [[TMP12:%.*]] = bitcast i32 [[TMP11]] to <4 x i8>
+; CHECK-NEXT:    [[TMP13:%.*]] = sext <4 x i8> [[TMP12]] to <4 x i32>
+; CHECK-NEXT:    [[TMP14:%.*]] = sext <4 x i8> [[TMP10]] to <4 x i32>
+; CHECK-NEXT:    [[TMP15:%.*]] = mul <4 x i32> [[TMP14]], [[TMP13]]
+; CHECK-NEXT:    [[TMP16:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP15]])
+; CHECK-NEXT:    [[TMP17:%.*]] = add i32 [[TMP8]], [[TMP16]]
+; CHECK-NEXT:    [[TMP18]] = insertelement <256 x i32> [[VEC_C_INNER_PHI]], i32 [[TMP17]], i16 [[TMP3]]
+; CHECK-NEXT:    br label [[TILEDPBSSD_SCALARIZE_INNER_LATCH]]
+; CHECK:       tiledpbssd.scalarize.inner.latch:
+; CHECK-NEXT:    [[TILEDPBSSD_SCALARIZE_INNER_STEP]] = add i16 [[TILEDPBSSD_SCALARIZE_INNER_IV]], 1
+; CHECK-NEXT:    [[TILEDPBSSD_SCALARIZE_INNER_COND:%.*]] = icmp ne i16 [[TILEDPBSSD_SCALARIZE_INNER_STEP]], [[TMP1]]
+; CHECK-NEXT:    br i1 [[TILEDPBSSD_SCALARIZE_INNER_COND]], label [[TILEDPBSSD_SCALARIZE_INNER_HEADER]], label [[TILEDPBSSD_SCALARIZE_COLS_LATCH]]
+; CHECK:       tiledpbssd.scalarize.cols.latch:
+; CHECK-NEXT:    [[TILEDPBSSD_SCALARIZE_COLS_STEP]] = add i16 [[TILEDPBSSD_SCALARIZE_COLS_IV]], 1
+; CHECK-NEXT:    [[TILEDPBSSD_SCALARIZE_COLS_COND:%.*]] = icmp ne i16 [[TILEDPBSSD_SCALARIZE_COLS_STEP]], [[TMP0]]
+; CHECK-NEXT:    [[TMP19:%.*]] = extractelement <256 x i32> [[TMP18]], i16 [[TMP3]]
+; CHECK-NEXT:    [[TMP20]] = insertelement <256 x i32> [[VEC_D_PHI_COL]], i32 [[TMP19]], i16 [[TMP3]]
+; CHECK-NEXT:    br i1 [[TILEDPBSSD_SCALARIZE_COLS_COND]], label [[TILEDPBSSD_SCALARIZE_COLS_HEADER]], label [[TILEDPBSSD_SCALARIZE_ROWS_LATCH]]
+; CHECK:       tiledpbssd.scalarize.rows.latch:
+; CHECK-NEXT:    [[TILEDPBSSD_SCALARIZE_ROWS_STEP]] = add i16 [[TILEDPBSSD_SCALARIZE_ROWS_IV]], 1
+; CHECK-NEXT:    [[TILEDPBSSD_SCALARIZE_ROWS_COND:%.*]] = icmp ne i16 [[TILEDPBSSD_SCALARIZE_ROWS_STEP]], [[ROW:%.*]]
+; CHECK-NEXT:    br i1 [[TILEDPBSSD_SCALARIZE_ROWS_COND]], label [[TILEDPBSSD_SCALARIZE_ROWS_HEADER]], label [[CONTINUE:%.*]]
+; CHECK:       continue:
+; CHECK-NEXT:    [[TMP21:%.*]] = bitcast <256 x i32> [[TMP20]] to x86_amx
+; CHECK-NEXT:    store <256 x i32> [[TMP20]], <256 x i32>* [[VPTR:%.*]], align 64
+; CHECK-NEXT:    ret void
+;
+entry:
+  %a.amx = bitcast <256 x i32> %a to x86_amx
+  %b.amx = bitcast <256 x i32> %b to x86_amx
+  %c.amx = bitcast <256 x i32> %c to x86_amx
+  %acc = call x86_amx @llvm.x86.tdpbssd.internal(i16 %row, i16 %col, i16 %k, x86_amx %c.amx, x86_amx %a.amx, x86_amx %b.amx)
+  %vec = bitcast x86_amx %acc to <256 x i32>
+  store <256 x i32> %vec, <256 x i32>* %vptr, align 64
+  ret void
+}
+
+define dso_local void @test_amx_store(i16 signext %row, i16 signext %col, i8 *%ptr, i64 %stride, <256 x i32>* %vptr, <256 x i32> %vec) #0 {
+; CHECK-LABEL: @test_amx_store(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[AMX:%.*]] = bitcast <256 x i32> [[VEC:%.*]] to x86_amx
+; CHECK-NEXT:    [[TMP0:%.*]] = lshr i16 [[COL:%.*]], 2
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr i64 [[STRIDE:%.*]], 2
+; CHECK-NEXT:    br label [[TILESTORE_SCALARIZE_ROWS_HEADER:%.*]]
+; CHECK:       tilestore.scalarize.rows.header:
+; CHECK-NEXT:    [[TILESTORE_SCALARIZE_ROWS_IV:%.*]] = phi i16 [ 0, [[ENTRY:%.*]] ], [ [[TILESTORE_SCALARIZE_ROWS_STEP:%.*]], [[TILESTORE_SCALARIZE_ROWS_LATCH:%.*]] ]
+; CHECK-NEXT:    br label [[TILESTORE_SCALARIZE_ROWS_BODY:%.*]]
+; CHECK:       tilestore.scalarize.rows.body:
+; CHECK-NEXT:    br label [[TILESTORE_SCALARIZE_COLS_HEADER:%.*]]
+; CHECK:       tilestore.scalarize.cols.header:
+; CHECK-NEXT:    [[TILESTORE_SCALARIZE_COLS_IV:%.*]] = phi i16 [ 0, [[TILESTORE_SCALARIZE_ROWS_BODY]] ], [ [[TILESTORE_SCALARIZE_COLS_STEP:%.*]], [[TILESTORE_SCALARIZE_COLS_LATCH:%.*]] ]
+; CHECK-NEXT:    br label [[TILESTORE_SCALARIZE_COLS_BODY:%.*]]
+; CHECK:       tilestore.scalarize.cols.body:
+; CHECK-NEXT:    [[TMP2:%.*]] = zext i16 [[TILESTORE_SCALARIZE_ROWS_IV]] to i64
+; CHECK-NEXT:    [[TMP3:%.*]] = zext i16 [[TILESTORE_SCALARIZE_COLS_IV]] to i64
+; CHECK-NEXT:    [[TMP4:%.*]] = mul i64 [[TMP2]], [[TMP1]]
+; CHECK-NEXT:    [[TMP5:%.*]] = add i64 [[TMP4]], [[TMP3]]
+; CHECK-NEXT:    [[TMP6:%.*]] = bitcast i8* [[PTR:%.*]] to i32*
+; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr i32, i32* [[TMP6]], i64 [[TMP5]]
+; CHECK-NEXT:    [[TMP8:%.*]] = mul i16 [[TILESTORE_SCALARIZE_ROWS_IV]], 16
+; CHECK-NEXT:    [[TMP9:%.*]] = add i16 [[TMP8]], [[TILESTORE_SCALARIZE_COLS_IV]]
+; CHECK-NEXT:    [[TMP10:%.*]] = extractelement <256 x i32> [[VEC]], i16 [[TMP9]]
+; CHECK-NEXT:    store i32 [[TMP10]], i32* [[TMP7]], align 4
+; CHECK-NEXT:    br label [[TILESTORE_SCALARIZE_COLS_LATCH]]
+; CHECK:       tilestore.scalarize.cols.latch:
+; CHECK-NEXT:    [[TILESTORE_SCALARIZE_COLS_STEP]] = add i16 [[TILESTORE_SCALARIZE_COLS_IV]], 1
+; CHECK-NEXT:    [[TILESTORE_SCALARIZE_COLS_COND:%.*]] = icmp ne i16 [[TILESTORE_SCALARIZE_COLS_STEP]], [[TMP0]]
+; CHECK-NEXT:    br i1 [[TILESTORE_SCALARIZE_COLS_COND]], label [[TILESTORE_SCALARIZE_COLS_HEADER]], label [[TILESTORE_SCALARIZE_ROWS_LATCH]]
+; CHECK:       tilestore.scalarize.rows.latch:
+; CHECK-NEXT:    [[TILESTORE_SCALARIZE_ROWS_STEP]] = add i16 [[TILESTORE_SCALARIZE_ROWS_IV]], 1
+; CHECK-NEXT:    [[TILESTORE_SCALARIZE_ROWS_COND:%.*]] = icmp ne i16 [[TILESTORE_SCALARIZE_ROWS_STEP]], [[ROW:%.*]]
+; CHECK-NEXT:    br i1 [[TILESTORE_SCALARIZE_ROWS_COND]], label [[TILESTORE_SCALARIZE_ROWS_HEADER]], label [[CONTINUE:%.*]]
+; CHECK:       continue:
+; CHECK-NEXT:    ret void
+;
+entry:
+  %amx = bitcast <256 x i32> %vec to x86_amx
+  call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %ptr, i64 %stride, x86_amx %amx)
+  ret void
+}
+
+define dso_local void @test_amx_zero(i16 signext %row, i16 signext %col, <256 x i32>* %vptr) #0 {
+; CHECK-LABEL: @test_amx_zero(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    store <256 x i32> zeroinitializer, <256 x i32>* [[VPTR:%.*]], align 64
+; CHECK-NEXT:    ret void
+;
+entry:
+  %amx = call x86_amx @llvm.x86.tilezero.internal(i16 %row, i16 %col)
+  %vec = bitcast x86_amx %amx to <256 x i32>
+  store <256 x i32> %vec, <256 x i32>* %vptr, align 64
+  ret void
+}
+
+declare x86_amx @llvm.x86.tilezero.internal(i16, i16)
+declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64)
+declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)
+declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx)
+
+attributes #0 = { noinline nounwind optnone }
index a70d09e..b730f3d 100644 (file)
@@ -1,5 +1,5 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
-; RUN: opt -lower-amx-type %s -S | FileCheck %s
+; RUN: opt --codegen-opt-level=2 -mtriple=x86_64 -lower-amx-type %s -S | FileCheck %s
 
 %struct.__tile_str = type { i16, i16, <256 x i32> }
 
index 8b0c604..2e1cbac 100644 (file)
@@ -18,6 +18,9 @@
 ; CHECK-NEXT:     Pre-ISel Intrinsic Lowering
 ; CHECK-NEXT:     FunctionPass Manager
 ; CHECK-NEXT:       Expand Atomic instructions
+; CHECK-NEXT:       Dominator Tree Construction
+; CHECK-NEXT:       Natural Loop Information
+; CHECK-NEXT:       Lower AMX intrinsics
 ; CHECK-NEXT:       Lower AMX type for load/store
 ; CHECK-NEXT:       Module Verifier
 ; CHECK-NEXT:       Lower Garbage Collection Instructions
index a99e203..0f92e5a 100644 (file)
 ; CHECK-NEXT:     Pre-ISel Intrinsic Lowering
 ; CHECK-NEXT:     FunctionPass Manager
 ; CHECK-NEXT:       Expand Atomic instructions
+; CHECK-NEXT:       Dominator Tree Construction
+; CHECK-NEXT:       Natural Loop Information
+; CHECK-NEXT:       Lower AMX intrinsics
 ; CHECK-NEXT:       Lower AMX type for load/store
 ; CHECK-NEXT:       Module Verifier
-; CHECK-NEXT:       Dominator Tree Construction
 ; CHECK-NEXT:       Basic Alias Analysis (stateless AA impl)
-; CHECK-NEXT:       Natural Loop Information
 ; CHECK-NEXT:       Canonicalize natural loops
 ; CHECK-NEXT:       Scalar Evolution Analysis
 ; CHECK-NEXT:       Loop Pass Manager
index f690d5f..54d5951 100644 (file)
@@ -513,7 +513,8 @@ static bool shouldPinPassToLegacyPM(StringRef Pass) {
       "expand-reductions",    "indirectbr-expand",
       "generic-to-nvvm",      "expandmemcmp",
       "loop-reduce",          "lower-amx-type",
-      "polyhedral-info",      "replace-with-veclib"};
+      "lower-amx-intrinsics", "polyhedral-info",
+      "replace-with-veclib"};
   for (const auto &P : PassNamePrefix)
     if (Pass.startswith(P))
       return true;