[mlir][gpu] Relax restriction on MMA store op to allow chain of mma ops.
authorthomasraoux <thomasraoux@google.com>
Thu, 27 May 2021 15:58:11 +0000 (08:58 -0700)
committerthomasraoux <thomasraoux@google.com>
Thu, 27 May 2021 16:13:51 +0000 (09:13 -0700)
In order to allow large matmul operations using the MMA ops we need to chain
operations this is not possible unless "DOp" and "COp" type have matching
layout so remove the "DOp" layout and force accumulator and result type to
match.
Added a test for the case where the MMA value is accumulated.

Differential Revision: https://reviews.llvm.org/D103023

mlir/include/mlir/Dialect/GPU/GPUDialect.h
mlir/include/mlir/Dialect/GPU/GPUOps.td
mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir
mlir/test/Dialect/GPU/invalid.mlir
mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f16.mlir
mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f32.mlir

index 7832d48..c6d84f2 100644 (file)
@@ -85,9 +85,9 @@ struct MMAMatrixStorageType : public TypeStorage {
   Type elementType;
 
   /// MMA operand that this MMAMatrix holds. The general form of operation this
-  /// type supports is given by the equation D = (alpha*(A*B)) + (beta*C). This
-  /// field specifies which operand in the given equation is held by this type.
-  /// The valid values are "AOp", "BOp", "COp" and "DOp".
+  /// type supports is given by the equation C += A*B. This field specifies
+  /// which operand in the given equation is held by this type. The valid values
+  /// are "AOp", "BOp" and "COp".
   StringRef operand;
 };
 
@@ -112,13 +112,13 @@ struct MMAMatrixStorageType : public TypeStorage {
 /// and 8 f32s for f32 data type of MMAMatrix. Some other instances of usage
 /// are:-
 ///
-///   %3 = gpu.subgroup_mma_compute %0, %1, %2 : !gpu.mma_matrix<16x16xf16,
-///   "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf32,
-///                             "COp"> -> !gpu.mma_matrix<16x16xf32, "DOp">
+///   %3 = gpu.subgroup_mma_compute %0, %1, %2 :
+///   !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">
+///    -> !gpu.mma_matrix<16x16xf32, "COp">
 ///
 ///
 ///   gpu.subgroup_mma_store_matrix %3, %arg22[%c0, %c0] {leadDimension = 16
-///           : index}: !gpu.mma_matrix<16x16xf32, "DOp">, memref<16x16xf32>
+///           : index}: !gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32>
 // TODO: consider moving this to ODS.
 class MMAMatrixType
     : public Type::TypeBase<MMAMatrixType, Type, MMAMatrixStorageType> {
@@ -154,9 +154,8 @@ public:
   Type getElementType() const;
 
   /// The general form of operation this type supports is given by the equation
-  /// D = (alpha*(A*B)) + (beta*C). This function returns which operand in the
-  /// given equation is held by this type. String returned can be one of"AOp",
-  /// "BOp", "COp" and "DOp".
+  /// C += A*B. This function returns which operand in the given equation is
+  /// held by this type. String returned can be one of"AOp", "BOp" and "COp".
   StringRef getOperand() const;
 };
 
index a29a22a..bc6f3f3 100644 (file)
@@ -966,7 +966,7 @@ def GPU_SubgroupMmaStoreMatrixOp : GPU_Op<"subgroup_mma_store_matrix",
 
     ```mlir
     gpu.subgroup_mma_store_matrix %D, %sg[%i,%j] : { leadDimension = 32 : i32} :
-    !gpu.mma_matrix<16x16xf16, "DOp">, memref<32x32xf16, 3>
+    !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, 3>
     ```
   }];
 
@@ -982,7 +982,8 @@ def GPU_SubgroupMmaStoreMatrixOp : GPU_Op<"subgroup_mma_store_matrix",
   let verifier = [{ return ::verify(*this); }];
 }
 
