[NVPTX] Make tensor shape part of WMMA intrinsic's name.
authorArtem Belevich <tra@google.com>
Wed, 21 Mar 2018 21:55:02 +0000 (21:55 +0000)
committerArtem Belevich <tra@google.com>
Wed, 21 Mar 2018 21:55:02 +0000 (21:55 +0000)
This is needed for the upcoming implementation of the
new 8x32x16 and 32x8x16 variants of WMMA instructions
introduced in CUDA 9.1.

Differential Revision: https://reviews.llvm.org/D44719

llvm-svn: 328158

clang/lib/CodeGen/CGBuiltin.cpp
clang/test/CodeGen/builtins-nvptx-sm_70.cu
llvm/include/llvm/IR/IntrinsicsNVVM.td
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
llvm/test/CodeGen/NVPTX/wmma.py

index 996e5e7..d3ea1f2 100644 (file)
@@ -10515,23 +10515,23 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
     unsigned NumResults;
     switch (BuiltinID) {
     case NVPTX::BI__hmma_m16n16k16_ld_a:
-      IID = isColMajor ? Intrinsic::nvvm_wmma_load_a_f16_col_stride
-                       : Intrinsic::nvvm_wmma_load_a_f16_row_stride;
+      IID = isColMajor ? Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_col_stride
+                       : Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row_stride;
       NumResults = 8;
       break;
     case NVPTX::BI__hmma_m16n16k16_ld_b:
-      IID = isColMajor ? Intrinsic::nvvm_wmma_load_b_f16_col_stride
-                       : Intrinsic::nvvm_wmma_load_b_f16_row_stride;
+      IID = isColMajor ? Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col_stride
+                       : Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row_stride;
       NumResults = 8;
       break;
     case NVPTX::BI__hmma_m16n16k16_ld_c_f16:
-      IID = isColMajor ? Intrinsic::nvvm_wmma_load_c_f16_col_stride
-                       : Intrinsic::nvvm_wmma_load_c_f16_row_stride;
+      IID = isColMajor ? Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col_stride
+                       : Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row_stride;
       NumResults = 4;
       break;
     case NVPTX::BI__hmma_m16n16k16_ld_c_f32:
-      IID = isColMajor ? Intrinsic::nvvm_wmma_load_c_f32_col_stride
-                       : Intrinsic::nvvm_wmma_load_c_f32_row_stride;
+      IID = isColMajor ? Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col_stride
+                       : Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row_stride;
       NumResults = 8;
       break;
     default:
@@ -10566,13 +10566,13 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
     // for some reason nvcc builtins use _c_.
     switch (BuiltinID) {
     case NVPTX::BI__hmma_m16n16k16_st_c_f16:
-      IID = isColMajor ? Intrinsic::nvvm_wmma_store_d_f16_col_stride
-                       : Intrinsic::nvvm_wmma_store_d_f16_row_stride;
+      IID = isColMajor ? Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col_stride
+                       : Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row_stride;
       NumResults = 4;
       break;
     case NVPTX::BI__hmma_m16n16k16_st_c_f32:
-      IID = isColMajor ? Intrinsic::nvvm_wmma_store_d_f32_col_stride
-                       : Intrinsic::nvvm_wmma_store_d_f32_row_stride;
+      IID = isColMajor ? Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col_stride
+                       : Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row_stride;
       break;
     default:
       llvm_unreachable("Unexpected builtin ID.");
@@ -10591,8 +10591,8 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
     return Result;
   }
 
-  // BI__hmma_m16n16k16_mma_<Dtype><CType>(d, a, b, c, layout, satf)
-  //  --> Intrinsic::nvvm_wmma_mma_sync<layout A,B><DType><CType><Satf>
+  // BI__hmma_m16n16k16_mma_<Dtype><CType>(d, a, b, c, layout, satf) -->
+  // Intrinsic::nvvm_wmma_m16n16k16_mma_sync<layout A,B><DType><CType><Satf>
   case NVPTX::BI__hmma_m16n16k16_mma_f16f16:
   case NVPTX::BI__hmma_m16n16k16_mma_f32f16:
   case NVPTX::BI__hmma_m16n16k16_mma_f32f32:
@@ -10613,15 +10613,15 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
     bool Satf = SatfArg.getSExtValue();
 
     // clang-format off
-#define MMA_VARIANTS(type) {{                                   \
-      Intrinsic::nvvm_wmma_mma_sync_row_row_##type,             \
-      Intrinsic::nvvm_wmma_mma_sync_row_row_##type##_satfinite, \
-      Intrinsic::nvvm_wmma_mma_sync_row_col_##type,             \
-      Intrinsic::nvvm_wmma_mma_sync_row_col_##type##_satfinite, \
-      Intrinsic::nvvm_wmma_mma_sync_col_row_##type,             \
-      Intrinsic::nvvm_wmma_mma_sync_col_row_##type##_satfinite, \
-      Intrinsic::nvvm_wmma_mma_sync_col_col_##type,             \
-      Intrinsic::nvvm_wmma_mma_sync_col_col_##type##_satfinite  \
+#define MMA_VARIANTS(type) {{                                        \
+      Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_##type,             \
+      Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_##type##_satfinite, \
+      Intrinsic::nvvm_wmma_m16n16k16_mma_row_col_##type,             \
+      Intrinsic::nvvm_wmma_m16n16k16_mma_row_col_##type##_satfinite, \
+      Intrinsic::nvvm_wmma_m16n16k16_mma_col_row_##type,             \
+      Intrinsic::nvvm_wmma_m16n16k16_mma_col_row_##type##_satfinite, \
+      Intrinsic::nvvm_wmma_m16n16k16_mma_col_col_##type,             \
+      Intrinsic::nvvm_wmma_m16n16k16_mma_col_col_##type##_satfinite  \
     }}
     // clang-format on
 
