Adds support for SPIR-V composite speciailization constants to spv._reference_of.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D88732
let summary = "Reference a specialization constant.";
let description = [{
- Specialization constant in module scope are defined using symbol names.
+ Specialization constants in module scope are defined using symbol names.
This op generates an SSA value that can be used to refer to the symbol
within function scope for use in ops that expect an SSA value.
This operation has no corresponding SPIR-V instruction; it's merely used
//===----------------------------------------------------------------------===//
static LogicalResult verify(spirv::ReferenceOfOp referenceOfOp) {
- auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(
- SymbolTable::lookupNearestSymbolFrom(referenceOfOp.getParentOp(),
- referenceOfOp.spec_const()));
- if (!specConstOp) {
- return referenceOfOp.emitOpError("expected spv.specConstant symbol");
- }
- if (referenceOfOp.reference().getType() !=
- specConstOp.default_value().getType()) {
+ auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
+ referenceOfOp.getParentOp(), referenceOfOp.spec_const());
+ Type constType;
+
+ auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
+ if (specConstOp)
+ constType = specConstOp.default_value().getType();
+
+ auto specConstCompositeOp =
+ dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
+ if (specConstCompositeOp)
+ constType = specConstCompositeOp.type();
+
+ if (!specConstOp && !specConstCompositeOp)
+ return referenceOfOp.emitOpError(
+ "expected spv.specConstant or spv.SpecConstantComposite symbol");
+
+ if (referenceOfOp.reference().getType() != constType)
return referenceOfOp.emitOpError("result type mismatch with the referenced "
"specialization constant's type");
- }
+
return success();
}
return specConstMap.lookup(id);
}
+ /// Gets the composite specialization constant with the given result <id>.
+ spirv::SpecConstantCompositeOp getSpecConstantComposite(uint32_t id) {
+ return specConstCompositeMap.lookup(id);
+ }
+
/// Creates a spirv::SpecConstantOp.
spirv::SpecConstantOp createSpecConstant(Location loc, uint32_t resultID,
Attribute defaultValue);
/// (and type) here. Later when it's used, we materialize the constant.
DenseMap<uint32_t, std::pair<Attribute, Type>> constantMap;
- // Result <id> to variable mapping.
+ // Result <id> to spec constant mapping.
DenseMap<uint32_t, spirv::SpecConstantOp> specConstMap;
+ // Result <id> to composite spec constant mapping.
+ DenseMap<uint32_t, spirv::SpecConstantCompositeOp> specConstCompositeMap;
+
// Result <id> to variable mapping.
DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap;
<< operands[0];
}
- auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(operands[1]));
+ auto resultID = operands[1];
+ auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
SmallVector<Attribute, 4> elements;
elements.reserve(operands.size() - 2);
elements.push_back(opBuilder.getSymbolRefAttr(elementInfo));
}
- opBuilder.create<spirv::SpecConstantCompositeOp>(
+ auto op = opBuilder.create<spirv::SpecConstantCompositeOp>(
unknownLoc, TypeAttr::get(resultType), symName,
opBuilder.getArrayAttr(elements));
+ specConstCompositeMap[resultID] = op;
return success();
}
opBuilder.getSymbolRefAttr(constOp.getOperation()));
return referenceOfOp.reference();
}
+ if (auto constCompositeOp = getSpecConstantComposite(id)) {
+ auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
+ unknownLoc, constCompositeOp.type(),
+ opBuilder.getSymbolRefAttr(constCompositeOp.getOperation()));
+ return referenceOfOp.reference();
+ }
if (auto undef = getUndefType(id)) {
return opBuilder.create<spirv::UndefOp>(unknownLoc, undef);
}
// CHECK: spv.specConstant @sc_float spec_id(5) = 1.000000e+00 : f32
spv.specConstant @sc_float spec_id(5) = 1. : f32
+ // CHECK: spv.specConstantComposite @scc (@sc_int, @sc_int) : !spv.array<2 x i32>
+ spv.specConstantComposite @scc (@sc_int, @sc_int) : !spv.array<2 x i32>
+
// CHECK-LABEL: @use
spv.func @use() -> (i32) "None" {
// We materialize a `spv._reference_of` op at every use of a
%1 = spv.IAdd %0, %0 : i32
spv.ReturnValue %1 : i32
}
+
+ // CHECK-LABEL: @use
+ spv.func @use_composite() -> (i32) "None" {
+ // We materialize a `spv._reference_of` op at every use of a
+ // specialization constant in the deserializer. So two ops here.
+ // CHECK: %[[USE1:.*]] = spv._reference_of @scc : !spv.array<2 x i32>
+ // CHECK: %[[ITM0:.*]] = spv.CompositeExtract %[[USE1]][0 : i32] : !spv.array<2 x i32>
+ // CHECK: %[[USE2:.*]] = spv._reference_of @scc : !spv.array<2 x i32>
+ // CHECK: %[[ITM1:.*]] = spv.CompositeExtract %[[USE2]][1 : i32] : !spv.array<2 x i32>
+ // CHECK: spv.IAdd %[[ITM0]], %[[ITM1]]
+
+ %0 = spv._reference_of @scc : !spv.array<2 x i32>
+ %1 = spv.CompositeExtract %0[0 : i32] : !spv.array<2 x i32>
+ %2 = spv.CompositeExtract %0[1 : i32] : !spv.array<2 x i32>
+ %3 = spv.IAdd %1, %2 : i32
+ spv.ReturnValue %3 : i32
+ }
+}
+
+// -----
+
+spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
+
+ spv.specConstant @sc_f32_1 = 1.5 : f32
+ spv.specConstant @sc_f32_2 = 2.5 : f32
+ spv.specConstant @sc_f32_3 = 3.5 : f32
+
+ spv.specConstant @sc_i32_1 = 1 : i32
+
+ // CHECK: spv.specConstantComposite @scc_array (@sc_f32_1, @sc_f32_2, @sc_f32_3) : !spv.array<3 x f32>
+ spv.specConstantComposite @scc_array (@sc_f32_1, @sc_f32_2, @sc_f32_3) : !spv.array<3 x f32>
+
+ // CHECK: spv.specConstantComposite @scc_struct (@sc_i32_1, @sc_f32_2, @sc_f32_3) : !spv.struct<i32, f32, f32>
+ spv.specConstantComposite @scc_struct (@sc_i32_1, @sc_f32_2, @sc_f32_3) : !spv.struct<i32, f32, f32>
+
+ // CHECK: spv.specConstantComposite @scc_vector (@sc_f32_1, @sc_f32_2, @sc_f32_3) : vector<3xf32>
+ spv.specConstantComposite @scc_vector (@sc_f32_1, @sc_f32_2, @sc_f32_3) : vector<3 x f32>
}
// -----
spv.specConstant @sc2 = 42 : i64
spv.specConstant @sc3 = 1.5 : f32
+ spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.struct<i1, i64, f32>
+
// CHECK-LABEL: @reference
spv.func @reference() -> i1 "None" {
// CHECK: spv._reference_of @sc1 : i1
spv.ReturnValue %0 : i1
}
+ // CHECK-LABEL: @reference_composite
+ spv.func @reference_composite() -> i1 "None" {
+ // CHECK: spv._reference_of @scc : !spv.struct<i1, i64, f32>
+ %0 = spv._reference_of @scc : !spv.struct<i1, i64, f32>
+ %1 = spv.CompositeExtract %0[0 : i32] : !spv.struct<i1, i64, f32>
+ spv.ReturnValue %1 : i1
+ }
+
// CHECK-LABEL: @initialize
spv.func @initialize() -> i64 "None" {
// CHECK: spv._reference_of @sc2 : i64
// -----
+spv.specConstant @sc = 5 : i32
+spv.specConstantComposite @scc (@sc) : !spv.array<1 x i32>
+
+func @reference_of_composite() {
+ // CHECK: spv._reference_of @scc : !spv.array<1 x i32>
+ %0 = spv._reference_of @scc : !spv.array<1 x i32>
+ %1 = spv.CompositeExtract %0[0 : i32] : !spv.array<1 x i32>
+ return
+}
+
+// -----
+
spv.module Logical GLSL450 {
spv.func @foo() -> () "None" {
- // expected-error @+1 {{expected spv.specConstant symbol}}
+ // expected-error @+1 {{expected spv.specConstant or spv.SpecConstantComposite symbol}}
%0 = spv._reference_of @sc : i32
spv.Return
}
// -----
+spv.module Logical GLSL450 {
+ spv.specConstant @sc = 42 : i32
+ spv.specConstantComposite @scc (@sc) : !spv.array<1 x i32>
+ spv.func @foo() -> () "None" {
+ // expected-error @+1 {{result type mismatch with the referenced specialization constant's type}}
+ %0 = spv._reference_of @scc : f32
+ spv.Return
+ }
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spv.specConstant
//===----------------------------------------------------------------------===//