[mlir][spirv] Add conversion from GPU WMMA ops to SPIRV Cooperative matrix
authorNirvedh Meshram <nirvedh@nod-labs.com>
Fri, 21 Oct 2022 00:32:17 +0000 (17:32 -0700)
committerNirvedh Meshram <nirvedh@nod-labs.com>
Sun, 23 Oct 2022 01:29:40 +0000 (18:29 -0700)
Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D136521

mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt
mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp [new file with mode: 0644]
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir [new file with mode: 0644]
mlir/test/Dialect/GPU/invalid.mlir

index 132866f..3c32815 100644 (file)
 #ifndef MLIR_CONVERSION_GPUTOSPIRV_GPUTOSPIRV_H
 #define MLIR_CONVERSION_GPUTOSPIRV_GPUTOSPIRV_H
 
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Transforms/DialectConversion.h"
 
 namespace mlir {
 class SPIRVTypeConverter;
 
+namespace gpu {
+class MMAMatrixType;
+} // namespace gpu
+
 /// Appends to a pattern list additional patterns for translating GPU Ops to
 /// SPIR-V ops. For a gpu.func to be converted, it should have a
 /// spirv.entry_point_abi attribute.
 void populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                 RewritePatternSet &patterns);
+
+/// Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV.
+void populateGpuWMMAToSPIRVConversionPatterns(SPIRVTypeConverter &typeConverter,
+                                              RewritePatternSet &patterns);
+
+spirv::CooperativeMatrixNVType convertMMAToSPIRVType(gpu::MMAMatrixType type);
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_GPUTOSPIRV_GPUTOSPIRV_H
index f5dedb0..0cf2029 100644 (file)
@@ -1246,19 +1246,35 @@ def GPU_SubgroupMmaConstantMatrixOp : GPU_Op<"subgroup_mma_constant_matrix",
   }];
 }
 
-def GPU_ElementwiseOpAdd  : I32EnumAttrCase<"ADDF", 0, "addf">;
-def GPU_ElementwiseOpMul  : I32EnumAttrCase<"MULF", 1, "mulf">;
-def GPU_ElementwiseOpMaxF : I32EnumAttrCase<"MAXF", 2, "maxf">;
-def GPU_ElementwiseOpMinF : I32EnumAttrCase<"MINF", 3, "minf">;
-def GPU_ElementwiseOpDivF : I32EnumAttrCase<"DIVF", 4, "divf">;
+def GPU_ElementwiseOpAddF  : I32EnumAttrCase<"ADDF", 0, "addf">;
+def GPU_ElementwiseOpMulF  : I32EnumAttrCase<"MULF", 1, "mulf">;
+def GPU_ElementwiseOpSUBF  : I32EnumAttrCase<"SUBF", 2, "subf">;
+def GPU_ElementwiseOpMaxF : I32EnumAttrCase<"MAXF", 3, "maxf">;
+def GPU_ElementwiseOpMinF : I32EnumAttrCase<"MINF", 4, "minf">;
+def GPU_ElementwiseOpDivF : I32EnumAttrCase<"DIVF", 5, "divf">;
+def GPU_ElementwiseOpAddI  : I32EnumAttrCase<"ADDI", 6, "addi">;
+def GPU_ElementwiseOpMulI  : I32EnumAttrCase<"MULI", 7, "muli">;
+def GPU_ElementwiseOpSUBI  : I32EnumAttrCase<"SUBI", 8, "subi">;
+def GPU_ElementwiseOpDivS : I32EnumAttrCase<"DIVS", 9, "divs">;
+def GPU_ElementwiseOpDivU : I32EnumAttrCase<"DIVU", 10, "divu">;
+def GPU_ElementwiseOpNEGF : I32EnumAttrCase<"NEGATEF", 11, "negatef">;
+def GPU_ElementwiseOpNEGS : I32EnumAttrCase<"NEGATES", 12, "negates">;
 
 def MMAElementWise : I32EnumAttr<"MMAElementwiseOp",
   "elementwise operation to apply to mma matrix", [
-    GPU_ElementwiseOpAdd,
-    GPU_ElementwiseOpMul,
+    GPU_ElementwiseOpAddF,
+    GPU_ElementwiseOpMulF,
+    GPU_ElementwiseOpSUBF,
     GPU_ElementwiseOpMaxF,
     GPU_ElementwiseOpMinF,
-    GPU_ElementwiseOpDivF
+    GPU_ElementwiseOpDivF,
+    GPU_ElementwiseOpAddI,
+    GPU_ElementwiseOpMulI,
+    GPU_ElementwiseOpSUBI,
+    GPU_ElementwiseOpDivS,
+    GPU_ElementwiseOpDivU,
+    GPU_ElementwiseOpNEGF,
+    GPU_ElementwiseOpNEGS
   ]> {
   let genSpecializedAttr = 0;
   let cppNamespace = "::mlir::gpu";
index 50b5c49..b124625 100644 (file)
@@ -42,7 +42,18 @@ class SPIRV_ArithmeticUnaryOp<string mnemonic, Type type,
       // Operand type same as result type.
       SPIRV_UnaryOp<mnemonic, type, type,
                    !listconcat(traits,
-                               [Pure, SameOperandsAndResultType])>;
+                               [Pure, SameOperandsAndResultType])> {
+  // In addition to normal types arithmetic instructions can support cooperative
+  // matrix.
+  let arguments = (ins
+    SPIRV_ScalarOrVectorOrCoopMatrixOf<type>:$operand
+  );
+
+  let results = (outs
+    SPIRV_ScalarOrVectorOrCoopMatrixOf<type>:$result
+  );
+  let assemblyFormat = "operands attr-dict `:` type($result)";
+  }
 
 // -----
 
