[mlir][spirv] Relax restriction on pointer type for CooperativeMatrix load/store
authorThomas Raoux <thomasraoux@google.com>
Fri, 31 Jul 2020 15:02:21 +0000 (08:02 -0700)
committerThomas Raoux <thomasraoux@google.com>
Fri, 31 Jul 2020 15:02:21 +0000 (08:02 -0700)
This change allow CooperativeMatrix Load/Store operations to use pointer type
that may not match the matrix element type. This allow us to declare buffer
with a larger type size than the matrix element type. This follows SPIR-V spec
and this is needed to be able to use cooperative matrix in combination with
shared local memory efficiently.

Differential Revision: https://reviews.llvm.org/D84993

mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir
mlir/test/Dialect/SPIRV/cooperative-matrix.mlir

index 9c3462a..720cfd6 100644 (file)
@@ -101,16 +101,17 @@ def SPV_CooperativeMatrixLoadNVOp : SPV_Op<"CooperativeMatrixLoadNV", []> {
 
     ``` {.ebnf}
     cooperative-matrixload-op ::= ssa-id `=` `spv.CooperativeMatrixLoadNV`
-                              storage-class ssa-use `,` ssa-use `,` ssa-use
+                              ssa-use `,` ssa-use `,` ssa-use
                               (`[` memory-access `]`)? ` : `
+                              pointer-type `as`
                               cooperative-matrix-type
     ```
 
     For example:
 
     ```
-    %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %colMajor
-         : !spv.coopmatrix<i32, Workgroup, 16, 8>
+    %0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %colMajor
+         : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<i32, Workgroup, 16, 8>
     ```
   }];
 
@@ -243,16 +244,17 @@ def SPV_CooperativeMatrixStoreNVOp : SPV_Op<"CooperativeMatrixStoreNV", []> {
 
     ``` {.ebnf}
     coop-matrix-store-op ::= `spv.CooperativeMatrixStoreNV `
-                              storage-class ssa-use `, ` ssa-use `, `
                               ssa-use `, ` ssa-use `, `
-                  (`[` memory-access `]`)? `:` spirv-element-type
+                              ssa-use `, ` ssa-use `, `
+                              (`[` memory-access `]`)? `:`
+                              pointer-type `,` spirv-element-type
     ```
 
     For example:
 
     ```
-      spv.CooperativeMatrixStoreNV "StorageBuffer" %arg0, %arg2, %arg1, %arg3 :
-        !spv.coopmatrix<Workgroup, i32, 16, 8>
+      spv.CooperativeMatrixStoreNV %arg0, %arg2, %arg1, %arg3 :
+        !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<Workgroup, i32, 16, 8>
     ```
   }];
 
