[mlir][NVGPU] Verifiers for nvgpu.mma.sync Op
authorManish Gupta <manigupta@google.com>
Wed, 13 Jul 2022 17:53:52 +0000 (17:53 +0000)
committerThomas Raoux <thomasraoux@google.com>
Wed, 13 Jul 2022 18:57:07 +0000 (18:57 +0000)
- Adds verification for `nvgpu.mma.sync` op
- Adds tests to `mlir/test/Dialect/NVGPU/invalid.mlir`
- `nvgpu.mma.sync` verifier caught a bug and triggered a failure in m16n8k4_tf32_f32 variant in `mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir`
     - The output shape of vector holding thread-level accumulators was inconsistent  and fixed in this change

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D129400

mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
mlir/test/Dialect/NVGPU/invalid.mlir

index 52338c0..ec0c18b 100644 (file)
@@ -81,7 +81,10 @@ def NVGPU_LdMatrixOp : NVGPU_Op<"ldmatrix",
   }];
 }
 
-def NVGPU_MmaSyncOp : NVGPU_Op<"mma.sync", [NoSideEffect]> {
+def NVGPU_MmaSyncOp : NVGPU_Op<"mma.sync", [
+    NoSideEffect,
+    PredOpTrait<"matrixA and matrixB have same element type", TCopVTEtIsSameAs<0, 1>>,
+  ]> {
   let description = [{
   The `nvgpu.mma.sync` op represents the distributed form of a collective
   matrix-multiply-and-accumulate (mma) operation that is compatible with
@@ -112,6 +115,8 @@ def NVGPU_MmaSyncOp : NVGPU_Op<"mma.sync", [NoSideEffect]> {
     `(` $matrixA`,` $matrixB`,` $matrixC `)` attr-dict
     `:` `(` type($matrixA) `,` type($matrixB) `,` type($matrixC) `)` `->` type($res)
   }];
+
+  let hasVerifier = 1;
 }
 
 
index c31a168..ac937e0 100644 (file)
@@ -88,5 +88,103 @@ LogicalResult DeviceAsyncCopyOp::verify() {
   return success();
 }
 
+LogicalResult MmaSyncOp::verify() {
+
+  // Fundamental tensor core mma.sync op
+  // For F32 (TF32), F16, S8, and S4 data types fundamental tensor core
+  // operation is of shape: 8-by-8-by-128b. F64 is an exception. The
+  // verification for mma.sync covering various shapes and data types is based
+  // on the fundamental tensor core operionation.
+  constexpr int kThreads = 32; // 32 threads per warp
+  int64_t shapeM = 8;
+  int64_t shapeN = 8;
+  int64_t shapeK; // set based on data type (128b for all data types except F64)
+
+  // Number of elements A, B, and C per thread per fundamental tensor core tile
+  int64_t numElementA;    // set based on data type (32b except F64)
+  int64_t numElementB;    // set based on data type (32b except F64)
+  int64_t numElementC{2}; // two accumulator elements per fundamental tile
+
+  // nvgpu.mma.sync vector operands (per thread)
+  auto aVector = getMatrixA().getType().cast<VectorType>();
+  auto bVector = getMatrixB().getType().cast<VectorType>();
+  auto cVector = getMatrixC().getType().cast<VectorType>();
+
+  // vector shapes
+  ArrayRef<int64_t> aShape = aVector.getShape();
+  ArrayRef<int64_t> bShape = bVector.getShape();
+  ArrayRef<int64_t> cShape = cVector.getShape();
+
+  // vector element type
+  Type aType = aVector.getElementType();
+
+  // nvgpu.mma.sync shape (per 32 threads or per warp)
+  int64_t m = getMmaShape()[0].cast<IntegerAttr>().getInt();
+  int64_t n = getMmaShape()[1].cast<IntegerAttr>().getInt();
+  int64_t k = getMmaShape()[2].cast<IntegerAttr>().getInt();
+
+  if (aType.isF64()) {
+    // exception to 8-by-8-128b fundamental tensor core tile size
+    shapeK = 4;
+    numElementA = 1;
+    numElementB = 1;
+  } else if (aType.isF32() || aType.isBF16() || aType.isF16() ||
+             aType.isInteger(8) || aType.isInteger(4)) {
+    // 8-by-8-128b fundamental tensor core tile size
+    int operandBitwidth = aType.getIntOrFloatBitWidth();
+    shapeK = 128 / operandBitwidth;     // 128b wide shapeK
+    numElementA = 32 / operandBitwidth; // 32b wide operand A
+    numElementB = 32 / operandBitwidth; // 32b wide operand B
+  } else {
+    return emitError() << "expected input data type (i4,i8,f16,bf16,tf32,f64) "
+                          "supported by nvgpu.mma.sync";
+  }
+
+  //
+  // Basic verification
+  //
+
+  // verify warp-wide size for vector a
+  if (aShape[0] * aShape[1] * kThreads != m * k)
+    return emitOpError() << "expected " << m * k
+                         << " warp-wide matrix A elements";
+
+  // verify warp-wide size for vector b
+  if (bShape[0] * bShape[1] * kThreads != k * n)
+    return emitOpError() << "expected " << k * n
+                         << " warp-wide matrix B elements";
+
+  // verify warp-wide size for vector c
+  if (cShape[0] * cShape[1] * kThreads != m * n)
+    return emitOpError() << "expected " << m * n
+                         << " warp-wide matrix C elements";
+
+  //
+  // Extended verification
+  //
+
+  // tiles of fundamental tensor core operations
+  int64_t mTile = m / shapeM;
+  int64_t nTile = n / shapeN;
+  int64_t kTile = k / shapeK;
+
+  // verify shape of aVector
+  if (!((aShape[0] == mTile * kTile) && (aShape[1] == numElementA)))
+    return emitOpError() << "expected matrix A to be shaped (" << mTile * kTile
+                         << " x " << numElementA << ")";
+
+  // verify shape of bVector
+  if (!((bShape[0] == kTile * nTile) && (bShape[1] == numElementB)))
+    return emitOpError() << "expected matrix B to be shaped (" << kTile * nTile
+                         << " x " << numElementB << ")";
+
+  // verify shape of cVector
+  if (!((cShape[0] == mTile * nTile) && (cShape[1] == numElementC)))
+    return emitOpError() << "expected matrix C to be shaped (" << mTile * nTile
+                         << " x " << numElementC << ")";
+
+  return success();
+}
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/NVGPU/IR/NVGPU.cpp.inc"
index 0d0f784..55b8df6 100644 (file)
@@ -205,7 +205,7 @@ func.func @ldmatrix_x1(%arg0: memref<128x128xf16, 3>) ->  vector<1x2xf16> {
 // -----
 
 // CHECK-LABEL: @m16n8k4_tf32
-func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: vector<4x1xf32>) -> vector<4x1xf32> {  
+func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {  
   // The A, B operand should be bitcast to i32
   // CHECK: llvm.extractvalue
   // CHECK: llvm.bitcast {{.*}} : vector<1xf32> to i32  
@@ -219,17 +219,22 @@ func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: v
   // CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<tf32>
   // CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 4>
   // CHECK-SAME: -> !llvm.struct<(f32, f32, f32, f32)>  
-  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4]} : (vector<2x1xf32>, vector<1x1xf32>, vector<4x1xf32>) -> vector<4x1xf32>  
-  // CHECK: [[el:%.+]] = llvm.extractvalue [[d]][0]
-  // CHECK: llvm.bitcast [[el]] : f32 to vector<1xf32>
-  // CHECK: [[el:%.+]] = llvm.extractvalue [[d]][1]
-  // CHECK: llvm.bitcast [[el]] : f32 to vector<1xf32>
-  // CHECK: [[el:%.+]] = llvm.extractvalue [[d]][2]
-  // CHECK: llvm.bitcast [[el]] : f32 to vector<1xf32>
-  // CHECK: [[el:%.+]] = llvm.extractvalue [[d]][3]
-  // CHECK: llvm.bitcast [[el]] : f32 to vector<1xf32>
-  // CHECK-COUNT-4: llvm.insertvalue {{.*}} : !llvm.array<4 x vector<1xf32>>
-  return %d : vector<4x1xf32>
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4]} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>  
+  // CHECK: [[undef:%.+]] = llvm.mlir.undef : vector<2xf32>
+  // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(f32, f32, f32, f32)>
+  // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(f32, f32, f32, f32)>
+  // CHECK: [[d00:%.+]] = llvm.insertelement {{%.+}}, [[undef]][{{.*}}] : vector<2xf32>
+  // CHECK: [[d01:%.+]] = llvm.insertelement {{%.+}}, [[d00]][{{.*}}] : vector<2xf32>
+
+  // CHECK: [[undef:%.+]] = llvm.mlir.undef : vector<2xf32>  
+  // CHECK-DAG: llvm.extractvalue [[d]][2] : !llvm.struct<(f32, f32, f32, f32)>
+  // CHECK-DAG: llvm.extractvalue [[d]][3] : !llvm.struct<(f32, f32, f32, f32)>
+  // CHECK: [[d10:%.+]] = llvm.insertelement {{%.+}}, [[undef]][{{.*}}] : vector<2xf32>
+  // CHECK: [[d11:%.+]] = llvm.insertelement {{%.+}}, [[d10]][{{.*}}] : vector<2xf32>
+  
+  // CHECK-DAG: llvm.insertvalue [[d01]], {{%.+}}[0] : !llvm.array<2 x vector<2xf32>>
+  // CHECK-DAG: llvm.insertvalue [[d11]], {{%.+}}[1] : !llvm.array<2 x vector<2xf32>>   
+  return %d : vector<2x2xf32>
 }
 
 // -----
