[spirv] Add spv.SubgroupBallotKHROp
authorLei Zhang <antiagainst@google.com>
Tue, 3 Dec 2019 12:49:20 +0000 (04:49 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 3 Dec 2019 12:49:56 +0000 (04:49 -0800)
PiperOrigin-RevId: 283522284

mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td [new file with mode: 0644]
mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/test/Dialect/SPIRV/Serialization/group-ops.mlir [new file with mode: 0644]
mlir/test/Dialect/SPIRV/ops.mlir
mlir/utils/spirv/gen_spirv_dialect.py

index e1897a9..bfb7497 100644 (file)
@@ -953,6 +953,8 @@ class SPV_ScalarOrVectorOf<Type type> :
 def SPV_ScalarOrVector : AnyTypeOf<[SPV_Scalar, SPV_Vector]>;
 def SPV_ScalarOrVectorOrPtr : AnyTypeOf<[SPV_ScalarOrVector, SPV_AnyPtr]>;
 
+def SPV_I32Vec4 : VectorOfLengthAndType<[4], [I32]>;
+
 // TODO(antiagainst): Use a more appropriate way to model optional operands
 class SPV_Optional<Type type> : Variadic<type>;
 
@@ -1107,6 +1109,7 @@ def SPV_OC_OpReturn                 : I32EnumAttrCase<"OpReturn", 253>;
 def SPV_OC_OpReturnValue            : I32EnumAttrCase<"OpReturnValue", 254>;
 def SPV_OC_OpUnreachable            : I32EnumAttrCase<"OpUnreachable", 255>;
 def SPV_OC_OpModuleProcessed        : I32EnumAttrCase<"OpModuleProcessed", 330>;
+def SPV_OC_OpSubgroupBallotKHR      : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
 
 def SPV_OpcodeAttr :
     I32EnumAttr<"Opcode", "valid SPIR-V instructions", [
@@ -1146,10 +1149,9 @@ def SPV_OpcodeAttr :
       SPV_OC_OpBitCount, SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier,
       SPV_OC_OpPhi, SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel,
       SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn,
-      SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpModuleProcessed
+      SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpModuleProcessed,
+      SPV_OC_OpSubgroupBallotKHR
       ]> {
-    let returnType = "::mlir::spirv::Opcode";
-    let convertFromStorage = "static_cast<::mlir::spirv::Opcode>($_self.getInt())";
     let cppNamespace = "::mlir::spirv";
 }
 
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td
new file mode 100644 (file)
index 0000000..5f60e6b
--- /dev/null
@@ -0,0 +1,74 @@
+//===-- SPIRVGroupOps.td - MLIR SPIR-V (Sub)Group Ops ------*- tablegen -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file contains group and subgroup ops for the SPIR-V dialect. It
+// corresponds to "3.32.21. Group and Subgroup Instructions" of the SPIR-V
+// specification.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef SPIRV_GROUP_OPS
+#define SPIRV_GROUP_OPS
+
+// -----
+
+def SPV_SubgroupBallotKHROp : SPV_Op<"SubgroupBallotKHR", []> {
+  let summary = "See extension SPV_KHR_shader_ballot";
+
+  let description = [{
+    Computes a bitfield value combining the Predicate value from all invocations
+    in the current Subgroup that execute the same dynamic instance of this
+    instruction. The bit is set to one if the corresponding invocation is active
+    and the predicate is evaluated to true; otherwise, it is set to zero.
+
+    Predicate must be a Boolean type.
+
+    Result Type must be a 4 component vector of 32 bit integer types.
+
+    Result is a set of bitfields where the first invocation is represented in bit
+    0 of the first vector component and the last (up to SubgroupSize) is the
+    higher bit number of the last bitmask needed to represent all bits of the
+    subgroup invocations.
+
+    ### Custom assembly form
+
+    ``` {.ebnf}
+    subgroup-ballot-op ::= ssa-id `=` `spv.SubgroupBallotKHR`
+                                ssa-use `:` `vector` `<` 4 `x` `i32` `>`
+    ```
+
+    For example:
+
+    ```
+    %0 = spv.SubgroupBallotKHR %predicate : vector<4xi32>
+    ```
+  }];
+
+  let arguments = (ins
+    SPV_Bool:$predicate
+  );
+
+  let results = (outs
+    SPV_I32Vec4:$result
+  );
+
+  let verifier = [{ return success(); }];
+}
+
+// -----
+
+#endif // SPIRV_GROUP_OPS
index 41d729d..178db0a 100644 (file)
@@ -34,11 +34,10 @@ include "mlir/Dialect/SPIRV/SPIRVArithmeticOps.td"
 include "mlir/Dialect/SPIRV/SPIRVBitOps.td"
 include "mlir/Dialect/SPIRV/SPIRVCastOps.td"
 include "mlir/Dialect/SPIRV/SPIRVControlFlowOps.td"
+include "mlir/Dialect/SPIRV/SPIRVGLSLOps.td"
+include "mlir/Dialect/SPIRV/SPIRVGroupOps.td"
 include "mlir/Dialect/SPIRV/SPIRVLogicalOps.td"
-// Pull in ops for defining the SPIR-V module structure
 include "mlir/Dialect/SPIRV/SPIRVStructureOps.td"
-// Pull in ops for extended instruction set for GLSL
-include "mlir/Dialect/SPIRV/SPIRVGLSLOps.td"
 
 // -----
 
index e8896fa..6e115f7 100644 (file)
@@ -2527,6 +2527,28 @@ static LogicalResult verify(spirv::StoreOp storeOp) {
 }
 
 //===----------------------------------------------------------------------===//
+// spv.SubgroupBallotKHROp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseSubgroupBallotKHROp(OpAsmParser &parser,
+                                            OperationState &state) {
+  OpAsmParser::OperandType operandInfo;
+  Type resultType;
+  IntegerType i1Type = parser.getBuilder().getI1Type();
+  if (parser.parseOperand(operandInfo) || parser.parseColonType(resultType) ||
+      parser.resolveOperand(operandInfo, i1Type, state.operands))
+    return failure();
+
+  return parser.addTypeToList(resultType, state.types);
+}
+
+static void print(spirv::SubgroupBallotKHROp ballotOp, OpAsmPrinter &printer) {
+  printer << spirv::SubgroupBallotKHROp::getOperationName() << ' ';
+  printer.printOperand(ballotOp.predicate());
+  printer << " : " << ballotOp.getType();
+}
+
+//===----------------------------------------------------------------------===//
 // spv.Undef
 //===----------------------------------------------------------------------===//
 
