From 0c8f9b8099fd0500cd885bc699924e20371014ff Mon Sep 17 00:00:00 2001 From: ergawy Date: Fri, 2 Oct 2020 14:56:17 -0400 Subject: [PATCH] [MLIR][SPIRV] Add initial support for OpSpecConstantComposite. This commit adds support to SPIR-V's composite specialization constants. These are specialization constants which are composed of other spec constants (whehter scalar or composite), regular constatns, or undef values. This commit adds support for parsing, printing, verification, and (De)serialization. A few TODOs are still in order: - Supporting more types of constituents; currently, only scalar spec constatns are supported. - Extending `spv._reference_of` to support composite spec constatns. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D88568 --- .../mlir/Dialect/SPIRV/SPIRVStructureOps.td | 54 ++++++++- mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | 90 +++++++++++++++ .../Dialect/SPIRV/Serialization/Deserializer.cpp | 37 ++++++ .../lib/Dialect/SPIRV/Serialization/Serializer.cpp | 42 +++++++ .../Dialect/SPIRV/Serialization/spec-constant.mlir | 22 +++- mlir/test/Dialect/SPIRV/structure-ops.mlir | 127 +++++++++++++++++++++ 6 files changed, 369 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td index 2ac28ef..0e866f0 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td @@ -491,6 +491,8 @@ def SPV_ReferenceOfOp : SPV_Op<"_reference_of", [NoSideEffect]> { ```mlir %0 = spv._reference_of @spec_const : f32 ``` + + TODO Add support for composite specialization constants. }]; let arguments = (ins @@ -541,8 +543,6 @@ def SPV_SpecConstantOp : SPV_Op<"specConstant", [InModuleScope, Symbol]> { spv.specConstant @spec_const1 = true spv.specConstant @spec_const2 spec_id(5) = 42 : i32 ``` - - TODO: support composite spec constants with another op }]; let arguments = (ins @@ -557,6 +557,56 @@ def SPV_SpecConstantOp : SPV_Op<"specConstant", [InModuleScope, Symbol]> { let autogenSerialization = 0; } +def SPV_SpecConstantCompositeOp : SPV_Op<"specConstantComposite", [InModuleScope, Symbol]> { + let summary = "Declare a new composite specialization constant."; + + let description = [{ + This op declares a SPIR-V composite specialization constant. This covers + the `OpSpecConstantComposite` SPIR-V instruction. Scalar constants are + covered by `spv.specConstant`. + + A constituent of a spec constant composite can be: + - A symbol referring of another spec constant. + - The SSA ID of a non-specialization constant (i.e. defined through + `spv.specConstant`). + - The SSA ID of a `spv.undef`. + + ``` + spv-spec-constant-composite-op ::= `spv.specConstantComposite` symbol-ref-id ` (` + symbol-ref-id (`, ` symbol-ref-id)* + `) :` composite-type + ``` + + where `composite-type` is some non-scalar type that can be represented in the `spv` + dialect: `spv.struct`, `spv.array`, or `vector`. + + #### Example: + + ```mlir + spv.specConstant @sc1 = 1 : i32 + spv.specConstant @sc2 = 2.5 : f32 + spv.specConstant @sc3 = 3.5 : f32 + spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.struct + ``` + + TODO Add support for constituents that are: + - regular constants. + - undef. + - spec constant composite. + }]; + + let arguments = (ins + TypeAttr:$type, + StrAttr:$sym_name, + SymbolRefArrayAttr:$constituents + ); + + let results = (outs); + + let hasOpcode = 0; + + let autogenSerialization = 0; +} // ----- #endif // SPIRV_STRUCTURE_OPS diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index a011771..363785e 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -53,6 +53,7 @@ static constexpr const char kTypeAttrName[] = "type"; static constexpr const char kUnequalSemanticsAttrName[] = "unequal_semantics"; static constexpr const char kValueAttrName[] = "value"; static constexpr const char kValuesAttrName[] = "values"; +static constexpr const char kCompositeSpecConstituentsName[] = "constituents"; //===----------------------------------------------------------------------===// // Common utility functions @@ -3287,6 +3288,95 @@ static LogicalResult verifyMatrixTimesMatrix(spirv::MatrixTimesMatrixOp op) { return success(); } +//===----------------------------------------------------------------------===// +// spv.specConstantComposite +//===----------------------------------------------------------------------===// + +static ParseResult parseSpecConstantCompositeOp(OpAsmParser &parser, + OperationState &state) { + + StringAttr compositeName; + if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(), + state.attributes)) + return failure(); + + if (parser.parseLParen()) + return failure(); + + SmallVector constituents; + + do { + // The name of the constituent attribute isn't important + const char *attrName = "spec_const"; + FlatSymbolRefAttr specConstRef; + NamedAttrList attrs; + + if (parser.parseAttribute(specConstRef, Type(), attrName, attrs)) + return failure(); + + constituents.push_back(specConstRef); + } while (!parser.parseOptionalComma()); + + if (parser.parseRParen()) + return failure(); + + state.addAttribute(kCompositeSpecConstituentsName, + parser.getBuilder().getArrayAttr(constituents)); + + Type type; + if (parser.parseColonType(type)) + return failure(); + + state.addAttribute(kTypeAttrName, TypeAttr::get(type)); + + return success(); +} + +static void print(spirv::SpecConstantCompositeOp op, OpAsmPrinter &printer) { + printer << spirv::SpecConstantCompositeOp::getOperationName() << " "; + printer.printSymbolName(op.sym_name()); + printer << " ("; + auto constituents = op.constituents().getValue(); + + if (!constituents.empty()) + llvm::interleaveComma(constituents, printer); + + printer << ") : " << op.type(); +} + +static LogicalResult verify(spirv::SpecConstantCompositeOp constOp) { + auto cType = constOp.type().dyn_cast(); + auto constituents = constOp.constituents().getValue(); + + if (!cType) + return constOp.emitError( + "result type must be a composite type, but provided ") + << constOp.type(); + + if (cType.isa()) + return constOp.emitError("unsupported composite type ") << cType; + else if (constituents.size() != cType.getNumElements()) + return constOp.emitError("has incorrect number of operands: expected ") + << cType.getNumElements() << ", but provided " + << constituents.size(); + + for (auto index : llvm::seq(0, constituents.size())) { + auto constituent = constituents[index].dyn_cast(); + + auto constituentSpecConstOp = + dyn_cast(SymbolTable::lookupNearestSymbolFrom( + constOp.getParentOp(), constituent.getValue())); + + if (constituentSpecConstOp.default_value().getType() != + cType.getElementType(index)) + return constOp.emitError("has incorrect types of operands: expected ") + << cType.getElementType(index) << ", but provided " + << constituentSpecConstOp.default_value().getType(); + } + + return success(); +} + namespace mlir { namespace spirv { diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp index b5eea43..153540d 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -249,6 +249,8 @@ private: /// `operands`. LogicalResult processConstantComposite(ArrayRef operands); + LogicalResult processSpecConstantComposite(ArrayRef operands); + /// Processes a SPIR-V OpConstantNull instruction with the given `operands`. LogicalResult processConstantNull(ArrayRef operands); @@ -1546,6 +1548,39 @@ Deserializer::processConstantComposite(ArrayRef operands) { return success(); } +LogicalResult +Deserializer::processSpecConstantComposite(ArrayRef operands) { + if (operands.size() < 2) { + return emitError(unknownLoc, + "OpConstantComposite must have type and result "); + } + if (operands.size() < 3) { + return emitError(unknownLoc, + "OpConstantComposite must have at least 1 parameter"); + } + + Type resultType = getType(operands[0]); + if (!resultType) { + return emitError(unknownLoc, "undefined result type from ") + << operands[0]; + } + + auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(operands[1])); + + SmallVector elements; + elements.reserve(operands.size() - 2); + for (unsigned i = 2, e = operands.size(); i < e; ++i) { + auto elementInfo = getSpecConstant(operands[i]); + elements.push_back(opBuilder.getSymbolRefAttr(elementInfo)); + } + + opBuilder.create( + unknownLoc, TypeAttr::get(resultType), symName, + opBuilder.getArrayAttr(elements)); + + return success(); +} + LogicalResult Deserializer::processConstantNull(ArrayRef operands) { if (operands.size() != 2) { return emitError(unknownLoc, @@ -2276,6 +2311,8 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode, return processConstant(operands, /*isSpec=*/true); case spirv::Opcode::OpConstantComposite: return processConstantComposite(operands); + case spirv::Opcode::OpSpecConstantComposite: + return processSpecConstantComposite(operands); case spirv::Opcode::OpConstantTrue: return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false); case spirv::Opcode::OpSpecConstantTrue: diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp index 1eda166..426c838 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -200,6 +200,9 @@ private: LogicalResult processSpecConstantOp(spirv::SpecConstantOp op); + LogicalResult + processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op); + /// SPIR-V dialect supports OpUndef using spv.UndefOp that produces a SSA /// value to use with other operations. The SPIR-V spec recommends that /// OpUndef be generated at module level. The serialization generates an @@ -645,6 +648,42 @@ LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) { return failure(); } +LogicalResult +Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) { + uint32_t typeID = 0; + if (failed(processType(op.getLoc(), op.type(), typeID))) { + return failure(); + } + + auto resultID = getNextID(); + + SmallVector operands; + operands.push_back(typeID); + operands.push_back(resultID); + + auto constituents = op.constituents(); + + for (auto index : llvm::seq(0, constituents.size())) { + auto constituent = constituents[index].dyn_cast(); + + auto constituentName = constituent.getValue(); + auto constituentID = getSpecConstID(constituentName); + + if (!constituentID) { + return op.emitError("unknown result for specialization constant ") + << constituentName; + } + + operands.push_back(constituentID); + } + + encodeInstructionInto(typesGlobalValues, + spirv::Opcode::OpSpecConstantComposite, operands); + specConstIDMap[op.sym_name()] = resultID; + + return processName(resultID, op.sym_name()); +} + LogicalResult Serializer::processUndefOp(spirv::UndefOp op) { auto undefType = op.getType(); auto &id = undefValIDMap[undefType]; @@ -1765,6 +1804,9 @@ LogicalResult Serializer::processOperation(Operation *opInst) { .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); }) .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); }) .Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); }) + .Case([&](spirv::SpecConstantCompositeOp op) { + return processSpecConstantCompositeOp(op); + }) .Case([&](spirv::UndefOp op) { return processUndefOp(op); }) .Case([&](spirv::VariableOp op) { return processVariableOp(op); }) diff --git a/mlir/test/Dialect/SPIRV/Serialization/spec-constant.mlir b/mlir/test/Dialect/SPIRV/Serialization/spec-constant.mlir index 03cc85b..0df9301 100644 --- a/mlir/test/Dialect/SPIRV/Serialization/spec-constant.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/spec-constant.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-translate -test-spirv-roundtrip %s | FileCheck %s +// RUN: mlir-translate -test-spirv-roundtrip -split-input-file %s | FileCheck %s spv.module Logical GLSL450 requires #spv.vce { // CHECK: spv.specConstant @sc_true = true @@ -25,3 +25,23 @@ spv.module Logical GLSL450 requires #spv.vce { spv.ReturnValue %1 : i32 } } + +// ----- + +spv.module Logical GLSL450 requires #spv.vce { + + spv.specConstant @sc_f32_1 = 1.5 : f32 + spv.specConstant @sc_f32_2 = 2.5 : f32 + spv.specConstant @sc_f32_3 = 3.5 : f32 + + spv.specConstant @sc_i32_1 = 1 : i32 + + // CHECK: spv.specConstantComposite @scc_array (@sc_f32_1, @sc_f32_2, @sc_f32_3) : !spv.array<3 x f32> + spv.specConstantComposite @scc_array (@sc_f32_1, @sc_f32_2, @sc_f32_3) : !spv.array<3 x f32> + + // CHECK: spv.specConstantComposite @scc_struct (@sc_i32_1, @sc_f32_2, @sc_f32_3) : !spv.struct + spv.specConstantComposite @scc_struct (@sc_i32_1, @sc_f32_2, @sc_f32_3) : !spv.struct + + // CHECK: spv.specConstantComposite @scc_vector (@sc_f32_1, @sc_f32_2, @sc_f32_3) : vector<3xf32> + spv.specConstantComposite @scc_vector (@sc_f32_1, @sc_f32_2, @sc_f32_3) : vector<3 x f32> +} diff --git a/mlir/test/Dialect/SPIRV/structure-ops.mlir b/mlir/test/Dialect/SPIRV/structure-ops.mlir index 98da480..765eba9 100644 --- a/mlir/test/Dialect/SPIRV/structure-ops.mlir +++ b/mlir/test/Dialect/SPIRV/structure-ops.mlir @@ -596,3 +596,130 @@ func @use_in_function() -> () { spv.specConstant @sc = false return } + +// ----- + +//===----------------------------------------------------------------------===// +// spv.specConstantComposite +//===----------------------------------------------------------------------===// + +spv.module Logical GLSL450 { + // expected-error @+1 {{result type must be a composite type}} + spv.specConstantComposite @scc2 (@sc1, @sc2, @sc3) : i32 +} + +//===----------------------------------------------------------------------===// +// spv.specConstantComposite (spv.array) +//===----------------------------------------------------------------------===// + +// ----- + +spv.module Logical GLSL450 { + spv.specConstant @sc1 = 1.5 : f32 + spv.specConstant @sc2 = 2.5 : f32 + spv.specConstant @sc3 = 3.5 : f32 + // CHECK: spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.array<3 x f32> + spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.array<3 x f32> +} + +// ----- + +spv.module Logical GLSL450 { + spv.specConstant @sc1 = false + spv.specConstant @sc2 spec_id(5) = 42 : i64 + spv.specConstant @sc3 = 1.5 : f32 + // expected-error @+1 {{has incorrect number of operands: expected 4, but provided 3}} + spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.array<4 x f32> + +} + +// ----- + +spv.module Logical GLSL450 { + spv.specConstant @sc1 = 1 : i32 + spv.specConstant @sc2 = 2.5 : f32 + spv.specConstant @sc3 = 3.5 : f32 + // expected-error @+1 {{has incorrect types of operands: expected 'f32', but provided 'i32'}} + spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.array<3 x f32> +} + +//===----------------------------------------------------------------------===// +// spv.specConstantComposite (spv.struct) +//===----------------------------------------------------------------------===// + +// ----- + +spv.module Logical GLSL450 { + spv.specConstant @sc1 = 1 : i32 + spv.specConstant @sc2 = 2.5 : f32 + spv.specConstant @sc3 = 3.5 : f32 + // CHECK: spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.struct + spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.struct +} + +// ----- + +spv.module Logical GLSL450 { + spv.specConstant @sc1 = 1 : i32 + spv.specConstant @sc2 = 2.5 : f32 + spv.specConstant @sc3 = 3.5 : f32 + // expected-error @+1 {{has incorrect number of operands: expected 2, but provided 3}} + spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.struct +} + +// ----- + +spv.module Logical GLSL450 { + spv.specConstant @sc1 = 1.5 : f32 + spv.specConstant @sc2 = 2.5 : f32 + spv.specConstant @sc3 = 3.5 : f32 + // expected-error @+1 {{has incorrect types of operands: expected 'i32', but provided 'f32'}} + spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.struct +} + +//===----------------------------------------------------------------------===// +// spv.specConstantComposite (vector) +//===----------------------------------------------------------------------===// + +// ----- + +spv.module Logical GLSL450 { + spv.specConstant @sc1 = 1.5 : f32 + spv.specConstant @sc2 = 2.5 : f32 + spv.specConstant @sc3 = 3.5 : f32 + // CHECK: spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : vector<3xf32> + spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : vector<3 x f32> +} + +// ----- + +spv.module Logical GLSL450 { + spv.specConstant @sc1 = false + spv.specConstant @sc2 spec_id(5) = 42 : i64 + spv.specConstant @sc3 = 1.5 : f32 + // expected-error @+1 {{has incorrect number of operands: expected 4, but provided 3}} + spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : vector<4xf32> + +} + +// ----- + +spv.module Logical GLSL450 { + spv.specConstant @sc1 = 1 : i32 + spv.specConstant @sc2 = 2.5 : f32 + spv.specConstant @sc3 = 3.5 : f32 + // expected-error @+1 {{has incorrect types of operands: expected 'f32', but provided 'i32'}} + spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : vector<3xf32> +} + +//===----------------------------------------------------------------------===// +// spv.specConstantComposite (spv.coopmatrix) +//===----------------------------------------------------------------------===// + +// ----- + +spv.module Logical GLSL450 { + spv.specConstant @sc1 = 1.5 : f32 + // expected-error @+1 {{unsupported composite type}} + spv.specConstantComposite @scc (@sc1) : !spv.coopmatrix<8x16xf32, Device> +} -- 2.7.4