index 09e5b6b..1e9133b 100644 (file)
@@ -22,145 +22,145 @@ typedef unsigned long long uint64_t;
 __device__ void nvvm_wmma(int *src, int *dst,
                           float *fsrc, float *fdst,
                           int ldm) {
-  // CHECK: call {{.*}} @llvm.nvvm.wmma.load.a.sync.row.m16n16k16.stride.f16
+  // CHECK: call {{.*}} @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.f16
   // expected-error@+1 {{'__hmma_m16n16k16_ld_a' needs target feature ptx60}}
   __hmma_m16n16k16_ld_a(dst, src, ldm, 0);
-  // CHECK: call {{.*}} @llvm.nvvm.wmma.load.a.sync.col.m16n16k16.stride.f16
+  // CHECK: call {{.*}} @llvm.nvvm.wmma.m16n16k16.load.a.col.stride.f16
   // expected-error@+1 {{'__hmma_m16n16k16_ld_a' needs target feature ptx60}}
   __hmma_m16n16k16_ld_a(dst, src+1, ldm, 1);
 
-  // CHECK: call {{.*}} @llvm.nvvm.wmma.load.b.sync.row.m16n16k16.stride.f16
+  // CHECK: call {{.*}} @llvm.nvvm.wmma.m16n16k16.load.b.row.stride.f16
   // expected-error@+1 {{'__hmma_m16n16k16_ld_b' needs target feature ptx60}}
   __hmma_m16n16k16_ld_b(dst, src, ldm, 0);
-  // CHECK: call {{.*}} @llvm.nvvm.wmma.load.b.sync.col.m16n16k16.stride.f16
+  // CHECK: call {{.*}} @llvm.nvvm.wmma.m16n16k16.load.b.col.stride.f16
   // expected-error@+1 {{'__hmma_m16n16k16_ld_b' needs target feature ptx60}}
   __hmma_m16n16k16_ld_b(dst, src+2, ldm, 1);
 
-  // CHECK: call {{.*}} @llvm.nvvm.wmma.load.c.sync.row.m16n16k16.stride.f16
+  // CHECK: call {{.*}} @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.f16
   // expected-error@+1 {{'__hmma_m16n16k16_ld_c_f16' needs target feature ptx60}}
   __hmma_m16n16k16_ld_c_f16(dst, src, ldm, 0);
-  // CHECK: call {{.*}} @llvm.nvvm.wmma.load.c.sync.col.m16n16k16.stride.f16
+  // CHECK: call {{.*}} @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.f16
   // expected-error@+1 {{'__hmma_m16n16k16_ld_c_f16' needs target feature ptx60}}
   __hmma_m16n16k16_ld_c_f16(dst, src, ldm, 1);
 
-  // CHECK: call {{.*}} @llvm.nvvm.wmma.load.c.sync.row.m16n16k16.stride.f32
+  // CHECK: call {{.*}} @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.f32
   // expected-error@+1 {{'__hmma_m16n16k16_ld_c_f32' needs target feature ptx60}}
   __hmma_m16n16k16_ld_c_f32(fdst, fsrc, ldm, 0);
-  // CHECK: call {{.*}} @llvm.nvvm.wmma.load.c.sync.col.m16n16k16.stride.f32
+  // CHECK: call {{.*}} @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.f32
   // expected-error@+1 {{'__hmma_m16n16k16_ld_c_f32' needs target feature ptx60}}
   __hmma_m16n16k16_ld_c_f32(fdst, fsrc, ldm, 1);
 
-  // CHECK: call {{.*}} @llvm.nvvm.wmma.store.d.sync.row.m16n16k16.stride.f16
+  // CHECK: call {{.*}} @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f16
   // expected-error@+1 {{'__hmma_m16n16k16_st_c_f16' needs target feature ptx60}}
   __hmma_m16n16k16_st_c_f16(dst, src, ldm, 0);
-  // CHECK: call {{.*}} @llvm.nvvm.wmma.store.d.sync.col.m16n16k16.stride.f16
+  // CHECK: call {{.*}} @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f16
   // expected-error@+1 {{'__hmma_m16n16k16_st_c_f16' needs target feature ptx60}}
   __hmma_m16n16k16_st_c_f16(dst, src, ldm, 1);
 
-  // CHECK: call {{.*}} @llvm.nvvm.wmma.store.d.sync.row.m16n16k16.stride.f32
+  // CHECK: call {{.*}} @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32
   // expected-error@+1 {{'__hmma_m16n16k16_st_c_f32' needs target feature ptx60}}
   __hmma_m16n16k16_st_c_f32(fdst, fsrc, ldm, 0);
-  // CHECK: call {{.*}} @llvm.nvvm.wmma.store.d.sync.col.m16n16k16.stride.f32
+  // CHECK: call {{.*}} @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32
   // expected-error@+1 {{'__hmma_m16n16k16_st_c_f32' needs target feature ptx60}}
   __hmma_m16n16k16_st_c_f32(fdst, fsrc, ldm, 1);
 
-  // CHECK: call {{.*}} @llvm.nvvm.wmma.mma.sync.row.row.m16n16k16.f16.f16
+  // CHECK: call {{.*}} @llvm.nvvm.wmma.m16n16k16.mma.row.row.f16.f16
   // expected-error@+1 {{'__hmma_m16n16k16_mma_f16f16' needs target feature ptx60}}
   __hmma_m16n16k16_mma_f16f16(dst, src, src, src, 0, 0);
-  // CHECK: call {{.*}} @llvm.nvvm.wmma.mma.sync.row.row.m16n16k16.f16.f16.satfinite
+  // CHECK: call {{.*}} @llvm.nvvm.wmma.m16n16k16.mma.row.row.f16.f16.satfinite
   // expected-error@+1 {{'__hmma_m16n16k16_mma_f16f16' needs target feature ptx60}}
   __hmma_m16n16k16_mma_f16f16(dst, src, src, src, 0, 1);
-  // CHECK: call {{.*}} @llvm.nvvm.wmma.mma.sync.row.col.m16n16k16.f16.f16
+  // CHECK: call {{.*}} @llvm.nvvm.wmma.m16n16k16.mma.row.col.f16.f16
   // expected-error@+1 {{'__hmma_m16n16k16_mma_f16f16' needs target feature ptx60}}
   __hmma_m16n16k16_mma_f16f16(dst, src, src, src, 1, 0);
-  // CHECK: call {{.*}} @llvm.nvvm.wmma.mma.sync.row.col.m16n16k16.f16.f16.satfinite
+  // CHECK: call {{.*}} @llvm.nvvm.wmma.m16n16k16.mma.row.col.f16.f16.satfinite
   // expected-error@+1 {{'__hmma_m16n16k16_mma_f16f16' needs target feature ptx60}}
   __hmma_m16n16k16_mma_f16f16(dst, src, src, src, 1, 1);
-  // CHECK: call {{.*}} @llvm.nvvm.wmma.mma.sync.col.row.m16n16k16.f16.f16
+  // CHECK: call {{.*}} @llvm.nvvm.wmma.m16n16k16.mma.col.row.f16.f16
   // expected-error@+1 {{'__hmma_m16n16k16_mma_f16f16' needs target feature ptx60}}
   __hmma_m16n16k16_mma_f16f16(dst, src, src, src, 2, 0);
-  // CHECK: call {{.*}} @llvm.nvvm.wmma.mma.sync.col.row.m16n16k16.f16.f16.satfinite
+  // CHECK: call {{.*}} @llvm.nvvm.wmma.m16n16k16.mma.col.row.f16.f16.satfinite
   // expected-error@+1 {{'__hmma_m16n16k16_mma_f16f16' needs target feature ptx60}}
   __hmma_m16n16k16_mma_f16f16(dst, src, src, src, 2, 1);
-  // CHECK: call {{.*}} @llvm.nvvm.wmma.mma.sync.col.col.m16n16k16.f16.f16
+  // CHECK: call {{.*}} @llvm.nvvm.wmma.m16n16k16.mma.col.col.f16.f16
   // expected-error@+1 {{'__hmma_m16n16k16_mma_f16f16' needs target feature ptx60}}
   __hmma_m16n16k16_mma_f16f16(dst, src, src, src, 3, 0);
-  // CHECK: call {{.*}} @llvm.nvvm.wmma.mma.sync.col.col.m16n16k16.f16.f16.satfinite
+  // CHECK: call {{.*}} @llvm.nvvm.wmma.m16n16k16.mma.col.col.f16.f16.satfinite
   // expected-error@+1 {{'__hmma_m16n16k16_mma_f16f16' needs target feature ptx60}}
   __hmma_m16n16k16_mma_f16f16(dst, src, src, src, 3, 1);
 
-  // CHECK: call {{.*}} @llvm.nvvm.wmma.mma.sync.row.row.m16n16k16.f16.f32
+  // CHECK: call {{.*}} @llvm.nvvm.wmma.m16n16k16.mma.row.row.f16.f32
   // expected-error@+1 {{'__hmma_m16n16k16_mma_f16f32' needs target feature ptx60}}
   __hmma_m16n16k16_mma_f16f32(dst, src, src, fsrc, 0, 0);
-  // CHECK: call {{.*}} @llvm.nvvm.wmma.mma.sync.row.row.m16n16k16.f16.f32.satfinite
+  // CHECK: call {{.*}} @llvm.nvvm.wmma.m16n16k16.mma.row.row.f16.f32.satfinite
   // expected-error@+1 {{'__hmma_m16n16k16_mma_f16f32' needs target feature ptx60}}
   __hmma_m16n16k16_mma_f16f32(dst, src, src, fsrc, 0, 1);
-  // CHECK: call {{.*}} @llvm.nvvm.wmma.mma.sync.row.col.m16n16k16.f16.f32
+  // CHECK: call {{.*}} @llvm.nvvm.wmma.m16n16k16.mma.row.col.f16.f32
   // expected-error@+1 {{'__hmma_m16n16k16_mma_f16f32' needs target feature ptx60}}
   __hmma_m16n16k16_mma_f16f32(dst, src, src, fsrc, 1, 0);
-  // CHECK: call {{.*}} @llvm.nvvm.wmma.mma.sync.row.col.m16n16k16.f16.f32.satfinite
+  // CHECK: call {{.*}} @llvm.nvvm.wmma.m16n16k16.mma.row.col.f16.f32.satfinite
   // expected-error@+1 {{'__hmma_m16n16k16_mma_f16f32' needs target feature ptx60}}
   __hmma_m16n16k16_mma_f16f32(dst, src, src, fsrc, 1, 1);
-  // CHECK: call {{.*}} @llvm.nvvm.wmma.mma.sync.col.row.m16n16k16.f16.f32
+  // CHECK: call {{.*}} @llvm.nvvm.wmma.m16n16k16.mma.col.row.f16.f32
   // expected-error@+1 {{'__hmma_m16n16k16_mma_f16f32' needs target feature ptx60}}
   __hmma_m16n16k16_mma_f16f32(dst, src, src, fsrc, 2, 0);
-  // CHECK: call {{.*}} @llvm.nvvm.wmma.mma.sync.col.row.m16n16k16.f16.f32.satfinite
+  // CHECK: call {{.*}} @llvm.nvvm.wmma.m16n16k16.mma.col.row.f16.f32.satfinite
   // expected-error@+1 {{'__hmma_m16n16k16_mma_f16f32' needs target feature ptx60}}
   __hmma_m16n16k16_mma_f16f32(dst, src, src, fsrc, 2, 1);
-  // CHECK: call {{.*}} @llvm.nvvm.wmma.mma.sync.col.col.m16n16k16.f16.f32
+  // CHECK: call {{.*}} @llvm.nvvm.wmma.m16n16k16.mma.col.col.f16.f32
   // expected-error@+1 {{'__hmma_m16n16k16_mma_f16f32' needs target feature ptx60}}
   __hmma_m16n16k16_mma_f16f32(dst, src, src, fsrc, 3, 0);
-  // CHECK: call {{.*}} @llvm.nvvm.wmma.mma.sync.col.col.m16n16k16.f16.f32.satfinite
+  // CHECK: call {{.*}} @llvm.nvvm.wmma.m16n16k16.mma.col.col.f16.f32.satfinite
   // expected-error@+1 {{'__hmma_m16n16k16_mma_f16f32' needs target feature ptx60}}
   __hmma_m16n16k16_mma_f16f32(dst, src, src, fsrc, 3, 1);
 
-  // CHECK: call {{.*}} @llvm.nvvm.wmma.mma.sync.row.row.m16n16k16.f32.f16
+  // CHECK: call {{.*}} @llvm.nvvm.wmma.m16n16k16.mma.row.row.f32.f16
   // expected-error@+1 {{'__hmma_m16n16k16_mma_f32f16' needs target feature ptx60}}
   __hmma_m16n16k16_mma_f32f16(fdst, src, src, src, 0, 0);
-  // CHECK: call {{.*}} @llvm.nvvm.wmma.mma.sync.row.row.m16n16k16.f32.f16.satfinite
+  // CHECK: call {{.*}} @llvm.nvvm.wmma.m16n16k16.mma.row.row.f32.f16.satfinite
   // expected-error@+1 {{'__hmma_m16n16k16_mma_f32f16' needs target feature ptx60}}
   __hmma_m16n16k16_mma_f32f16(fdst, src, src, src, 0, 1);
-  // CHECK: call {{.*}} @llvm.nvvm.wmma.mma.sync.row.col.m16n16k16.f32.f16
+  // CHECK: call {{.*}} @llvm.nvvm.wmma.m16n16k16.mma.row.col.f32.f16
   // expected-error@+1 {{'__hmma_m16n16k16_mma_f32f16' needs target feature ptx60}}
   __hmma_m16n16k16_mma_f32f16(fdst, src, src, src, 1, 0);
-  // CHECK: call {{.*}} @llvm.nvvm.wmma.mma.sync.row.col.m16n16k16.f32.f16.satfinite
+  // CHECK: call {{.*}} @llvm.nvvm.wmma.m16n16k16.mma.row.col.f32.f16.satfinite
   // expected-error@+1 {{'__hmma_m16n16k16_mma_f32f16' needs target feature ptx60}}
   __hmma_m16n16k16_mma_f32f16(fdst, src, src, src, 1, 1);
-  // CHECK: call {{.*}} @llvm.nvvm.wmma.mma.sync.col.row.m16n16k16.f32.f16
+  // CHECK: call {{.*}} @llvm.nvvm.wmma.m16n16k16.mma.col.row.f32.f16
   // expected-error@+1 {{'__hmma_m16n16k16_mma_f32f16' needs target feature ptx60}}
   __hmma_m16n16k16_mma_f32f16(fdst, src, src, src, 2, 0);
-  // CHECK: call {{.*}} @llvm.nvvm.wmma.mma.sync.col.row.m16n16k16.f32.f16.satfinite
+  // CHECK: call {{.*}} @llvm.nvvm.wmma.m16n16k16.mma.col.row.f32.f16.satfinite
   // expected-error@+1 {{'__hmma_m16n16k16_mma_f32f16' needs target feature ptx60}}
   __hmma_m16n16k16_mma_f32f16(fdst, src, src, src, 2, 1);
-  // CHECK: call {{.*}} @llvm.nvvm.wmma.mma.sync.col.col.m16n16k16.f32.f16
+  // CHECK: call {{.*}} @llvm.nvvm.wmma.m16n16k16.mma.col.col.f32.f16
   // expected-error@+1 {{'__hmma_m16n16k16_mma_f32f16' needs target feature ptx60}}
   __hmma_m16n16k16_mma_f32f16(fdst, src, src, src, 3, 0);
-  // CHECK: call {{.*}} @llvm.nvvm.wmma.mma.sync.col.col.m16n16k16.f32.f16.satfinite
+  // CHECK: call {{.*}} @llvm.nvvm.wmma.m16n16k16.mma.col.col.f32.f16.satfinite
   // expected-error@+1 {{'__hmma_m16n16k16_mma_f32f16' needs target feature ptx60}}
   __hmma_m16n16k16_mma_f32f16(fdst, src, src, src, 3, 1);
 
-  // CHECK: call {{.*}} @llvm.nvvm.wmma.mma.sync.row.row.m16n16k16.f32.f32
+  // CHECK: call {{.*}} @llvm.nvvm.wmma.m16n16k16.mma.row.row.f32.f32
   // expected-error@+1 {{'__hmma_m16n16k16_mma_f32f32' needs target feature ptx60}}
   __hmma_m16n16k16_mma_f32f32(fdst, src, src, fsrc, 0, 0);
-  // CHECK: call {{.*}} @llvm.nvvm.wmma.mma.sync.row.row.m16n16k16.f32.f32.satfinite
+  // CHECK: call {{.*}} @llvm.nvvm.wmma.m16n16k16.mma.row.row.f32.f32.satfinite
   // expected-error@+1 {{'__hmma_m16n16k16_mma_f32f32' needs target feature ptx60}}
   __hmma_m16n16k16_mma_f32f32(fdst, src, src, fsrc, 0, 1);
-  // CHECK: call {{.*}} @llvm.nvvm.wmma.mma.sync.row.col.m16n16k16.f32.f32
+  // CHECK: call {{.*}} @llvm.nvvm.wmma.m16n16k16.mma.row.col.f32.f32
   // expected-error@+1 {{'__hmma_m16n16k16_mma_f32f32' needs target feature ptx60}}
   __hmma_m16n16k16_mma_f32f32(fdst, src, src, fsrc, 1, 0);
-  // CHECK: call {{.*}} @llvm.nvvm.wmma.mma.sync.row.col.m16n16k16.f32.f32.satfinite
+  // CHECK: call {{.*}} @llvm.nvvm.wmma.m16n16k16.mma.row.col.f32.f32.satfinite
   // expected-error@+1 {{'__hmma_m16n16k16_mma_f32f32' needs target feature ptx60}}
   __hmma_m16n16k16_mma_f32f32(fdst, src, src, fsrc, 1, 1);
-  // CHECK: call {{.*}} @llvm.nvvm.wmma.mma.sync.col.row.m16n16k16.f32.f32
+  // CHECK: call {{.*}} @llvm.nvvm.wmma.m16n16k16.mma.col.row.f32.f32
   // expected-error@+1 {{'__hmma_m16n16k16_mma_f32f32' needs target feature ptx60}}
   __hmma_m16n16k16_mma_f32f32(fdst, src, src, fsrc, 2, 0);
-  // CHECK: call {{.*}} @llvm.nvvm.wmma.mma.sync.col.row.m16n16k16.f32.f32.satfinite
+  // CHECK: call {{.*}} @llvm.nvvm.wmma.m16n16k16.mma.col.row.f32.f32.satfinite
   // expected-error@+1 {{'__hmma_m16n16k16_mma_f32f32' needs target feature ptx60}}
   __hmma_m16n16k16_mma_f32f32(fdst, src, src, fsrc, 2, 1);
-  // CHECK: call {{.*}} @llvm.nvvm.wmma.mma.sync.col.col.m16n16k16.f32.f32
+  // CHECK: call {{.*}} @llvm.nvvm.wmma.m16n16k16.mma.col.col.f32.f32
   // expected-error@+1 {{'__hmma_m16n16k16_mma_f32f32' needs target feature ptx60}}
   __hmma_m16n16k16_mma_f32f32(fdst, src, src, fsrc, 3, 0);
-  // CHECK: call {{.*}} @llvm.nvvm.wmma.mma.sync.col.col.m16n16k16.f32.f32.satfinite
+  // CHECK: call {{.*}} @llvm.nvvm.wmma.m16n16k16.mma.col.col.f32.f32.satfinite
   // expected-error@+1 {{'__hmma_m16n16k16_mma_f32f32' needs target feature ptx60}}
   __hmma_m16n16k16_mma_f32f32(fdst, src, src, fsrc, 3, 1);
 }
