[X86][AMX] Fix the shape dependency issue.
authorLuo, Yuanke <yuanke.luo@intel.com>
Mon, 14 Nov 2022 02:20:15 +0000 (10:20 +0800)
committerLuo, Yuanke <yuanke.luo@intel.com>
Wed, 16 Nov 2022 02:47:11 +0000 (10:47 +0800)
AMX shape should be defined before AMX intrinsics. However for below
case, the shape a.row is defined after tile load of b. If we transform
`load b` to `@llvm.x86.tileloadd64 intrinsic`, the shape dependency
doesn't meet.
```
void test_tile_dpbsud(__tile1024i a, __tile1024i b, __tile1024i c) {
  __tile_dpbsud(&c, a, b);
}
```
This patch is to store the tile b to stack and reloaded it after the
def of b.row. It would cause redundant store/load, but it is simple
to avoid generating invalid IR.
The better way may hoist `def b.row` before tile load instruction,
but it seems more complicated to recursively hoist its operands.

Differential Revision: https://reviews.llvm.org/D137923

llvm/lib/Target/X86/X86LowerAMXType.cpp
llvm/test/CodeGen/X86/AMX/amx-combine.ll

index 540182c..9419a3e 100644 (file)
@@ -700,11 +700,12 @@ namespace {
 
 class X86LowerAMXCast {
   Function &Func;
+  std::unique_ptr<DominatorTree> DT;
 
 public:
-  X86LowerAMXCast(Function &F) : Func(F) {}
+  X86LowerAMXCast(Function &F) : Func(F), DT(nullptr) {}
   void combineCastStore(IntrinsicInst *Cast, StoreInst *ST);
-  void combineLoadCast(IntrinsicInst *Cast, LoadInst *LD);
+  bool combineLoadCast(IntrinsicInst *Cast, LoadInst *LD);
   bool combineLdSt(SmallVectorImpl<Instruction *> &Casts);
   bool combineAMXcast(TargetLibraryInfo *TLI);
   bool transformAMXCast(IntrinsicInst *AMXCast);
@@ -942,7 +943,8 @@ void X86LowerAMXCast::combineCastStore(IntrinsicInst *Cast, StoreInst *ST) {
 // -->
 // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
 //                                                   i8* %p, i64 64)
-void X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) {
+bool X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) {
+  bool EraseLoad = true;
   Value *Row = nullptr, *Col = nullptr;
   Use &U = *(Cast->use_begin());
   unsigned OpNo = U.getOperandNo();
@@ -950,18 +952,37 @@ void X86LowerAMXCast::combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) {
   // TODO: If it is cast intrinsic or phi node, we can propagate the
   // shape information through def-use chain.
   if (!isAMXIntrinsic(II))
-    return;
+    return false;
   std::tie(Row, Col) = getShape(II, OpNo);
   IRBuilder<> Builder(LD);
   // Use the maximun column as stride.
   Value *Stride = Builder.getInt64(64);
-  Value *I8Ptr =
-      Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy());
+  Value *I8Ptr;
+
+  // To save compiling time, we create doninator tree when it is really
+  // needed.
+  if (!DT)
+    DT.reset(new DominatorTree(Func));
+  if (!DT->dominates(Row, LD) || !DT->dominates(Col, LD)) {
+    // store the value to stack and reload it from stack before cast.
+    auto *AllocaAddr =
+        createAllocaInstAtEntry(Builder, Cast->getParent(), LD->getType());
+    Builder.SetInsertPoint(&*std::next(LD->getIterator()));
+    Builder.CreateStore(LD, AllocaAddr);
+
+    Builder.SetInsertPoint(Cast);
+    I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy());
+    EraseLoad = false;
+  } else {
+    I8Ptr = Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy());
+  }
   std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
 
   Value *NewInst =
       Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
   Cast->replaceAllUsesWith(NewInst);
+
+  return EraseLoad;
 }
 
 bool X86LowerAMXCast::combineLdSt(SmallVectorImpl<Instruction *> &Casts) {
@@ -995,10 +1016,11 @@ bool X86LowerAMXCast::combineLdSt(SmallVectorImpl<Instruction *> &Casts) {
       // -->
       // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
       //                                                   i8* %p, i64 64)
-      combineLoadCast(cast<IntrinsicInst>(Cast), Load);
-      // Set the operand is null so that load instruction can be erased.
-      Cast->setOperand(0, nullptr);
-      Load->eraseFromParent();
+      if (combineLoadCast(cast<IntrinsicInst>(Cast), Load)) {
+        // Set the operand is null so that load instruction can be erased.
+        Cast->setOperand(0, nullptr);
+        Load->eraseFromParent();
+      }
     }
   }
   return Change;
