From: ergawy Date: Mon, 5 Oct 2020 20:39:39 +0000 (-0400) Subject: [MLIR][SPIRV] Extend _reference_of to support SpecConstantCompositeOp. X-Git-Tag: llvmorg-13-init~10064 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=1b31b50d384b5f25221ac268ef781d26f5beacc1;p=platform%2Fupstream%2Fllvm.git [MLIR][SPIRV] Extend _reference_of to support SpecConstantCompositeOp. Adds support for SPIR-V composite speciailization constants to spv._reference_of. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D88732 --- diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td index 0e866f0..c64606b 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td @@ -472,7 +472,7 @@ def SPV_ReferenceOfOp : SPV_Op<"_reference_of", [NoSideEffect]> { let summary = "Reference a specialization constant."; let description = [{ - Specialization constant in module scope are defined using symbol names. + Specialization constants in module scope are defined using symbol names. This op generates an SSA value that can be used to refer to the symbol within function scope for use in ops that expect an SSA value. This operation has no corresponding SPIR-V instruction; it's merely used diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index 363785e..ad25ecb 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -2568,17 +2568,27 @@ static LogicalResult verify(spirv::ModuleOp moduleOp) { //===----------------------------------------------------------------------===// static LogicalResult verify(spirv::ReferenceOfOp referenceOfOp) { - auto specConstOp = dyn_cast_or_null( - SymbolTable::lookupNearestSymbolFrom(referenceOfOp.getParentOp(), - referenceOfOp.spec_const())); - if (!specConstOp) { - return referenceOfOp.emitOpError("expected spv.specConstant symbol"); - } - if (referenceOfOp.reference().getType() != - specConstOp.default_value().getType()) { + auto *specConstSym = SymbolTable::lookupNearestSymbolFrom( + referenceOfOp.getParentOp(), referenceOfOp.spec_const()); + Type constType; + + auto specConstOp = dyn_cast_or_null(specConstSym); + if (specConstOp) + constType = specConstOp.default_value().getType(); + + auto specConstCompositeOp = + dyn_cast_or_null(specConstSym); + if (specConstCompositeOp) + constType = specConstCompositeOp.type(); + + if (!specConstOp && !specConstCompositeOp) + return referenceOfOp.emitOpError( + "expected spv.specConstant or spv.SpecConstantComposite symbol"); + + if (referenceOfOp.reference().getType() != constType) return referenceOfOp.emitOpError("result type mismatch with the referenced " "specialization constant's type"); - } + return success(); } diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp index 153540d..33966f8 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -187,6 +187,11 @@ private: return specConstMap.lookup(id); } + /// Gets the composite specialization constant with the given result . + spirv::SpecConstantCompositeOp getSpecConstantComposite(uint32_t id) { + return specConstCompositeMap.lookup(id); + } + /// Creates a spirv::SpecConstantOp. spirv::SpecConstantOp createSpecConstant(Location loc, uint32_t resultID, Attribute defaultValue); @@ -461,9 +466,12 @@ private: /// (and type) here. Later when it's used, we materialize the constant. DenseMap> constantMap; - // Result to variable mapping. + // Result to spec constant mapping. DenseMap specConstMap; + // Result to composite spec constant mapping. + DenseMap specConstCompositeMap; + // Result to variable mapping. DenseMap globalVariableMap; @@ -1565,7 +1573,8 @@ Deserializer::processSpecConstantComposite(ArrayRef operands) { << operands[0]; } - auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(operands[1])); + auto resultID = operands[1]; + auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID)); SmallVector elements; elements.reserve(operands.size() - 2); @@ -1574,9 +1583,10 @@ Deserializer::processSpecConstantComposite(ArrayRef operands) { elements.push_back(opBuilder.getSymbolRefAttr(elementInfo)); } - opBuilder.create( + auto op = opBuilder.create( unknownLoc, TypeAttr::get(resultType), symName, opBuilder.getArrayAttr(elements)); + specConstCompositeMap[resultID] = op; return success(); } @@ -2208,6 +2218,12 @@ Value Deserializer::getValue(uint32_t id) { opBuilder.getSymbolRefAttr(constOp.getOperation())); return referenceOfOp.reference(); } + if (auto constCompositeOp = getSpecConstantComposite(id)) { + auto referenceOfOp = opBuilder.create( + unknownLoc, constCompositeOp.type(), + opBuilder.getSymbolRefAttr(constCompositeOp.getOperation())); + return referenceOfOp.reference(); + } if (auto undef = getUndefType(id)) { return opBuilder.create(unknownLoc, undef); } diff --git a/mlir/test/Dialect/SPIRV/Serialization/spec-constant.mlir b/mlir/test/Dialect/SPIRV/Serialization/spec-constant.mlir index 0df9301..2cbfcc6 100644 --- a/mlir/test/Dialect/SPIRV/Serialization/spec-constant.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/spec-constant.mlir @@ -12,6 +12,9 @@ spv.module Logical GLSL450 requires #spv.vce { // CHECK: spv.specConstant @sc_float spec_id(5) = 1.000000e+00 : f32 spv.specConstant @sc_float spec_id(5) = 1. : f32 + // CHECK: spv.specConstantComposite @scc (@sc_int, @sc_int) : !spv.array<2 x i32> + spv.specConstantComposite @scc (@sc_int, @sc_int) : !spv.array<2 x i32> + // CHECK-LABEL: @use spv.func @use() -> (i32) "None" { // We materialize a `spv._reference_of` op at every use of a @@ -24,6 +27,43 @@ spv.module Logical GLSL450 requires #spv.vce { %1 = spv.IAdd %0, %0 : i32 spv.ReturnValue %1 : i32 } + + // CHECK-LABEL: @use + spv.func @use_composite() -> (i32) "None" { + // We materialize a `spv._reference_of` op at every use of a + // specialization constant in the deserializer. So two ops here. + // CHECK: %[[USE1:.*]] = spv._reference_of @scc : !spv.array<2 x i32> + // CHECK: %[[ITM0:.*]] = spv.CompositeExtract %[[USE1]][0 : i32] : !spv.array<2 x i32> + // CHECK: %[[USE2:.*]] = spv._reference_of @scc : !spv.array<2 x i32> + // CHECK: %[[ITM1:.*]] = spv.CompositeExtract %[[USE2]][1 : i32] : !spv.array<2 x i32> + // CHECK: spv.IAdd %[[ITM0]], %[[ITM1]] + + %0 = spv._reference_of @scc : !spv.array<2 x i32> + %1 = spv.CompositeExtract %0[0 : i32] : !spv.array<2 x i32> + %2 = spv.CompositeExtract %0[1 : i32] : !spv.array<2 x i32> + %3 = spv.IAdd %1, %2 : i32 + spv.ReturnValue %3 : i32 + } +} + +// ----- + +spv.module Logical GLSL450 requires #spv.vce { + + spv.specConstant @sc_f32_1 = 1.5 : f32 + spv.specConstant @sc_f32_2 = 2.5 : f32 + spv.specConstant @sc_f32_3 = 3.5 : f32 + + spv.specConstant @sc_i32_1 = 1 : i32 + + // CHECK: spv.specConstantComposite @scc_array (@sc_f32_1, @sc_f32_2, @sc_f32_3) : !spv.array<3 x f32> + spv.specConstantComposite @scc_array (@sc_f32_1, @sc_f32_2, @sc_f32_3) : !spv.array<3 x f32> + + // CHECK: spv.specConstantComposite @scc_struct (@sc_i32_1, @sc_f32_2, @sc_f32_3) : !spv.struct + spv.specConstantComposite @scc_struct (@sc_i32_1, @sc_f32_2, @sc_f32_3) : !spv.struct + + // CHECK: spv.specConstantComposite @scc_vector (@sc_f32_1, @sc_f32_2, @sc_f32_3) : vector<3xf32> + spv.specConstantComposite @scc_vector (@sc_f32_1, @sc_f32_2, @sc_f32_3) : vector<3 x f32> } // ----- diff --git a/mlir/test/Dialect/SPIRV/structure-ops.mlir b/mlir/test/Dialect/SPIRV/structure-ops.mlir index 765eba9..7bb98b9 100644 --- a/mlir/test/Dialect/SPIRV/structure-ops.mlir +++ b/mlir/test/Dialect/SPIRV/structure-ops.mlir @@ -496,6 +496,8 @@ spv.module Logical GLSL450 { spv.specConstant @sc2 = 42 : i64 spv.specConstant @sc3 = 1.5 : f32 + spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.struct + // CHECK-LABEL: @reference spv.func @reference() -> i1 "None" { // CHECK: spv._reference_of @sc1 : i1 @@ -503,6 +505,14 @@ spv.module Logical GLSL450 { spv.ReturnValue %0 : i1 } + // CHECK-LABEL: @reference_composite + spv.func @reference_composite() -> i1 "None" { + // CHECK: spv._reference_of @scc : !spv.struct + %0 = spv._reference_of @scc : !spv.struct + %1 = spv.CompositeExtract %0[0 : i32] : !spv.struct + spv.ReturnValue %1 : i1 + } + // CHECK-LABEL: @initialize spv.func @initialize() -> i64 "None" { // CHECK: spv._reference_of @sc2 : i64 @@ -534,9 +544,21 @@ func @reference_of() { // ----- +spv.specConstant @sc = 5 : i32 +spv.specConstantComposite @scc (@sc) : !spv.array<1 x i32> + +func @reference_of_composite() { + // CHECK: spv._reference_of @scc : !spv.array<1 x i32> + %0 = spv._reference_of @scc : !spv.array<1 x i32> + %1 = spv.CompositeExtract %0[0 : i32] : !spv.array<1 x i32> + return +} + +// ----- + spv.module Logical GLSL450 { spv.func @foo() -> () "None" { - // expected-error @+1 {{expected spv.specConstant symbol}} + // expected-error @+1 {{expected spv.specConstant or spv.SpecConstantComposite symbol}} %0 = spv._reference_of @sc : i32 spv.Return } @@ -555,6 +577,18 @@ spv.module Logical GLSL450 { // ----- +spv.module Logical GLSL450 { + spv.specConstant @sc = 42 : i32 + spv.specConstantComposite @scc (@sc) : !spv.array<1 x i32> + spv.func @foo() -> () "None" { + // expected-error @+1 {{result type mismatch with the referenced specialization constant's type}} + %0 = spv._reference_of @scc : f32 + spv.Return + } +} + +// ----- + //===----------------------------------------------------------------------===// // spv.specConstant //===----------------------------------------------------------------------===//