index e6734ed..609aebd 100644 (file)
@@ -3884,39 +3884,53 @@ def int_nvvm_match_all_sync_i64p :
 //
 
 // WMMA.LOAD
-class NVVM_WMMA_LD_ALSTS<string Abc, string Layout, string Type,
-                         LLVMType regty, int WithStride>
+class NVVM_WMMA_LD_GALSTS<string Geometry, string Abc, string Layout,
+                          string Type, LLVMType regty, int WithStride>
   : Intrinsic<!if(!eq(Abc#Type,"cf16"),
                   [regty, regty, regty, regty],
                   [regty, regty, regty, regty,
                    regty, regty, regty, regty]),
               !if(WithStride, [llvm_anyptr_ty, llvm_i32_ty], [llvm_anyptr_ty]),
               [IntrReadMem, IntrArgMemOnly, ReadOnly<0>, NoCapture<0>],
-              "llvm.nvvm.wmma.load."#Abc#".sync."#Layout#".m16n16k16"
-                #!if(WithStride,".stride","")
-                #"."#Type>;
-
-multiclass NVVM_WMMA_LD_ALT<string Abc, string Layout, string Type,
-                            LLVMType regty> {
-  def _stride: NVVM_WMMA_LD_ALSTS<Abc, Layout, Type, regty, 1>;
-  def NAME   : NVVM_WMMA_LD_ALSTS<Abc, Layout, Type, regty, 0>;
+              "llvm.nvvm.wmma."
+                # Geometry
+                # ".load"
+                # "." # Abc
+                # "." # Layout
+                # !if(WithStride, ".stride", "")
+                # "." # Type>;
+
+multiclass NVVM_WMMA_LD_GALT<string Geometry, string Abc, string Layout,
+                             string Type, LLVMType regty> {
+  def _stride: NVVM_WMMA_LD_GALSTS<Geometry, Abc, Layout, Type, regty, 1>;
+  def NAME   : NVVM_WMMA_LD_GALSTS<Geometry, Abc, Layout, Type, regty, 0>;
 }
 
-multiclass NVVM_WMMA_LD_AT<string Abc, string Type, LLVMType regty> {
-  defm _row: NVVM_WMMA_LD_ALT<Abc, "row", Type, regty>;
-  defm _col: NVVM_WMMA_LD_ALT<Abc, "col", Type, regty>;
+multiclass NVVM_WMMA_LD_GAT<string Geometry, string Abc,
+                           string Type, LLVMType regty> {
+  defm _row: NVVM_WMMA_LD_GALT<Geometry, Abc, "row", Type, regty>;
+  defm _col: NVVM_WMMA_LD_GALT<Geometry, Abc, "col", Type, regty>;
 }
 
-defm int_nvvm_wmma_load_a_f16: NVVM_WMMA_LD_AT<"a", "f16", llvm_v2f16_ty>;
-defm int_nvvm_wmma_load_b_f16: NVVM_WMMA_LD_AT<"b", "f16", llvm_v2f16_ty>;
-defm int_nvvm_wmma_load_c_f16: NVVM_WMMA_LD_AT<"c", "f16", llvm_v2f16_ty>;
-defm int_nvvm_wmma_load_c_f32: NVVM_WMMA_LD_AT<"c", "f32", llvm_float_ty>;
+multiclass NVVM_WMMA_LD_G<string Geometry> {
+  defm _a_f16: NVVM_WMMA_LD_GAT<Geometry, "a", "f16", llvm_v2f16_ty>;
+  defm _b_f16: NVVM_WMMA_LD_GAT<Geometry, "b", "f16", llvm_v2f16_ty>;
+  defm _c_f16: NVVM_WMMA_LD_GAT<Geometry, "c", "f16", llvm_v2f16_ty>;
+  defm _c_f32: NVVM_WMMA_LD_GAT<Geometry, "c", "f32", llvm_float_ty>;
+}
+
+multiclass NVVM_WMMA_LD {
+  defm _m16n16k16_load: NVVM_WMMA_LD_G<"m16n16k16">;
+}
+
+defm int_nvvm_wmma: NVVM_WMMA_LD;
 
 // WMMA.STORE.D
-class NVVM_WMMA_STD_LSTS<string Layout, string Type, LLVMType regty, int WithStride,
-                         // This is only used to create a typed empty array we
-                         // need to pass to !if below.
-                         list<LLVMType>Empty=[]>
+class NVVM_WMMA_STD_GLSTS<string Geometry, string Layout,
+                          string Type, LLVMType regty, int WithStride,
+                          // This is only used to create a typed empty array we
+                          // need to pass to !if below.
+                          list<LLVMType>Empty=[]>
   : Intrinsic<[],
               !listconcat(
                 [llvm_anyptr_ty],
@@ -3926,29 +3940,40 @@ class NVVM_WMMA_STD_LSTS<string Layout, string Type, LLVMType regty, int WithStr
                      regty, regty, regty, regty]),
                 !if(WithStride, [llvm_i32_ty], Empty)),
               [IntrWriteMem, IntrArgMemOnly, WriteOnly<0>, NoCapture<0>],
-              "llvm.nvvm.wmma.store.d.sync."#Layout
-                   #".m16n16k16"
-                   #!if(WithStride,".stride","")
-                   #"."#Type>;
-
-multiclass NVVM_WMMA_STD_LT<string Layout, string Type, LLVMType regty> {
-  def _stride: NVVM_WMMA_STD_LSTS<Layout, Type, regty, 1>;
-  def NAME:    NVVM_WMMA_STD_LSTS<Layout, Type, regty, 0>;
+              "llvm.nvvm.wmma."
+                   # Geometry
+                   # ".store.d"
+                   # "." # Layout
+                   # !if(WithStride, ".stride", "")
+                   # "." # Type>;
+
+multiclass NVVM_WMMA_STD_GLT<string Geometry, string Layout, 
+                             string Type, LLVMType regty> {
+  def _stride: NVVM_WMMA_STD_GLSTS<Geometry, Layout, Type, regty, 1>;
+  def NAME:    NVVM_WMMA_STD_GLSTS<Geometry, Layout, Type, regty, 0>;
+}
+
+multiclass NVVM_WMMA_STD_GT<string Geometry, string Type, LLVMType regty> {
+  defm _row: NVVM_WMMA_STD_GLT<Geometry, "row", Type, regty>;
+  defm _col: NVVM_WMMA_STD_GLT<Geometry, "col", Type, regty>;
+}
+multiclass NVVM_WMMA_STD_G<string Geometry> {
+  defm _d_f16: NVVM_WMMA_STD_GT<Geometry, "f16", llvm_v2f16_ty>;
+  defm _d_f32: NVVM_WMMA_STD_GT<Geometry, "f32", llvm_float_ty>;
 }
 
-multiclass NVVM_WMMA_STD_T<string Type, LLVMType regty> {
-  defm _row: NVVM_WMMA_STD_LT<"row", Type, regty>;
-  defm _col: NVVM_WMMA_STD_LT<"col", Type, regty>;
+multiclass NVVM_WMMA_STD {
+  defm _m16n16k16_store: NVVM_WMMA_STD_G<"m16n16k16">;
 }
 
-defm int_nvvm_wmma_store_d_f16: NVVM_WMMA_STD_T<"f16", llvm_v2f16_ty>;
-defm int_nvvm_wmma_store_d_f32: NVVM_WMMA_STD_T<"f32", llvm_float_ty>;
+defm int_nvvm_wmma: NVVM_WMMA_STD;
 
 // WMMA.MMA
-class NVVM_WMMA_MMA_ABDCS<string ALayout, string BLayout,
-                          string DType, LLVMType d_regty,
-                          string CType, LLVMType c_regty,
-                          string Satfinite = "">
+class NVVM_WMMA_MMA_GABDCS<string Geometry,
+                           string ALayout, string BLayout,
+                           string DType, LLVMType d_regty,
+                           string CType, LLVMType c_regty,
+                           string Satfinite = "">
   : Intrinsic<!if(!eq(DType,"f16"),
                       [d_regty, d_regty, d_regty, d_regty],
                       [d_regty, d_regty, d_regty, d_regty,
@@ -3965,39 +3990,52 @@ class NVVM_WMMA_MMA_ABDCS<string ALayout, string BLayout,
                       [c_regty, c_regty, c_regty, c_regty,
                        c_regty, c_regty, c_regty, c_regty])),
               [IntrNoMem],
-              "llvm.nvvm.wmma.mma.sync."#ALayout#"."#BLayout
-                 #".m16n16k16."#DType#"."#CType#Satfinite>;
-
-multiclass NVVM_WMMA_MMA_ABDC<string ALayout, string BLayout,
-                              string DType, LLVMType d_regty,
-                              string CType, LLVMType c_regty> {
-  def NAME : NVVM_WMMA_MMA_ABDCS<ALayout, BLayout,
-                                 DType, d_regty,
-                                 CType, c_regty>;
-  def _satfinite: NVVM_WMMA_MMA_ABDCS<ALayout, BLayout,
-                                      DType, d_regty,
-                                      CType, c_regty,".satfinite">;
+              "llvm.nvvm.wmma."
+                # Geometry
+                # ".mma"
+                # "." # ALayout
+                # "." # BLayout
+                # "." # DType
+                # "." # CType
+                # Satfinite> {
 }
 
-multiclass NVVM_WMMA_MMA_ABD<string ALayout, string BLayout,
+multiclass NVVM_WMMA_MMA_GABDC<string Geometry, string ALayout, string BLayout,
+                               string DType, LLVMType d_regty,
+                               string CType, LLVMType c_regty> {
+  def NAME : NVVM_WMMA_MMA_GABDCS<Geometry, ALayout, BLayout,
+                                  DType, d_regty, CType, c_regty>;
+  def _satfinite: NVVM_WMMA_MMA_GABDCS<Geometry, ALayout, BLayout,
+                                       DType, d_regty, CType, c_regty,".satfinite">;
+}
+
+multiclass NVVM_WMMA_MMA_GABD<string Geometry, string ALayout, string BLayout,
                               string DType, LLVMType d_regty> {
-  defm _f16: NVVM_WMMA_MMA_ABDC<ALayout, BLayout, DType, d_regty,
+  defm _f16: NVVM_WMMA_MMA_GABDC<Geometry, ALayout, BLayout, DType, d_regty,
                                 "f16", llvm_v2f16_ty>;
-  defm _f32: NVVM_WMMA_MMA_ABDC<ALayout, BLayout, DType, d_regty,
+  defm _f32: NVVM_WMMA_MMA_GABDC<Geometry, ALayout, BLayout, DType, d_regty,
                                 "f32", llvm_float_ty>;
 }
 
-multiclass NVVM_WMMA_MMA_AB<string ALayout, string BLayout> {
-  defm _f16: NVVM_WMMA_MMA_ABD<ALayout, BLayout, "f16", llvm_v2f16_ty>;
-  defm _f32: NVVM_WMMA_MMA_ABD<ALayout, BLayout, "f32", llvm_float_ty>;
+multiclass NVVM_WMMA_MMA_GAB<string Geometry, string ALayout, string BLayout> {
+  defm _f16: NVVM_WMMA_MMA_GABD<Geometry, ALayout, BLayout, "f16", llvm_v2f16_ty>;
+  defm _f32: NVVM_WMMA_MMA_GABD<Geometry, ALayout, BLayout, "f32", llvm_float_ty>;
+}
+
+multiclass NVVM_WMMA_MMA_GA<string Geometry, string ALayout> {
+  defm _col: NVVM_WMMA_MMA_GAB<Geometry, ALayout, "col">;
+  defm _row: NVVM_WMMA_MMA_GAB<Geometry, ALayout, "row">;
+}
+
+multiclass NVVM_WMMA_MMA_G<string Geometry> {
+  defm _col: NVVM_WMMA_MMA_GA<Geometry, "col">;
+  defm _row: NVVM_WMMA_MMA_GA<Geometry, "row">;
 }
 
-multiclass NVVM_WMMA_MMA_A<string ALayout> {
-  defm _col: NVVM_WMMA_MMA_AB<ALayout, "col">;
-  defm _row: NVVM_WMMA_MMA_AB<ALayout, "row">;
+multiclass NVVM_WMMA_MMA {
+  defm _m16n16k16_mma : NVVM_WMMA_MMA_G<"m16n16k16">;
 }
 
-defm int_nvvm_wmma_mma_sync_col: NVVM_WMMA_MMA_A<"col">;
-defm int_nvvm_wmma_mma_sync_row: NVVM_WMMA_MMA_A<"row">;
+defm int_nvvm_wmma : NVVM_WMMA_MMA;
 
 } // let TargetPrefix = "nvvm"
index 2fcfe95..3a5c607 100644 (file)
@@ -3323,14 +3323,14 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
     // Our result depends on both our and other thread's arguments.
     Info.flags = MachineMemOperand::MOLoad | MachineMemOperand::MOStore;
     return true;
-  case Intrinsic::nvvm_wmma_load_a_f16_col:
-  case Intrinsic::nvvm_wmma_load_a_f16_row:
-  case Intrinsic::nvvm_wmma_load_a_f16_col_stride:
-  case Intrinsic::nvvm_wmma_load_a_f16_row_stride:
-  case Intrinsic::nvvm_wmma_load_b_f16_col:
-  case Intrinsic::nvvm_wmma_load_b_f16_row:
-  case Intrinsic::nvvm_wmma_load_b_f16_col_stride:
-  case Intrinsic::nvvm_wmma_load_b_f16_row_stride: {
+  case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_col:
+  case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row:
+  case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_col_stride:
+  case Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row_stride:
+  case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col:
+  case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row:
+  case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col_stride:
+  case Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row_stride: {
     Info.opc = ISD::INTRINSIC_W_CHAIN;
     Info.memVT = MVT::v8f16;
     Info.ptrVal = I.getArgOperand(0);
@@ -3340,10 +3340,10 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
     return true;
   }
 
-  case Intrinsic::nvvm_wmma_load_c_f16_col:
-  case Intrinsic::nvvm_wmma_load_c_f16_row:
-  case Intrinsic::nvvm_wmma_load_c_f16_col_stride:
-  case Intrinsic::nvvm_wmma_load_c_f16_row_stride: {
+  case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col:
+  case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row:
+  case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col_stride:
+  case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row_stride: {
     Info.opc = ISD::INTRINSIC_W_CHAIN;
     Info.memVT = MVT::v4f16;
     Info.ptrVal = I.getArgOperand(0);
@@ -3353,10 +3353,10 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
     return true;
   }
 
-  case Intrinsic::nvvm_wmma_load_c_f32_col:
-  case Intrinsic::nvvm_wmma_load_c_f32_row:
-  case Intrinsic::nvvm_wmma_load_c_f32_col_stride:
-  case Intrinsic::nvvm_wmma_load_c_f32_row_stride: {
+  case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col:
+  case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row:
+  case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col_stride:
+  case Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row_stride: {
     Info.opc = ISD::INTRINSIC_W_CHAIN;
     Info.memVT = MVT::v8f32;
     Info.ptrVal = I.getArgOperand(0);
@@ -3366,10 +3366,10 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
     return true;
   }
 
-  case Intrinsic::nvvm_wmma_store_d_f16_col:
-  case Intrinsic::nvvm_wmma_store_d_f16_row:
-  case Intrinsic::nvvm_wmma_store_d_f16_col_stride:
-  case Intrinsic::nvvm_wmma_store_d_f16_row_stride: {
+  case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col:
+  case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row:
+  case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col_stride:
+  case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row_stride: {
     Info.opc = ISD::INTRINSIC_VOID;
     Info.memVT = MVT::v4f16;
     Info.ptrVal = I.getArgOperand(0);
@@ -3379,10 +3379,10 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
     return true;
   }
 
-  case Intrinsic::nvvm_wmma_store_d_f32_col:
-  case Intrinsic::nvvm_wmma_store_d_f32_row:
-  case Intrinsic::nvvm_wmma_store_d_f32_col_stride:
-  case Intrinsic::nvvm_wmma_store_d_f32_row_stride: {
+  case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col:
+  case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row:
+  case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col_stride:
+  case Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row_stride: {
     Info.opc = ISD::INTRINSIC_VOID;
     Info.memVT = MVT::v8f32;
     Info.ptrVal = I.getArgOperand(0);
index ba3f2e3..b46247f 100644 (file)
@@ -7375,16 +7375,15 @@ def INT_PTX_SREG_WARPSIZE :
 
 class EmptyNVPTXInst : NVPTXInst<(outs), (ins), "?", []>;
 
-class WMMA_LOAD_ALSTOS<string Abc, string Layout, string Space,
-                           string Type, NVPTXRegClass regclass,
-                           DAGOperand SrcOp, bit WithStride>
+class WMMA_LOAD_GALSTOS<string Geometry, string Abc, string Layout,
+                        string Space, string Type, NVPTXRegClass regclass,
+                        DAGOperand SrcOp, bit WithStride>
   : EmptyNVPTXInst, Requires<[hasPTX60, hasSM70]> {
   // Pattern (created by WMMA_LOAD_INTR_HELPER below) that matches the intrinsic
   // for this function.
-  PatFrag IntrMatcher = !cast<PatFrag>("INT_WMMA_LOAD_"
-                                       # !subst("a", "A",
-                                         !subst("b", "B",
-                                         !subst("c", "C_" # Type, Abc)))
+  PatFrag IntrMatcher = !cast<PatFrag>("INT_WMMA_"
+                                       # Geometry # "_load_"
+                                       # !subst("c", "c_" # Type, Abc)
                                        # "_" # Layout
                                        # !subst(".", "_", Space)
                                        # !if(WithStride,"_stride", "")
@@ -7419,23 +7418,28 @@ class WMMA_LOAD_ALSTOS<string Abc, string Layout, string Space,
   let Pattern = [!con(PatOuts, (set PatArgs))];
   let OutOperandList = Outs;
   let InOperandList = Ins;
-  let AsmString = "wmma.load."#Abc#".sync."#Layout#".m16n16k16"#Space#"." #Type# " \t"
-                 #!if(!eq(Abc#Type,"cf16"),
-                      "{{$r0, $r1, $r2, $r3}}",
-                      "{{$r0, $r1, $r2, $r3, $r4, $r5, $r6, $r7}}")
-                 #", [$src]"
-                 #!if(WithStride, ", $ldm", "")
-                 #";";
+  let AsmString = "wmma.load."
+                  # Abc
+                  # ".sync."
+                  # Layout
+                  # ".m16n16k16"
+                  # Space 
+                  # "." # Type # " \t"
+                  # !if(!eq(Abc#Type, "cf16"),
+                        "{{$r0, $r1, $r2, $r3}}",
+                        "{{$r0, $r1, $r2, $r3, $r4, $r5, $r6, $r7}}")
+                  # ", [$src]"
+                  # !if(WithStride, ", $ldm", "")
+                  # ";";
 }
 
-class WMMA_LOAD_INTR_HELPER<string Abc, string Layout, string Space,
-                           string Type, bit WithStride>
+class WMMA_LOAD_INTR_HELPER<string Geometry, string Abc, string Layout,
+                            string Space, string Type, bit WithStride>
                            : PatFrag <(ops),(ops)> {
   // Intrinsic that matches this instruction.
-  Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma_load_"
-                                    # Abc
-                                    # "_" # Type
-                                    # "_" # Layout
+  Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma"
+                                    # "_" # Geometry # "_load_"
+                                    # Abc # "_" # Type # "_" # Layout
                                     # !if(WithStride,"_stride", ""));
   code match_generic = [{
    return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GENERIC);
@@ -7453,62 +7457,81 @@ class WMMA_LOAD_INTR_HELPER<string Abc, string Layout, string Space,
                       !if(!eq(Space, ".global"), match_global, match_generic));
 }
 
-multiclass WMMA_LOAD_ALSTS<string Abc, string Layout, string Space,
-                          string Type, NVPTXRegClass regclass, bit WithStride> {
-  def _avar:  WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, imem, WithStride>;
-  def _areg: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, Int32Regs, WithStride>;
-  def _areg64: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, Int64Regs, WithStride>;
-  def _ari: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, MEMri, WithStride>;
-  def _ari64: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, MEMri64, WithStride>;
+multiclass WMMA_LOAD_GALSTS<string Geometry, string Abc, string Layout,
+                            string Space, string Type, NVPTXRegClass regclass,
+                            bit WithStride> {
+  def _avar:  WMMA_LOAD_GALSTOS<Geometry, Abc, Layout, Space, Type, regclass,
+                                imem, WithStride>;
+  def _areg: WMMA_LOAD_GALSTOS<Geometry, Abc, Layout, Space, Type, regclass,
+                                Int32Regs, WithStride>;
+  def _areg64: WMMA_LOAD_GALSTOS<Geometry, Abc, Layout, Space, Type, regclass,
+                                Int64Regs, WithStride>;
+  def _ari: WMMA_LOAD_GALSTOS<Geometry, Abc, Layout, Space, Type, regclass,
+                                MEMri, WithStride>;
+  def _ari64: WMMA_LOAD_GALSTOS<Geometry, Abc, Layout, Space, Type, regclass,
+                                MEMri64, WithStride>;
 }
 
-multiclass WMMA_LOAD_ALSTSh<string Abc, string Layout, string Space,
-                            string Type, NVPTXRegClass regclass, bit WithStride> {
+multiclass WMMA_LOAD_GALSTSh<string Geometry, string Abc, string Layout,
+                             string Space, string Type, NVPTXRegClass regclass,
+                             bit WithStride> {
   // Define a PatFrag that matches appropriate intrinsic that loads from the
   // given address space.
-  def _Intr : WMMA_LOAD_INTR_HELPER<Abc, Layout, Space, Type, WithStride>;
-  defm NAME:  WMMA_LOAD_ALSTS<Abc, Layout, Space, Type, regclass, WithStride>;
+  def _Intr:  WMMA_LOAD_INTR_HELPER<Geometry, Abc, Layout, Space, Type,
+                                    WithStride>;
+  defm NAME:  WMMA_LOAD_GALSTS<Geometry, Abc, Layout, Space, Type, regclass,
+                               WithStride>;
 }
 
-multiclass WMMA_LOAD_ALST<string Abc, string Layout, string Space,
-                           string Type, NVPTXRegClass regclass> {
-  defm _stride: WMMA_LOAD_ALSTSh<Abc, Layout, Space, Type, regclass, 1>;
-  defm NAME:    WMMA_LOAD_ALSTSh<Abc, Layout, Space, Type, regclass, 0>;
+multiclass WMMA_LOAD_GALST<string Geometry, string Abc, string Layout,
+                           string Space, string Type, NVPTXRegClass regclass> {
+  defm _stride: WMMA_LOAD_GALSTSh<Geometry, Abc, Layout, Space, Type, regclass, 1>;
+  defm NAME:    WMMA_LOAD_GALSTSh<Geometry, Abc, Layout, Space, Type, regclass, 0>;
 }
 
-multiclass WMMA_LOAD_ALT<string Abc, string Layout,
-                        string Type, NVPTXRegClass regclass> {
-  defm _global: WMMA_LOAD_ALST<Abc, Layout, ".global", Type, regclass>;
-  defm _shared: WMMA_LOAD_ALST<Abc, Layout, ".shared", Type, regclass>;
-  defm NAME:    WMMA_LOAD_ALST<Abc, Layout,        "", Type, regclass>;
+multiclass WMMA_LOAD_GALT<string Geometry, string Abc, string Layout,
+                          string Type, NVPTXRegClass regclass> {
+  defm _global: WMMA_LOAD_GALST<Geometry, Abc, Layout, ".global",
+                                Type, regclass>;
+  defm _shared: WMMA_LOAD_GALST<Geometry, Abc, Layout, ".shared",
+                                Type, regclass>;
+  defm NAME:    WMMA_LOAD_GALST<Geometry, Abc, Layout,        "",
+                                Type, regclass>;
 }
 
-multiclass WMMA_LOAD_AT<string Abc, string Type, NVPTXRegClass regclass> {
-  defm _row: WMMA_LOAD_ALT<Abc, "row", Type, regclass>;
-  defm _col: WMMA_LOAD_ALT<Abc, "col", Type, regclass>;
+multiclass WMMA_LOAD_GAT<string Geometry, string Abc,
+                         string Type, NVPTXRegClass regclass> {
+  defm _row: WMMA_LOAD_GALT<Geometry, Abc, "row", Type, regclass>;
+  defm _col: WMMA_LOAD_GALT<Geometry, Abc, "col", Type, regclass>;
 }
 
-defm INT_WMMA_LOAD_A: WMMA_LOAD_AT<"a", "f16", Float16x2Regs>;
-defm INT_WMMA_LOAD_B: WMMA_LOAD_AT<"b", "f16", Float16x2Regs>;
-defm INT_WMMA_LOAD_C_f16: WMMA_LOAD_AT<"c", "f16", Float16x2Regs>;
-defm INT_WMMA_LOAD_C_f32: WMMA_LOAD_AT<"c", "f32", Float32Regs>;
+multiclass WMMA_LOAD_G<string Geometry> {
+  defm _load_a: WMMA_LOAD_GAT<Geometry, "a", "f16", Float16x2Regs>;
+  defm _load_b: WMMA_LOAD_GAT<Geometry, "b", "f16", Float16x2Regs>;
+  defm _load_c_f16: WMMA_LOAD_GAT<Geometry, "c", "f16", Float16x2Regs>;
+  defm _load_c_f32: WMMA_LOAD_GAT<Geometry, "c", "f32", Float32Regs>;
+}
+
+defm INT_WMMA_m16n16k16: WMMA_LOAD_G<"m16n16k16">;
 
 //
 // wmma.store.d.sync.[row|col].m16n16k16[|.global|.shared].[f16|f32]
 //
-class WMMA_STORE_D_LSTSO<string Layout, string Space,
-                         string Type, NVPTXRegClass regclass,
-                         bit WithStride, DAGOperand DstOp>
+class WMMA_STORE_D_GLSTSO<string Geometry, string Layout, string Space,
+                          string Type, NVPTXRegClass regclass,
+                          bit WithStride, DAGOperand DstOp>
   : EmptyNVPTXInst, Requires<[hasPTX60, hasSM70]> {
-  PatFrag IntrMatcher = !cast<PatFrag>("INT_WMMA_STORE_D"
+  PatFrag IntrMatcher = !cast<PatFrag>("INT_WMMA"
+                                       # "_" # Geometry # "_store_d"
                                        # "_" # Type
                                        # "_" # Layout
                                        # !subst(".", "_", Space)
                                        # !if(WithStride,"_stride", "")
                                        # "_Intr");
-
-  dag InsR03 = (ins DstOp:$src, regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3);
-  dag InsR47 = (ins regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7);
+  dag InsR03 = (ins DstOp:$src, regclass:$r0, regclass:$r1,
+                                regclass:$r2, regclass:$r3);
+  dag InsR47 = (ins regclass:$r4, regclass:$r5,
+                    regclass:$r6, regclass:$r7);
   dag InsR = !if(!eq(Type,"f16"), InsR03, !con(InsR03, InsR47));
   dag StrideArg = !if(WithStride, (ins Int32Regs:$ldm), (ins));
   dag Ins = !con(InsR, StrideArg);
@@ -7525,7 +7548,7 @@ class WMMA_STORE_D_LSTSO<string Layout, string Space,
   let InOperandList = Ins;
   let AsmString = "wmma.store.d.sync."
                   # Layout
-                  # ".m16n16k16"
+                  # "." # Geometry
                   # Space
                   # "." # Type
                   # " \t[$src],"
@@ -7537,11 +7560,13 @@ class WMMA_STORE_D_LSTSO<string Layout, string Space,
 
 }
 
-class WMMA_STORE_INTR_HELPER<string Layout, string Space,
+class WMMA_STORE_INTR_HELPER<string Geometry, string Layout, string Space,
                              string Type, bit WithStride>
                             : PatFrag <(ops),(ops)> {
   // Intrinsic that matches this instruction.
-  Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma_store_d"
+  Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma_"
+                                    # Geometry
+                                    # "_store_d"
                                     # "_" # Type
                                     # "_" # Layout
                                     # !if(WithStride, "_stride", ""));
@@ -7566,57 +7591,77 @@ class WMMA_STORE_INTR_HELPER<string Layout, string Space,
                       !if(!eq(Space, ".global"), match_global, match_generic));
 }
 
-multiclass WMMA_STORE_D_LSTS<string Layout, string Space,
-                            string Type, NVPTXRegClass regclass, bit WithStride> {
-  def _avar:   WMMA_STORE_D_LSTSO<Layout, Space, Type, regclass, WithStride, imem>;
-  def _areg:   WMMA_STORE_D_LSTSO<Layout, Space, Type, regclass, WithStride, Int32Regs>;
-  def _areg64: WMMA_STORE_D_LSTSO<Layout, Space, Type, regclass, WithStride, Int64Regs>;
-  def _ari:    WMMA_STORE_D_LSTSO<Layout, Space, Type, regclass, WithStride, MEMri>;
-  def _ari64:  WMMA_STORE_D_LSTSO<Layout, Space, Type, regclass, WithStride, MEMri64>;
+multiclass WMMA_STORE_D_GLSTS<string Geometry, string Layout, string Space,
+                              string Type, NVPTXRegClass regclass,
+                              bit WithStride> {
+  def _avar:   WMMA_STORE_D_GLSTSO<Geometry, Layout, Space, Type, regclass,
+                                   WithStride, imem>;
+  def _areg:   WMMA_STORE_D_GLSTSO<Geometry, Layout, Space, Type, regclass,
+                                   WithStride, Int32Regs>;
+  def _areg64: WMMA_STORE_D_GLSTSO<Geometry, Layout, Space, Type, regclass,
+                                   WithStride, Int64Regs>;
+  def _ari:    WMMA_STORE_D_GLSTSO<Geometry, Layout, Space, Type, regclass,
+                                   WithStride, MEMri>;
+  def _ari64:  WMMA_STORE_D_GLSTSO<Geometry, Layout, Space, Type, regclass,
+                                   WithStride, MEMri64>;
 }
 
-multiclass WMMA_STORE_D_LSTSh<string Layout, string Space,
-                              string Type, NVPTXRegClass regclass, bit WithStride> {
+multiclass WMMA_STORE_D_GLSTSh<string Geometry, string Layout, string Space,
+                               string Type, NVPTXRegClass regclass,
+                               bit WithStride> {
   // Define a PatFrag that matches appropriate intrinsic that loads from the
   // given address space.
-  def _Intr:    WMMA_STORE_INTR_HELPER<Layout, Space, Type, WithStride>;
-  defm NAME:    WMMA_STORE_D_LSTS<Layout, Space, Type, regclass, WithStride>;
+  def _Intr:    WMMA_STORE_INTR_HELPER<Geometry, Layout, Space, Type,
+                                       WithStride>;
+  defm NAME:    WMMA_STORE_D_GLSTS<Geometry, Layout, Space, Type, regclass,
+                                   WithStride>;
 }
 
-multiclass WMMA_STORE_D_LST<string Layout, string Space,
+multiclass WMMA_STORE_D_GLST<string Geometry, string Layout, string Space,
                              string Type, NVPTXRegClass regclass > {
-  defm _stride: WMMA_STORE_D_LSTSh<Layout, Space, Type, regclass, 1>;
-  defm NAME:    WMMA_STORE_D_LSTSh<Layout, Space, Type, regclass, 0>;
+  defm _stride: WMMA_STORE_D_GLSTSh<Geometry, Layout, Space, Type, regclass, 1>;
+  defm NAME:    WMMA_STORE_D_GLSTSh<Geometry, Layout, Space, Type, regclass, 0>;
 }
 
-multiclass WMMA_STORE_D_LT<string Layout,
+multiclass WMMA_STORE_D_GLT<string Geometry, string Layout,
                            string Type, NVPTXRegClass regclass> {
-  defm _global: WMMA_STORE_D_LST<Layout, ".global", Type, regclass>;
-  defm _shared: WMMA_STORE_D_LST<Layout, ".shared", Type, regclass>;
-  defm NAME:    WMMA_STORE_D_LST<Layout,        "", Type, regclass>;
+  defm _global: WMMA_STORE_D_GLST<Geometry, Layout, ".global", Type, regclass>;
+  defm _shared: WMMA_STORE_D_GLST<Geometry, Layout, ".shared", Type, regclass>;
+  defm NAME:    WMMA_STORE_D_GLST<Geometry, Layout,        "", Type, regclass>;
 }
 
-multiclass WMMA_STORE_D_T<string Type, NVPTXRegClass regclass> {
-  defm _row:    WMMA_STORE_D_LT<"row", Type, regclass>;
-  defm _col:    WMMA_STORE_D_LT<"col", Type, regclass>;
+multiclass WMMA_STORE_D_GT<string Geometry, string Type,
+                           NVPTXRegClass regclass> {
+  defm _row:    WMMA_STORE_D_GLT<Geometry, "row", Type, regclass>;
+  defm _col:    WMMA_STORE_D_GLT<Geometry, "col", Type, regclass>;
 }
 
-defm INT_WMMA_STORE_D_f16: WMMA_STORE_D_T<"f16", Float16x2Regs>;
-defm INT_WMMA_STORE_D_f32: WMMA_STORE_D_T<"f32", Float32Regs>;
+multiclass WMMA_STORE_D_G<string Geometry> {
+  defm _store_d_f16: WMMA_STORE_D_GT<Geometry, "f16", Float16x2Regs>;
+  defm _store_d_f32: WMMA_STORE_D_GT<Geometry, "f32", Float32Regs>;
+}
+
+// multiclass WMMA_STORE_D {
+//   defm _m16n16k16: WMMA_STORE_D_G<"m16n16k16">;
+// }
+
+defm INT_WMMA_m16n16k16: WMMA_STORE_D_G<"m16n16k16">;
 
 // WMMA.MMA
-class WMMA_MMA_ABDCS<string ALayout, string BLayout,
+class WMMA_MMA_GABDCS<string Geometry, string ALayout, string BLayout,
                      string DType, NVPTXRegClass d_reg,
                      string CType, NVPTXRegClass c_reg,
                      NVPTXRegClass ab_reg,
                      string Satfinite = "">
   : EmptyNVPTXInst, Requires<[hasPTX60, hasSM70]> {
-  Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma_mma_sync_"
-                                    # ALayout
+  Intrinsic Intr = !cast<Intrinsic>("int_nvvm_wmma_"
+                                    # Geometry
+                                    # "_mma"
+                                    # "_" # ALayout
                                     # "_" # BLayout
                                     # "_" # DType
                                     # "_" # CType
-                                    # !subst(".","_",Satfinite));
+                                    # !subst(".", "_", Satfinite));
   dag Outs = !if(!eq(DType,"f16"),
                  (outs d_reg:$d0, d_reg:$d1, d_reg:$d2, d_reg:$d3),
                  (outs d_reg:$d0, d_reg:$d1, d_reg:$d2, d_reg:$d3,
@@ -7655,33 +7700,38 @@ class WMMA_MMA_ABDCS<string ALayout, string BLayout,
                         "{{$c0, $c1, $c2, $c3, $c4, $c5, $c6, $c7}};");
 }
 
-multiclass WMMA_MMA_ABDC<string ALayout, string BLayout,
+multiclass WMMA_MMA_GABDC<string Geometry, string ALayout, string BLayout,
                          string DType, NVPTXRegClass d_reg,
                          string CType, NVPTXRegClass c_reg> {
-  def _satfinite: WMMA_MMA_ABDCS<ALayout, BLayout,
+  def _satfinite: WMMA_MMA_GABDCS<Geometry, ALayout, BLayout,
                                  DType, d_reg, CType, c_reg,
                                  Float16x2Regs, ".satfinite">;
-  def NAME:       WMMA_MMA_ABDCS<ALayout, BLayout,
+  def NAME:       WMMA_MMA_GABDCS<Geometry, ALayout, BLayout,
                                  DType, d_reg, CType, c_reg,
                                  Float16x2Regs>;
 }
 
-multiclass WMMA_MMA_ABD<string ALayout, string BLayout,
+multiclass WMMA_MMA_GABD<string Geometry, string ALayout, string BLayout,
                         string DType, NVPTXRegClass d_reg> {
-  defm _f16: WMMA_MMA_ABDC<ALayout, BLayout, DType, d_reg, "f16", Float16x2Regs>;
-  defm _f32: WMMA_MMA_ABDC<ALayout, BLayout, DType, d_reg, "f32", Float32Regs>;
+  defm _f16: WMMA_MMA_GABDC<Geometry, ALayout, BLayout, DType, d_reg,
+                            "f16", Float16x2Regs>;
+  defm _f32: WMMA_MMA_GABDC<Geometry, ALayout, BLayout, DType, d_reg,
+                            "f32", Float32Regs>;
 }
 
-multiclass WMMA_MMA_AB<string ALayout, string BLayout> {
-  defm _f16: WMMA_MMA_ABD<ALayout, BLayout, "f16", Float16x2Regs>;
-  defm _f32: WMMA_MMA_ABD<ALayout, BLayout, "f32", Float32Regs>;
+multiclass WMMA_MMA_GAB<string Geometry, string ALayout, string BLayout> {
+  defm _f16: WMMA_MMA_GABD<Geometry, ALayout, BLayout, "f16", Float16x2Regs>;
+  defm _f32: WMMA_MMA_GABD<Geometry, ALayout, BLayout, "f32", Float32Regs>;
 }
 
-multiclass WMMA_MMA_A<string ALayout> {
-  defm _col: WMMA_MMA_AB<ALayout, "col">;
-  defm _row: WMMA_MMA_AB<ALayout, "row">;
+multiclass WMMA_MMA_GA<string Geometry, string ALayout> {
+  defm _col: WMMA_MMA_GAB<Geometry, ALayout, "col">;
+  defm _row: WMMA_MMA_GAB<Geometry, ALayout, "row">;
 }
 
-defm INT_WMMA_MMA_col: WMMA_MMA_A<"col">;
-defm INT_WMMA_MMA_row: WMMA_MMA_A<"row">;
+multiclass WMMA_MMA_G<string Geometry> {
+  defm _col: WMMA_MMA_GA<Geometry, "col">;
+  defm _row: WMMA_MMA_GA<Geometry, "row">;
+}
 
+defm INT_WMMA_MMA_m16n16k16 : WMMA_MMA_G<"m16n16k16">;
index d0fa90c..7dc2d2e 100644 (file)
@@ -38,29 +38,29 @@ check_f32_8 = "{{%s}}" % ", *".join(["%f[0-9]+"] * 8)
 
 def gen_wmma_load_tests():
   load_template = """
-declare ${ret_ty} @llvm.nvvm.wmma.load.${intrinsic_suffix}(i8 ${as}* %src ${extra_args});
+declare ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args});
 
-; CHECK-LABEL: .func {{.*}}test_wmma_load_${function_suffix}(
-define ${ret_ty} @test_wmma_load_${function_suffix}(i8 ${as}* %src ${extra_args}) {
-; CHECK wmma.load.${intrinsic_suffix}
+; CHECK-LABEL: .func {{.*}}test_${function}(
+define ${ret_ty} @test_${function}(i8 ${as}* %src ${extra_args}) {
+; CHECK ${instruction}
 ; CHECK: {${check_result}}
 ; CHECK: [%rd{{[0-9]+}}]${stride_pattern}
-  %v0 = call ${ret_ty} @llvm.nvvm.wmma.load.${intrinsic_suffix}(i8 ${as}* %src ${extra_args});
+  %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src ${extra_args});
   ret ${ret_ty} %v0;
 }
 
-; CHECK-LABEL: .func{{.*}}test_wmma_load_${function_suffix}_o(
-define ${ret_ty} @test_wmma_load_${function_suffix}_o(i8 ${as}* %src ${extra_args}) {
-; CHECK wmma.load.${intrinsic_suffix}
+; CHECK-LABEL: .func{{.*}}test_${function}_o(
+define ${ret_ty} @test_${function}_o(i8 ${as}* %src ${extra_args}) {
+; CHECK ${instruction}
 ; CHECK: {${check_result}}
 ; CHECK: [%rd{{[0-9]+}}+128]${stride_pattern}
   %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
-  %v0 = call ${ret_ty} @llvm.nvvm.wmma.load.${intrinsic_suffix}(i8 ${as}* %src1 ${extra_args});
+  %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src1 ${extra_args});
   ret ${ret_ty} %v0;
 }
 """
-  suffix_template = "${abc}.sync.${layout}.m16n16k16${stride}.${itype}.${pspace}"
-  instruction_template = "${abc}.sync.${layout}.m16n16k16${space}.${itype}"
+  intrinsic_template = "llvm.nvvm.wmma.${geom}.load.${abc}.${layout}${stride}.${itype}.${pspace}"
+  instruction_template = "wmma.load.${abc}.sync.${geom}.${layout}${space}.${itype}"
 
   for abc, layout, space, stride, itype in product(
       "abc",
@@ -76,16 +76,17 @@ define ${ret_ty} @test_wmma_load_${function_suffix}_o(i8 ${as}* %src ${extra_arg
         "stride" : stride,
         "itype" : itype,
         "pspace" : get_pspace(space),
-        "as"     : "addrspace(%d)" % get_aspace(space)
+        "as"     : "addrspace(%d)" % get_aspace(space),
+        "geom"   : "m16n16k16",
     }
 
     if itype == "f32" and abc != "c":
       continue
 
     test_params = params
-    test_params["intrinsic_suffix"] = Template(suffix_template).substitute(params)
-    test_params["function_suffix"] = test_params["intrinsic_suffix"].replace(".","_")
-    test_params["instruction_suffix"] = Template(instruction_template).substitute(params)
+    test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
+    test_params["function"] = test_params["intrinsic"].replace(".","_")
+    test_params["instruction"] = Template(instruction_template).substitute(params)
     test_params["ret_ty"] = make_wmma_ld_ret_ty(abc, itype)
     if abc == "c" :
       test_params["check_result"] = check_f16_4 if itype == "f16" else check_f32_8
@@ -107,29 +108,29 @@ def make_wmma_slice_args(itype, abcd, prefix="v"):
 
 def gen_wmma_store_tests():
   store_template = """
-declare void @llvm.nvvm.wmma.store.${intrinsic_suffix}(i8 ${as}* %src, ${args}${extra_args});
+declare void @${intrinsic}(i8 ${as}* %src, ${args}${extra_args});
 
-; CHECK-LABEL: .func {{.*}}test_wmma_store_${function_suffix}(
-define void @test_wmma_store_${function_suffix}(i8 ${as}* %src, ${args}${extra_args}) {
-; CHECK wmma.store.${intrinsic_suffix} {{.*}}[%rd{{[0-9+]}}
+; CHECK-LABEL: .func {{.*}}test_${function}(
+define void @test_${function}(i8 ${as}* %src, ${args}${extra_args}) {
+; CHECK ${instruction} {{.*}}[%rd{{[0-9+]}}
 ; CHECK: {${check_args}}
 ; CHECK: ${stride_pattern}
-  call void @llvm.nvvm.wmma.store.${intrinsic_suffix}(i8 ${as}* %src, ${args} ${extra_args});
+  call void @${intrinsic}(i8 ${as}* %src, ${args} ${extra_args});
   ret void
 }
 
-; CHECK-LABEL: .func{{.*}}test_wmma_store_${function_suffix}_o(
-define void @test_wmma_store_${function_suffix}_o(i8 ${as}* %src, ${args}${extra_args}) {
-; CHECK wmma.store.${intrinsic_suffix} {{.*}}[%rd{{[0-9+]}}+128]
+; CHECK-LABEL: .func{{.*}}test_${function}_o(
+define void @test_${function}_o(i8 ${as}* %src, ${args}${extra_args}) {
+; CHECK ${instruction} {{.*}}[%rd{{[0-9+]}}+128]
 ; CHECK: ${check_args}
 ; CHECK: ${stride_pattern}
   %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
-  call void @llvm.nvvm.wmma.store.${intrinsic_suffix}(i8 ${as}* %src1, ${args}${extra_args});
+  call void @${intrinsic}(i8 ${as}* %src1, ${args}${extra_args});
   ret void
 }
 """
-  suffix_template = "${abc}.sync.${layout}.m16n16k16${stride}.${itype}.${pspace}"
-  instruction_template = "${abc}.sync.${layout}.m16n16k16${space}.${itype}"
+  intrinsic_template = "llvm.nvvm.wmma.${geom}.store.${abc}.${layout}${stride}.${itype}.${pspace}"
+  instruction_template = "wmma.store.${abc}.sync.${geom}.${layout}${space}.${itype}"
 
   for abc, layout, space, stride, itype in product(
       "d",
@@ -145,13 +146,14 @@ define void @test_wmma_store_${function_suffix}_o(i8 ${as}* %src, ${args}${extra
         "stride" : stride,
         "itype" : itype,
         "pspace" : get_pspace(space),
-        "as"     : "addrspace(%d)" % get_aspace(space)
+        "as"     : "addrspace(%d)" % get_aspace(space),
+        "geom"   : "m16n16k16",
     }
 
     test_params = params
-    test_params["intrinsic_suffix"] = Template(suffix_template).substitute(params)
-    test_params["function_suffix"] = test_params["intrinsic_suffix"].replace(".","_")
-    test_params["instruction_suffix"] = Template(instruction_template).substitute(params)
+    test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
+    test_params["function"] = test_params["intrinsic"].replace(".","_")
+    test_params["instruction"] = Template(instruction_template).substitute(params)
     test_params["ret_ty"] = make_wmma_ld_ret_ty(abc, itype)
     test_params["check_args"] = check_f16_4 if itype == "f16" else check_f32_8
     if stride:
@@ -166,23 +168,24 @@ define void @test_wmma_store_${function_suffix}_o(i8 ${as}* %src, ${args}${extra
 
 def gen_wmma_mma_tests():
   mma_template = """
-declare ${ret_ty} @llvm.nvvm.wmma.mma.sync.$intrinsic_suffix(
+declare ${ret_ty} @${intrinsic}(
         ${args});
 
-; CHECK-LABEL: .func {{.*}}test_wmma_mma_${function_suffix}(
-define ${ret_ty} @test_wmma_mma_${function_suffix}(
+; CHECK-LABEL: .func {{.*}}test_${function}(
+define ${ret_ty} @test_${function}(
         ${args}) {
-; CHECK wmma.mma.${intrinsic_suffix} {{.*}}[%rd{{[0-9+]}}
+; CHECK ${instruction} {{.*}}[%rd{{[0-9+]}}
 ; CHECK ${check_d}
 ; CHECK ${check_ab}
 ; CHECK ${check_ab}
 ; CHECK ${check_c}
-  %r = call ${ret_ty} @llvm.nvvm.wmma.mma.sync.${intrinsic_suffix}(
+  %r = call ${ret_ty} @${intrinsic}(
         ${args});
   ret ${ret_ty} %r;
 }
 """
-  suffix_template = "${alayout}.${blayout}.m16n16k16.${dtype}.${ctype}${satf}"
+  intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}.${dtype}.${ctype}${satf}"
+  instruction_template = "wmma.mma.sync.${alayout}.${blayout}.${geom}.${dtype}.${ctype}${satf}"
 
   for alayout, blayout, ctype, dtype, satf in product(
       ["row","col"],
@@ -196,12 +199,14 @@ define ${ret_ty} @test_wmma_mma_${function_suffix}(
         "blayout" : blayout,
         "ctype" : ctype,
         "dtype" : dtype,
-        "satf"  : satf
+        "satf"  : satf,
+        "geom"  : "m16n16k16",
     }
 
     test_params = params
-    test_params["intrinsic_suffix"] = Template(suffix_template).substitute(params)
-    test_params["function_suffix"] = test_params["intrinsic_suffix"].replace(".", "_")
+    test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
+    test_params["function"] = test_params["intrinsic"].replace(".", "_")
+    test_params["instruction"] = Template(instruction_template).substitute(params)
     test_params["ret_ty"] = make_wmma_ld_ret_ty("d", dtype)
     test_params["check_ab"] = check_f16_8
     test_params["check_c"] = check_f16_4 if ctype == "f16" else check_f32_8