[spirv] Add spv.GroupNonUniformBallot
authorLei Zhang <antiagainst@google.com>
Wed, 4 Dec 2019 00:43:40 +0000 (16:43 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 4 Dec 2019 00:44:09 +0000 (16:44 -0800)
This CL also did the following cleanup:
- Moved the test for spv.SubgroupBallotKHR to its own file
- Wrapped generated canonicalization patterns in anonymous namespace
- Updated header comments in SPVOps.td

PiperOrigin-RevId: 283650091

mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td [new file with mode: 0644]
mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/test/Dialect/SPIRV/Serialization/non-uniform-ops.mlir [new file with mode: 0644]
mlir/test/Dialect/SPIRV/group-ops.mlir [new file with mode: 0644]
mlir/test/Dialect/SPIRV/non-uniform-ops.mlir [new file with mode: 0644]

index bfb7497..2ee8f3b 100644 (file)
@@ -953,7 +953,9 @@ class SPV_ScalarOrVectorOf<Type type> :
 def SPV_ScalarOrVector : AnyTypeOf<[SPV_Scalar, SPV_Vector]>;
 def SPV_ScalarOrVectorOrPtr : AnyTypeOf<[SPV_ScalarOrVector, SPV_AnyPtr]>;
 
-def SPV_I32Vec4 : VectorOfLengthAndType<[4], [I32]>;
+class SPV_Vec4<Type type> : VectorOfLengthAndType<[4], [type]>;
+def SPV_IntVec4 : SPV_Vec4<SPV_Integer>;
+def SPV_I32Vec4 : SPV_Vec4<I32>;
 
 // TODO(antiagainst): Use a more appropriate way to model optional operands
 class SPV_Optional<Type type> : Variadic<type>;
@@ -1109,6 +1111,7 @@ def SPV_OC_OpReturn                 : I32EnumAttrCase<"OpReturn", 253>;
 def SPV_OC_OpReturnValue            : I32EnumAttrCase<"OpReturnValue", 254>;
 def SPV_OC_OpUnreachable            : I32EnumAttrCase<"OpUnreachable", 255>;
 def SPV_OC_OpModuleProcessed        : I32EnumAttrCase<"OpModuleProcessed", 330>;
+def SPV_OC_OpGroupNonUniformBallot  : I32EnumAttrCase<"OpGroupNonUniformBallot", 339>;
 def SPV_OC_OpSubgroupBallotKHR      : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
 
 def SPV_OpcodeAttr :
@@ -1150,7 +1153,7 @@ def SPV_OpcodeAttr :
       SPV_OC_OpPhi, SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel,
       SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn,
       SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpModuleProcessed,
-      SPV_OC_OpSubgroupBallotKHR
+      SPV_OC_OpGroupNonUniformBallot, SPV_OC_OpSubgroupBallotKHR
       ]> {
     let cppNamespace = "::mlir::spirv";
 }
diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVNonUniformOps.td
new file mode 100644 (file)
index 0000000..a37f5b5
--- /dev/null
@@ -0,0 +1,78 @@
+//===-- SPIRVNonUniformOps.td - MLIR SPIR-V NonUniform Ops -*- tablegen -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file contains non-uniform ops for the SPIR-V dialect. It corresponds to
+// "3.32.24. Non-Uniform Instructions" of the SPIR-V specification.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef SPIRV_NON_UNIFORM_OPS
+#define SPIRV_NON_UNIFORM_OPS
+
+// -----
+
+def SPV_GroupNonUniformBallotOp : SPV_Op<"GroupNonUniformBallot", []> {
+  let summary = [{
+    Returns a bitfield value combining the Predicate value from all
+    invocations in the group that execute the same dynamic instance of this
+    instruction. The bit is set to one if the corresponding invocation is
+    active and the Predicate for that invocation evaluated to true;
+    otherwise, it is set to zero.
+  }];
+
+  let description = [{
+    Result Type  must be a vector of four components of integer type scalar,
+    whose Signedness operand is 0.
+
+    Result is a set of bitfields where the first invocation is represented
+    in the lowest bit of the first vector component and the last (up to the
+    size of the group) is the higher bit number of the last bitmask needed
+    to represent all bits of the group invocations.
+
+    Execution must be Workgroup or Subgroup Scope.
+
+    Predicate must be a Boolean type.
+
+    ### Custom assembly form
+
+    ``` {.ebnf}
+    scope ::= `"Workgroup"` | `"Subgroup"`
+    non-uniform-ballot-op ::= ssa-id `=` `spv.GroupNonUniformBallot` scope
+                              ssa-use `:` `vector` `<` 4 `x` `integer-type` `>`
+    ```
+
+    For example:
+
+    ```
+    %0 = spv.GroupNonUniformBallot "SubGroup" %predicate : vector<4xi32>
+    ```
+  }];
+
+  let arguments = (ins
+    SPV_ScopeAttr:$execution_scope,
+    SPV_Bool:$predicate
+  );
+
+  let results = (outs
+    SPV_IntVec4:$result
+  );
+}
+
+// -----
+
+#endif // SPIRV_NON_UNIFORM_OPS
+
index 178db0a..149c235 100644 (file)
 //
 //===----------------------------------------------------------------------===//
 