@@ -1198,6 +1220,7 @@ public:
     TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
     TargetLibraryInfo *TLI =
         &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
+
     X86LowerAMXCast LAC(F);
     C |= LAC.combineAMXcast(TLI);
     // There might be remaining AMXcast after combineAMXcast and they should be
index 1f4ed0d..0dc1f4b 100644 (file)
@@ -18,9 +18,9 @@ define <256 x i32> @combine_store_2user(ptr%p) {
 ; CHECK-NEXT:    [[TMP1:%.*]] = alloca <256 x i32>, align 64
 ; CHECK-NEXT:    [[T1:%.*]] = call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64)
 ; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 16, i16 64, ptr [[TMP1]], i64 64, x86_amx [[T1]])
-; CHECK-NEXT:    [[TMP3:%.*]] = load <256 x i32>, ptr [[TMP1]], align 1024
+; CHECK-NEXT:    [[TMP2:%.*]] = load <256 x i32>, ptr [[TMP1]], align 1024
 ; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 16, i16 64, ptr [[P:%.*]], i64 64, x86_amx [[T1]])
-; CHECK-NEXT:    ret <256 x i32> [[TMP3]]
+; CHECK-NEXT:    ret <256 x i32> [[TMP2]]
 ;
   %t1 = call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64)
   %t2 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %t1)
