let hasVerifier = 1;
}
+def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix">,
+ Results<(outs AnyType:$res)>,
+ Arguments<(ins LLVM_AnyPointer: $ptr, I32Attr:$num, MMALayoutAttr:$layout)> {
+
+ let summary = "cooperative matrix load";
+
+ string llvmBuilder = [{
+ auto operands = moduleTranslation.lookupValues(opInst.getOperands());
+ auto intId = getLdMatrixIntrinsicId($layout, $num);
+ $res = createIntrinsicCall(builder, intId, operands, {operands[0]->getType()});
+ }];
+
+ string baseDescription = [{
+ The `nvvm.ldmatrix` operation collectively loads one or more matrices across
+ all threads in a warp from the location indicated by the address operand
+ `ptr` from shared memory.
+
+ The attribute `num` indicates how many 8x8 16-bit matrices are to be loaded.
+
+ All the threads in the warp must execute the same ldmatrix operations.
+
+ Each row of 8 elements needs to be consecutive in memory. Each lane of the
+ warp contains the start address of a row of 8 elements laid out as below:
+
+ ```
+ num | lane 0--7 | Threads 8--15 | Threads 16--31
+ 1 | addr0--addr7 | |
+ 2 | addr0--addr7 | addr8--addr15 |
+ 4 | addr0--addr7 | addr8--addr15 | addr16--addr31
+ ```
+
+ Example:
+ ```mlir
+ %l1 = nvvm.ldmatrix %ptr {num = 1 : i32, layout = #nvvm.mma_layout<row>} :
+ (!llvm.ptr<i32, 3>) -> i32
+ %l2 = nvvm.ldmatrix %ptr {num = 4 : i32, layout = #nvvm.mma_layout<row>} :
+ (!llvm.ptr<i32, 3>) -> !llvm.struct<(i32, i32, i32, i32)>
+ ```
+ }];
+
+ let assemblyFormat = "$ptr attr-dict `:` functional-type($ptr, $res)";
+ let hasVerifier = 1;
+}
+
#endif // NVVMIR_OPS
return success();
}
+LogicalResult NVVM::LdMatrixOp::verify() {
+ unsigned addressSpace =
+ ptr().getType().cast<LLVM::LLVMPointerType>().getAddressSpace();
+ if (addressSpace != 3)
+ return emitOpError("expected source pointer in memory space 3");
+
+ if (num() != 1 && num() != 2 && num() != 4)
+ return emitOpError("expected num attribute to be 1, 2 or 4");
+
+ Type i32 = IntegerType::get(getContext(), 32);
+ if (num() == 1 && getType() != i32)
+ return emitOpError("expected destination type is i32");
+ if (num() == 2 || num() == 4) {
+ Type dstType = LLVM::LLVMStructType::getLiteral(
+ getContext(), SmallVector<Type>(num(), i32));
+ if (getType() != dstType)
+ return emitOpError("expected destination type is a structure of ")
+ << num() << " elements of type i32";
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// NVVMDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//
llvm_unreachable("unknown shuffle kind");
}
+/// Return the intrinsic ID associated with ldmatrix for the given paramters.
+static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout,
+ int32_t num) {
+ if (layout == NVVM::MMALayout::col) {
+ switch (num) {
+ case 1:
+ return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16;
+ case 2:
+ return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16;
+ case 4:
+ return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16;
+ default:
+ llvm_unreachable("unsupported number of matrix");
+ }
+
+ } else {
+ switch (num) {
+ case 1:
+ return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
+ case 2:
+ return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
+ case 4:
+ return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
+ default:
+ llvm_unreachable("unsupported number of matrix");
+ }
+ }
+}
+
namespace {
/// Implementation of the dialect interface that converts operations belonging
/// to the NVVM dialect to LLVM IR.
// -----
+llvm.func @wmmald_matrix(%arg0: !llvm.ptr<i32>) {
+ // expected-error@+1 {{'nvvm.ldmatrix' op expected source pointer in memory space 3}}
+ %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32>) -> i32
+ llvm.return
+}
+
+// -----
+
+llvm.func @wmmald_matrix(%arg0: !llvm.ptr<i32, 3>) {
+ // expected-error@+1 {{'nvvm.ldmatrix' op expected num attribute to be 1, 2 or 4}}
+ %l = nvvm.ldmatrix %arg0 {num = 3 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32, 3>) -> i32
+ llvm.return
+}
+
+// -----
+
+llvm.func @wmmald_matrix(%arg0: !llvm.ptr<i32, 3>) {
+ // expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is i32}}
+ %l = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32, 3>) -> !llvm.struct<(i32)>
+ llvm.return
+}
+
+// -----
+
+llvm.func @wmmald_matrix(%arg0: !llvm.ptr<i32, 3>) {
+ // expected-error@+1 {{'nvvm.ldmatrix' op expected destination type is a structure of 4 elements of type i32}}
+ %l = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32, 3>) -> !llvm.struct<(i32, i32)>
+ llvm.return
+}
+
+// -----
+
llvm.func @caller() {
// expected-error @below {{expected function call to produce a value}}
llvm.call @callee() : () -> ()
llvm.return
}
+// CHECK-LABEL: llvm.func @ld_matrix
+llvm.func @ld_matrix(%arg0: !llvm.ptr<i32, 3>) {
+ // CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 1 : i32} : (!llvm.ptr<i32, 3>) -> i32
+ %l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32, 3>) -> i32
+ // CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 2 : i32} : (!llvm.ptr<i32, 3>) -> !llvm.struct<(i32, i32)>
+ %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32, 3>) -> !llvm.struct<(i32, i32)>
+ // CHECK: nvvm.ldmatrix %{{.*}} {layout = #nvvm.mma_layout<row>, num = 4 : i32} : (!llvm.ptr<i32, 3>) -> !llvm.struct<(i32, i32, i32, i32)>
+ %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32, 3>) -> !llvm.struct<(i32, i32, i32, i32)>
+ llvm.return
+}
// -----
// expected-error@below {{attribute attached to unexpected op}}
llvm.return
}
+// CHECK-LABEL: @ld_matrix(
+llvm.func @ld_matrix(%arg0: !llvm.ptr<i32, 3>) {
+ // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.trans.b16.p3i32(i32 addrspace(3)* %{{.*}})
+ %l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32, 3>) -> i32
+ // CHECK: call { i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x2.trans.b16.p3i32(i32 addrspace(3)* %{{.*}})
+ %l2 = nvvm.ldmatrix %arg0 {num = 2 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32, 3>) -> !llvm.struct<(i32, i32)>
+ // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x4.trans.b16.p3i32(i32 addrspace(3)* %{{.*}})
+ %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32, 3>) -> !llvm.struct<(i32, i32, i32, i32)>
+ llvm.return
+}
+
// This function has the "kernel" attribute attached and should appear in the
// NVVM annotations after conversion.
llvm.func @kernel_func() attributes {nvvm.kernel} {