-// Note that for each op in this file, we use a tool to automatically generate
-// certain sections in its definition: basic structure, summary, description.
-// So modifications to these sections will not be respected. Modifications to
-// op traits, arguments, results, and sections after the results are retained.
-// Besides, ops in this file must be separated via the '// -----' marker.
+// Note that for each op in this file and the included files for specific op
+// categories, we use a tool to automatically generate certain sections in its
+// definition: basic structure, summary, description. So modifications to these
+// sections will not be respected. Modifications to op traits, arguments,
+// results, and sections after the results are retained. Besides, ops must be
+// separated via the '// -----' marker.
 
 #ifndef SPIRV_OPS
 #define SPIRV_OPS
@@ -37,6 +38,7 @@ include "mlir/Dialect/SPIRV/SPIRVControlFlowOps.td"
 include "mlir/Dialect/SPIRV/SPIRVGLSLOps.td"
 include "mlir/Dialect/SPIRV/SPIRVGroupOps.td"
 include "mlir/Dialect/SPIRV/SPIRVLogicalOps.td"
+include "mlir/Dialect/SPIRV/SPIRVNonUniformOps.td"
 include "mlir/Dialect/SPIRV/SPIRVStructureOps.td"
 
 // -----
index 6e115f7..89abbe8 100644 (file)
@@ -385,7 +385,9 @@ static inline bool isMergeBlock(Block &block) {
 // TableGen'erated canonicalizers
 //===----------------------------------------------------------------------===//
 
+namespace {
 #include "SPIRVCanonicalization.inc"
+}
 
 //===----------------------------------------------------------------------===//
 // Common parsers and printers
@@ -1552,6 +1554,44 @@ static LogicalResult verify(spirv::GlobalVariableOp varOp) {
 }
 
 //===----------------------------------------------------------------------===//
+// spv.GroupNonUniformBallotOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseGroupNonUniformBallotOp(OpAsmParser &parser,
+                                                OperationState &state) {
+  spirv::Scope executionScope;
+  OpAsmParser::OperandType operandInfo;
+  Type resultType;
+  IntegerType i1Type = parser.getBuilder().getI1Type();
+  if (parseEnumAttribute(executionScope, parser, state,
+                         kExecutionScopeAttrName) ||
+      parser.parseOperand(operandInfo) || parser.parseColonType(resultType) ||
+      parser.resolveOperand(operandInfo, i1Type, state.operands))
+    return failure();
+
+  return parser.addTypeToList(resultType, state.types);
+}
+
+static void print(spirv::GroupNonUniformBallotOp ballotOp,
+                  OpAsmPrinter &printer) {
+  printer << spirv::GroupNonUniformBallotOp::getOperationName() << " \""
+          << stringifyScope(ballotOp.execution_scope()) << "\" ";
+  printer.printOperand(ballotOp.predicate());
+  printer << " : " << ballotOp.getType();
+}
+
+static LogicalResult verify(spirv::GroupNonUniformBallotOp ballotOp) {
+  // TODO(antiagainst): check the result integer type's signedness bit is 0.
+
+  spirv::Scope scope = ballotOp.execution_scope();
+  if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
+    return ballotOp.emitOpError(
+        "execution scope must be 'Workgroup' or 'Subgroup'");
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
 // spv.IAdd
 //===----------------------------------------------------------------------===//
 
diff --git a/mlir/test/Dialect/SPIRV/Serialization/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/non-uniform-ops.mlir
new file mode 100644 (file)
index 0000000..282811e
--- /dev/null
@@ -0,0 +1,10 @@
+// RUN: mlir-translate -test-spirv-roundtrip -split-input-file %s | FileCheck %s
+
+spv.module "Logical" "GLSL450" {
+  // CHECK-LABEL: @group_non_uniform_ballot
+  func @group_non_uniform_ballot(%predicate: i1) -> vector<4xi32> {
+    // CHECK: %{{.*}} = spv.GroupNonUniformBallot "Workgroup" %{{.*}}: vector<4xi32>
+  %0 = spv.GroupNonUniformBallot "Workgroup" %predicate : vector<4xi32>
+    spv.ReturnValue %0: vector<4xi32>
+  }
+}
diff --git a/mlir/test/Dialect/SPIRV/group-ops.mlir b/mlir/test/Dialect/SPIRV/group-ops.mlir
new file mode 100644 (file)
index 0000000..ba5e792
--- /dev/null
@@ -0,0 +1,11 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spv.SubgroupBallotKHR
+//===----------------------------------------------------------------------===//
+
+func @subgroup_ballot(%predicate: i1) -> vector<4xi32> {
+  // CHECK: %{{.*}} = spv.SubgroupBallotKHR %{{.*}} : vector<4xi32>
+  %0 = spv.SubgroupBallotKHR %predicate: vector<4xi32>
+  return %0: vector<4xi32>
+}
diff --git a/mlir/test/Dialect/SPIRV/non-uniform-ops.mlir b/mlir/test/Dialect/SPIRV/non-uniform-ops.mlir
new file mode 100644 (file)
index 0000000..483a731
--- /dev/null
@@ -0,0 +1,19 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spv.GroupNonUniformBallot
+//===----------------------------------------------------------------------===//
+
+func @subgroup_ballot(%predicate: i1) -> vector<4xi32> {
+  // CHECK: %{{.*}} = spv.GroupNonUniformBallot "Workgroup" %{{.*}}: vector<4xi32>
+  %0 = spv.GroupNonUniformBallot "Workgroup" %predicate : vector<4xi32>
+  return %0: vector<4xi32>
+}
+
+// -----
+
+func @subgroup_ballot(%predicate: i1) -> vector<4xi32> {
+  // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
+  %0 = spv.GroupNonUniformBallot "Device" %predicate : vector<4xi32>
+  return %0: vector<4xi32>
+}