[mlir][nvvm] Fix bug in ldmatrix intrinsic conversion
authorThomas Raoux <thomasraoux@google.com>
Tue, 15 Mar 2022 04:55:07 +0000 (04:55 +0000)
committerThomas Raoux <thomasraoux@google.com>
Tue, 15 Mar 2022 05:04:09 +0000 (05:04 +0000)
The ldmatrix intrinsic trans option was inverted.

Bug found by @christopherbate!

Differential Revision: https://reviews.llvm.org/D121666

mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
mlir/test/Target/LLVMIR/nvvmir.mlir

index 88f8af0eef136a1c2cf594330391b7db0b3355f4..f39b0d337811554cda8065252f500977340cf93b 100644 (file)
@@ -67,7 +67,7 @@ static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType,
 /// Return the intrinsic ID associated with ldmatrix for the given paramters.
 static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout,
                                                   int32_t num) {
-  if (layout == NVVM::MMALayout::col) {
+  if (layout == NVVM::MMALayout::row) {
     switch (num) {
     case 1:
       return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16;
index ef7a1f94105982751ff4f72489b90759eab3d639..b62913b7c273793e171377d055672a739bc3df23 100644 (file)
@@ -178,12 +178,18 @@ llvm.func @cp_async(%arg0: !llvm.ptr<i8, 3>, %arg1: !llvm.ptr<i8, 1>) {
 
 // CHECK-LABEL: @ld_matrix(
 llvm.func @ld_matrix(%arg0: !llvm.ptr<i32, 3>) {
-  // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.trans.b16.p3i32(i32 addrspace(3)* %{{.*}})
+  // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.b16.p3i32(i32 addrspace(3)* %{{.*}})
   %l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32, 3>) -> i32
-  // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.trans.b16.p3i32(i32 addrspace(3)* %{{.*}})
+  // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.b16.p3i32(i32 addrspace(3)* %{{.*}})
   %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32, 3>) -> !llvm.struct<(i32, i32)>
-  // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.trans.b16.p3i32(i32 addrspace(3)* %{{.*}})
+  // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.b16.p3i32(i32 addrspace(3)* %{{.*}})
   %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32, 3>) -> !llvm.struct<(i32, i32, i32, i32)>
+   // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.trans.b16.p3i32(i32 addrspace(3)* %{{.*}})
+  %l1t = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<col>} : (!llvm.ptr<i32, 3>) -> i32
+  // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.trans.b16.p3i32(i32 addrspace(3)* %{{.*}})
+  %l2t = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<col>} : (!llvm.ptr<i32, 3>) -> !llvm.struct<(i32, i32)>
+  // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.trans.b16.p3i32(i32 addrspace(3)* %{{.*}})
+  %l4t = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<col>} : (!llvm.ptr<i32, 3>) -> !llvm.struct<(i32, i32, i32, i32)>
   llvm.return
 }