[mlir][spirv] Add basic definitions for supporting availability
authorLei Zhang <antiagainst@google.com>
Fri, 27 Dec 2019 21:24:33 +0000 (16:24 -0500)
committerLei Zhang <antiagainst@google.com>
Fri, 27 Dec 2019 21:25:09 +0000 (16:25 -0500)
SPIR-V has a few mechanisms to control op availability: version,
extension, and capabilities. These mechanisms are considered as
different availability classes.

This commit introduces basic definitions for modelling SPIR-V
availability classes. Specifically, an `Availability` class is
added to SPIRVBase.td, along with two subclasses: MinVersion
and MaxVersion for versioning. SPV_Op is extended to take a
list of `Availability`. Each `Availability` instance carries
information for generating op interfaces for the corresponding
availability class and also the concrete availability
requirements.

With the availability spec on ops, we can now auto-generate the
op interfaces of all SPIR-V availability classes and also
synthesize the op's implementations of these interfaces. The
interface generation is done via new TableGen backends
-gen-avail-interface-{decls|defs}. The op's implementation is
done via -gen-spirv-avail-impls.

Differential Revision: https://reviews.llvm.org/D71930

15 files changed:
mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt
mlir/include/mlir/Dialect/SPIRV/SPIRVAtomicOps.td
mlir/include/mlir/Dialect/SPIRV/SPIRVAvailability.td [new file with mode: 0644]
mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td
mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h
mlir/lib/Dialect/SPIRV/CMakeLists.txt
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/test/CMakeLists.txt
mlir/test/Dialect/CMakeLists.txt [new file with mode: 0644]
mlir/test/Dialect/SPIRV/CMakeLists.txt [new file with mode: 0644]
mlir/test/Dialect/SPIRV/TestAvailability.cpp [new file with mode: 0644]
mlir/test/Dialect/SPIRV/availability.mlir [new file with mode: 0644]
mlir/tools/mlir-opt/CMakeLists.txt
mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp

index fc7180d..5246478 100644 (file)
@@ -1,8 +1,3 @@
-set(LLVM_TARGET_DEFINITIONS SPIRVLowering.td)
-mlir_tablegen(SPIRVLowering.h.inc -gen-struct-attr-decls)
-mlir_tablegen(SPIRVLowering.cpp.inc -gen-struct-attr-defs)
-add_public_tablegen_target(MLIRSPIRVLoweringStructGen)
-
 add_mlir_dialect(SPIRVOps SPIRVOps)
 
 set(LLVM_TARGET_DEFINITIONS SPIRVBase.td)
@@ -11,9 +6,20 @@ mlir_tablegen(SPIRVEnums.cpp.inc -gen-enum-defs)
 add_public_tablegen_target(MLIRSPIRVEnumsIncGen)
 
 set(LLVM_TARGET_DEFINITIONS SPIRVOps.td)
+mlir_tablegen(SPIRVAvailability.h.inc -gen-avail-interface-decls)
+mlir_tablegen(SPIRVAvailability.cpp.inc -gen-avail-interface-defs)
+mlir_tablegen(SPIRVOpAvailabilityImpl.inc -gen-spirv-avail-impls)
+add_public_tablegen_target(MLIRSPIRVAvailabilityIncGen)
+
+set(LLVM_TARGET_DEFINITIONS SPIRVOps.td)
 mlir_tablegen(SPIRVSerialization.inc -gen-spirv-serialization)
 add_public_tablegen_target(MLIRSPIRVSerializationGen)
 
 set(LLVM_TARGET_DEFINITIONS SPIRVBase.td)
 mlir_tablegen(SPIRVOpUtils.inc -gen-spirv-op-utils)
 add_public_tablegen_target(MLIRSPIRVOpUtilsGen)