@@ -2595,11 +2617,10 @@ static ParseResult parseVariableOp(OpAsmParser &parser, OperationState &state) {
   state.addTypes(ptrType);
 
   // Resolve the initializer operand
-  SmallVector<Value *, 1> init;
   if (initInfo) {
-    if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(), init))
+    if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(),
+                              state.operands))
       return failure();
-    state.addOperands(init);
   }
 
   auto attr = parser.getBuilder().getI32IntegerAttr(
diff --git a/mlir/test/Dialect/SPIRV/Serialization/group-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/group-ops.mlir
new file mode 100644 (file)
index 0000000..020a4cb
--- /dev/null
@@ -0,0 +1,10 @@
+// RUN: mlir-translate -test-spirv-roundtrip -split-input-file %s | FileCheck %s
+
+spv.module "Logical" "GLSL450" {
+  // CHECK-LABEL: @subgroup_ballot
+  func @subgroup_ballot(%predicate: i1) -> vector<4xi32> {
+    // CHECK: %{{.*}} = spv.SubgroupBallotKHR %{{.*}}: vector<4xi32>
+    %0 = spv.SubgroupBallotKHR %predicate: vector<4xi32>
+    spv.ReturnValue %0: vector<4xi32>
+  }
+}
index b27ebbb..784af94 100644 (file)
@@ -1214,6 +1214,17 @@ spv.module "Logical" "GLSL450" {
 // -----
 
 //===----------------------------------------------------------------------===//
+// spv.SubgroupBallotKHR
+//===----------------------------------------------------------------------===//
+
+func @subgroup_ballot(%predicate: i1) -> vector<4xi32> {
+  %0 = spv.SubgroupBallotKHR %predicate: vector<4xi32>
+  return %0: vector<4xi32>
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
 // spv.undef
 //===----------------------------------------------------------------------===//
 
index 5ef5667..d1530f7 100755 (executable)
@@ -426,7 +426,11 @@ def get_op_definition(instruction, doc, existing_info):
   # Make sure we have ', ' to separate the category arguments from traits
   category_args = category_args.rstrip(', ') + ', '
 
-  summary, text = doc.split('\n', 1)
+  if '\n' in doc:
+    summary, text = doc.split('\n', 1)
+  else:
+    summary = doc
+    text = ''
   wrapper = textwrap.TextWrapper(
       width=76, initial_indent='    ', subsequent_indent='    ')