From 1af9633d851d1827b28b8c00c74e55247075f554 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Tue, 3 Dec 2019 04:49:20 -0800 Subject: [PATCH] [spirv] Add spv.SubgroupBallotKHROp PiperOrigin-RevId: 283522284 --- mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td | 8 ++- mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td | 74 ++++++++++++++++++++++ mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td | 5 +- mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | 27 +++++++- .../Dialect/SPIRV/Serialization/group-ops.mlir | 10 +++ mlir/test/Dialect/SPIRV/ops.mlir | 11 ++++ mlir/utils/spirv/gen_spirv_dialect.py | 6 +- 7 files changed, 131 insertions(+), 10 deletions(-) create mode 100644 mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td create mode 100644 mlir/test/Dialect/SPIRV/Serialization/group-ops.mlir diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td index e1897a9..bfb7497 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -953,6 +953,8 @@ class SPV_ScalarOrVectorOf : def SPV_ScalarOrVector : AnyTypeOf<[SPV_Scalar, SPV_Vector]>; def SPV_ScalarOrVectorOrPtr : AnyTypeOf<[SPV_ScalarOrVector, SPV_AnyPtr]>; +def SPV_I32Vec4 : VectorOfLengthAndType<[4], [I32]>; + // TODO(antiagainst): Use a more appropriate way to model optional operands class SPV_Optional : Variadic; @@ -1107,6 +1109,7 @@ def SPV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>; def SPV_OC_OpReturnValue : I32EnumAttrCase<"OpReturnValue", 254>; def SPV_OC_OpUnreachable : I32EnumAttrCase<"OpUnreachable", 255>; def SPV_OC_OpModuleProcessed : I32EnumAttrCase<"OpModuleProcessed", 330>; +def SPV_OC_OpSubgroupBallotKHR : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>; def SPV_OpcodeAttr : I32EnumAttr<"Opcode", "valid SPIR-V instructions", [ @@ -1146,10 +1149,9 @@ def SPV_OpcodeAttr : SPV_OC_OpBitCount, SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier, SPV_OC_OpPhi, SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn, - SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpModuleProcessed + SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpModuleProcessed, + SPV_OC_OpSubgroupBallotKHR ]> { - let returnType = "::mlir::spirv::Opcode"; - let convertFromStorage = "static_cast<::mlir::spirv::Opcode>($_self.getInt())"; let cppNamespace = "::mlir::spirv"; } diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td new file mode 100644 index 0000000..5f60e6b --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVGroupOps.td @@ -0,0 +1,74 @@ +//===-- SPIRVGroupOps.td - MLIR SPIR-V (Sub)Group Ops ------*- tablegen -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file contains group and subgroup ops for the SPIR-V dialect. It +// corresponds to "3.32.21. Group and Subgroup Instructions" of the SPIR-V +// specification. +// +//===----------------------------------------------------------------------===// + +#ifndef SPIRV_GROUP_OPS +#define SPIRV_GROUP_OPS + +// ----- + +def SPV_SubgroupBallotKHROp : SPV_Op<"SubgroupBallotKHR", []> { + let summary = "See extension SPV_KHR_shader_ballot"; + + let description = [{ + Computes a bitfield value combining the Predicate value from all invocations + in the current Subgroup that execute the same dynamic instance of this + instruction. The bit is set to one if the corresponding invocation is active + and the predicate is evaluated to true; otherwise, it is set to zero. + + Predicate must be a Boolean type. + + Result Type must be a 4 component vector of 32 bit integer types. + + Result is a set of bitfields where the first invocation is represented in bit + 0 of the first vector component and the last (up to SubgroupSize) is the + higher bit number of the last bitmask needed to represent all bits of the + subgroup invocations. + + ### Custom assembly form + + ``` {.ebnf} + subgroup-ballot-op ::= ssa-id `=` `spv.SubgroupBallotKHR` + ssa-use `:` `vector` `<` 4 `x` `i32` `>` + ``` + + For example: + + ``` + %0 = spv.SubgroupBallotKHR %predicate : vector<4xi32> + ``` + }]; + + let arguments = (ins + SPV_Bool:$predicate + ); + + let results = (outs + SPV_I32Vec4:$result + ); + + let verifier = [{ return success(); }]; +} + +// ----- + +#endif // SPIRV_GROUP_OPS diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td index 41d729d..178db0a 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td @@ -34,11 +34,10 @@ include "mlir/Dialect/SPIRV/SPIRVArithmeticOps.td" include "mlir/Dialect/SPIRV/SPIRVBitOps.td" include "mlir/Dialect/SPIRV/SPIRVCastOps.td" include "mlir/Dialect/SPIRV/SPIRVControlFlowOps.td" +include "mlir/Dialect/SPIRV/SPIRVGLSLOps.td" +include "mlir/Dialect/SPIRV/SPIRVGroupOps.td" include "mlir/Dialect/SPIRV/SPIRVLogicalOps.td" -// Pull in ops for defining the SPIR-V module structure include "mlir/Dialect/SPIRV/SPIRVStructureOps.td" -// Pull in ops for extended instruction set for GLSL -include "mlir/Dialect/SPIRV/SPIRVGLSLOps.td" // ----- diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index e8896fa..6e115f7 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -2527,6 +2527,28 @@ static LogicalResult verify(spirv::StoreOp storeOp) { } //===----------------------------------------------------------------------===// +// spv.SubgroupBallotKHROp +//===----------------------------------------------------------------------===// + +static ParseResult parseSubgroupBallotKHROp(OpAsmParser &parser, + OperationState &state) { + OpAsmParser::OperandType operandInfo; + Type resultType; + IntegerType i1Type = parser.getBuilder().getI1Type(); + if (parser.parseOperand(operandInfo) || parser.parseColonType(resultType) || + parser.resolveOperand(operandInfo, i1Type, state.operands)) + return failure(); + + return parser.addTypeToList(resultType, state.types); +} + +static void print(spirv::SubgroupBallotKHROp ballotOp, OpAsmPrinter &printer) { + printer << spirv::SubgroupBallotKHROp::getOperationName() << ' '; + printer.printOperand(ballotOp.predicate()); + printer << " : " << ballotOp.getType(); +} + +//===----------------------------------------------------------------------===// // spv.Undef //===----------------------------------------------------------------------===// @@ -2595,11 +2617,10 @@ static ParseResult parseVariableOp(OpAsmParser &parser, OperationState &state) { state.addTypes(ptrType); // Resolve the initializer operand - SmallVector init; if (initInfo) { - if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(), init)) + if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(), + state.operands)) return failure(); - state.addOperands(init); } auto attr = parser.getBuilder().getI32IntegerAttr( diff --git a/mlir/test/Dialect/SPIRV/Serialization/group-ops.mlir b/mlir/test/Dialect/SPIRV/Serialization/group-ops.mlir new file mode 100644 index 0000000..020a4cb --- /dev/null +++ b/mlir/test/Dialect/SPIRV/Serialization/group-ops.mlir @@ -0,0 +1,10 @@ +// RUN: mlir-translate -test-spirv-roundtrip -split-input-file %s | FileCheck %s + +spv.module "Logical" "GLSL450" { + // CHECK-LABEL: @subgroup_ballot + func @subgroup_ballot(%predicate: i1) -> vector<4xi32> { + // CHECK: %{{.*}} = spv.SubgroupBallotKHR %{{.*}}: vector<4xi32> + %0 = spv.SubgroupBallotKHR %predicate: vector<4xi32> + spv.ReturnValue %0: vector<4xi32> + } +} diff --git a/mlir/test/Dialect/SPIRV/ops.mlir b/mlir/test/Dialect/SPIRV/ops.mlir index b27ebbb..784af94 100644 --- a/mlir/test/Dialect/SPIRV/ops.mlir +++ b/mlir/test/Dialect/SPIRV/ops.mlir @@ -1214,6 +1214,17 @@ spv.module "Logical" "GLSL450" { // ----- //===----------------------------------------------------------------------===// +// spv.SubgroupBallotKHR +//===----------------------------------------------------------------------===// + +func @subgroup_ballot(%predicate: i1) -> vector<4xi32> { + %0 = spv.SubgroupBallotKHR %predicate: vector<4xi32> + return %0: vector<4xi32> +} + +// ----- + +//===----------------------------------------------------------------------===// // spv.undef //===----------------------------------------------------------------------===// diff --git a/mlir/utils/spirv/gen_spirv_dialect.py b/mlir/utils/spirv/gen_spirv_dialect.py index 5ef5667..d1530f7 100755 --- a/mlir/utils/spirv/gen_spirv_dialect.py +++ b/mlir/utils/spirv/gen_spirv_dialect.py @@ -426,7 +426,11 @@ def get_op_definition(instruction, doc, existing_info): # Make sure we have ', ' to separate the category arguments from traits category_args = category_args.rstrip(', ') + ', ' - summary, text = doc.split('\n', 1) + if '\n' in doc: + summary, text = doc.split('\n', 1) + else: + summary = doc + text = '' wrapper = textwrap.TextWrapper( width=76, initial_indent=' ', subsequent_indent=' ') -- 2.7.4