index 7a9acb4..6be9cda 100644 (file)
@@ -1,4 +1,73 @@
 // RUN: mlir-opt -split-input-file -verify-diagnostics %s
+func.func @m16n8k16_fp16_vector_shape_a(%arg0: vector<4x4xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
+  // expected-error @+1 {{expected 256 warp-wide matrix A elements}}
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x4xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>    
+  return %d : vector<2x2xf16>
+}
+// -----
+
+func.func @m16n8k16_fp16_vector_shape_b(%arg0: vector<4x2xf16>, %arg1: vector<2x4xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
+  // expected-error @+1 {{expected 128 warp-wide matrix B elements}}
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x4xf16>, vector<2x2xf16>) -> vector<2x2xf16>    
+  return %d : vector<2x2xf16>
+}
+// -----
+
+func.func @m16n8k16_fp16_vector_shape_c(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x4xf16>) -> vector<2x4xf16> {
+  // expected-error @+1 {{expected 128 warp-wide matrix C elements}}
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x4xf16>) -> vector<2x4xf16>    
+  return %d : vector<2x4xf16>
+}
+// -----
+
+func.func @m16n8k16_fp16_vector_shape_a_extended(%arg0: vector<2x4xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
+  // expected-error @+1 {{expected matrix A to be shaped (4 x 2)}}
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<2x4xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>    
+  return %d : vector<2x2xf16>
+}
+// -----
+
+func.func @m16n8k8_fp32_vector_shape_a(%arg0: vector<4x2xf32>, %arg1: vector<2x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
+  // expected-error @+1 {{expected 128 warp-wide matrix A elements}}
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<4x2xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>    
+  return %d : vector<2x2xf32>
+}
+// -----
+
+func.func @m16n8k8_fp32_vector_shape_a_extended(%arg0: vector<1x4xf32>, %arg1: vector<2x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
+  // expected-error @+1 {{expected matrix A to be shaped (4 x 1)}}
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<1x4xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>    
+  return %d : vector<2x2xf32>
+}
+// -----
+
+func.func @m8n8k4_fp64_vector_shape_a(%arg0: vector<1x2xf64>, %arg1: vector<1x1xf64>, %arg2: vector<1x2xf64>) -> vector<1x2xf64> {
+  // expected-error @+1 {{expected 32 warp-wide matrix A elements}}
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [8, 8, 4]} : (vector<1x2xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64>    
+  return %d : vector<1x2xf64>
+}
+// -----
+
+func.func @m8n8k4_fp64_vector_shape_c_extended(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vector<2x1xf64>) -> vector<2x1xf64> {
+  // expected-error @+1 {{expected matrix C to be shaped (1 x 2)}}
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [8, 8, 4]} : (vector<1x1xf64>, vector<1x1xf64>, vector<2x1xf64>) -> vector<2x1xf64>    
+  return %d : vector<2x1xf64>
+}
+// -----
+
+func.func @m16n8k32_int8_vector_shape_b(%arg0: vector<4x4xi8>, %arg1: vector<4x4xi8>, %arg2: vector<2x2xi32>) -> vector<2x2xi32> {
+  // expected-error @+1 {{expected 256 warp-wide matrix B elements}}
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<4x4xi8>, vector<4x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
+  return %d : vector<2x2xi32>
+}
+// -----
+
+func.func @m16n8k32_int32_datatype(%arg0: vector<4x4xi32>, %arg1: vector<2x4xi8>, %arg2: vector<2x2xi32>) -> vector<2x2xi32> {
+  // expected-error @+1 {{op failed to verify that matrixA and matrixB have same element type}}
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<4x4xi32>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
+  return %d : vector<2x2xi32>
+}
+// -----
 
 func.func @async_cp_memory_space(%dst : memref<16xf32>, %src : memref<16xf32>, %i : index) -> () {
   // expected-error @+1 {{destination memref must have memory space 3}}