[mlir][spirv] Fix shader ABI attribute prefix and add verification
authorLei Zhang <antiagainst@google.com>
Fri, 3 Jan 2020 12:37:19 +0000 (07:37 -0500)
committerLei Zhang <antiagainst@google.com>
Fri, 3 Jan 2020 12:44:27 +0000 (07:44 -0500)
This commit fixes shader ABI attributes to use `spv.` as the prefix
so that they match the dialect's namespace. This enables us to add
verification hooks in the SPIR-V dialect to verify them.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D72062

mlir/include/mlir/Dialect/SPIRV/SPIRVDialect.h
mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp
mlir/lib/Dialect/SPIRV/TargetAndABI.cpp
mlir/test/Conversion/GPUToSPIRV/load-store.mlir
mlir/test/Conversion/GPUToSPIRV/simple.mlir
mlir/test/Dialect/SPIRV/Transforms/abi-load-store.mlir
mlir/test/Dialect/SPIRV/Transforms/abi-simple.mlir
mlir/test/Dialect/SPIRV/target-and-abi.mlir [new file with mode: 0644]

index 0c0eebd..303895e 100644 (file)
@@ -45,6 +45,23 @@ public:
   /// Provides a hook for materializing a constant to this dialect.
   Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
                                  Location loc) override;
+
+  /// Provides a hook for verifying SPIR-V dialect attributes attached to the
+  /// given op.
+  LogicalResult verifyOperationAttribute(Operation *op,
+                                         NamedAttribute attribute) override;
+
+  /// Provides a hook for verifying SPIR-V dialect attributes attached to the
+  /// given op's region argument.
+  LogicalResult verifyRegionArgAttribute(Operation *op, unsigned regionIndex,
+                                         unsigned argIndex,
+                                         NamedAttribute attribute) override;
+
+  /// Provides a hook for verifying SPIR-V dialect attributes attached to the
+  /// given op's region result.
+  LogicalResult verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
+                                            unsigned resultIndex,
+                                            NamedAttribute attribute) override;
 };
 
 } // end namespace spirv
index 144252b..d6fd354 100644 (file)
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
+#include "mlir/Dialect/SPIRV/TargetAndABI.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/MLIRContext.h"
@@ -637,3 +638,62 @@ Operation *SPIRVDialect::materializeConstant(OpBuilder &builder,
 
   return builder.create<spirv::ConstantOp>(loc, type, value);
 }
+
+//===----------------------------------------------------------------------===//
+// Shader Interface ABI
+//===----------------------------------------------------------------------===//
+
+LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op,
+                                                     NamedAttribute attribute) {
+  StringRef symbol = attribute.first.strref();
+  Attribute attr = attribute.second;
+
+  if (symbol != spirv::getEntryPointABIAttrName())
+    return op->emitError("found unsupported '")
+           << symbol << "' attribute on operation";
+
+  if (!spirv::EntryPointABIAttr::classof(attr))
+    return op->emitError("'")
+           << symbol
+           << "' attribute must be a dictionary attribute containing one "
+              "integer elements attribute: 'local_size'";
+
+  return success();
+}
+
+// Verifies the given SPIR-V `attribute` attached to a region's argument or
+// result and reports error to the given location if invalid.
+static LogicalResult
+verifyRegionAttribute(Location loc, NamedAttribute attribute, bool forArg) {
+  StringRef symbol = attribute.first.strref();
+  Attribute attr = attribute.second;
+
+  if (symbol != spirv::getInterfaceVarABIAttrName())
+    return emitError(loc, "found unsupported '")
+           << symbol << "' attribute on region "
+           << (forArg ? "argument" : "result");
+
+  if (!spirv::InterfaceVarABIAttr::classof(attr))
+    return emitError(loc, "'")
+           << symbol
+           << "' attribute must be a dictionary attribute containing three "
+              "integer attributes: 'descriptor_set', 'binding', and "
+              "'storage_class'";
+
+  return success();
+}
+
+LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
+                                                     unsigned /*regionIndex*/,
+                                                     unsigned /*argIndex*/,
+                                                     NamedAttribute attribute) {
+  return verifyRegionAttribute(op->getLoc(), attribute,
+                               /*forArg=*/true);
+}
+
+LogicalResult SPIRVDialect::verifyRegionResultAttribute(
+    Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/,
+    NamedAttribute attribute) {
+  return verifyRegionAttribute(op->getLoc(), attribute,
+                               /*forArg=*/false);
+}
index 696b8b5..5aa8282 100644 (file)
@@ -17,7 +17,7 @@ namespace mlir {
 }
 
 StringRef mlir::spirv::getInterfaceVarABIAttrName() {
-  return "spirv.interface_var_abi";
+  return "spv.interface_var_abi";
 }
 
 mlir::spirv::InterfaceVarABIAttr
