}];
}
-def NVGPU_MmaSyncOp : NVGPU_Op<"mma.sync", [NoSideEffect]> {
+def NVGPU_MmaSyncOp : NVGPU_Op<"mma.sync", [
+ NoSideEffect,
+ PredOpTrait<"matrixA and matrixB have same element type", TCopVTEtIsSameAs<0, 1>>,
+ ]> {
let description = [{
The `nvgpu.mma.sync` op represents the distributed form of a collective
matrix-multiply-and-accumulate (mma) operation that is compatible with
`(` $matrixA`,` $matrixB`,` $matrixC `)` attr-dict
`:` `(` type($matrixA) `,` type($matrixB) `,` type($matrixC) `)` `->` type($res)
}];
+
+ let hasVerifier = 1;
}
return success();
}
+LogicalResult MmaSyncOp::verify() {
+
+ // Fundamental tensor core mma.sync op
+ // For F32 (TF32), F16, S8, and S4 data types fundamental tensor core
+ // operation is of shape: 8-by-8-by-128b. F64 is an exception. The
+ // verification for mma.sync covering various shapes and data types is based
+ // on the fundamental tensor core operionation.
+ constexpr int kThreads = 32; // 32 threads per warp
+ int64_t shapeM = 8;
+ int64_t shapeN = 8;
+ int64_t shapeK; // set based on data type (128b for all data types except F64)
+
+ // Number of elements A, B, and C per thread per fundamental tensor core tile
+ int64_t numElementA; // set based on data type (32b except F64)
+ int64_t numElementB; // set based on data type (32b except F64)
+ int64_t numElementC{2}; // two accumulator elements per fundamental tile
+
+ // nvgpu.mma.sync vector operands (per thread)
+ auto aVector = getMatrixA().getType().cast<VectorType>();
+ auto bVector = getMatrixB().getType().cast<VectorType>();
+ auto cVector = getMatrixC().getType().cast<VectorType>();
+
+ // vector shapes
+ ArrayRef<int64_t> aShape = aVector.getShape();
+ ArrayRef<int64_t> bShape = bVector.getShape();
+ ArrayRef<int64_t> cShape = cVector.getShape();
+
+ // vector element type
+ Type aType = aVector.getElementType();
+
+ // nvgpu.mma.sync shape (per 32 threads or per warp)
+ int64_t m = getMmaShape()[0].cast<IntegerAttr>().getInt();
+ int64_t n = getMmaShape()[1].cast<IntegerAttr>().getInt();
+ int64_t k = getMmaShape()[2].cast<IntegerAttr>().getInt();
+
+ if (aType.isF64()) {
+ // exception to 8-by-8-128b fundamental tensor core tile size
+ shapeK = 4;
+ numElementA = 1;
+ numElementB = 1;
+ } else if (aType.isF32() || aType.isBF16() || aType.isF16() ||
+ aType.isInteger(8) || aType.isInteger(4)) {
+ // 8-by-8-128b fundamental tensor core tile size
+ int operandBitwidth = aType.getIntOrFloatBitWidth();
+ shapeK = 128 / operandBitwidth; // 128b wide shapeK
+ numElementA = 32 / operandBitwidth; // 32b wide operand A
+ numElementB = 32 / operandBitwidth; // 32b wide operand B
+ } else {
+ return emitError() << "expected input data type (i4,i8,f16,bf16,tf32,f64) "
+ "supported by nvgpu.mma.sync";
+ }
+
+ //
+ // Basic verification
+ //
+
+ // verify warp-wide size for vector a
+ if (aShape[0] * aShape[1] * kThreads != m * k)
+ return emitOpError() << "expected " << m * k
+ << " warp-wide matrix A elements";
+
+ // verify warp-wide size for vector b
+ if (bShape[0] * bShape[1] * kThreads != k * n)
+ return emitOpError() << "expected " << k * n
+ << " warp-wide matrix B elements";
+
+ // verify warp-wide size for vector c
+ if (cShape[0] * cShape[1] * kThreads != m * n)
+ return emitOpError() << "expected " << m * n
+ << " warp-wide matrix C elements";
+
+ //
+ // Extended verification
+ //
+
+ // tiles of fundamental tensor core operations
+ int64_t mTile = m / shapeM;
+ int64_t nTile = n / shapeN;
+ int64_t kTile = k / shapeK;
+
+ // verify shape of aVector
+ if (!((aShape[0] == mTile * kTile) && (aShape[1] == numElementA)))
+ return emitOpError() << "expected matrix A to be shaped (" << mTile * kTile
+ << " x " << numElementA << ")";
+
+ // verify shape of bVector
+ if (!((bShape[0] == kTile * nTile) && (bShape[1] == numElementB)))
+ return emitOpError() << "expected matrix B to be shaped (" << kTile * nTile
+ << " x " << numElementB << ")";
+
+ // verify shape of cVector
+ if (!((cShape[0] == mTile * nTile) && (cShape[1] == numElementC)))
+ return emitOpError() << "expected matrix C to be shaped (" << mTile * nTile
+ << " x " << numElementC << ")";
+
+ return success();
+}
+
#define GET_OP_CLASSES
#include "mlir/Dialect/NVGPU/IR/NVGPU.cpp.inc"
// -----
// CHECK-LABEL: @m16n8k4_tf32
-func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: vector<4x1xf32>) -> vector<4x1xf32> {
+func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
// The A, B operand should be bitcast to i32
// CHECK: llvm.extractvalue
// CHECK: llvm.bitcast {{.*}} : vector<1xf32> to i32
// CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<tf32>
// CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 4>
// CHECK-SAME: -> !llvm.struct<(f32, f32, f32, f32)>
- %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4]} : (vector<2x1xf32>, vector<1x1xf32>, vector<4x1xf32>) -> vector<4x1xf32>
- // CHECK: [[el:%.+]] = llvm.extractvalue [[d]][0]
- // CHECK: llvm.bitcast [[el]] : f32 to vector<1xf32>
- // CHECK: [[el:%.+]] = llvm.extractvalue [[d]][1]
- // CHECK: llvm.bitcast [[el]] : f32 to vector<1xf32>
- // CHECK: [[el:%.+]] = llvm.extractvalue [[d]][2]
- // CHECK: llvm.bitcast [[el]] : f32 to vector<1xf32>
- // CHECK: [[el:%.+]] = llvm.extractvalue [[d]][3]
- // CHECK: llvm.bitcast [[el]] : f32 to vector<1xf32>
- // CHECK-COUNT-4: llvm.insertvalue {{.*}} : !llvm.array<4 x vector<1xf32>>
- return %d : vector<4x1xf32>
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4]} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+ // CHECK: [[undef:%.+]] = llvm.mlir.undef : vector<2xf32>
+ // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(f32, f32, f32, f32)>
+ // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(f32, f32, f32, f32)>
+ // CHECK: [[d00:%.+]] = llvm.insertelement {{%.+}}, [[undef]][{{.*}}] : vector<2xf32>
+ // CHECK: [[d01:%.+]] = llvm.insertelement {{%.+}}, [[d00]][{{.*}}] : vector<2xf32>
+
+ // CHECK: [[undef:%.+]] = llvm.mlir.undef : vector<2xf32>
+ // CHECK-DAG: llvm.extractvalue [[d]][2] : !llvm.struct<(f32, f32, f32, f32)>
+ // CHECK-DAG: llvm.extractvalue [[d]][3] : !llvm.struct<(f32, f32, f32, f32)>
+ // CHECK: [[d10:%.+]] = llvm.insertelement {{%.+}}, [[undef]][{{.*}}] : vector<2xf32>
+ // CHECK: [[d11:%.+]] = llvm.insertelement {{%.+}}, [[d10]][{{.*}}] : vector<2xf32>
+
+ // CHECK-DAG: llvm.insertvalue [[d01]], {{%.+}}[0] : !llvm.array<2 x vector<2xf32>>
+ // CHECK-DAG: llvm.insertvalue [[d11]], {{%.+}}[1] : !llvm.array<2 x vector<2xf32>>
+ return %d : vector<2x2xf32>
}
// -----
// RUN: mlir-opt -split-input-file -verify-diagnostics %s
+func.func @m16n8k16_fp16_vector_shape_a(%arg0: vector<4x4xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
+ // expected-error @+1 {{expected 256 warp-wide matrix A elements}}
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x4xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+ return %d : vector<2x2xf16>
+}
+// -----
+
+func.func @m16n8k16_fp16_vector_shape_b(%arg0: vector<4x2xf16>, %arg1: vector<2x4xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
+ // expected-error @+1 {{expected 128 warp-wide matrix B elements}}
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x4xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+ return %d : vector<2x2xf16>
+}
+// -----
+
+func.func @m16n8k16_fp16_vector_shape_c(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x4xf16>) -> vector<2x4xf16> {
+ // expected-error @+1 {{expected 128 warp-wide matrix C elements}}
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x4xf16>) -> vector<2x4xf16>
+ return %d : vector<2x4xf16>
+}
+// -----
+
+func.func @m16n8k16_fp16_vector_shape_a_extended(%arg0: vector<2x4xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
+ // expected-error @+1 {{expected matrix A to be shaped (4 x 2)}}
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<2x4xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+ return %d : vector<2x2xf16>
+}
+// -----
+
+func.func @m16n8k8_fp32_vector_shape_a(%arg0: vector<4x2xf32>, %arg1: vector<2x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
+ // expected-error @+1 {{expected 128 warp-wide matrix A elements}}
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<4x2xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+ return %d : vector<2x2xf32>
+}
+// -----
+
+func.func @m16n8k8_fp32_vector_shape_a_extended(%arg0: vector<1x4xf32>, %arg1: vector<2x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
+ // expected-error @+1 {{expected matrix A to be shaped (4 x 1)}}
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<1x4xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
+ return %d : vector<2x2xf32>
+}
+// -----
+
+func.func @m8n8k4_fp64_vector_shape_a(%arg0: vector<1x2xf64>, %arg1: vector<1x1xf64>, %arg2: vector<1x2xf64>) -> vector<1x2xf64> {
+ // expected-error @+1 {{expected 32 warp-wide matrix A elements}}
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [8, 8, 4]} : (vector<1x2xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64>
+ return %d : vector<1x2xf64>
+}
+// -----
+
+func.func @m8n8k4_fp64_vector_shape_c_extended(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vector<2x1xf64>) -> vector<2x1xf64> {
+ // expected-error @+1 {{expected matrix C to be shaped (1 x 2)}}
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [8, 8, 4]} : (vector<1x1xf64>, vector<1x1xf64>, vector<2x1xf64>) -> vector<2x1xf64>
+ return %d : vector<2x1xf64>
+}
+// -----
+
+func.func @m16n8k32_int8_vector_shape_b(%arg0: vector<4x4xi8>, %arg1: vector<4x4xi8>, %arg2: vector<2x2xi32>) -> vector<2x2xi32> {
+ // expected-error @+1 {{expected 256 warp-wide matrix B elements}}
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<4x4xi8>, vector<4x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
+ return %d : vector<2x2xi32>
+}
+// -----
+
+func.func @m16n8k32_int32_datatype(%arg0: vector<4x4xi32>, %arg1: vector<2x4xi8>, %arg2: vector<2x2xi32>) -> vector<2x2xi32> {
+ // expected-error @+1 {{op failed to verify that matrixA and matrixB have same element type}}
+ %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<4x4xi32>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
+ return %d : vector<2x2xi32>
+}
+// -----
func.func @async_cp_memory_space(%dst : memref<16xf32>, %src : memref<16xf32>, %i : index) -> () {
// expected-error @+1 {{destination memref must have memory space 3}}