index b0235d4..bac65a0 100644 (file)
@@ -2793,21 +2793,16 @@ static LogicalResult verify(spirv::VariableOp varOp) {
 
 static ParseResult parseCooperativeMatrixLoadNVOp(OpAsmParser &parser,
                                                   OperationState &state) {
-  spirv::StorageClass storageClass;
   SmallVector<OpAsmParser::OperandType, 3> operandInfo;
   Type strideType = parser.getBuilder().getIntegerType(32);
   Type columnMajorType = parser.getBuilder().getIntegerType(1);
+  Type ptrType;
   Type elementType;
-  if (parseEnumStrAttr(storageClass, parser) ||
-      parser.parseOperandList(operandInfo, 3) ||
+  if (parser.parseOperandList(operandInfo, 3) ||
       parseMemoryAccessAttributes(parser, state) || parser.parseColon() ||
-      parser.parseType(elementType)) {
+      parser.parseType(ptrType) || parser.parseKeywordType("as", elementType)) {
     return failure();
   }
-
-  auto ptrType = spirv::PointerType::get(
-      elementType.cast<spirv::CooperativeMatrixNVType>().getElementType(),
-      storageClass);
   SmallVector<Type, 3> OperandType = {ptrType, strideType, columnMajorType};
   if (parser.resolveOperands(operandInfo, OperandType, parser.getNameLoc(),
                              state.operands)) {
@@ -2819,25 +2814,30 @@ static ParseResult parseCooperativeMatrixLoadNVOp(OpAsmParser &parser,
 }
 
 static void print(spirv::CooperativeMatrixLoadNVOp M, OpAsmPrinter &printer) {
-  StringRef sc = stringifyStorageClass(
-      M.pointer().getType().cast<spirv::PointerType>().getStorageClass());
-  printer << spirv::CooperativeMatrixLoadNVOp::getOperationName() << " \"" << sc
-          << "\" " << M.pointer() << ", " << M.stride() << ", "
-          << M.columnmajor();
+  printer << spirv::CooperativeMatrixLoadNVOp::getOperationName() << " "
+          << M.pointer() << ", " << M.stride() << ", " << M.columnmajor();
   // Print optional memory access attribute.
   if (auto memAccess = M.memory_access())
     printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
-  printer << " : " << M.getType();
+  printer << " : " << M.pointer().getType() << " as " << M.getType();
 }
 
 static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
                                                     Type coopMatrix) {
-  if (pointer.cast<spirv::PointerType>().getPointeeType() !=
-      coopMatrix.cast<spirv::CooperativeMatrixNVType>().getElementType())
+  Type pointeeType = pointer.cast<spirv::PointerType>().getPointeeType();
+  if (!pointeeType.isa<spirv::ScalarType>() && !pointeeType.isa<VectorType>())
     return op->emitError(
-               "expected the same type for pointer and the cooperative matrix"
-               "element, bu provided ")
-           << pointer << " and " << coopMatrix;
+               "Pointer must point to a scalar or vector type but provided ")
+           << pointeeType;
+  spirv::StorageClass storage =
+      pointer.cast<spirv::PointerType>().getStorageClass();
+  if (storage != spirv::StorageClass::Workgroup &&
+      storage != spirv::StorageClass::StorageBuffer &&
+      storage != spirv::StorageClass::PhysicalStorageBuffer)
+    return op->emitError(
+               "Pointer storage class must be Workgroup, StorageBuffer or "
+               "PhysicalStorageBufferEXT but provided ")
+           << stringifyStorageClass(storage);
   return success();
 }
 
@@ -2847,21 +2847,17 @@ static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
 
 static ParseResult parseCooperativeMatrixStoreNVOp(OpAsmParser &parser,
                                                    OperationState &state) {
-  spirv::StorageClass storageClass;
   SmallVector<OpAsmParser::OperandType, 4> operandInfo;
   Type strideType = parser.getBuilder().getIntegerType(32);
   Type columnMajorType = parser.getBuilder().getIntegerType(1);
+  Type ptrType;
   Type elementType;
-  if (parseEnumStrAttr(storageClass, parser) ||
-      parser.parseOperandList(operandInfo, 4) ||
+  if (parser.parseOperandList(operandInfo, 4) ||
       parseMemoryAccessAttributes(parser, state) || parser.parseColon() ||
+      parser.parseType(ptrType) || parser.parseComma() ||
       parser.parseType(elementType)) {
     return failure();
   }
-
-  auto ptrType = spirv::PointerType::get(
-      elementType.cast<spirv::CooperativeMatrixNVType>().getElementType(),
-      storageClass);
   SmallVector<Type, 4> OperandType = {ptrType, elementType, strideType,
                                       columnMajorType};
   if (parser.resolveOperands(operandInfo, OperandType, parser.getNameLoc(),
@@ -2874,17 +2870,14 @@ static ParseResult parseCooperativeMatrixStoreNVOp(OpAsmParser &parser,
 
 static void print(spirv::CooperativeMatrixStoreNVOp coopMatrix,
                   OpAsmPrinter &printer) {
-  StringRef sc = stringifyStorageClass(coopMatrix.pointer()
-                                           .getType()
-                                           .cast<spirv::PointerType>()
-                                           .getStorageClass());
-  printer << spirv::CooperativeMatrixStoreNVOp::getOperationName() << " \""
-          << sc << "\" " << coopMatrix.pointer() << ", " << coopMatrix.object()
-          << ", " << coopMatrix.stride() << ", " << coopMatrix.columnmajor();
+  printer << spirv::CooperativeMatrixStoreNVOp::getOperationName() << " "
+          << coopMatrix.pointer() << ", " << coopMatrix.object() << ", "
+          << coopMatrix.stride() << ", " << coopMatrix.columnmajor();
   // Print optional memory access attribute.
   if (auto memAccess = coopMatrix.memory_access())
     printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
-  printer << " : " << coopMatrix.getOperand(1).getType();
+  printer << " : " << coopMatrix.pointer().getType() << ", "
+          << coopMatrix.getOperand(1).getType();
 }
 
 //===----------------------------------------------------------------------===//
index ad913df..3f8c4bf 100644 (file)
@@ -3,29 +3,29 @@
 spv.module Logical GLSL450 requires #spv.vce<v1.0, [CooperativeMatrixNV], [SPV_NV_cooperative_matrix]> {
   // CHECK-LABEL: @cooperative_matrix_load
   spv.func @cooperative_matrix_load(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
-    // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<16x8xi32, Workgroup>
-    %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b : !spv.coopmatrix<16x8xi32, Workgroup>
+    // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV {{%.*}}, {{%.*}}, {{%.*}} : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<16x8xi32, Workgroup>
+    %0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<16x8xi32, Workgroup>
     spv.Return
   }
 
   // CHECK-LABEL: @cooperative_matrix_load_memaccess
   spv.func @cooperative_matrix_load_memaccess(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
-    // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
-    %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
+    // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup>
+    %0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b ["Volatile"] : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup>
     spv.Return
   }
 
   // CHECK-LABEL: @cooperative_matrix_store
   spv.func @cooperative_matrix_store(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %m : !spv.coopmatrix<16x8xi32, Workgroup>, %b : i1) "None" {
-    // CHECK: spv.CooperativeMatrixStoreNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<16x8xi32, Workgroup>
-    spv.CooperativeMatrixStoreNV "StorageBuffer" %ptr, %m, %stride, %b : !spv.coopmatrix<16x8xi32, Workgroup>
+    // CHECK: spv.CooperativeMatrixStoreNV {{%.*}}, {{%.*}}, {{%.*}} : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<16x8xi32, Workgroup>
+    spv.CooperativeMatrixStoreNV %ptr, %m, %stride, %b : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<16x8xi32, Workgroup>
     spv.Return
   }
 
   // CHECK-LABEL: @cooperative_matrix_store_memaccess
   spv.func @cooperative_matrix_store_memaccess(%ptr : !spv.ptr<i32, StorageBuffer>, %m : !spv.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" {
-    // CHECK: spv.CooperativeMatrixStoreNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
-    spv.CooperativeMatrixStoreNV "StorageBuffer" %ptr, %m, %stride, %b ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
+    // CHECK: spv.CooperativeMatrixStoreNV {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<8x16xi32, Subgroup>
+    spv.CooperativeMatrixStoreNV %ptr, %m, %stride, %b ["Volatile"] : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<8x16xi32, Subgroup>
     spv.Return
   }
 
index 523bd6b..f0bb50d 100644 (file)
@@ -2,30 +2,37 @@
 
 // CHECK-LABEL: @cooperative_matrix_load
 spv.func @cooperative_matrix_load(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
-  // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<16x8xi32, Workgroup>
-  %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b : !spv.coopmatrix<16x8xi32, Workgroup>
+  // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV {{%.*}}, {{%.*}}, {{%.*}} : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<16x8xi32, Workgroup>
+  %0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<16x8xi32, Workgroup>
   spv.Return
 }
 
 // -----
 // CHECK-LABEL: @cooperative_matrix_load_memaccess
 spv.func @cooperative_matrix_load_memaccess(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %b : i1) "None" {
-  // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
-  %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
+  // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup>
+  %0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b ["Volatile"] : !spv.ptr<i32, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup>
+  spv.Return
+}
+
+// CHECK-LABEL: @cooperative_matrix_load_diff_ptr_type
+spv.func @cooperative_matrix_load_diff_ptr_type(%ptr : !spv.ptr<vector<4xi32>, StorageBuffer>, %stride : i32, %b : i1) "None" {
+  // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.ptr<vector<4xi32>, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup>
+  %0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b ["Volatile"] : !spv.ptr<vector<4xi32>, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup>
   spv.Return
 }
 
 // CHECK-LABEL: @cooperative_matrix_store
 spv.func @cooperative_matrix_store(%ptr : !spv.ptr<i32, StorageBuffer>, %stride : i32, %m : !spv.coopmatrix<8x16xi32, Workgroup>, %b : i1) "None" {
-  // CHECK: spv.CooperativeMatrixStoreNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xi32, Workgroup>
-  spv.CooperativeMatrixStoreNV "StorageBuffer" %ptr, %m, %stride, %b : !spv.coopmatrix<8x16xi32, Workgroup>
+  // CHECK: spv.CooperativeMatrixStoreNV {{%.*}}, {{%.*}}, {{%.*}} : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<8x16xi32, Workgroup>
+  spv.CooperativeMatrixStoreNV %ptr, %m, %stride, %b : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<8x16xi32, Workgroup>
   spv.Return
 }
 
 // CHECK-LABEL: @cooperative_matrix_store_memaccess
 spv.func @cooperative_matrix_store_memaccess(%ptr : !spv.ptr<i32, StorageBuffer>, %m : !spv.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" {
-  // CHECK: spv.CooperativeMatrixStoreNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
-  spv.CooperativeMatrixStoreNV "StorageBuffer" %ptr, %m, %stride, %b ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup>
+  // CHECK: spv.CooperativeMatrixStoreNV {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<8x16xi32, Subgroup>
+  spv.CooperativeMatrixStoreNV %ptr, %m, %stride, %b ["Volatile"] : !spv.ptr<i32, StorageBuffer>, !spv.coopmatrix<8x16xi32, Subgroup>
   spv.Return
 }
 
@@ -134,3 +141,18 @@ spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<8x16xf32, Subgroup>, %b
   spv.Return
 }
 
+// -----
+
+spv.func @cooperative_matrix_load_memaccess(%ptr : !spv.ptr<!spv.struct<f32 [0]>, StorageBuffer>, %stride : i32, %b : i1) "None" {
+  // expected-error @+1 {{Pointer must point to a scalar or vector type}}
+  %0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b : !spv.ptr<!spv.struct<f32 [0]>, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup>
+  spv.Return
+}
+
+// -----
+
+spv.func @cooperative_matrix_load_memaccess(%ptr : !spv.ptr<i32, Function>, %stride : i32, %b : i1) "None" {
+  // expected-error @+1 {{Pointer storage class must be Workgroup, StorageBuffer or PhysicalStorageBufferEXT}}
+  %0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b : !spv.ptr<i32, Function> as !spv.coopmatrix<8x16xi32, Subgroup>
+  spv.Return
+}