@@ -30,8 +30,8 @@ define <256 x i32> @combine_store_2user(ptr%p) {
 
 define void @combine_load(ptr%p, ptr%p2) {
 ; CHECK-LABEL: @combine_load(
-; CHECK-NEXT:    [[TMP2:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, ptr [[P:%.*]], i64 64)
-; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 16, i16 64, ptr [[P2:%.*]], i64 64, x86_amx [[TMP2]])
+; CHECK-NEXT:    [[TMP1:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, ptr [[P:%.*]], i64 64)
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 16, i16 64, ptr [[P2:%.*]], i64 64, x86_amx [[TMP1]])
 ; CHECK-NEXT:    ret void
 ;
   %t1 = load <256 x i32>, ptr %p, align 64
@@ -42,9 +42,9 @@ define void @combine_load(ptr%p, ptr%p2) {
 
 define void @combine_cast_across_store(ptr%p, ptr%p2) {
 ; CHECK-LABEL: @combine_cast_across_store(
-; CHECK-NEXT:    [[TMP2:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, ptr [[P:%.*]], i64 64)
+; CHECK-NEXT:    [[TMP1:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, ptr [[P:%.*]], i64 64)
 ; CHECK-NEXT:    store <256 x i32> zeroinitializer, ptr [[P]], align 64
-; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 16, i16 64, ptr [[P2:%.*]], i64 64, x86_amx [[TMP2]])
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 16, i16 64, ptr [[P2:%.*]], i64 64, x86_amx [[TMP1]])
 ; CHECK-NEXT:    ret void
 ;
   %t1 = load <256 x i32>, ptr %p, align 64
@@ -59,8 +59,8 @@ define <256 x i32> @combine_load_2user(ptr%p, ptr%p2) {
 ; CHECK-NEXT:    [[TMP1:%.*]] = alloca <256 x i32>, align 64
 ; CHECK-NEXT:    [[T1:%.*]] = load <256 x i32>, ptr [[P:%.*]], align 64
 ; CHECK-NEXT:    store <256 x i32> [[T1]], ptr [[TMP1]], align 1024
-; CHECK-NEXT:    [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, ptr [[TMP1]], i64 64)
-; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 16, i16 64, ptr [[P2:%.*]], i64 64, x86_amx [[TMP3]])
+; CHECK-NEXT:    [[TMP2:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, ptr [[TMP1]], i64 64)
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 16, i16 64, ptr [[P2:%.*]], i64 64, x86_amx [[TMP2]])
 ; CHECK-NEXT:    ret <256 x i32> [[T1]]
 ;
   %t1 = load <256 x i32>, ptr %p, align 64
@@ -75,9 +75,9 @@ define <256 x i32> @combine_load_3user(ptr%p, ptr%p2) {
 ; CHECK-NEXT:    [[TMP1:%.*]] = alloca <256 x i32>, align 64
 ; CHECK-NEXT:    [[T1:%.*]] = load <256 x i32>, ptr [[P:%.*]], align 64
 ; CHECK-NEXT:    store <256 x i32> [[T1]], ptr [[TMP1]], align 1024
-; CHECK-NEXT:    [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 16, ptr [[TMP1]], i64 16)
-; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 16, i16 64, ptr [[P2:%.*]], i64 64, x86_amx [[TMP3]])
-; CHECK-NEXT:    [[TMP4:%.*]] = call x86_amx @llvm.x86.tdpbssd.internal(i16 16, i16 16, i16 64, x86_amx [[TMP3]], x86_amx [[TMP3]], x86_amx [[TMP3]])
+; CHECK-NEXT:    [[TMP2:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 16, ptr [[TMP1]], i64 16)
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 16, i16 64, ptr [[P2:%.*]], i64 64, x86_amx [[TMP2]])
+; CHECK-NEXT:    [[TMP3:%.*]] = call x86_amx @llvm.x86.tdpbssd.internal(i16 16, i16 16, i16 64, x86_amx [[TMP2]], x86_amx [[TMP2]], x86_amx [[TMP2]])
 ; CHECK-NEXT:    ret <256 x i32> [[T1]]
 ;
   %t1 = load <256 x i32>, ptr %p, align 64
@@ -88,6 +88,48 @@ define <256 x i32> @combine_load_3user(ptr%p, ptr%p2) {
   ret <256 x i32> %t3
 }
 
+; the shape is loaded after tile.
+%struct.__tile1024i_str = type <{ i16, i16, [60 x i8], <256 x i32> }>
+define void @test_tile_dpbssd(ptr byval(%struct.__tile1024i_str) align 64 %a, ptr byval(%struct.__tile1024i_str) align 64 %b, ptr byval(%struct.__tile1024i_str) align 64 %c) {
+; CHECK-LABEL: @test_tile_dpbssd(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = alloca <256 x i32>, align 64
+; CHECK-NEXT:    [[B_ROW_PTR:%.*]] = getelementptr inbounds i8, ptr [[B:%.*]], i64 2
+; CHECK-NEXT:    [[B_ROW:%.*]] = load i16, ptr [[B_ROW_PTR]], align 2
+; CHECK-NEXT:    [[B_TILE_PTR:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 64
+; CHECK-NEXT:    [[B_TILE:%.*]] = load <256 x i32>, ptr [[B_TILE_PTR]], align 64
+; CHECK-NEXT:    store <256 x i32> [[B_TILE]], ptr [[TMP0]], align 1024
+; CHECK-NEXT:    [[A_ROW:%.*]] = load i16, ptr [[A:%.*]], align 64
+; CHECK-NEXT:    [[A_COL_PTR:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 2
+; CHECK-NEXT:    [[A_COL:%.*]] = load i16, ptr [[A_COL_PTR]], align 2
+; CHECK-NEXT:    [[TMP1:%.*]] = udiv i16 [[A_COL]], 4
+; CHECK-NEXT:    [[A_TILE_PTR:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 64
+; CHECK-NEXT:    [[TMP2:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[A_ROW]], i16 [[A_COL]], ptr [[A_TILE_PTR]], i64 64)
+; CHECK-NEXT:    [[C_TILE_PTR:%.*]] = getelementptr inbounds [[STRUCT___TILE1024I_STR:%.*]], ptr [[C:%.*]], i64 0, i32 3
+; CHECK-NEXT:    [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[A_ROW]], i16 [[B_ROW]], ptr [[C_TILE_PTR]], i64 64)
+; CHECK-NEXT:    [[TMP4:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP1]], i16 [[B_ROW]], ptr [[TMP0]], i64 64)
+; CHECK-NEXT:    [[RES:%.*]] = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 [[A_ROW]], i16 [[B_ROW]], i16 [[A_COL]], x86_amx [[TMP3]], x86_amx [[TMP2]], x86_amx [[TMP4]])
+; CHECK-NEXT:    ret void
+;
+entry:
+  %b.row.ptr= getelementptr inbounds i8, ptr %b, i64 2
+  %b.row = load i16, ptr %b.row.ptr, align 2
+  %b.tile.ptr = getelementptr inbounds i8, ptr %b, i64 64
+  %b.tile = load <256 x i32>, ptr %b.tile.ptr, align 64
+  %a.row = load i16, ptr %a, align 64
+  %a.col.ptr = getelementptr inbounds i8, ptr %a, i64 2
+  %a.col = load i16, ptr %a.col.ptr, align 2
+  %a.tile.ptr = getelementptr inbounds i8, ptr %a, i64 64
+  %a.tile = load <256 x i32>, ptr %a.tile.ptr, align 64
+  %c.tile.ptr = getelementptr inbounds %struct.__tile1024i_str, ptr %c, i64 0, i32 3
+  %c.tile = load <256 x i32>, ptr %c.tile.ptr, align 64
+  %c.amx = tail call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %c.tile)
+  %a.amx = tail call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %a.tile)
+  %b.amx = tail call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %b.tile)
+  %res = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 %a.row, i16 %b.row, i16 %a.col, x86_amx %c.amx, x86_amx %a.amx, x86_amx %b.amx)
+  ret void
+}
+
 declare x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32>)
 declare <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx)
 declare x86_amx @llvm.x86.tilezero.internal(i16, i16)