@@ -32,7 +32,7 @@ mlir::spirv::getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding,
 }
 
 StringRef mlir::spirv::getEntryPointABIAttrName() {
-  return "spirv.entry_point_abi";
+  return "spv.entry_point_abi";
 }
 
 mlir::spirv::EntryPointABIAttr
index 3244256..d104c96 100644 (file)
@@ -22,13 +22,13 @@ module attributes {gpu.container_module} {
     // CHECK-DAG: spv.globalVariable [[LOCALINVOCATIONIDVAR:@.*]] built_in("LocalInvocationId") : !spv.ptr<vector<3xi32>, Input>
     // CHECK-DAG: spv.globalVariable [[WORKGROUPIDVAR:@.*]] built_in("WorkgroupId") : !spv.ptr<vector<3xi32>, Input>
     // CHECK-LABEL:    func @load_store_kernel
-    // CHECK-SAME: [[ARG0:%.*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32 [4]> [0]>, StorageBuffer> {spirv.interface_var_abi = {binding = 0 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
-    // CHECK-SAME: [[ARG1:%.*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32 [4]> [0]>, StorageBuffer> {spirv.interface_var_abi = {binding = 1 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
-    // CHECK-SAME: [[ARG2:%.*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32 [4]> [0]>, StorageBuffer> {spirv.interface_var_abi = {binding = 2 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
-    // CHECK-SAME: [[ARG3:%.*]]: i32 {spirv.interface_var_abi = {binding = 3 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
-    // CHECK-SAME: [[ARG4:%.*]]: i32 {spirv.interface_var_abi = {binding = 4 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
-    // CHECK-SAME: [[ARG5:%.*]]: i32 {spirv.interface_var_abi = {binding = 5 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
-    // CHECK-SAME: [[ARG6:%.*]]: i32 {spirv.interface_var_abi = {binding = 6 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
+    // CHECK-SAME: [[ARG0:%.*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32 [4]> [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 0 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
+    // CHECK-SAME: [[ARG1:%.*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32 [4]> [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 1 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
+    // CHECK-SAME: [[ARG2:%.*]]: !spv.ptr<!spv.struct<!spv.array<48 x f32 [4]> [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 2 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
+    // CHECK-SAME: [[ARG3:%.*]]: i32 {spv.interface_var_abi = {binding = 3 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
+    // CHECK-SAME: [[ARG4:%.*]]: i32 {spv.interface_var_abi = {binding = 4 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
+    // CHECK-SAME: [[ARG5:%.*]]: i32 {spv.interface_var_abi = {binding = 5 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
+    // CHECK-SAME: [[ARG6:%.*]]: i32 {spv.interface_var_abi = {binding = 6 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
     gpu.func @load_store_kernel(%arg0: memref<12x4xf32>, %arg1: memref<12x4xf32>, %arg2: memref<12x4xf32>, %arg3: index, %arg4: index, %arg5: index, %arg6: index)
       attributes  {gpu.kernel} {
       // CHECK: [[ADDRESSWORKGROUPID:%.*]] = spv._address_of [[WORKGROUPIDVAR]]
index c1f4324..e1b687c 100644 (file)
@@ -5,9 +5,9 @@ module attributes {gpu.container_module} {
   module @kernels attributes {gpu.kernel_module} {
     // CHECK:       spv.module "Logical" "GLSL450" {
     // CHECK-LABEL: func @kernel_1
-    // CHECK-SAME: {{%.*}}: f32 {spirv.interface_var_abi = {binding = 0 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
-    // CHECK-SAME: {{%.*}}: !spv.ptr<!spv.struct<!spv.array<12 x f32 [4]> [0]>, StorageBuffer> {spirv.interface_var_abi = {binding = 1 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
-    // CHECK-SAME: spirv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>}
+    // CHECK-SAME: {{%.*}}: f32 {spv.interface_var_abi = {binding = 0 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
+    // CHECK-SAME: {{%.*}}: !spv.ptr<!spv.struct<!spv.array<12 x f32 [4]> [0]>, StorageBuffer> {spv.interface_var_abi = {binding = 1 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32{{[}][}]}}
+    // CHECK-SAME: spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>}
     gpu.func @kernel_1(%arg0 : f32, %arg1 : memref<12xf32>) attributes {gpu.kernel} {
       // CHECK: spv.Return
       gpu.return
index ebfec94..173218c 100644 (file)
@@ -19,34 +19,34 @@ spv.module "Logical" "GLSL450" {
   // CHECK-DAG: spv.globalVariable [[VAR6:@.*]] bind(0, 6) : !spv.ptr<!spv.struct<i32 [0]>, StorageBuffer>
   // CHECK: func [[FN:@.*]]()
   func @load_store_kernel(%arg0: !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32>>>, StorageBuffer>
-                          {spirv.interface_var_abi = {binding = 0 : i32,
-                                                      descriptor_set = 0 : i32,
-                                                      storage_class = 12 : i32}},
+                          {spv.interface_var_abi = {binding = 0 : i32,
+                                                    descriptor_set = 0 : i32,
+                                                    storage_class = 12 : i32}},
                           %arg1: !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32>>>, StorageBuffer>
-                          {spirv.interface_var_abi = {binding = 1 : i32,
-                                                      descriptor_set = 0 : i32,
-                                                      storage_class = 12 : i32}},
+                          {spv.interface_var_abi = {binding = 1 : i32,
+                                                    descriptor_set = 0 : i32,
+                                                    storage_class = 12 : i32}},
                           %arg2: !spv.ptr<!spv.struct<!spv.array<12 x !spv.array<4 x f32>>>, StorageBuffer>
-                          {spirv.interface_var_abi = {binding = 2 : i32,
-                                                      descriptor_set = 0 : i32,
-                                                      storage_class = 12 : i32}},
+                          {spv.interface_var_abi = {binding = 2 : i32,
+                                                    descriptor_set = 0 : i32,
+                                                    storage_class = 12 : i32}},
                           %arg3: i32
-                          {spirv.interface_var_abi = {binding = 3 : i32,
-                                                      descriptor_set = 0 : i32,
-                                                      storage_class = 12 : i32}},
+                          {spv.interface_var_abi = {binding = 3 : i32,
+                                                    descriptor_set = 0 : i32,
+                                                    storage_class = 12 : i32}},
                           %arg4: i32
-                          {spirv.interface_var_abi = {binding = 4 : i32,
-                                                      descriptor_set = 0 : i32,
-                                                      storage_class = 12 : i32}},
+                          {spv.interface_var_abi = {binding = 4 : i32,
+                                                    descriptor_set = 0 : i32,
+                                                    storage_class = 12 : i32}},
                           %arg5: i32
-                          {spirv.interface_var_abi = {binding = 5 : i32,
-                                                      descriptor_set = 0 : i32,
-                                                      storage_class = 12 : i32}},
+                          {spv.interface_var_abi = {binding = 5 : i32,
+                                                    descriptor_set = 0 : i32,
+                                                    storage_class = 12 : i32}},
                           %arg6: i32
-                          {spirv.interface_var_abi = {binding = 6 : i32,
-                                                      descriptor_set = 0 : i32,
-                                                      storage_class = 12 : i32}})
-  attributes  {spirv.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}} {
+                          {spv.interface_var_abi = {binding = 6 : i32,
+                                                    descriptor_set = 0 : i32,
+                                                    storage_class = 12 : i32}})
+  attributes  {spv.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}} {
     // CHECK: [[ADDRESSARG6:%.*]] = spv._address_of [[VAR6]]
     // CHECK: [[CONST6:%.*]] = spv.constant 0 : i32
     // CHECK: [[ARG6PTR:%.*]] = spv.AccessChain [[ADDRESSARG6]]{{\[}}[[CONST6]]
index aa16877..97035eb 100644 (file)
@@ -6,14 +6,14 @@ spv.module "Logical" "GLSL450" {
   // CHECK-DAG:    spv.globalVariable [[VAR1:@.*]] bind(0, 1) : !spv.ptr<!spv.struct<!spv.array<12 x f32 [4]> [0]>, StorageBuffer>
   // CHECK:    func [[FN:@.*]]()
   func @kernel_1(%arg0: f32
-                {spirv.interface_var_abi = {binding = 0 : i32,
-                                            descriptor_set = 0 : i32,
-                                            storage_class = 12 : i32}},
+                {spv.interface_var_abi = {binding = 0 : i32,
+                                          descriptor_set = 0 : i32,
+                                          storage_class = 12 : i32}},
                  %arg1: !spv.ptr<!spv.struct<!spv.array<12 x f32>>, StorageBuffer>
-                 {spirv.interface_var_abi = {binding = 1 : i32,
-                                             descriptor_set = 0 : i32,
-                                             storage_class = 12 : i32}})
-  attributes  {spirv.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}} {
+                 {spv.interface_var_abi = {binding = 1 : i32,
+                                           descriptor_set = 0 : i32,
+                                           storage_class = 12 : i32}})
+  attributes  {spv.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}} {
     // CHECK: [[ARG1:%.*]] = spv._address_of [[VAR1]]
     // CHECK: [[ADDRESSARG0:%.*]] = spv._address_of [[VAR0]]
     // CHECK: [[CONST0:%.*]] = spv.constant 0 : i32
diff --git a/mlir/test/Dialect/SPIRV/target-and-abi.mlir b/mlir/test/Dialect/SPIRV/target-and-abi.mlir
new file mode 100644 (file)
index 0000000..19bfe9e
--- /dev/null
@@ -0,0 +1,101 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
+
+// expected-error @+1 {{found unsupported 'spv.something' attribute on operation}}
+func @unknown_attr_on_op() attributes {
+  spv.something = 64
+} { return }
+
+// -----
+
+// expected-error @+1 {{found unsupported 'spv.something' attribute on region argument}}
+func @unknown_attr_on_region(%arg: i32 {spv.something}) {
+  return
+}
+
+// -----
+
+// expected-error @+1 {{found unsupported 'spv.something' attribute on region result}}
+func @unknown_attr_on_region() -> (i32 {spv.something}) {
+  %0 = constant 10.0 : f32
+  return %0: f32
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spv.entry_point_abi
+//===----------------------------------------------------------------------===//
+
+// expected-error @+1 {{'spv.entry_point_abi' attribute must be a dictionary attribute containing one integer elements attribute: 'local_size'}}
+func @spv_entry_point() attributes {
+  spv.entry_point_abi = 64
+} { return }
+
+// -----
+
+// expected-error @+1 {{'spv.entry_point_abi' attribute must be a dictionary attribute containing one integer elements attribute: 'local_size'}}
+func @spv_entry_point() attributes {
+  spv.entry_point_abi = {local_size = 64}
+} { return }
+
+// -----
+
+func @spv_entry_point() attributes {
+  // CHECK: {spv.entry_point_abi = {local_size = dense<[64, 1, 1]> : vector<3xi32>}}
+  spv.entry_point_abi = {local_size = dense<[64, 1, 1]>: vector<3xi32>}
+} { return }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spv.interface_var_abi
+//===----------------------------------------------------------------------===//
+
+// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing three integer attributes: 'descriptor_set', 'binding', and 'storage_class'}}
+func @interface_var(
+  %arg0 : f32 {spv.interface_var_abi = 64}
+) { return }
+
+// -----
+
+// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing three integer attributes: 'descriptor_set', 'binding', and 'storage_class'}}
+func @interface_var(
+  %arg0 : f32 {spv.interface_var_abi = {binding = 0: i32}}
+) { return }
+
+// -----
+
+// CHECK: {spv.interface_var_abi = {binding = 0 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32}}
+func @interface_var(
+  %arg0 : f32 {spv.interface_var_abi = {binding = 0 : i32,
+                                        descriptor_set = 0 : i32,
+                                        storage_class = 12 : i32}}
+) { return }
+
+// -----
+
+// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing three integer attributes: 'descriptor_set', 'binding', and 'storage_class'}}
+func @interface_var() -> (f32 {spv.interface_var_abi = 64})
+{
+  %0 = constant 10.0 : f32
+  return %0: f32
+}
+
+// -----
+
+// expected-error @+1 {{'spv.interface_var_abi' attribute must be a dictionary attribute containing three integer attributes: 'descriptor_set', 'binding', and 'storage_class'}}
+func @interface_var() -> (f32 {spv.interface_var_abi = {binding = 0: i32}})
+{
+  %0 = constant 10.0 : f32
+  return %0: f32
+}
+
+// -----
+
+// CHECK: {spv.interface_var_abi = {binding = 0 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32}}
+func @interface_var() -> (f32 {spv.interface_var_abi = {
+    binding = 0 : i32, descriptor_set = 0 : i32, storage_class = 12 : i32}})
+{
+  %0 = constant 10.0 : f32
+  return %0: f32
+}