+
+set(LLVM_TARGET_DEFINITIONS SPIRVLowering.td)
+mlir_tablegen(SPIRVLowering.h.inc -gen-struct-attr-decls)
+mlir_tablegen(SPIRVLowering.cpp.inc -gen-struct-attr-defs)
+add_public_tablegen_target(MLIRSPIRVLoweringStructGen)
index c2ea100..17be79d 100644 (file)
@@ -120,6 +120,13 @@ def SPV_AtomicCompareExchangeWeakOp : SPV_Op<"AtomicCompareExchangeWeak", []> {
     ```
   }];
 
+  let availability = [
+    MinVersion<SPV_V_1_0>,
+    MaxVersion<SPV_V_1_3>,
+    Extension<[]>,
+    Capability<[SPV_C_Kernel]>
+  ];
+
   let arguments = (ins
     SPV_AnyPtr:$pointer,
     SPV_ScopeAttr:$memory_scope,
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVAvailability.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVAvailability.td
new file mode 100644 (file)
index 0000000..8ec74ac
--- /dev/null
@@ -0,0 +1,86 @@
+//===- SPIRVAvailability.td - Op Availability Base file ----*- tablegen -*-===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef SPIRV_AVAILABILITY
+#define SPIRV_AVAILABILITY
+
+include "mlir/IR/OpBase.td"
+
+//===----------------------------------------------------------------------===//
+// Op availaility definitions
+//===----------------------------------------------------------------------===//
+
+// The base class for defining op availability dimensions.
+class Availability {
+  // The following are fields for controlling the generated C++ OpInterface.
+
+  // The name for the generated C++ OpInterface subclass.
+  string interfaceName = ?;
+  // The documentation for the generated C++ OpInterface subclass.
+  string interfaceDescription = "";
+
+  // The following are fields for controlling the query function signature.
+
+  // The query function's return type in the generated C++ OpInterface subclass.
+  string queryFnRetType = ?;
+  // The query function's name in the generated C++ OpInterface subclass.
+  string queryFnName = ?;
+
+  // The following are fields for controlling the query function implementation.
+
+  // The logic for merging two availability requirements. This is used to derive
+  // the final availability requirement when, for example, an op has two
+  // operands and these two operands have different availability requirements.
+  //
+  // The code should use `$overall` as the placeholder for the final requirement
+  // and `$instance` for the current availability requirement instance.
+  code mergeAction = ?;
+  // The initializer for the final availability requirement.
+  string initializer = ?;
+  // An availability instance's type.
+  string instanceType = ?;
+
+  // The following are fields for a concrete availability instance.
+
+  // The availability requirement carried by a concrete instance.
+  string instance = ?;
+}
+
+class MinVersionBase<string name, I32EnumAttr scheme, I32EnumAttrCase min>
+    : Availability {
+  let interfaceName = name;
+
+  let queryFnRetType = scheme.returnType;
+  let queryFnName = "getMinVersion";
+
+  let mergeAction = "$overall = static_cast<" # scheme.returnType # ">("
+                      "std::max($overall, $instance))";
+  let initializer = "static_cast<" # scheme.returnType # ">(uint32_t(0))";
+  let instanceType = scheme.cppNamespace # "::" # scheme.className;
+
+  let instance = scheme.cppNamespace # "::" # scheme.className # "::" #
+                 min.symbol;
+}
+
+class MaxVersionBase<string name, I32EnumAttr scheme, I32EnumAttrCase max>
+    : Availability {
+  let interfaceName = name;
+
+  let queryFnRetType = scheme.returnType;
+  let queryFnName = "getMaxVersion";
+
+  let mergeAction = "$overall = static_cast<" # scheme.returnType # ">("
+                      "std::min($overall, $instance))";
+  let initializer = "static_cast<" # scheme.returnType # ">(~uint32_t(0))";
+  let instanceType = scheme.cppNamespace # "::" # scheme.className;
+
+  let instance = scheme.cppNamespace # "::" # scheme.className # "::" #
+                 max.symbol;
+}
+
+#endif // SPIRV_AVAILABILITY
index 5751a32..acbbbfc 100644 (file)
@@ -16,6 +16,7 @@
 #define SPIRV_BASE
 
 include "mlir/IR/OpBase.td"
+include "mlir/Dialect/SPIRV/SPIRVAvailability.td"
 
 //===----------------------------------------------------------------------===//
 // SPIR-V dialect definitions
@@ -46,6 +47,142 @@ def SPV_Dialect : Dialect {
 }
 
 //===----------------------------------------------------------------------===//
+// SPIR-V availability definitions
+//===----------------------------------------------------------------------===//
+
+def SPV_V_1_0 : I32EnumAttrCase<"V_1_0", 0>;
+def SPV_V_1_1 : I32EnumAttrCase<"V_1_1", 1>;
+def SPV_V_1_2 : I32EnumAttrCase<"V_1_2", 2>;
+def SPV_V_1_3 : I32EnumAttrCase<"V_1_3", 3>;
+def SPV_V_1_4 : I32EnumAttrCase<"V_1_4", 4>;
+def SPV_V_1_5 : I32EnumAttrCase<"V_1_5", 5>;
+
+def SPV_VersionAttr : I32EnumAttr<"Version", "valid SPIR-V version", [
+    SPV_V_1_0, SPV_V_1_1, SPV_V_1_2, SPV_V_1_3, SPV_V_1_4, SPV_V_1_5]> {
+  let cppNamespace = "::mlir::spirv";
+}
+
+class MinVersion<I32EnumAttrCase min> : MinVersionBase<
+    "QueryMinVersionInterface", SPV_VersionAttr, min> {
+  let interfaceDescription = [{
+    Querying interface for minimal required SPIR-V version.
+
+    This interface provides a `getMinVersion()` method to query the minimal
+    required version for the implementing SPIR-V operation. The returned value
+    is a `mlir::spirv::Version` enumerant.
+  }];
+}
+
+class MaxVersion<I32EnumAttrCase max> : MaxVersionBase<
+    "QueryMaxVersionInterface", SPV_VersionAttr, max> {
+  let interfaceDescription = [{
+    Querying interface for maximal supported SPIR-V version.
+
+    This interface provides a `getMaxVersion()` method to query the maximal
+    supported version for the implementing SPIR-V operation. The returned value
+    is a `mlir::spirv::Version` enumerant.
+  }];
+}
+
+class Extension<list<StrEnumAttrCase> extensions> : Availability {
+  let interfaceName = "QueryExtensionInterface";
+  let interfaceDescription = [{
+    Querying interface for required SPIR-V extensions.
+
+    This interface provides a `getExtensions()` method to query the required
+    extensions for the implementing SPIR-V operation. The returned value
+    is a nested vector whose element is `mlir::spirv::Extension`s. The outer
+    vector's elements (which are vectors) should be interpreted as conjunction
+    while the innner vector's elements (which are `mlir::spirv::Extension`s)
+    should be interpreted as disjunction. For example, given
+
+    ```
+    {{Extension::A, Extension::B}, {Extension::C}, {{Extension::D, Extension::E}}
+    ```
+
+    The operation instance is available when (`Extension::A` OR `Extension::B`)
+    AND (`Extension::C`) AND (`Extension::D` OR `Extension::E`) is enabled.
+  }];
+
+  // TODO(antiagainst): Using SmallVector<SmallVector<...>> is an anti-pattern.
+  // Find a better way for this.
+  let queryFnRetType = "::llvm::SmallVector<::llvm::SmallVector<"
+                          "::mlir::spirv::Extension, 1>, 1>";
+  let queryFnName = "getExtensions";
+
+  let mergeAction = !if(
+      !empty(extensions), "", "$overall.emplace_back($instance)");
+  let initializer = "{}";
+  let instanceType = "::llvm::SmallVector<::mlir::spirv::Extension, 1>";
+
+  // Compose all capabilities as an C++ initializer list
+  let instance = "std::initializer_list<::mlir::spirv::Extension>{" #
+                 StrJoin<!foreach(
+                   ext, extensions,
+                   "::mlir::spirv::Extension::" # ext.symbol)>.result #
+                 "}";
+}
+
+class Capability<list<I32EnumAttrCase> capabilities> : Availability {
+  let interfaceName = "QueryCapabilityInterface";
+  let interfaceDescription = [{
+    Querying interface for required SPIR-V capabilities.
+
+    This interface provides a `getCapabilities()` method to query the required
+    capabilities for the implementing SPIR-V operation. The returned value
+    is a neted vector whose element is `mlir::spirv::Capability`s. The outer
+    vector's elements (which are vectors) should be interpreted as conjunction
+    while the innner vector's elements (which are `mlir::spirv::Capability`s)
+    should be interpreted as disjunction. For example, given
+
+    ```
+    {{Capability::A, Capability::B}, {Capability::C}, {{Capability::D, Capability::E}}
+    ```
+
+    The operation instance is available when (`Capability::A` OR `Capability::B`)
+    AND (`Capability::C`) AND (`Capability::D` OR `Capability::E`) is enabled.
+  }];
+
+  let queryFnRetType = "::llvm::SmallVector<::llvm::SmallVector<"
+                          "::mlir::spirv::Capability, 1>, 1>";
+  let queryFnName = "getCapabilities";
+
+  let mergeAction = !if(
+      !empty(capabilities), "", "$overall.emplace_back($instance)");
+  let initializer = "{}";
+  let instanceType = "::llvm::SmallVector<::mlir::spirv::Capability, 1>";
+
+  // Compose all capabilities as an C++ initializer list
+  let instance = "std::initializer_list<::mlir::spirv::Capability>{" #
+                 StrJoin<!foreach(
+                   cap, capabilities,
+                   "::mlir::spirv::Capability::" # cap.symbol)>.result #
+                 "}";
+}
+
+// TODO(antiagainst): the following interfaces definitions are duplicating with
+// the above. Remove them once we are able to support dialect-specific contents
+// in ODS.
+def QueryMinVersionInterface : OpInterface<"QueryMinVersionInterface"> {
+  let methods = [InterfaceMethod<"", "::mlir::spirv::Version", "getMinVersion">];
+}
+def QueryMaxVersionInterface : OpInterface<"QueryMaxVersionInterface"> {
+  let methods = [InterfaceMethod<"", "::mlir::spirv::Version", "getMaxVersion">];
+}
+def QueryExtensionInterface : OpInterface<"QueryExtensionInterface"> {
+  let methods = [InterfaceMethod<
+    "",
+    "::llvm::SmallVector<::llvm::SmallVector<::mlir::spirv::Extension, 1>, 1>",
+    "getExtensions">];
+}
+def QueryCapabilityInterface : OpInterface<"QueryCapabilityInterface"> {
+  let methods = [InterfaceMethod<
+    "",
+    "::llvm::SmallVector<::llvm::SmallVector<::mlir::spirv::Capability, 1>, 1>",
+    "getCapabilities">];
+}
+
+//===----------------------------------------------------------------------===//
 // SPIR-V extension definitions
 //===----------------------------------------------------------------------===//
 
@@ -1216,7 +1353,22 @@ def SPV_OpcodeAttr :
 
 // Base class for all SPIR-V ops.
 class SPV_Op<string mnemonic, list<OpTrait> traits = []> :
-    Op<SPV_Dialect, mnemonic, traits> {
+    Op<SPV_Dialect, mnemonic, !listconcat(traits, [
+         // TODO(antiagainst): We don't need all of the following traits for
+         // every op; only the suitabble ones should be added automatically
+         // after ODS supports dialect-specific contents.
+         DeclareOpInterfaceMethods<QueryMinVersionInterface>,
+         DeclareOpInterfaceMethods<QueryMaxVersionInterface>,
+         DeclareOpInterfaceMethods<QueryExtensionInterface>,
+         DeclareOpInterfaceMethods<QueryCapabilityInterface>
+       ])> {
+  // Availability specification for this op itself.
+  list<Availability> availability = [
+    MinVersion<SPV_V_1_0>,
+    MaxVersion<SPV_V_1_5>,
+    Extension<[]>,
+    Capability<[]>
+  ];
 
   // For each SPIR-V op, the following static functions need to be defined
   // in SPVOps.cpp:
index f3a9a61..1ac0ae1 100644 (file)
@@ -53,6 +53,13 @@ def SPV_GroupNonUniformBallotOp : SPV_Op<"GroupNonUniformBallot", []> {
     ```
   }];
 
+  let availability = [
+    MinVersion<SPV_V_1_3>,
+    MaxVersion<SPV_V_1_5>,
+    Extension<[]>,
+    Capability<[SPV_C_GroupNonUniformBallot]>
+  ];
+
   let arguments = (ins
     SPV_ScopeAttr:$execution_scope,
     SPV_Bool:$predicate
index 2fa417b..3806418 100644 (file)
@@ -21,18 +21,23 @@ class OpBuilder;
 
 namespace spirv {
 
+// TableGen'erated operation interfaces for querying versions, extensions, and
+// capabilities.
+#include "mlir/Dialect/SPIRV/SPIRVAvailability.h.inc"
+
+// TablenGen'erated operation declarations.
 #define GET_OP_CLASSES
 #include "mlir/Dialect/SPIRV/SPIRVOps.h.inc"
 
-/// Following methods are auto-generated.
-///
-/// Get the name used in the Op to refer to an enum value of the given
-/// `EnumClass`.
-/// template <typename EnumClass> StringRef attributeName();
-///
-/// Get the function that can be used to symbolize an enum value.
-/// template <typename EnumClass>
-/// Optional<EnumClass> (*)(StringRef) symbolizeEnum();
+// TableGen'erated helper functions.
+//
+// Get the name used in the Op to refer to an enum value of the given
+// `EnumClass`.
+// template <typename EnumClass> StringRef attributeName();
+//
+// Get the function that can be used to symbolize an enum value.
+// template <typename EnumClass>
+// Optional<EnumClass> (*)(StringRef) symbolizeEnum();
 #include "mlir/Dialect/SPIRV/SPIRVOpUtils.inc"
 
 } // end namespace spirv
index 2c3b1b9..d3af53e 100644 (file)
@@ -15,6 +15,7 @@ add_llvm_library(MLIRSPIRV
   )
 
 add_dependencies(MLIRSPIRV
+  MLIRSPIRVAvailabilityIncGen
   MLIRSPIRVCanonicalizationIncGen
   MLIRSPIRVEnumsIncGen
   MLIRSPIRVLoweringStructGen
index f42c077..1de7bce 100644 (file)
@@ -3063,8 +3063,16 @@ static LogicalResult verify(spirv::VariableOp varOp) {
 namespace mlir {
 namespace spirv {
 
+// TableGen'erated operation interfaces for querying versions, extensions, and
+// capabilities.
+#include "mlir/Dialect/SPIRV/SPIRVAvailability.cpp.inc"
+
+// TablenGen'erated operation definitions.
 #define GET_OP_CLASSES
 #include "mlir/Dialect/SPIRV/SPIRVOps.cpp.inc"
 
+// TableGen'erated operation availability interface implementations.
+#include "mlir/Dialect/SPIRV/SPIRVOpAvailabilityImpl.inc"
+
 } // namespace spirv
 } // namespace mlir
index 9579254..571a0d8 100644 (file)
@@ -1,3 +1,4 @@
+add_subdirectory(Dialect)
 add_subdirectory(EDSC)
 add_subdirectory(mlir-cpu-runner)
 add_subdirectory(SDBM)
diff --git a/mlir/test/Dialect/CMakeLists.txt b/mlir/test/Dialect/CMakeLists.txt
new file mode 100644 (file)
index 0000000..cc1766c
--- /dev/null
@@ -0,0 +1 @@
+add_subdirectory(SPIRV)
diff --git a/mlir/test/Dialect/SPIRV/CMakeLists.txt b/mlir/test/Dialect/SPIRV/CMakeLists.txt
new file mode 100644 (file)
index 0000000..25ea962
--- /dev/null
@@ -0,0 +1,14 @@
+add_llvm_library(MLIRSPIRVTestPasses
+  TestAvailability.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
+  )
+
+target_link_libraries(MLIRSPIRVTestPasses
+  MLIRIR
+  MLIRPass
+  MLIRSPIRV
+  MLIRSupport
+  )
diff --git a/mlir/test/Dialect/SPIRV/TestAvailability.cpp b/mlir/test/Dialect/SPIRV/TestAvailability.cpp
new file mode 100644 (file)
index 0000000..bb16421
--- /dev/null
@@ -0,0 +1,73 @@
+//===- TestAvailability.cpp - Pass to test SPIR-V op availability ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SPIRV/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
+#include "mlir/IR/Function.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+/// A pass for testing SPIR-V op availability.
+struct TestAvailability : public FunctionPass<TestAvailability> {
+  void runOnFunction() override;
+};
+} // end anonymous namespace
+
+void TestAvailability::runOnFunction() {
+  auto f = getFunction();
+  llvm::outs() << f.getName() << "\n";
+
+  Dialect *spvDialect = getContext().getRegisteredDialect("spv");
+
+  f.getOperation()->walk([&](Operation *op) {
+    if (op->getDialect() != spvDialect)
+      return WalkResult::advance();
+
+    auto &os = llvm::outs();
+
+    if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op))
+      os << " min version: "
+         << spirv::stringifyVersion(minVersion.getMinVersion()) << "\n";
+
+    if (auto maxVersion = dyn_cast<spirv::QueryMaxVersionInterface>(op))
+      os << " max version: "
+         << spirv::stringifyVersion(maxVersion.getMaxVersion()) << "\n";
+
+    if (auto extension = dyn_cast<spirv::QueryExtensionInterface>(op)) {
+      os << " extensions: [";
+      for (const auto &exts : extension.getExtensions()) {
+        os << " [";
+        interleaveComma(exts, os, [&](spirv::Extension ext) {
+          os << spirv::stringifyExtension(ext);
+        });
+        os << "]";
+      }
+      os << " ]\n";
+    }
+
+    if (auto capability = dyn_cast<spirv::QueryCapabilityInterface>(op)) {
+      os << " capabilities: [";
+      for (const auto &caps : capability.getCapabilities()) {
+        os << " [";
+        interleaveComma(caps, os, [&](spirv::Capability cap) {
+          os << spirv::stringifyCapability(cap);
+        });
+        os << "]";
+      }
+      os << " ]\n";
+    }
+    os.flush();
+
+    return WalkResult::advance();
+  });
+}
+
+static PassRegistration<TestAvailability> pass("test-spirv-op-availability",
+                                               "Test SPIR-V op availability");
diff --git a/mlir/test/Dialect/SPIRV/availability.mlir b/mlir/test/Dialect/SPIRV/availability.mlir
new file mode 100644 (file)
index 0000000..ed4d29c
--- /dev/null
@@ -0,0 +1,31 @@
+// RUN: mlir-opt -disable-pass-threading -test-spirv-op-availability %s | FileCheck %s
+
+// CHECK-LABEL: iadd
+func @iadd(%arg: i32) -> i32 {
+  // CHECK: min version: V_1_0
+  // CHECK: max version: V_1_5
+  // CHECK: extensions: [ ]
+  // CHECK: capabilities: [ ]
+  %0 = spv.IAdd %arg, %arg: i32
+  return %0: i32
+}
+
+// CHECK: atomic_compare_exchange_weak
+func @atomic_compare_exchange_weak(%ptr: !spv.ptr<i32, Workgroup>, %value: i32, %comparator: i32) -> i32 {
+  // CHECK: min version: V_1_0
+  // CHECK: max version: V_1_3
+  // CHECK: extensions: [ ]
+  // CHECK: capabilities: [ [Kernel] ]
+  %0 = spv.AtomicCompareExchangeWeak "Workgroup" "Release" "Acquire" %ptr, %value, %comparator: !spv.ptr<i32, Workgroup>
+  return %0: i32
+}
+
+// CHECK-LABEL: subgroup_ballot
+func @subgroup_ballot(%predicate: i1) -> vector<4xi32> {
+  // CHECK: min version: V_1_3
+  // CHECK: max version: V_1_5
+  // CHECK: extensions: [ ]
+  // CHECK: capabilities: [ [GroupNonUniformBallot] ]
+  %0 = spv.GroupNonUniformBallot "Workgroup" %predicate : vector<4xi32>
+  return %0: vector<4xi32>
+}
index b30d7e3..1281569 100644 (file)
@@ -41,6 +41,7 @@ set(LIBS
   MLIRROCDLIR
   MLIRSPIRV
   MLIRStandardToSPIRVTransforms
+  MLIRSPIRVTestPasses
   MLIRSPIRVTransforms
   MLIRStandardOps
   MLIRStandardToLLVM
index d65b216..639f014 100644 (file)
@@ -13,6 +13,7 @@
 
 #include "mlir/Support/StringExtras.h"
 #include "mlir/TableGen/Attribute.h"
+#include "mlir/TableGen/Format.h"
 #include "mlir/TableGen/GenInfo.h"
 #include "mlir/TableGen/Operator.h"
 #include "llvm/ADT/Sequence.h"
@@ -44,6 +45,233 @@ using mlir::tblgen::NamedTypeConstraint;
 using mlir::tblgen::Operator;
 
 //===----------------------------------------------------------------------===//
+// Availability Wrapper Class
+//===----------------------------------------------------------------------===//
+
+namespace {
+// Wrapper class with helper methods for accessing availability defined in
+// TableGen.
+class Availability {
+public:
+  explicit Availability(const Record *def);
+
+  // Returns the name of the direct TableGen class for this availability
+  // instance.
+  StringRef getClass() const;
+
+  // Returns the generated C++ interface's class name.
+  StringRef getInterfaceClassName() const;
+
+  // Returns the generated C++ interface's description.
+  StringRef getInterfaceDescription() const;
+
+  // Returns the name of the query function insided the generated C++ interface.
+  StringRef getQueryFnName() const;
+
+  // Returns the return type of the query function insided the generated C++
+  // interface.
+  StringRef getQueryFnRetType() const;
+
+  // Returns the code for merging availability requirements.
+  StringRef getMergeActionCode() const;
+
+  // Returns the initializer expression for initializing the final availability
+  // requirements.
+  StringRef getMergeInitializer() const;
+
+  // Returns the C++ type for an availability instance.
+  StringRef getMergeInstanceType() const;
+
+  // Returns the concrete availability instance carried in this case.
+  StringRef getMergeInstance() const;
+
+private:
+  // The TableGen definition of this availability.
+  const llvm::Record *def;
+};
+} // namespace
+
+Availability::Availability(const llvm::Record *def) : def(def) {
+  assert(def->isSubClassOf("Availability") &&
+         "must be subclass of TableGen 'Availability' class");
+}
+
+StringRef Availability::getClass() const {
+  SmallVector<Record *, 1> parentClass;
+  def->getDirectSuperClasses(parentClass);
+  if (parentClass.size() != 1) {
+    PrintFatalError(def->getLoc(),
+                    "expected to only have one direct superclass");
+  }
+  return parentClass.front()->getName();
+}
+
+StringRef Availability::getInterfaceClassName() const {
+  return def->getValueAsString("interfaceName");
+}
+
+StringRef Availability::getInterfaceDescription() const {
+  return def->getValueAsString("interfaceDescription");
+}
+
+StringRef Availability::getQueryFnRetType() const {
+  return def->getValueAsString("queryFnRetType");
+}
+
+StringRef Availability::getQueryFnName() const {
+  return def->getValueAsString("queryFnName");
+}
+
+StringRef Availability::getMergeActionCode() const {
+  return def->getValueAsString("mergeAction");
+}
+
+StringRef Availability::getMergeInitializer() const {
+  return def->getValueAsString("initializer");
+}
+
+StringRef Availability::getMergeInstanceType() const {
+  return def->getValueAsString("instanceType");
+}
+
+StringRef Availability::getMergeInstance() const {
+  return def->getValueAsString("instance");
+}
+
+//===----------------------------------------------------------------------===//
+// Availability Interface Definitions AutoGen
+//===----------------------------------------------------------------------===//
+
+static void emitInterfaceDef(const Availability &availability,
+                             raw_ostream &os) {
+  StringRef methodName = availability.getQueryFnName();
+  os << availability.getQueryFnRetType() << " "
+     << availability.getInterfaceClassName() << "::" << methodName << "() {\n"
+     << "  return getImpl()->" << methodName << "(getOperation());\n"
+     << "}\n";
+}
+
+static bool emitInterfaceDefs(const RecordKeeper &recordKeeper,
+                              raw_ostream &os) {
+  llvm::emitSourceFileHeader("Availability Interface Definitions", os);
+
+  auto defs = recordKeeper.getAllDerivedDefinitions("Availability");
+  SmallVector<const Record *, 1> handledClasses;
+  for (const Record *def : defs) {
+    SmallVector<Record *, 1> parent;
+    def->getDirectSuperClasses(parent);
+    if (parent.size() != 1) {
+      PrintFatalError(def->getLoc(),
+                      "expected to only have one direct superclass");
+    }
+    if (llvm::is_contained(handledClasses, parent.front()))
+      continue;
+
+    Availability availability(def);
+    emitInterfaceDef(availability, os);
+    handledClasses.push_back(parent.front());
+  }
+  return false;
+}
+
+//===----------------------------------------------------------------------===//
+// Availability Interface Declarations AutoGen
+//===----------------------------------------------------------------------===//
+
+static void emitConceptDecl(const Availability &availability, raw_ostream &os) {
+  os << "  class Concept {\n"
+     << "  public:\n"
+     << "    virtual ~Concept() = default;\n"
+     << "    virtual " << availability.getQueryFnRetType() << " "
+     << availability.getQueryFnName() << "(Operation *tblgen_opaque_op) = 0;\n"
+     << "  };\n";
+}
+
+static void emitModelDecl(const Availability &availability, raw_ostream &os) {
+  os << "  template<typename ConcreteOp>\n";
+  os << "  class Model : public Concept {\n"
+     << "  public:\n"
+     << "    " << availability.getQueryFnRetType() << " "
+     << availability.getQueryFnName()
+     << "(Operation *tblgen_opaque_op) final {\n"
+     << "      auto op = llvm::cast<ConcreteOp>(tblgen_opaque_op);\n"
+     << "      (void)op;\n"
+     // Forward to the method on the concrete operation type.
+     << "      return op." << availability.getQueryFnName() << "();\n"
+     << "    }\n"
+     << "  };\n";
+}
+
+static void emitInterfaceDecl(const Availability &availability,
+                              raw_ostream &os) {
+  StringRef interfaceName = availability.getInterfaceClassName();
+  std::string interfaceTraitsName = formatv("{0}Traits", interfaceName);
+
+  // Emit the traits struct containing the concept and model declarations.
+  os << "namespace detail {\n"
+     << "struct " << interfaceTraitsName << " {\n";
+  emitConceptDecl(availability, os);
+  os << '\n';
+  emitModelDecl(availability, os);
+  os << "};\n} // end namespace detail\n\n";
+
+  // Emit the main interface class declaration.
+  os << "/*\n" << availability.getInterfaceDescription().trim() << "\n*/\n";
+  os << llvm::formatv("class {0} : public OpInterface<{1}, detail::{2}> {\n"
+                      "public:\n"
+                      "  using OpInterface<{1}, detail::{2}>::OpInterface;\n",
+                      interfaceName, interfaceName, interfaceTraitsName);
+
+  // Emit query function declaration.
+  os << "  " << availability.getQueryFnRetType() << " "
+     << availability.getQueryFnName() << "();\n";
+  os << "};\n\n";
+}
+
+static bool emitInterfaceDecls(const RecordKeeper &recordKeeper,
+                               raw_ostream &os) {
+  llvm::emitSourceFileHeader("Availability Interface Declarations", os);
+
+  auto defs = recordKeeper.getAllDerivedDefinitions("Availability");
+  SmallVector<const Record *, 4> handledClasses;
+  for (const Record *def : defs) {
+    SmallVector<Record *, 1> parent;
+    def->getDirectSuperClasses(parent);
+    if (parent.size() != 1) {
+      PrintFatalError(def->getLoc(),
+                      "expected to only have one direct superclass");
+    }
+    if (llvm::is_contained(handledClasses, parent.front()))
+      continue;
+
+    Availability avail(def);
+    emitInterfaceDecl(avail, os);
+    handledClasses.push_back(parent.front());
+  }
+  return false;
+}
+
+//===----------------------------------------------------------------------===//
+// Availability Interface Hook Registration
+//===----------------------------------------------------------------------===//
+
+// Registers the operation interface generator to mlir-tblgen.
+static mlir::GenRegistration
+    genInterfaceDecls("gen-avail-interface-decls",
+                      "Generate availability interface declarations",
+                      [](const RecordKeeper &records, raw_ostream &os) {
+                        return emitInterfaceDecls(records, os);
+                      });
+
+// Registers the operation interface generator to mlir-tblgen.
+static mlir::GenRegistration
+    genInterfaceDefs("gen-avail-interface-defs",
+                     "Generate op interface definitions",
+                     [](const RecordKeeper &records, raw_ostream &os) {
+                       return emitInterfaceDefs(records, os);
+                     });
+
+//===----------------------------------------------------------------------===//
 // Serialization AutoGen
 //===----------------------------------------------------------------------===//
 
@@ -651,6 +879,17 @@ static bool emitSerializationFns(const RecordKeeper &recordKeeper,
 }
 
 //===----------------------------------------------------------------------===//
+// Serialization Hook Registration
+//===----------------------------------------------------------------------===//
+
+static mlir::GenRegistration genSerialization(
+    "gen-spirv-serialization",
+    "Generate SPIR-V (de)serialization utilities and functions",
+    [](const RecordKeeper &records, raw_ostream &os) {
+      return emitSerializationFns(records, os);
+    });
+
+//===----------------------------------------------------------------------===//
 // Op Utils AutoGen
 //===----------------------------------------------------------------------===//
 
@@ -707,19 +946,92 @@ static bool emitOpUtils(const RecordKeeper &recordKeeper, raw_ostream &os) {
 }
 
 //===----------------------------------------------------------------------===//
-// Hook Registration
+// Op Utils Hook Registration
 //===----------------------------------------------------------------------===//
 
-static mlir::GenRegistration genSerialization(
-    "gen-spirv-serialization",
-    "Generate SPIR-V (de)serialization utilities and functions",
-    [](const RecordKeeper &records, raw_ostream &os) {
-      return emitSerializationFns(records, os);
-    });
-
 static mlir::GenRegistration
     genOpUtils("gen-spirv-op-utils",
                "Generate SPIR-V operation utility definitions",
                [](const RecordKeeper &records, raw_ostream &os) {
                  return emitOpUtils(records, os);
                });
+
+//===----------------------------------------------------------------------===//
+// SPIR-V Availability Impl AutoGen
+//===----------------------------------------------------------------------===//
+
+// Returns the availability spec of the given `def`.
+std::vector<Availability> getAvailabilities(const Record &def) {
+  std::vector<Availability> availabilities;
+  if (auto *availListInit = def.getValueAsListInit("availability")) {
+    availabilities.reserve(availListInit->size());
+    for (auto *availInit : *availListInit)
+      availabilities.emplace_back(
+          llvm::cast<llvm::DefInit>(availInit)->getDef());
+  }
+  return availabilities;
+}
+
+static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
+  mlir::tblgen::FmtContext fctx;
+  fctx.addSubst("overall", "overall");
+
+  std::vector<Availability> opAvailabilities =
+      getAvailabilities(srcOp.getDef());
+
+  // First collect all availablity classes this op should implement.
+  // All availablity instances keep information for the generated interface and
+  // the instance's specific requirement. Here we remember a random instance so
+  // we can get the information regarding the generated interface.
+  llvm::StringMap<Availability> availClasses;
+  for (const Availability &avail : opAvailabilities)
+    availClasses.try_emplace(avail.getClass(), avail);
+
+  // Then generate implementation for each availability class.
+  for (const auto &availClass : availClasses) {
+    StringRef availClassName = availClass.getKey();
+    Availability avail = availClass.getValue();
+
+    // Generate the implementation method signature.
+    os << formatv("{0} {1}::{2}() {{\n", avail.getQueryFnRetType(),
+                  srcOp.getCppClassName(), avail.getQueryFnName());
+
+    // Create the variable for the final requirement and initialize it.
+    os << formatv("  {0} overall = {1};\n", avail.getQueryFnRetType(),
+                  avail.getMergeInitializer());
+
+    // Update with the op's specific availability spec.
+    for (const Availability &avail : opAvailabilities)
+      if (avail.getClass() == availClassName) {
+        os << "  "
+           << tgfmt(avail.getMergeActionCode(),
+                    &fctx.addSubst("instance", avail.getMergeInstance()))
+           << ";\n";
+      }
+    os << "  return overall;\n";
+    os << "}\n";
+  }
+}
+
+static bool emitAvailabilityImpl(const RecordKeeper &recordKeeper,
+                                 raw_ostream &os) {
+  llvm::emitSourceFileHeader("SPIR-V Op Availability Implementations", os);
+
+  auto defs = recordKeeper.getAllDerivedDefinitions("SPV_Op");
+  for (const auto *def : defs) {
+    Operator op(def);
+    emitAvailabilityImpl(op, os);
+  }
+  return false;
+}
+
+//===----------------------------------------------------------------------===//
+// Op Availability Implementation Hook Registration
+//===----------------------------------------------------------------------===//
+
+static mlir::GenRegistration
+    genOpAvailabilityImpl("gen-spirv-avail-impls",
+                          "Generate SPIR-V operation utility definitions",
+                          [](const RecordKeeper &records, raw_ostream &os) {
+                            return emitAvailabilityImpl(records, os);
+                          });