-def GPU_SubgroupMmaComputeOp : GPU_Op<"subgroup_mma_compute", []>{
+def GPU_SubgroupMmaComputeOp : GPU_Op<"subgroup_mma_compute",
+   [NoSideEffect, AllTypesMatch<["opC", "res"]>]>{
 
   let summary = "GPU warp synchronous matrix multiply accumulate";
 
@@ -992,7 +993,7 @@ def GPU_SubgroupMmaComputeOp : GPU_Op<"subgroup_mma_compute", []>{
 
     This operation takes three `!gpu.mma_matrix`s as arguments. All of them hold `A`,
      `B` and `C`operands for the mma operation. The operation performed is represented
-    as `D = A * B + C`. The op returns a `!gpu.mma_matrix` which contains the result of
+    as `C += A * B`. The op returns a `!gpu.mma_matrix` which contains the result of
     the operation held by the current thread.
 
     This op is meant to be used along with `gpu.subgroup_mma_store_matrix` and
@@ -1002,8 +1003,8 @@ def GPU_SubgroupMmaComputeOp : GPU_Op<"subgroup_mma_compute", []>{
 
     ```mlir
     %D = gpu.subgroup_mma_compute_matrix %A, %B, %C :
-    !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">,
-    !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "DOp">
+    !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">>
+    -> !gpu.mma_matrix<16x16xf16, "COp">
     ```
   }];
 
@@ -1014,7 +1015,7 @@ def GPU_SubgroupMmaComputeOp : GPU_Op<"subgroup_mma_compute", []>{
   let results = (outs GPU_MMAMatrix:$res);
 
   let assemblyFormat = [{
-    $opA`,` $opB`,` $opC attr-dict `:` type($opA)`,` type($opB)`,` type($opC) `->` type($res)
+    $opA`,` $opB`,` $opC attr-dict `:` type($opA)`,` type($opB) `->` type($res)
   }];
 
   let verifier = [{ return ::verify(*this); }];
index 42e64e5..ea336dc 100644 (file)
@@ -135,11 +135,9 @@ struct LowerGpuOpsToNVVMOpsPass
       numElemsPerThreadF16["AOp"] = 8;
       numElemsPerThreadF16["BOp"] = 8;
       numElemsPerThreadF16["COp"] = 4;
-      numElemsPerThreadF16["DOp"] = 4;
       numElemsPerThreadF32["AOp"] = 8;
       numElemsPerThreadF32["BOp"] = 8;
       numElemsPerThreadF32["COp"] = 8;
-      numElemsPerThreadF32["DOp"] = 8;
       Type structToReturn;
       if (type.getElementType().isF16()) {
         // Number of f16's in 32-bit.
index c4458aa..5f5213a 100644 (file)
@@ -29,7 +29,6 @@ public:
     numHalfsInOpFrags[A] = 8;
     numHalfsInOpFrags[B] = 8;
     numHalfsInOpFrags[C] = 4;
-    numHalfsInOpFrags[D] = 4;
     i32Ty = IntegerType::get(context, 32);
     f16Ty = FloatType::getF16(context);
     f32Ty = FloatType::getF32(context);
@@ -63,7 +62,7 @@ public:
   SmallVector<unsigned, 4> numHalfsInOpFrags;
   /// Represents the operands of a MMA operation of the form D = (alpha*(A*B)) +
   /// (beta*C).
-  enum OperandMap { A, B, C, D };
+  enum OperandMap { A, B, C };
 };
 
 /// Checks if all the operands of the op being lowered are of LLVM Types. The
@@ -305,7 +304,7 @@ public:
             .getType()
             .cast<gpu::MMAMatrixType>()
             .getElementType() == f16Ty) {
-      for (unsigned i = 0, e = numHalfsInOpFrags[D]; i < e; ++i) {
+      for (unsigned i = 0, e = numHalfsInOpFrags[C]; i < e; ++i) {
         Value toUse = rewriter.create<LLVM::ExtractValueOp>(
             loc, f16x2Ty, operands[0], rewriter.getI32ArrayAttr(i));
         storeOpOperands.push_back(toUse);
index 1f081d8..39acf41 100644 (file)
@@ -64,8 +64,8 @@ MMAMatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
                       ArrayRef<int64_t> shape, Type elementType,
                       StringRef operand) {
   if (!operand.equals("AOp") && !operand.equals("BOp") &&
-      !operand.equals("COp") && !operand.equals("DOp"))
-    return emitError() << "operand expected to be one of AOp, BOp, COp or DOp";
+      !operand.equals("COp"))
+    return emitError() << "operand expected to be one of AOp, BOp or COp";
 
   if (shape.size() != 2)
     return emitError() << "MMAMatrixType must have exactly two dimensions";
@@ -1027,9 +1027,9 @@ static LogicalResult verify(SubgroupMmaStoreMatrixOp op) {
         "destination memorySpace of kGenericMemorySpace, "
         "kGlobalMemorySpace or kSharedMemorySpace only allowed");
 
-  if (!srcMatrixType.getOperand().equals("DOp"))
+  if (!srcMatrixType.getOperand().equals("COp"))
     return op.emitError(
-        "expected the operand matrix being stored to have 'DOp' operand type");
+        "expected the operand matrix being stored to have 'COp' operand type");
 
   return success();
 }
index c44d7f0..de5d0d3 100644 (file)
@@ -31,11 +31,11 @@ gpu.module @test_module {
 
   // CHECK-LABEL: func @gpu_wmma_store_op
   // CHECK-SAME: (%[[D:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>) {
-  func @gpu_wmma_store_op(%arg0 : !gpu.mma_matrix<16x16xf16, "DOp">) -> () {
+  func @gpu_wmma_store_op(%arg0 : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
     %sg = memref.alloca(){alignment = 32} : memref<32x32xf16, 3>
     %i = constant 16 : index
     %j = constant 16 : index
-    gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "DOp">, memref<32x32xf16, 3>
+    gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, 3>
     // CHECK:  %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32
     // CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}]
     // CHECK:  %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32
@@ -61,9 +61,9 @@ gpu.module @test_module {
 gpu.module @test_module {
 
   // CHECK-LABEL: func @gpu_wmma_mma_op
-  // CHECK-SAME: (%[[A:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, %[[B:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, %[[C:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>) {
-  func @gpu_wmma_mma_op(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
-    %D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "DOp">
+  // CHECK-SAME: (%[[A:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, %[[B:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, %[[C:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>)
+  func @gpu_wmma_mma_op(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) -> (!gpu.mma_matrix<16x16xf16, "COp">) {
+    %D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
     // CHECK:  %[[A1:.*]] = llvm.extractvalue %[[A]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
     // CHECK:  %[[A2:.*]] = llvm.extractvalue %[[A]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
     // CHECK:  %[[A3:.*]] = llvm.extractvalue %[[A]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
@@ -84,8 +84,70 @@ gpu.module @test_module {
     // CHECK:  %[[C2:.*]] = llvm.extractvalue %[[C]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
     // CHECK:  %[[C3:.*]] = llvm.extractvalue %[[C]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
     // CHECK:  %[[C4:.*]] = llvm.extractvalue %[[C]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
-    // CHECK:  %{{.*}} = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[A8]], %[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]], %[[B6]], %[[B7]], %[[B8]], %[[C1]], %[[C2]], %[[C3]], %[[C4]] : vector<2xf16> -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
-    // CHECK:  llvm.return
-    return
+    // CHECK:  %[[RES:.*]] = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[A8]], %[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]], %[[B6]], %[[B7]], %[[B8]], %[[C1]], %[[C2]], %[[C3]], %[[C4]] : vector<2xf16> -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+    // CHECK:  llvm.return %[[RES]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+    return %D : !gpu.mma_matrix<16x16xf16, "COp">
   }
 }
+
+// -----
+
+gpu.module @test_module {
+
+// CHECK-LABEL: func @gpu_wmma_mma_loop_op
+//       CHECK:   %[[C:.+]] = nvvm.wmma.m16n16k16.load.c.f16.row.stride %{{.*}}, %{{.*}} : (!llvm.ptr<i32>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   llvm.br ^bb1(%{{.*}}, %[[C]] : i32, !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>)
+//       CHECK:  ^bb1(%{{.*}}: i32, %[[ACC:.+]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>):  // 2 preds: ^bb0, ^bb2
+//       CHECK:   llvm.cond_br %38, ^bb2, ^bb3
+//       CHECK:  ^bb2:  // pred: ^bb1
+//       CHECK:   %[[A:.+]] = nvvm.wmma.m16n16k16.load.a.f16.row.stride %{{.*}}, %{{.*}} : (!llvm.ptr<i32>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[B:.+]] = nvvm.wmma.m16n16k16.load.b.f16.row.stride %{{.*}}, %{{.*}} : (!llvm.ptr<i32>, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[A0:.+]] = llvm.extractvalue %[[A]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[A1:.+]] = llvm.extractvalue %[[A]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[A2:.+]] = llvm.extractvalue %[[A]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[A3:.+]] = llvm.extractvalue %[[A]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[A4:.+]] = llvm.extractvalue %[[A]][4 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[A5:.+]] = llvm.extractvalue %[[A]][5 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[A6:.+]] = llvm.extractvalue %[[A]][6 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[A7:.+]] = llvm.extractvalue %[[A]][7 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[B0:.+]] = llvm.extractvalue %[[B]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[B1:.+]] = llvm.extractvalue %[[B]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[B2:.+]] = llvm.extractvalue %[[B]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[B3:.+]] = llvm.extractvalue %[[B]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[B4:.+]] = llvm.extractvalue %[[B]][4 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[B5:.+]] = llvm.extractvalue %[[B]][5 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[B6:.+]] = llvm.extractvalue %[[B]][6 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[B7:.+]] = llvm.extractvalue %[[B]][7 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[ACC0:.+]] = llvm.extractvalue %[[ACC]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[ACC1:.+]] = llvm.extractvalue %[[ACC]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[ACC2:.+]] = llvm.extractvalue %[[ACC]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[ACC3:.+]] = llvm.extractvalue %[[ACC]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %[[ACC_MUL:.+]] = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[B0]], %[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]], %[[B6]], %[[B7]], %[[ACC0]], %[[ACC1]], %[[ACC2]], %[[ACC3]] : vector<2xf16> -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   llvm.br ^bb1(%{{.*}}, %[[ACC_MUL]] : i32, !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>)
+//       CHECK:  ^bb3:  // pred: ^bb1
+//       CHECK:   %87 = llvm.extractvalue %[[ACC]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %88 = llvm.extractvalue %[[ACC]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %89 = llvm.extractvalue %[[ACC]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   %90 = llvm.extractvalue %[[ACC]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
+//       CHECK:   nvvm.wmma.m16n16k16.store.d.f16.row.stride %86, %87, %88, %89, %90, %79 : !llvm.ptr<i32>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, i32
+
+  func @gpu_wmma_mma_loop_op(%arg0: memref<128x128xf16>, %arg1: memref<128x128xf16>, %arg2: memref<128x128xf16>) {
+      %c0 = constant 0 : index
+      %c128 = constant 128 : index
+      %c32 = constant 32 : index
+      %0 = gpu.subgroup_mma_load_matrix %arg2[%c0, %c0] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
+      br ^bb1(%c0, %0 : index, !gpu.mma_matrix<16x16xf16, "COp">)
+    ^bb1(%1: index, %2: !gpu.mma_matrix<16x16xf16, "COp">):  // 2 preds: ^bb0, ^bb2
+      %3 = cmpi slt, %1, %c128 : index
+      cond_br %3, ^bb2, ^bb3
+    ^bb2:  // pred: ^bb1
+      %4 = gpu.subgroup_mma_load_matrix %arg0[%c0, %1] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
+      %5 = gpu.subgroup_mma_load_matrix %arg1[%1, %c0] {leadDimension = 128 : index} : memref<128x128xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
+      %6 = gpu.subgroup_mma_compute %4, %5, %2 : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
+      %7 = addi %1, %c32 : index
+      br ^bb1(%7, %6 : index, !gpu.mma_matrix<16x16xf16, "COp">)
+    ^bb3:  // pred: ^bb1
+      gpu.subgroup_mma_store_matrix %2, %arg2[%c0, %c0] {leadDimension = 128 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<128x128xf16>
+      return
+    }
+}
index 58eca3b..f399ddd 100644 (file)
@@ -474,7 +474,7 @@ func @mmamatrix_invalid_shape(){
 func @mmamatrix_operand_type(){
     %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
     %i = constant 16 : index
-    // expected-error @+1 {{operand expected to be one of AOp, BOp, COp or DOp}}
+    // expected-error @+1 {{operand expected to be one of AOp, BOp or COp}}
     %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "EOp">
     return
 }
@@ -513,35 +513,25 @@ func @mmaLoadOp_invalid_mem_space(){
 
 // -----
 
-func @mmaLoadOp_operand_type(){
-    %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3>
-    %i = constant 16 : index
-    // expected-error @+1 {{only AOp, BOp and COp can be loaded}}
-    %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "DOp">
-    return
-}
-
-// -----
-
 #layout_map_col_major = affine_map<(i, j) -> (j, i)>
 
-func @wmmaStoreOp_invalid_map(%arg0 : !gpu.mma_matrix<16x16xf16, "DOp">) -> () {
+func @wmmaStoreOp_invalid_map(%arg0 : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
     %sg = memref.alloca(){alignment = 32} : memref<32x32xf16, #layout_map_col_major, 3>
     %i = constant 16 : index
     %j = constant 16 : index
     // expected-error @+1 {{expected identity layout map for destination memref}}
-    gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "DOp">, memref<32x32xf16,#layout_map_col_major, 3>
+    gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16,#layout_map_col_major, 3>
     return
 }
 
 // -----
 
-func @wmmaStoreOp_invalid_mem_space(%arg0 : !gpu.mma_matrix<16x16xf16, "DOp">) -> () {
+func @wmmaStoreOp_invalid_mem_space(%arg0 : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
     %sg = memref.alloca(){alignment = 32} : memref<32x32xf16, 5>
     %i = constant 16 : index
     %j = constant 16 : index
     // expected-error @+1 {{destination memorySpace of kGenericMemorySpace, kGlobalMemorySpace or kSharedMemorySpace only allowed}}
-    gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "DOp">, memref<32x32xf16, 5>
+    gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, 5>
     return
 }
 
@@ -551,7 +541,7 @@ func @wmmaStoreOp_invalid_store_operand(%arg0 : !gpu.mma_matrix<16x16xf16, "AOp"
     %sg = memref.alloca(){alignment = 32} : memref<32x32xf16, 3>
     %i = constant 16 : index
     %j = constant 16 : index
-    // expected-error @+1 {{expected the operand matrix being stored to have 'DOp' operand type}}
+    // expected-error @+1 {{expected the operand matrix being stored to have 'COp' operand type}}
     gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "AOp">, memref<32x32xf16, 3>
     return
 }
@@ -560,7 +550,7 @@ func @wmmaStoreOp_invalid_store_operand(%arg0 : !gpu.mma_matrix<16x16xf16, "AOp"
 
 func @wmmaMmaOp_invalid_operand_order(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
     // expected-error @+1 {{operands must be in the order AOp, BOp, COp}}
-    %D = gpu.subgroup_mma_compute %B, %A, %C : !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "DOp">
+    %D = gpu.subgroup_mma_compute %B, %A, %C : !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf16, "AOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
     return
 }
 
@@ -568,6 +558,6 @@ func @wmmaMmaOp_invalid_operand_order(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B
 
 func @wmmaMmaOp_invalid_operand_shapes(%A : !gpu.mma_matrix<16x32xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
     // expected-error @+1 {{operand shapes do not satisfy matmul constraints}}
-    %D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x32xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "DOp">
+    %D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x32xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
     return
 }
index 9e6802a..3dc911b 100644 (file)
@@ -82,9 +82,9 @@ module attributes {gpu.container_module}  {
       %1 = gpu.subgroup_mma_load_matrix %arg0[%c0, %c0] {operand = "BOp", leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
       %2 = gpu.subgroup_mma_load_matrix %arg22[%c0, %c0] {operand = "COp", leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
 
-      %3 = gpu.subgroup_mma_compute %0, %1, %2 : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "DOp">
+      %3 = gpu.subgroup_mma_compute %0, %1, %2 : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
 
-      gpu.subgroup_mma_store_matrix %3, %arg0[%c0, %c0] {leadDimension = 16 : index}: !gpu.mma_matrix<16x16xf16, "DOp">, memref<16x16xf16>
+      gpu.subgroup_mma_store_matrix %3, %arg0[%c0, %c0] {leadDimension = 16 : index}: !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16>
 
       gpu.return
     }
index 5ca0147..ba948ea 100644 (file)
@@ -73,9 +73,9 @@ module attributes {gpu.container_module}  {
       %1 = gpu.subgroup_mma_load_matrix %arg0[%c0, %c0] {operand = "BOp", leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
       %2 = gpu.subgroup_mma_load_matrix %arg22[%c0, %c0] {operand = "COp", leadDimension = 16 : index} : memref<16x16xf32> -> !gpu.mma_matrix<16x16xf32, "COp">
 
-      %3 = gpu.subgroup_mma_compute %0, %1, %2 : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf32, "COp"> -> !gpu.mma_matrix<16x16xf32, "DOp">
+      %3 = gpu.subgroup_mma_compute %0, %1, %2 : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf32, "COp">
 
-      gpu.subgroup_mma_store_matrix %3, %arg22[%c0, %c0] {leadDimension = 16 : index}: !gpu.mma_matrix<16x16xf32, "DOp">, memref<16x16xf32>
+      gpu.subgroup_mma_store_matrix %3, %arg22[%c0, %c0] {leadDimension = 16 : index}: !gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32>
 
       gpu.return
     }