#ifndef MLIR_CONVERSION_GPUTOSPIRV_GPUTOSPIRV_H
#define MLIR_CONVERSION_GPUTOSPIRV_GPUTOSPIRV_H
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
class SPIRVTypeConverter;
+namespace gpu {
+class MMAMatrixType;
+} // namespace gpu
+
/// Appends to a pattern list additional patterns for translating GPU Ops to
/// SPIR-V ops. For a gpu.func to be converted, it should have a
/// spirv.entry_point_abi attribute.
void populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns);
+
+/// Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV.
+void populateGpuWMMAToSPIRVConversionPatterns(SPIRVTypeConverter &typeConverter,
+ RewritePatternSet &patterns);
+
+spirv::CooperativeMatrixNVType convertMMAToSPIRVType(gpu::MMAMatrixType type);
} // namespace mlir
#endif // MLIR_CONVERSION_GPUTOSPIRV_GPUTOSPIRV_H
}];
}
-def GPU_ElementwiseOpAdd : I32EnumAttrCase<"ADDF", 0, "addf">;
-def GPU_ElementwiseOpMul : I32EnumAttrCase<"MULF", 1, "mulf">;
-def GPU_ElementwiseOpMaxF : I32EnumAttrCase<"MAXF", 2, "maxf">;
-def GPU_ElementwiseOpMinF : I32EnumAttrCase<"MINF", 3, "minf">;
-def GPU_ElementwiseOpDivF : I32EnumAttrCase<"DIVF", 4, "divf">;
+def GPU_ElementwiseOpAddF : I32EnumAttrCase<"ADDF", 0, "addf">;
+def GPU_ElementwiseOpMulF : I32EnumAttrCase<"MULF", 1, "mulf">;
+def GPU_ElementwiseOpSUBF : I32EnumAttrCase<"SUBF", 2, "subf">;
+def GPU_ElementwiseOpMaxF : I32EnumAttrCase<"MAXF", 3, "maxf">;
+def GPU_ElementwiseOpMinF : I32EnumAttrCase<"MINF", 4, "minf">;
+def GPU_ElementwiseOpDivF : I32EnumAttrCase<"DIVF", 5, "divf">;
+def GPU_ElementwiseOpAddI : I32EnumAttrCase<"ADDI", 6, "addi">;
+def GPU_ElementwiseOpMulI : I32EnumAttrCase<"MULI", 7, "muli">;
+def GPU_ElementwiseOpSUBI : I32EnumAttrCase<"SUBI", 8, "subi">;
+def GPU_ElementwiseOpDivS : I32EnumAttrCase<"DIVS", 9, "divs">;
+def GPU_ElementwiseOpDivU : I32EnumAttrCase<"DIVU", 10, "divu">;
+def GPU_ElementwiseOpNEGF : I32EnumAttrCase<"NEGATEF", 11, "negatef">;
+def GPU_ElementwiseOpNEGS : I32EnumAttrCase<"NEGATES", 12, "negates">;
def MMAElementWise : I32EnumAttr<"MMAElementwiseOp",
"elementwise operation to apply to mma matrix", [
- GPU_ElementwiseOpAdd,
- GPU_ElementwiseOpMul,
+ GPU_ElementwiseOpAddF,
+ GPU_ElementwiseOpMulF,
+ GPU_ElementwiseOpSUBF,
GPU_ElementwiseOpMaxF,
GPU_ElementwiseOpMinF,
- GPU_ElementwiseOpDivF
+ GPU_ElementwiseOpDivF,
+ GPU_ElementwiseOpAddI,
+ GPU_ElementwiseOpMulI,
+ GPU_ElementwiseOpSUBI,
+ GPU_ElementwiseOpDivS,
+ GPU_ElementwiseOpDivU,
+ GPU_ElementwiseOpNEGF,
+ GPU_ElementwiseOpNEGS
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::gpu";
// Operand type same as result type.
SPIRV_UnaryOp<mnemonic, type, type,
!listconcat(traits,
- [Pure, SameOperandsAndResultType])>;
+ [Pure, SameOperandsAndResultType])> {
+ // In addition to normal types arithmetic instructions can support cooperative
+ // matrix.
+ let arguments = (ins
+ SPIRV_ScalarOrVectorOrCoopMatrixOf<type>:$operand
+ );
+
+ let results = (outs
+ SPIRV_ScalarOrVectorOrCoopMatrixOf<type>:$result
+ );
+ let assemblyFormat = "operands attr-dict `:` type($result)";
+ }
// -----
case gpu::MMAElementwiseOp::MINF:
return createMinMaxF(builder, loc, operands[0], operands[1],
/*isMin=*/true);
+ default:
+ llvm_unreachable("unknown op");
}
- llvm_unreachable("unknown op");
}
/// Convert GPU MMA elementwise ops to extract + op + insert.
add_mlir_conversion_library(MLIRGPUToSPIRV
GPUToSPIRV.cpp
GPUToSPIRVPass.cpp
+ WmmaOpsToSPIRV.cpp
DEPENDS
MLIRConversionPassIncGen
SPIRVConversionTarget::get(targetAttr);
SPIRVTypeConverter typeConverter(targetAttr);
+ typeConverter.addConversion([&](gpu::MMAMatrixType type) -> Type {
+ return convertMMAToSPIRVType(type);
+ });
RewritePatternSet patterns(context);
populateGPUToSPIRVPatterns(typeConverter, patterns);
-
+ populateGpuWMMAToSPIRVConversionPatterns(typeConverter, patterns);
// TODO: Change SPIR-V conversion to be progressive and remove the following
// patterns.
mlir::arith::populateArithToSPIRVPatterns(typeConverter, patterns);
--- /dev/null
+//===------ WmmaOpsToSPIRV.cpp - WMMA LD/ST/Compute to SPIRV lowering------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains definitions of patterns to lower GPU Subgroup MMA ops to
+// SPIRV Dialect ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h"
+#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h"
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
+#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/IR/TypeUtilities.h"
+
+using namespace mlir;
+
+// See SPV_NV_cooperative_matrix for supported element wise ops.
+static void createElementWiseOp(ConversionPatternRewriter &builder,
+ gpu::SubgroupMmaElementwiseOp op,
+ spirv::CooperativeMatrixNVType coopType,
+ ValueRange operands) {
+ switch (op.getOpType()) {
+ case gpu::MMAElementwiseOp::ADDF:
+ builder.replaceOpWithNewOp<spirv::FAddOp>(op, coopType, operands);
+ return;
+ case gpu::MMAElementwiseOp::ADDI:
+ builder.replaceOpWithNewOp<spirv::IAddOp>(op, coopType, operands);
+ return;
+ case gpu::MMAElementwiseOp::SUBF:
+ builder.replaceOpWithNewOp<spirv::FSubOp>(op, coopType, operands);
+ return;
+ case gpu::MMAElementwiseOp::SUBI:
+ builder.replaceOpWithNewOp<spirv::ISubOp>(op, coopType, operands);
+ return;
+ case gpu::MMAElementwiseOp::DIVF:
+ builder.replaceOpWithNewOp<spirv::FDivOp>(op, coopType, operands);
+ return;
+ case gpu::MMAElementwiseOp::DIVS:
+ builder.replaceOpWithNewOp<spirv::SDivOp>(op, coopType, operands);
+ return;
+ case gpu::MMAElementwiseOp::DIVU:
+ builder.replaceOpWithNewOp<spirv::UDivOp>(op, coopType, operands);
+ return;
+ case gpu::MMAElementwiseOp::NEGATEF:
+ builder.replaceOpWithNewOp<spirv::FNegateOp>(op, coopType, operands);
+ return;
+ case gpu::MMAElementwiseOp::NEGATES:
+ builder.replaceOpWithNewOp<spirv::SNegateOp>(op, coopType, operands);
+ return;
+ default:
+ llvm_unreachable("unknown op");
+ }
+}
+
+namespace {
+
+/// This class implements the conversion of GPU MMA loadOp to
+/// CooperativeMatrixLoad op in the SPIRV dialect.
+struct WmmaLoadOpToSPIRVLowering
+ : public OpConversionPattern<gpu::SubgroupMmaLoadMatrixOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp subgroupMmaLoadMatrixOp,
+ OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = subgroupMmaLoadMatrixOp->getLoc();
+ gpu::MMAMatrixType retType =
+ subgroupMmaLoadMatrixOp.getRes().getType().cast<gpu::MMAMatrixType>();
+ auto memrefType =
+ subgroupMmaLoadMatrixOp.getSrcMemref().getType().cast<MemRefType>();
+ Value bufferPtr = spirv::getElementPtr(
+ *getTypeConverter<SPIRVTypeConverter>(), memrefType,
+ adaptor.getSrcMemref(), adaptor.getIndices(), loc, rewriter);
+ auto coopType = convertMMAToSPIRVType(retType);
+ int64_t stride = subgroupMmaLoadMatrixOp.getLeadDimension().getSExtValue();
+ auto i32Type = rewriter.getI32Type();
+ auto strideValue = rewriter.create<spirv::ConstantOp>(
+ loc, i32Type, IntegerAttr::get(i32Type, stride));
+ auto coloumnMajor = rewriter.create<spirv::ConstantOp>(
+ loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
+ rewriter.replaceOpWithNewOp<spirv::NVCooperativeMatrixLoadOp>(
+ subgroupMmaLoadMatrixOp, coopType, bufferPtr, strideValue, coloumnMajor,
+ spirv::MemoryAccessAttr());
+ return success();
+ }
+};
+
+/// This class implements the conversion of GPU MMA StoreOp to
+/// CooperativeMatrixStore op in the SPIRV dialect.
+struct WmmaStoreOpToSPIRVLowering
+ : public OpConversionPattern<gpu::SubgroupMmaStoreMatrixOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp subgroupMmaStoreMatrixOp,
+ OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = subgroupMmaStoreMatrixOp->getLoc();
+ auto memrefType =
+ subgroupMmaStoreMatrixOp.getDstMemref().getType().cast<MemRefType>();
+ Value bufferPtr = spirv::getElementPtr(
+ *getTypeConverter<SPIRVTypeConverter>(), memrefType,
+ adaptor.getDstMemref(), adaptor.getIndices(), loc, rewriter);
+ int64_t stride = subgroupMmaStoreMatrixOp.getLeadDimension().getSExtValue();
+ auto i32Type = rewriter.getI32Type();
+ auto strideValue = rewriter.create<spirv::ConstantOp>(
+ loc, i32Type, IntegerAttr::get(i32Type, stride));
+ auto coloumnMajor = rewriter.create<spirv::ConstantOp>(
+ loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
+ rewriter.replaceOpWithNewOp<spirv::NVCooperativeMatrixStoreOp>(
+ subgroupMmaStoreMatrixOp, bufferPtr, adaptor.getSrc(), strideValue,
+ coloumnMajor, spirv::MemoryAccessAttr());
+ return success();
+ }
+};
+
+/// This class implements the conversion of GPU MMA Compute to
+/// CooperativeMatrixMulAdd op in the SPIRV dialect.
+struct WmmaMmaOpToSPIRVLowering
+ : public OpConversionPattern<gpu::SubgroupMmaComputeOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp,
+ OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<spirv::NVCooperativeMatrixMulAddOp>(
+ subgroupMmaComputeOp, adaptor.getOpC().getType(), adaptor.getOpA(),
+ adaptor.getOpB(), adaptor.getOpC());
+ return success();
+ }
+};
+
+/// Convert GPU MMA ConstantMatrixOp to constant SPIR-V cooperative matrix ops.
+struct WmmaConstantOpToSPIRVLowering
+ : public OpConversionPattern<gpu::SubgroupMmaConstantMatrixOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(gpu::SubgroupMmaConstantMatrixOp subgroupMmaConstantMatrixOp,
+ OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Value cst = adaptor.getOperands()[0];
+ auto coopType = convertMMAToSPIRVType(
+ subgroupMmaConstantMatrixOp.getType().cast<gpu::MMAMatrixType>());
+ rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
+ subgroupMmaConstantMatrixOp, coopType, cst);
+ return success();
+ }
+};
+
+/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops.
+struct WmmaElementwiseOpToSPIRVLowering
+ : public OpConversionPattern<gpu::SubgroupMmaElementwiseOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(gpu::SubgroupMmaElementwiseOp subgroupMmaElementwiseOp,
+ OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // All operands should be of cooperative matrix types.
+ for (Value operand : adaptor.getOperands()) {
+ if (!operand.getType().isa<spirv::CooperativeMatrixNVType>())
+ return failure();
+ }
+ auto coopType = convertMMAToSPIRVType(
+ subgroupMmaElementwiseOp.getType().cast<gpu::MMAMatrixType>());
+ createElementWiseOp(rewriter, subgroupMmaElementwiseOp, coopType,
+ adaptor.getOperands());
+ return success();
+ }
+};
+
+} // namespace
+
+/// Return the LLVMStructureType corresponding to the MMAMatrixType `type`.
+mlir::spirv::CooperativeMatrixNVType
+mlir::convertMMAToSPIRVType(gpu::MMAMatrixType type) {
+ ArrayRef<int64_t> retTypeShape = type.getShape();
+ Type elementType = type.getElementType();
+ return spirv::CooperativeMatrixNVType::get(
+ elementType, spirv::Scope::Subgroup, retTypeShape[0], retTypeShape[1]);
+}
+
+void mlir::populateGpuWMMAToSPIRVConversionPatterns(
+ SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
+ patterns.add<WmmaLoadOpToSPIRVLowering, WmmaMmaOpToSPIRVLowering,
+ WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
+ WmmaElementwiseOpToSPIRVLowering>(converter,
+ patterns.getContext());
+}
\ No newline at end of file
auto resMatrixType = resType.cast<gpu::MMAMatrixType>();
auto operand = resMatrixType.getOperand();
auto srcMemrefType = srcType.cast<MemRefType>();
- auto srcMemSpace = srcMemrefType.getMemorySpaceAsInt();
if (!isLastMemrefDimUnitStride(srcMemrefType))
return emitError(
"expected source memref most minor dim must have unit stride");
- if (srcMemSpace != kGenericMemorySpace && srcMemSpace != kSharedMemorySpace &&
- srcMemSpace != kGlobalMemorySpace)
- return emitError(
- "source memorySpace kGenericMemorySpace, kSharedMemorySpace or "
- "kGlobalMemorySpace only allowed");
-
if (!operand.equals("AOp") && !operand.equals("BOp") &&
!operand.equals("COp"))
return emitError("only AOp, BOp and COp can be loaded");
auto dstType = getDstMemref().getType();
auto srcMatrixType = srcType.cast<gpu::MMAMatrixType>();
auto dstMemrefType = dstType.cast<MemRefType>();
- auto dstMemSpace = dstMemrefType.getMemorySpaceAsInt();
if (!isLastMemrefDimUnitStride(dstMemrefType))
return emitError(
"expected destination memref most minor dim must have unit stride");
- if (dstMemSpace != kGenericMemorySpace && dstMemSpace != kSharedMemorySpace &&
- dstMemSpace != kGlobalMemorySpace)
- return emitError("destination memorySpace of kGenericMemorySpace, "
- "kGlobalMemorySpace or kSharedMemorySpace only allowed");
-
if (!srcMatrixType.getOperand().equals("COp"))
return emitError(
"expected the operand matrix being stored to have 'COp' operand type");
--- /dev/null
+// RUN: mlir-opt -allow-unregistered-dialect -convert-gpu-to-spirv -split-input-file -verify-diagnostics %s | FileCheck %s
+
+module attributes {
+ gpu.container_module,
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
+ gpu.module @kernels {
+ // CHECK: spirv.module @{{.*}} Logical GLSL450 {
+ // CHECK-LABEL: spirv.func @gpu_wmma_load_op
+ // CHECK-SAME: {{%.*}}: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
+ // CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi<local_size = dense<[32, 4, 1]> : vector<3xi32>>
+ gpu.func @gpu_wmma_load_op(%arg0 : memref<32x32xf16, #spirv.storage_class<StorageBuffer>>) kernel
+ attributes {spirv.entry_point_abi = #spirv.entry_point_abi<local_size = dense<[32, 4, 1]>: vector<3xi32>>} {
+ %i = arith.constant 16 : index
+ %j = arith.constant 16 : index
+ // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<f32, StorageBuffer> as !spirv.coopmatrix<16x16xf16, Subgroup>
+ %0 = gpu.subgroup_mma_load_matrix %arg0[%i, %j] {leadDimension = 32 : index} : memref<32x32xf16, #spirv.storage_class<StorageBuffer>> -> !gpu.mma_matrix<16x16xf16, "COp">
+ // CHECK: spirv.Return
+ gpu.return
+ }
+ }
+}
+
+// -----
+
+module attributes {
+ gpu.container_module,
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
+ gpu.module @kernels {
+ // CHECK: spirv.module @{{.*}} Logical GLSL450 {
+ // CHECK-LABEL: spirv.func @gpu_wmma_store_op
+ // CHECK-SAME: {{%.*}}: !spirv.ptr<!spirv.struct<(!spirv.array<512 x f32, stride=4> [0])>, StorageBuffer> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
+ // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>})
+ // CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi<local_size = dense<[32, 4, 1]> : vector<3xi32>>
+ gpu.func @gpu_wmma_store_op(%arg0 : memref<32x32xf16, #spirv.storage_class<StorageBuffer>>, %arg1 : !gpu.mma_matrix<16x16xf16, "COp">) kernel
+ attributes {spirv.entry_point_abi = #spirv.entry_point_abi<local_size = dense<[32, 4, 1]>: vector<3xi32>>} {
+ %i = arith.constant 16 : index
+ %j = arith.constant 16 : index
+ // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr<f32, StorageBuffer>, !spirv.coopmatrix<16x16xf16, Subgroup>
+ gpu.subgroup_mma_store_matrix %arg1, %arg0[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, #spirv.storage_class<StorageBuffer>>
+ // CHECK: spirv.Return
+ gpu.return
+ }
+ }
+}
+
+// -----
+
+module attributes {
+ gpu.container_module,
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
+ gpu.module @kernels {
+ // CHECK: spirv.module @{{.*}} Logical GLSL450 {
+ // CHECK-LABEL: spirv.func @gpu_wmma_mma_op
+ // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
+ // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>}
+ // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 2)>})
+ // CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi<local_size = dense<[32, 4, 1]> : vector<3xi32>>
+ gpu.func @gpu_wmma_mma_op(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) kernel
+ attributes {spirv.entry_point_abi = #spirv.entry_point_abi<local_size = dense<[32, 4, 1]>: vector<3xi32>>} {
+ // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>, !spirv.coopmatrix<16x16xf16, Subgroup> -> !spirv.coopmatrix<16x16xf16, Subgroup>
+ %D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
+ // CHECK: spirv.Return
+ gpu.return
+ }
+ }
+}
+
+// -----
+
+module attributes {
+ gpu.container_module,
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
+ gpu.module @kernels {
+ // CHECK: spirv.module @{{.*}} Logical GLSL450 {
+ // CHECK-LABEL: spirv.func @gpu_wmma_constant_op
+ gpu.func @gpu_wmma_constant_op() kernel
+ attributes {spirv.entry_point_abi = #spirv.entry_point_abi<local_size = dense<[32, 4, 1]>: vector<3xi32>>} {
+ // CHECK: {{%.*}} = spirv.Constant
+ %cst = arith.constant 1.0 : f16
+ // CHECK: {{%.*}} = spirv.CompositeConstruct {{%.*}} : (f16) -> !spirv.coopmatrix<16x16xf16, Subgroup>
+ %C = gpu.subgroup_mma_constant_matrix %cst : !gpu.mma_matrix<16x16xf16, "COp">
+ // CHECK: spirv.Return
+ gpu.return
+ }
+ }
+}
+
+// -----
+
+module attributes {
+ gpu.container_module,
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, CooperativeMatrixNV, Float16], [SPV_KHR_storage_buffer_storage_class, SPV_NV_cooperative_matrix]>, #spirv.resource_limits<>>} {
+ gpu.module @kernels {
+ // CHECK: spirv.module @{{.*}} Logical GLSL450 {
+ // CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op
+ // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 0)>}
+ // CHECK-SAME: {{%.*}}: !spirv.coopmatrix<16x16xf16, Subgroup> {spirv.interface_var_abi = #spirv.interface_var_abi<(0, 1)>})
+ gpu.func @gpu_wmma_elementwise_op(%A : !gpu.mma_matrix<16x16xf16, "COp">, %B : !gpu.mma_matrix<16x16xf16, "COp">) kernel
+ attributes {spirv.entry_point_abi = #spirv.entry_point_abi<local_size = dense<[32, 4, 1]>: vector<3xi32>>} {
+ // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>
+ %C = gpu.subgroup_mma_elementwise addf %A, %B : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
+ // CHECK: {{%.*}} = spirv.FNegate {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>
+ %D = gpu.subgroup_mma_elementwise negatef %C : (!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
+ // CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>
+ %E = gpu.subgroup_mma_elementwise divf %D, %A : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
+ // CHECK: spirv.Return
+ gpu.return
+ }
+ }
+}
\ No newline at end of file
// -----
-func.func @mmaLoadOp_invalid_mem_space(){
- %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 5>
- %i = arith.constant 16 : index
- // expected-error @+1 {{source memorySpace kGenericMemorySpace, kSharedMemorySpace or kGlobalMemorySpace only allowed}}
- %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 5> -> !gpu.mma_matrix<16x16xf16, "AOp">
- return
-}
-
-// -----
-
#layout_map_col_major = affine_map<(i, j) -> (j, i)>
func.func @wmmaStoreOp_invalid_map(%arg0 : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
// -----
-func.func @wmmaStoreOp_invalid_mem_space(%arg0 : !gpu.mma_matrix<16x16xf16, "COp">) -> () {
- %sg = memref.alloca(){alignment = 32} : memref<32x32xf16, 5>
- %i = arith.constant 16 : index
- %j = arith.constant 16 : index
- // expected-error @+1 {{destination memorySpace of kGenericMemorySpace, kGlobalMemorySpace or kSharedMemorySpace only allowed}}
- gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, 5>
- return
-}
-
-// -----
-
func.func @wmmaStoreOp_invalid_store_operand(%arg0 : !gpu.mma_matrix<16x16xf16, "AOp">) -> () {
%sg = memref.alloca(){alignment = 32} : memref<32x32xf16, 3>
%i = arith.constant 16 : index