index 8ecbdf2..963a718 100644 (file)
@@ -311,8 +311,9 @@ static Value createScalarOp(OpBuilder &builder, Location loc,
   case gpu::MMAElementwiseOp::MINF:
     return createMinMaxF(builder, loc, operands[0], operands[1],
                          /*isMin=*/true);
+  default:
+    llvm_unreachable("unknown op");
   }
-  llvm_unreachable("unknown op");
 }
 
 /// Convert GPU MMA elementwise ops to extract + op + insert.
index 7ee339a..3b3156d 100644 (file)
@@ -1,6 +1,7 @@
 add_mlir_conversion_library(MLIRGPUToSPIRV
   GPUToSPIRV.cpp
   GPUToSPIRVPass.cpp
+  WmmaOpsToSPIRV.cpp
 
   DEPENDS
   MLIRConversionPassIncGen
index c425346..f1c4e32 100644 (file)
@@ -86,9 +86,12 @@ void GPUToSPIRVPass::runOnOperation() {
         SPIRVConversionTarget::get(targetAttr);
 
     SPIRVTypeConverter typeConverter(targetAttr);
+    typeConverter.addConversion([&](gpu::MMAMatrixType type) -> Type {
+      return convertMMAToSPIRVType(type);
+    });
     RewritePatternSet patterns(context);
     populateGPUToSPIRVPatterns(typeConverter, patterns);
-
+    populateGpuWMMAToSPIRVConversionPatterns(typeConverter, patterns);
     // TODO: Change SPIR-V conversion to be progressive and remove the following
     // patterns.
     mlir::arith::populateArithToSPIRVPatterns(typeConverter, patterns);
diff --git a/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp
new file mode 100644 (file)
index 0000000..c890d41
--- /dev/null
@@ -0,0 +1,203 @@
+//===------ WmmaOpsToSPIRV.cpp - WMMA LD/ST/Compute to SPIRV lowering------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains definitions of patterns to lower GPU Subgroup MMA ops to
+// SPIRV Dialect ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h"
+#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
+#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/IR/TypeUtilities.h"
+
+using namespace mlir;
+
+// See SPV_NV_cooperative_matrix for supported element wise ops.
+static void createElementWiseOp(ConversionPatternRewriter &builder,
+                                gpu::SubgroupMmaElementwiseOp op,
+                                spirv::CooperativeMatrixNVType coopType,
+                                ValueRange operands) {
+  switch (op.getOpType()) {
+  case gpu::MMAElementwiseOp::ADDF:
+    builder.replaceOpWithNewOp<spirv::FAddOp>(op, coopType, operands);
+    return;
+  case gpu::MMAElementwiseOp::ADDI:
+    builder.replaceOpWithNewOp<spirv::IAddOp>(op, coopType, operands);
+    return;
+  case gpu::MMAElementwiseOp::SUBF:
+    builder.replaceOpWithNewOp<spirv::FSubOp>(op, coopType, operands);
+    return;
+  case gpu::MMAElementwiseOp::SUBI:
+    builder.replaceOpWithNewOp<spirv::ISubOp>(op, coopType, operands);
+    return;
+  case gpu::MMAElementwiseOp::DIVF:
+    builder.replaceOpWithNewOp<spirv::FDivOp>(op, coopType, operands);
+    return;
+  case gpu::MMAElementwiseOp::DIVS:
+    builder.replaceOpWithNewOp<spirv::SDivOp>(op, coopType, operands);
+    return;
+  case gpu::MMAElementwiseOp::DIVU:
+    builder.replaceOpWithNewOp<spirv::UDivOp>(op, coopType, operands);
+    return;
+  case gpu::MMAElementwiseOp::NEGATEF:
+    builder.replaceOpWithNewOp<spirv::FNegateOp>(op, coopType, operands);
+    return;
+  case gpu::MMAElementwiseOp::NEGATES:
+    builder.replaceOpWithNewOp<spirv::SNegateOp>(op, coopType, operands);
+    return;
+  default:
+    llvm_unreachable("unknown op");
+  }
+}
+
+namespace {
+
+/// This class implements the conversion of GPU MMA loadOp to
+/// CooperativeMatrixLoad op in the SPIRV dialect.
+struct WmmaLoadOpToSPIRVLowering
+    : public OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp subgroupMmaLoadMatrixOp,
+                  OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = subgroupMmaLoadMatrixOp->getLoc();
+    gpu::MMAMatrixType retType =
+        subgroupMmaLoadMatrixOp.getRes().getType().cast<gpu::MMAMatrixType>();
+    auto memrefType =
+        subgroupMmaLoadMatrixOp.getSrcMemref().getType().cast<MemRefType>();
+    Value bufferPtr = spirv::getElementPtr(
+        *getTypeConverter<SPIRVTypeConverter>(), memrefType,
+        adaptor.getSrcMemref(), adaptor.getIndices(), loc, rewriter);
+    auto coopType = convertMMAToSPIRVType(retType);
+    int64_t stride = subgroupMmaLoadMatrixOp.getLeadDimension().getSExtValue();
+    auto i32Type = rewriter.getI32Type();
+    auto strideValue = rewriter.create<spirv::ConstantOp>(
+        loc, i32Type, IntegerAttr::get(i32Type, stride));
+    auto coloumnMajor = rewriter.create<spirv::ConstantOp>(
+        loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
+    rewriter.replaceOpWithNewOp<spirv::NVCooperativeMatrixLoadOp>(
+        subgroupMmaLoadMatrixOp, coopType, bufferPtr, strideValue, coloumnMajor,
+        spirv::MemoryAccessAttr());
+    return success();
+  }
+};
+
+/// This class implements the conversion of GPU MMA StoreOp to
+/// CooperativeMatrixStore op in the SPIRV dialect.
+struct WmmaStoreOpToSPIRVLowering
+    : public OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp subgroupMmaStoreMatrixOp,
+                  OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = subgroupMmaStoreMatrixOp->getLoc();
+    auto memrefType =
+        subgroupMmaStoreMatrixOp.getDstMemref().getType().cast<MemRefType>();
+    Value bufferPtr = spirv::getElementPtr(
+        *getTypeConverter<SPIRVTypeConverter>(), memrefType,
+        adaptor.getDstMemref(), adaptor.getIndices(), loc, rewriter);
+    int64_t stride = subgroupMmaStoreMatrixOp.getLeadDimension().getSExtValue();
+    auto i32Type = rewriter.getI32Type();
+    auto strideValue = rewriter.create<spirv::ConstantOp>(
+        loc, i32Type, IntegerAttr::get(i32Type, stride));
+    auto coloumnMajor = rewriter.create<spirv::ConstantOp>(
+        loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
+    rewriter.replaceOpWithNewOp<spirv::NVCooperativeMatrixStoreOp>(
+        subgroupMmaStoreMatrixOp, bufferPtr, adaptor.getSrc(), strideValue,
+        coloumnMajor, spirv::MemoryAccessAttr());
+    return success();
+  }
+};
+
+/// This class implements the conversion of GPU MMA Compute to
+/// CooperativeMatrixMulAdd op in the SPIRV dialect.
+struct WmmaMmaOpToSPIRVLowering
+    : public OpConversionPattern<gpu::SubgroupMmaComputeOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
+                  OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<spirv::NVCooperativeMatrixMulAddOp>(
+        subgroupMmaComputeOp, adaptor.getOpC().getType(), adaptor.getOpA(),
+        adaptor.getOpB(), adaptor.getOpC());
+    return success();
+  }
+};
+
+/// Convert GPU MMA ConstantMatrixOp to constant SPIR-V cooperative matrix ops.
+struct WmmaConstantOpToSPIRVLowering
+    : public OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp subgroupMmaConstantMatrixOp,
+                  OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Value cst = adaptor.getOperands()[0];
+    auto coopType = convertMMAToSPIRVType(
+        subgroupMmaConstantMatrixOp.getType().cast<gpu::MMAMatrixType>());
+    rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
+        subgroupMmaConstantMatrixOp, coopType, cst);
+    return success();
+  }
+};
+
+/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops.
+struct WmmaElementwiseOpToSPIRVLowering
+    : public OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(gpu::SubgroupMmaElementwiseOp subgroupMmaElementwiseOp,
+                  OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // All operands should be of cooperative matrix types.
+    for (Value operand : adaptor.getOperands()) {
+      if (!operand.getType().isa<spirv::CooperativeMatrixNVType>())
+        return failure();
+    }
+    auto coopType = convertMMAToSPIRVType(
+        subgroupMmaElementwiseOp.getType().cast<gpu::MMAMatrixType>());
+    createElementWiseOp(rewriter, subgroupMmaElementwiseOp, coopType,
+                        adaptor.getOperands());
+    return success();
+  }
+};
+
+} // namespace
+
+/// Return the LLVMStructureType corresponding to the MMAMatrixType `type`.
+mlir::spirv::CooperativeMatrixNVType
+mlir::convertMMAToSPIRVType(gpu::MMAMatrixType type) {
+  ArrayRef<int64_t> retTypeShape = type.getShape();
+  Type elementType = type.getElementType();
+  return spirv::CooperativeMatrixNVType::get(
+      elementType, spirv::Scope::Subgroup, retTypeShape[0], retTypeShape[1]);
+}
+
+void mlir::populateGpuWMMAToSPIRVConversionPatterns(
+    SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
+  patterns.add<WmmaLoadOpToSPIRVLowering, WmmaMmaOpToSPIRVLowering,
+               WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
+               WmmaElementwiseOpToSPIRVLowering>(converter,
+                                                 patterns.getContext());
+}
\ No newline at end of file
index 4f5cfb6..3eee608 100644 (file)
@@ -1192,18 +1192,11 @@ LogicalResult SubgroupMmaLoadMatrixOp::verify() {
   auto resMatrixType = resType.cast<gpu::MMAMatrixType>();
   auto operand = resMatrixType.getOperand();
   auto srcMemrefType = srcType.cast<MemRefType>();
-  auto srcMemSpace = srcMemrefType.getMemorySpaceAsInt();
 
   if (!isLastMemrefDimUnitStride(srcMemrefType))
     return emitError(
         "expected source memref most minor dim must have unit stride");
 
-  if (srcMemSpace != kGenericMemorySpace && srcMemSpace != kSharedMemorySpace &&
-      srcMemSpace != kGlobalMemorySpace)
-    return emitError(
-        "source memorySpace kGenericMemorySpace, kSharedMemorySpace or "
-        "kGlobalMemorySpace only allowed");
-
   if (!operand.equals("AOp") && !operand.equals("BOp") &&
       !operand.equals("COp"))
     return emitError("only AOp, BOp and COp can be loaded");
@@ -1220,17 +1213,11 @@ LogicalResult SubgroupMmaStoreMatrixOp::verify() {
   auto dstType = getDstMemref().getType();
   auto srcMatrixType = srcType.cast<gpu::MMAMatrixType>();
   auto dstMemrefType = dstType.cast<MemRefType>();
-  auto dstMemSpace = dstMemrefType.getMemorySpaceAsInt();
 
   if (!isLastMemrefDimUnitStride(dstMemrefType))
     return emitError(
         "expected destination memref most minor dim must have unit stride");
 
-  if (dstMemSpace != kGenericMemorySpace && dstMemSpace != kSharedMemorySpace &&
-      dstMemSpace != kGlobalMemorySpace)
-    return emitError("destination memorySpace of kGenericMemorySpace, "
-                     "kGlobalMemorySpace or kSharedMemorySpace only allowed");
-
   if (!srcMatrixType.getOperand().equals("COp"))
     return emitError(
         "expected the operand matrix being stored to have 'COp' operand type");
diff --git a/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir b/mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv.mlir
new file mode 100644 (file)
index 0000000..9f2a27c
--- /dev/null
@@ -0,0 +1,110 @@
+// RUN: mlir-opt -allow-unregistered-dialect -convert-gpu-to-spirv -split-input-file -verify-diagnostics %s | FileCheck %s
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
+  gpu.module @kernels {
+    // CHECK:       spirv.module @{{.*}} Logical GLSL450 {
+    // CHECK-LABEL: spirv.func @gpu_wmma_load_op
+    // CHECK-SAME: {{%.*}}: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
+    // CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi<local_size = dense<[32, 4, 1]> : vector<3xi32>>
+    gpu.func @gpu_wmma_load_op(%arg0 : memref<32x32xf16, #spirv.storage_class<StorageBuffer>>) kernel
+      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<local_size = dense<[32, 4, 1]>: vector<3xi32>>} {
+      %i = arith.constant 16 : index
+      %j = arith.constant 16 : index
+      // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} :  !spirv.ptr<f32, StorageBuffer> as !spirv.coopmatrix<16x16xf16, Subgroup>
+      %0 = gpu.subgroup_mma_load_matrix %arg0[%i, %j] {leadDimension = 32 : index} : memref<32x32xf16, #spirv.storage_class<StorageBuffer>> -> !gpu.mma_matrix<16x16xf16, "COp">
+      // CHECK: spirv.Return
+      gpu.return
+    }
+  }
+}
+
+// -----
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
+  gpu.module @kernels {
+    // CHECK:       spirv.module @{{.*}} Logical GLSL450 {
+    // CHECK-LABEL: spirv.func @gpu_wmma_store_op
+    // CHECK-SAME: {{%.*}}: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
+    // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>})
+    // CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi<local_size = dense<[32, 4, 1]> : vector<3xi32>>
+    gpu.func @gpu_wmma_store_op(%arg0 : memref<32x32xf16, #spirv.storage_class<StorageBuffer>>, %arg1 : !gpu.mma_matrix<16x16xf16, "COp">) kernel
+      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<local_size = dense<[32, 4, 1]>: vector<3xi32>>} {
+      %i = arith.constant 16 : index
+      %j = arith.constant 16 : index
+      //  CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<f32, StorageBuffer>, !spirv.coopmatrix<16x16xf16, Subgroup>
+      gpu.subgroup_mma_store_matrix %arg1, %arg0[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16,  #spirv.storage_class<StorageBuffer>>
+      // CHECK: spirv.Return
+      gpu.return
+    }
+  }
+}
+
+// -----
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
+  gpu.module @kernels {
+    // CHECK:       spirv.module @{{.*}} Logical GLSL450 {
+    // CHECK-LABEL: spirv.func @gpu_wmma_mma_op
+    // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
+    // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}
+    // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 2)>})
+    // CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi<local_size = dense<[32, 4, 1]> : vector<3xi32>>
+    gpu.func @gpu_wmma_mma_op(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) kernel
+      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<local_size = dense<[32, 4, 1]>: vector<3xi32>>} {
+      // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>, !spirv.coopmatrix<16x16xf16, Subgroup> -> !spirv.coopmatrix<16x16xf16, Subgroup>
+      %D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
+      // CHECK: spirv.Return
+      gpu.return
+    }
+  }
+}
+
+// -----
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
+  gpu.module @kernels {
+    // CHECK:       spirv.module @{{.*}} Logical GLSL450 {
+    // CHECK-LABEL: spirv.func @gpu_wmma_constant_op
+    gpu.func @gpu_wmma_constant_op() kernel
+      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<local_size = dense<[32, 4, 1]>: vector<3xi32>>} {
+      // CHECK: {{%.*}} = spirv.Constant
+      %cst = arith.constant 1.0 : f16
+      // CHECK: {{%.*}} = spirv.CompositeConstruct {{%.*}} : (f16) -> !spirv.coopmatrix<16x16xf16, Subgroup>
+      %C = gpu.subgroup_mma_constant_matrix %cst : !gpu.mma_matrix<16x16xf16, "COp">
+      // CHECK: spirv.Return
+      gpu.return
+    }
+  }
+}
+
+// -----
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
+  gpu.module @kernels {
+    // CHECK:       spirv.module @{{.*}} Logical GLSL450 {
+    // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op
+    // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
+    // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>})
+    gpu.func @gpu_wmma_elementwise_op(%A : !gpu.mma_matrix<16x16xf16, "COp">, %B : !gpu.mma_matrix<16x16xf16, "COp">) kernel
+      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<local_size = dense<[32, 4, 1]>: vector<3xi32>>} {
+      // CHECK:  {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>
+      %C = gpu.subgroup_mma_elementwise addf %A, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
+      // CHECK:  {{%.*}} = spirv.FNegate {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>
+      %D = gpu.subgroup_mma_elementwise negatef %C : (!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
+      // CHECK:  {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>
+      %E = gpu.subgroup_mma_elementwise divf %D, %A : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
+      // CHECK: spirv.Return
+      gpu.return
+    }
+  }
+}
\ No newline at end of file
index b029d2f..6390de3 100644 (file)
@@ -515,16 +515,6 @@ func.func @mmaLoadOp_identity_layout(){
 
 // -----
 
-func.func @mmaLoadOp_invalid_mem_space(){
-    %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 5>
-    %i = arith.constant 16 : index
-    // expected-error @+1 {{source memorySpace kGenericMemorySpace, kSharedMemorySpace or kGlobalMemorySpace only allowed}}
-    %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 5> -> !gpu.mma_matrix<16x16xf16, "AOp">
-    return
-}
-
-// -----
-
 #layout_map_col_major = affine_map<(i, j) -> (j, i)>
 
 func.func @wmmaStoreOp_invalid_map(%arg0 : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
@@ -538,17 +528,6 @@ func.func @wmmaStoreOp_invalid_map(%arg0 : !gpu.mma_matrix<16x16xf16, "COp">) ->
 
 // -----
 
-func.func @wmmaStoreOp_invalid_mem_space(%arg0 : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
-    %sg = memref.alloca(){alignment = 32} : memref<32x32xf16, 5>
-    %i = arith.constant 16 : index
-    %j = arith.constant 16 : index
-    // expected-error @+1 {{destination memorySpace of kGenericMemorySpace, kGlobalMemorySpace or kSharedMemorySpace only allowed}}
-    gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, 5>
-    return
-}
-
-// -----
-
 func.func @wmmaStoreOp_invalid_store_operand(%arg0 : !gpu.mma_matrix<16x16xf16, "AOp">) -> () {
     %sg = memref.alloca(){alignment = 32} : memref<32x32xf16, 3>
     %i = arith.constant 16 : index