def SPV_ScalarOrVector : AnyTypeOf<[SPV_Scalar, SPV_Vector]>;
def SPV_ScalarOrVectorOrPtr : AnyTypeOf<[SPV_ScalarOrVector, SPV_AnyPtr]>;
-def SPV_I32Vec4 : VectorOfLengthAndType<[4], [I32]>;
+class SPV_Vec4<Type type> : VectorOfLengthAndType<[4], [type]>;
+def SPV_IntVec4 : SPV_Vec4<SPV_Integer>;
+def SPV_I32Vec4 : SPV_Vec4<I32>;
// TODO(antiagainst): Use a more appropriate way to model optional operands
class SPV_Optional<Type type> : Variadic<type>;
def SPV_OC_OpReturnValue : I32EnumAttrCase<"OpReturnValue", 254>;
def SPV_OC_OpUnreachable : I32EnumAttrCase<"OpUnreachable", 255>;
def SPV_OC_OpModuleProcessed : I32EnumAttrCase<"OpModuleProcessed", 330>;
+def SPV_OC_OpGroupNonUniformBallot : I32EnumAttrCase<"OpGroupNonUniformBallot", 339>;
def SPV_OC_OpSubgroupBallotKHR : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
def SPV_OpcodeAttr :
SPV_OC_OpPhi, SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel,
SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn,
SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpModuleProcessed,
- SPV_OC_OpSubgroupBallotKHR
+ SPV_OC_OpGroupNonUniformBallot, SPV_OC_OpSubgroupBallotKHR
]> {
let cppNamespace = "::mlir::spirv";
}
--- /dev/null
+//===-- SPIRVNonUniformOps.td - MLIR SPIR-V NonUniform Ops -*- tablegen -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file contains non-uniform ops for the SPIR-V dialect. It corresponds to
+// "3.32.24. Non-Uniform Instructions" of the SPIR-V specification.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef SPIRV_NON_UNIFORM_OPS
+#define SPIRV_NON_UNIFORM_OPS
+
+// -----
+
+def SPV_GroupNonUniformBallotOp : SPV_Op<"GroupNonUniformBallot", []> {
+ let summary = [{
+ Returns a bitfield value combining the Predicate value from all
+ invocations in the group that execute the same dynamic instance of this
+ instruction. The bit is set to one if the corresponding invocation is
+ active and the Predicate for that invocation evaluated to true;
+ otherwise, it is set to zero.
+ }];
+
+ let description = [{
+ Result Type must be a vector of four components of integer type scalar,
+ whose Signedness operand is 0.
+
+ Result is a set of bitfields where the first invocation is represented
+ in the lowest bit of the first vector component and the last (up to the
+ size of the group) is the higher bit number of the last bitmask needed
+ to represent all bits of the group invocations.
+
+ Execution must be Workgroup or Subgroup Scope.
+
+ Predicate must be a Boolean type.
+
+ ### Custom assembly form
+
+ ``` {.ebnf}
+ scope ::= `"Workgroup"` | `"Subgroup"`
+ non-uniform-ballot-op ::= ssa-id `=` `spv.GroupNonUniformBallot` scope
+ ssa-use `:` `vector` `<` 4 `x` `integer-type` `>`
+ ```
+
+ For example:
+
+ ```
+ %0 = spv.GroupNonUniformBallot "SubGroup" %predicate : vector<4xi32>
+ ```
+ }];
+
+ let arguments = (ins
+ SPV_ScopeAttr:$execution_scope,
+ SPV_Bool:$predicate
+ );
+
+ let results = (outs
+ SPV_IntVec4:$result
+ );
+}
+
+// -----
+
+#endif // SPIRV_NON_UNIFORM_OPS
+
//
//===----------------------------------------------------------------------===//
-// Note that for each op in this file, we use a tool to automatically generate
-// certain sections in its definition: basic structure, summary, description.
-// So modifications to these sections will not be respected. Modifications to
-// op traits, arguments, results, and sections after the results are retained.
-// Besides, ops in this file must be separated via the '// -----' marker.
+// Note that for each op in this file and the included files for specific op
+// categories, we use a tool to automatically generate certain sections in its
+// definition: basic structure, summary, description. So modifications to these
+// sections will not be respected. Modifications to op traits, arguments,
+// results, and sections after the results are retained. Besides, ops must be
+// separated via the '// -----' marker.
#ifndef SPIRV_OPS
#define SPIRV_OPS
include "mlir/Dialect/SPIRV/SPIRVGLSLOps.td"
include "mlir/Dialect/SPIRV/SPIRVGroupOps.td"
include "mlir/Dialect/SPIRV/SPIRVLogicalOps.td"
+include "mlir/Dialect/SPIRV/SPIRVNonUniformOps.td"
include "mlir/Dialect/SPIRV/SPIRVStructureOps.td"
// -----
// TableGen'erated canonicalizers
//===----------------------------------------------------------------------===//
+namespace {
#include "SPIRVCanonicalization.inc"
+}
//===----------------------------------------------------------------------===//
// Common parsers and printers
}
//===----------------------------------------------------------------------===//
+// spv.GroupNonUniformBallotOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseGroupNonUniformBallotOp(OpAsmParser &parser,
+ OperationState &state) {
+ spirv::Scope executionScope;
+ OpAsmParser::OperandType operandInfo;
+ Type resultType;
+ IntegerType i1Type = parser.getBuilder().getI1Type();
+ if (parseEnumAttribute(executionScope, parser, state,
+ kExecutionScopeAttrName) ||
+ parser.parseOperand(operandInfo) || parser.parseColonType(resultType) ||
+ parser.resolveOperand(operandInfo, i1Type, state.operands))
+ return failure();
+
+ return parser.addTypeToList(resultType, state.types);
+}
+
+static void print(spirv::GroupNonUniformBallotOp ballotOp,
+ OpAsmPrinter &printer) {
+ printer << spirv::GroupNonUniformBallotOp::getOperationName() << " \""
+ << stringifyScope(ballotOp.execution_scope()) << "\" ";
+ printer.printOperand(ballotOp.predicate());
+ printer << " : " << ballotOp.getType();
+}
+
+static LogicalResult verify(spirv::GroupNonUniformBallotOp ballotOp) {
+ // TODO(antiagainst): check the result integer type's signedness bit is 0.
+
+ spirv::Scope scope = ballotOp.execution_scope();
+ if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
+ return ballotOp.emitOpError(
+ "execution scope must be 'Workgroup' or 'Subgroup'");
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// spv.IAdd
//===----------------------------------------------------------------------===//
--- /dev/null
+// RUN: mlir-translate -test-spirv-roundtrip -split-input-file %s | FileCheck %s
+
+spv.module "Logical" "GLSL450" {
+ // CHECK-LABEL: @group_non_uniform_ballot
+ func @group_non_uniform_ballot(%predicate: i1) -> vector<4xi32> {
+ // CHECK: %{{.*}} = spv.GroupNonUniformBallot "Workgroup" %{{.*}}: vector<4xi32>
+ %0 = spv.GroupNonUniformBallot "Workgroup" %predicate : vector<4xi32>
+ spv.ReturnValue %0: vector<4xi32>
+ }
+}
--- /dev/null
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spv.SubgroupBallotKHR
+//===----------------------------------------------------------------------===//
+
+func @subgroup_ballot(%predicate: i1) -> vector<4xi32> {
+ // CHECK: %{{.*}} = spv.SubgroupBallotKHR %{{.*}} : vector<4xi32>
+ %0 = spv.SubgroupBallotKHR %predicate: vector<4xi32>
+ return %0: vector<4xi32>
+}
--- /dev/null
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spv.GroupNonUniformBallot
+//===----------------------------------------------------------------------===//
+
+func @subgroup_ballot(%predicate: i1) -> vector<4xi32> {
+ // CHECK: %{{.*}} = spv.GroupNonUniformBallot "Workgroup" %{{.*}}: vector<4xi32>
+ %0 = spv.GroupNonUniformBallot "Workgroup" %predicate : vector<4xi32>
+ return %0: vector<4xi32>
+}
+
+// -----
+
+func @subgroup_ballot(%predicate: i1) -> vector<4xi32> {
+ // expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
+ %0 = spv.GroupNonUniformBallot "Device" %predicate : vector<4xi32>
+ return %0: vector<4xi32>
+}