``` {.ebnf}
cooperative-matrixload-op ::= ssa-id `=` `spv.CooperativeMatrixLoadNV`
- storage-class ssa-use `,` ssa-use `,` ssa-use
+ ssa-use `,` ssa-use `,` ssa-use
(`[` memory-access `]`)? ` : `
+ pointer-type `as`
cooperative-matrix-type
```
For example:
```
- %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %colMajor
- : !spv.coopmatrix<i32, Workgroup, 16, 8>
+ %0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %colMajor
+ : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<i32, Workgroup, 16, 8>
```
}];
``` {.ebnf}
coop-matrix-store-op ::= `spv.CooperativeMatrixStoreNV `
- storage-class ssa-use `, ` ssa-use `, `
ssa-use `, ` ssa-use `, `
- (`[` memory-access `]`)? `:` spirv-element-type
+ ssa-use `, ` ssa-use `, `
+ (`[` memory-access `]`)? `:`
+ pointer-type `,` spirv-element-type
```
For example:
```
- spv.CooperativeMatrixStoreNV "StorageBuffer" %arg0, %arg2, %arg1, %arg3 :
- !spv.coopmatrix<Workgroup, i32, 16, 8>
+ spv.CooperativeMatrixStoreNV %arg0, %arg2, %arg1, %arg3 :
+ !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<Workgroup, i32, 16, 8>
```
}];
static ParseResult parseCooperativeMatrixLoadNVOp(OpAsmParser &parser,
OperationState &state) {
- spirv::StorageClass storageClass;
SmallVector<OpAsmParser::OperandType, 3> operandInfo;
Type strideType = parser.getBuilder().getIntegerType(32);
Type columnMajorType = parser.getBuilder().getIntegerType(1);
+ Type ptrType;
Type elementType;
- if (parseEnumStrAttr(storageClass, parser) ||
- parser.parseOperandList(operandInfo, 3) ||
+ if (parser.parseOperandList(operandInfo, 3) ||
parseMemoryAccessAttributes(parser, state) || parser.parseColon() ||
- parser.parseType(elementType)) {
+ parser.parseType(ptrType) || parser.parseKeywordType("as", elementType)) {
return failure();
}
-
- auto ptrType = spirv::PointerType::get(
- elementType.cast<spirv::CooperativeMatrixNVType>().getElementType(),
- storageClass);
SmallVector<Type, 3> OperandType = {ptrType, strideType, columnMajorType};
if (parser.resolveOperands(operandInfo, OperandType, parser.getNameLoc(),
state.operands)) {
}
static void print(spirv::CooperativeMatrixLoadNVOp M, OpAsmPrinter &printer) {
- StringRef sc = stringifyStorageClass(
- M.pointer().getType().cast<spirv::PointerType>().getStorageClass());
- printer << spirv::CooperativeMatrixLoadNVOp::getOperationName() << " \"" << sc
- << "\" " << M.pointer() << ", " << M.stride() << ", "
- << M.columnmajor();
+ printer << spirv::CooperativeMatrixLoadNVOp::getOperationName() << " "
+ << M.pointer() << ", " << M.stride() << ", " << M.columnmajor();
// Print optional memory access attribute.
if (auto memAccess = M.memory_access())
printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
- printer << " : " << M.getType();
+ printer << " : " << M.pointer().getType() << " as " << M.getType();
}
static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
Type coopMatrix) {
- if (pointer.cast<spirv::PointerType>().getPointeeType() !=
- coopMatrix.cast<spirv::CooperativeMatrixNVType>().getElementType())
+ Type pointeeType = pointer.cast<spirv::PointerType>().getPointeeType();
+ if (!pointeeType.isa<spirv::ScalarType>() && !pointeeType.isa<VectorType>())
return op->emitError(
- "expected the same type for pointer and the cooperative matrix"
- "element, bu provided ")
- << pointer << " and " << coopMatrix;
+ "Pointer must point to a scalar or vector type but provided ")
+ << pointeeType;
+ spirv::StorageClass storage =
+ pointer.cast<spirv::PointerType>().getStorageClass();
+ if (storage != spirv::StorageClass::Workgroup &&
+ storage != spirv::StorageClass::StorageBuffer &&
+ storage != spirv::StorageClass::PhysicalStorageBuffer)
+ return op->emitError(
+ "Pointer storage class must be Workgroup, StorageBuffer or "
+ "PhysicalStorageBufferEXT but provided ")
+ << stringifyStorageClass(storage);
return success();
}
static ParseResult parseCooperativeMatrixStoreNVOp(OpAsmParser &parser,
OperationState &state) {
- spirv::StorageClass storageClass;
SmallVector<OpAsmParser::OperandType, 4> operandInfo;
Type strideType = parser.getBuilder().getIntegerType(32);
Type columnMajorType = parser.getBuilder().getIntegerType(1);
+ Type ptrType;
Type elementType;
- if (parseEnumStrAttr(storageClass, parser) ||
- parser.parseOperandList(operandInfo, 4) ||
+ if (parser.parseOperandList(operandInfo, 4) ||
parseMemoryAccessAttributes(parser, state) || parser.parseColon() ||
+ parser.parseType(ptrType) || parser.parseComma() ||
parser.parseType(elementType)) {
return failure();
}
-
- auto ptrType = spirv::PointerType::get(
- elementType.cast<spirv::CooperativeMatrixNVType>().getElementType(),
- storageClass);
SmallVector<Type, 4> OperandType = {ptrType, elementType, strideType,
columnMajorType};
if (parser.resolveOperands(operandInfo, OperandType, parser.getNameLoc(),
static void print(spirv::CooperativeMatrixStoreNVOp coopMatrix,
OpAsmPrinter &printer) {
- StringRef sc = stringifyStorageClass(coopMatrix.pointer()
- .getType()
- .cast<spirv::PointerType>()
- .getStorageClass());
- printer << spirv::CooperativeMatrixStoreNVOp::getOperationName() << " \""
- << sc << "\" " << coopMatrix.pointer() << ", " << coopMatrix.object()
- << ", " << coopMatrix.stride() << ", " << coopMatrix.columnmajor();
+ printer << spirv::CooperativeMatrixStoreNVOp::getOperationName() << " "
+ << coopMatrix.pointer() << ", " << coopMatrix.object() << ", "
+ << coopMatrix.stride() << ", " << coopMatrix.columnmajor();
// Print optional memory access attribute.
if (auto memAccess = coopMatrix.memory_access())
printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
- printer << " : " << coopMatrix.getOperand(1).getType();
+ printer << " : " << coopMatrix.pointer().getType() << ", "
+ << coopMatrix.getOperand(1).getType();
}
//===----------------------------------------------------------------------===//
spv.module Logical GLSL450 requires #spv.vce<v1.0, [CooperativeMatrixNV], [SPV_NV_cooperative_matrix]> {
// CHECK-LABEL: @cooperative_matrix_load
spv.func @cooperative_matrix_load(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
- // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<16x8xi32, Workgroup>
- %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b : !spv.coopmatrix<16x8xi32, Workgroup>
+ // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV {{%.*}}, {{%.*}}, {{%.*}} : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<16x8xi32, Workgroup>
+ %0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<16x8xi32, Workgroup>
spv.Return
}
// CHECK-LABEL: @cooperative_matrix_load_memaccess
spv.func @cooperative_matrix_load_memaccess(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
- // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
- %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
+ // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup>
+ %0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b ["Volatile"] : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup>
spv.Return
}
// CHECK-LABEL: @cooperative_matrix_store
spv.func @cooperative_matrix_store(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %m : !spv.coopmatrix<16x8xi32, Workgroup>, %b : i1) "None" {
- // CHECK: spv.CooperativeMatrixStoreNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<16x8xi32, Workgroup>
- spv.CooperativeMatrixStoreNV "StorageBuffer" %ptr, %m, %stride, %b : !spv.coopmatrix<16x8xi32, Workgroup>
+ // CHECK: spv.CooperativeMatrixStoreNV {{%.*}}, {{%.*}}, {{%.*}} : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<16x8xi32, Workgroup>
+ spv.CooperativeMatrixStoreNV %ptr, %m, %stride, %b : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<16x8xi32, Workgroup>
spv.Return
}
// CHECK-LABEL: @cooperative_matrix_store_memaccess
spv.func @cooperative_matrix_store_memaccess(%ptr : !spv.ptr<i32, StorageBuffer>, %m : !spv.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" {
- // CHECK: spv.CooperativeMatrixStoreNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
- spv.CooperativeMatrixStoreNV "StorageBuffer" %ptr, %m, %stride, %b ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
+ // CHECK: spv.CooperativeMatrixStoreNV {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<8x16xi32, Subgroup>
+ spv.CooperativeMatrixStoreNV %ptr, %m, %stride, %b ["Volatile"] : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<8x16xi32, Subgroup>
spv.Return
}
// CHECK-LABEL: @cooperative_matrix_load
spv.func @cooperative_matrix_load(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
- // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<16x8xi32, Workgroup>
- %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b : !spv.coopmatrix<16x8xi32, Workgroup>
+ // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV {{%.*}}, {{%.*}}, {{%.*}} : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<16x8xi32, Workgroup>
+ %0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<16x8xi32, Workgroup>
spv.Return
}
// -----
// CHECK-LABEL: @cooperative_matrix_load_memaccess
spv.func @cooperative_matrix_load_memaccess(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
- // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
- %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
+ // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup>
+ %0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b ["Volatile"] : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup>
+ spv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_load_diff_ptr_type
+spv.func @cooperative_matrix_load_diff_ptr_type(%ptr : !spv.ptr<vector<4xi32>, StorageBuffer>, %stride : i32, %b : i1) "None" {
+ // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.ptr<vector<4xi32>, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup>
+ %0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b ["Volatile"] : !spv.ptr<vector<4xi32>, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup>
spv.Return
}
// CHECK-LABEL: @cooperative_matrix_store
spv.func @cooperative_matrix_store(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %m : !spv.coopmatrix<8x16xi32, Workgroup>, %b : i1) "None" {
- // CHECK: spv.CooperativeMatrixStoreNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xi32, Workgroup>
- spv.CooperativeMatrixStoreNV "StorageBuffer" %ptr, %m, %stride, %b : !spv.coopmatrix<8x16xi32, Workgroup>
+ // CHECK: spv.CooperativeMatrixStoreNV {{%.*}}, {{%.*}}, {{%.*}} : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<8x16xi32, Workgroup>
+ spv.CooperativeMatrixStoreNV %ptr, %m, %stride, %b : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<8x16xi32, Workgroup>
spv.Return
}
// CHECK-LABEL: @cooperative_matrix_store_memaccess
spv.func @cooperative_matrix_store_memaccess(%ptr : !spv.ptr<i32, StorageBuffer>, %m : !spv.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" {
- // CHECK: spv.CooperativeMatrixStoreNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
- spv.CooperativeMatrixStoreNV "StorageBuffer" %ptr, %m, %stride, %b ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
+ // CHECK: spv.CooperativeMatrixStoreNV {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<8x16xi32, Subgroup>
+ spv.CooperativeMatrixStoreNV %ptr, %m, %stride, %b ["Volatile"] : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<8x16xi32, Subgroup>
spv.Return
}
spv.Return
}
+// -----
+
+spv.func @cooperative_matrix_load_memaccess(%ptr : !spv.ptr<!spv.struct<f32 [0]>, StorageBuffer>, %stride : i32, %b : i1) "None" {
+ // expected-error @+1 {{Pointer must point to a scalar or vector type}}
+ %0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b : !spv.ptr<!spv.struct<f32 [0]>, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup>
+ spv.Return
+}
+
+// -----
+
+spv.func @cooperative_matrix_load_memaccess(%ptr : !spv.ptr<i32, Function>, %stride : i32, %b : i1) "None" {
+ // expected-error @+1 {{Pointer storage class must be Workgroup, StorageBuffer or PhysicalStorageBufferEXT}}
+ %0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b : !spv.ptr<i32, Function> as !spv.coopmatrix<8x16xi32, Subgroup>
+ spv.Return
+}