[spirv] Add support for specialization constant
authorLei Zhang <antiagainst@google.com>
Thu, 1 Aug 2019 21:12:58 +0000 (14:12 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 1 Aug 2019 21:13:37 +0000 (14:13 -0700)
This CL extends the existing spv.constant op to also support
specialization constant by adding an extra unit attribute
on it.

PiperOrigin-RevId: 261194869

mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
mlir/test/Dialect/SPIRV/Serialization/spec_constant.mlir [new file with mode: 0644]

index 40251d6..1a722f8 100644 (file)
@@ -72,58 +72,62 @@ class SPV_OpCode<string name, int val> {
 
 // Begin opcode section. Generated from SPIR-V spec; DO NOT MODIFY!
 
-def SPV_OC_OpNop               : I32EnumAttrCase<"OpNop", 0>;
-def SPV_OC_OpName              : I32EnumAttrCase<"OpName", 5>;
-def SPV_OC_OpMemoryModel       : I32EnumAttrCase<"OpMemoryModel", 14>;
-def SPV_OC_OpEntryPoint        : I32EnumAttrCase<"OpEntryPoint", 15>;
-def SPV_OC_OpExecutionMode     : I32EnumAttrCase<"OpExecutionMode", 16>;
-def SPV_OC_OpTypeVoid          : I32EnumAttrCase<"OpTypeVoid", 19>;
-def SPV_OC_OpTypeBool          : I32EnumAttrCase<"OpTypeBool", 20>;
-def SPV_OC_OpTypeInt           : I32EnumAttrCase<"OpTypeInt", 21>;
-def SPV_OC_OpTypeFloat         : I32EnumAttrCase<"OpTypeFloat", 22>;
-def SPV_OC_OpTypeVector        : I32EnumAttrCase<"OpTypeVector", 23>;
-def SPV_OC_OpTypeArray         : I32EnumAttrCase<"OpTypeArray", 28>;
-def SPV_OC_OpTypePointer       : I32EnumAttrCase<"OpTypePointer", 32>;
-def SPV_OC_OpTypeFunction      : I32EnumAttrCase<"OpTypeFunction", 33>;
-def SPV_OC_OpConstantTrue      : I32EnumAttrCase<"OpConstantTrue", 41>;
-def SPV_OC_OpConstantFalse     : I32EnumAttrCase<"OpConstantFalse", 42>;
-def SPV_OC_OpConstant          : I32EnumAttrCase<"OpConstant", 43>;
-def SPV_OC_OpConstantComposite : I32EnumAttrCase<"OpConstantComposite", 44>;
-def SPV_OC_OpConstantNull      : I32EnumAttrCase<"OpConstantNull", 46>;
-def SPV_OC_OpFunction          : I32EnumAttrCase<"OpFunction", 54>;
-def SPV_OC_OpFunctionParameter : I32EnumAttrCase<"OpFunctionParameter", 55>;
-def SPV_OC_OpFunctionEnd       : I32EnumAttrCase<"OpFunctionEnd", 56>;
-def SPV_OC_OpVariable          : I32EnumAttrCase<"OpVariable", 59>;
-def SPV_OC_OpLoad              : I32EnumAttrCase<"OpLoad", 61>;
-def SPV_OC_OpStore             : I32EnumAttrCase<"OpStore", 62>;
-def SPV_OC_OpAccessChain       : I32EnumAttrCase<"OpAccessChain", 65>;
-def SPV_OC_OpDecorate          : I32EnumAttrCase<"OpDecorate", 71>;
-def SPV_OC_OpCompositeExtract  : I32EnumAttrCase<"OpCompositeExtract", 81>;
-def SPV_OC_OpIAdd              : I32EnumAttrCase<"OpIAdd", 128>;
-def SPV_OC_OpFAdd              : I32EnumAttrCase<"OpFAdd", 129>;
-def SPV_OC_OpISub              : I32EnumAttrCase<"OpISub", 130>;
-def SPV_OC_OpFSub              : I32EnumAttrCase<"OpFSub", 131>;
-def SPV_OC_OpIMul              : I32EnumAttrCase<"OpIMul", 132>;
-def SPV_OC_OpFMul              : I32EnumAttrCase<"OpFMul", 133>;
-def SPV_OC_OpUDiv              : I32EnumAttrCase<"OpUDiv", 134>;
-def SPV_OC_OpSDiv              : I32EnumAttrCase<"OpSDiv", 135>;
-def SPV_OC_OpFDiv              : I32EnumAttrCase<"OpFDiv", 136>;
-def SPV_OC_OpUMod              : I32EnumAttrCase<"OpUMod", 137>;
-def SPV_OC_OpSRem              : I32EnumAttrCase<"OpSRem", 138>;
-def SPV_OC_OpSMod              : I32EnumAttrCase<"OpSMod", 139>;
-def SPV_OC_OpFRem              : I32EnumAttrCase<"OpFRem", 140>;
-def SPV_OC_OpFMod              : I32EnumAttrCase<"OpFMod", 141>;
-def SPV_OC_OpIEqual            : I32EnumAttrCase<"OpIEqual", 170>;
-def SPV_OC_OpINotEqual         : I32EnumAttrCase<"OpINotEqual", 171>;
-def SPV_OC_OpUGreaterThan      : I32EnumAttrCase<"OpUGreaterThan", 172>;
-def SPV_OC_OpSGreaterThan      : I32EnumAttrCase<"OpSGreaterThan", 173>;
-def SPV_OC_OpUGreaterThanEqual : I32EnumAttrCase<"OpUGreaterThanEqual", 174>;
-def SPV_OC_OpSGreaterThanEqual : I32EnumAttrCase<"OpSGreaterThanEqual", 175>;
-def SPV_OC_OpULessThan         : I32EnumAttrCase<"OpULessThan", 176>;
-def SPV_OC_OpSLessThan         : I32EnumAttrCase<"OpSLessThan", 177>;
-def SPV_OC_OpULessThanEqual    : I32EnumAttrCase<"OpULessThanEqual", 178>;
-def SPV_OC_OpSLessThanEqual    : I32EnumAttrCase<"OpSLessThanEqual", 179>;
-def SPV_OC_OpReturn            : I32EnumAttrCase<"OpReturn", 253>;
+def SPV_OC_OpNop                   : I32EnumAttrCase<"OpNop", 0>;
+def SPV_OC_OpName                  : I32EnumAttrCase<"OpName", 5>;
+def SPV_OC_OpMemoryModel           : I32EnumAttrCase<"OpMemoryModel", 14>;
+def SPV_OC_OpEntryPoint            : I32EnumAttrCase<"OpEntryPoint", 15>;
+def SPV_OC_OpExecutionMode         : I32EnumAttrCase<"OpExecutionMode", 16>;
+def SPV_OC_OpTypeVoid              : I32EnumAttrCase<"OpTypeVoid", 19>;
+def SPV_OC_OpTypeBool              : I32EnumAttrCase<"OpTypeBool", 20>;
+def SPV_OC_OpTypeInt               : I32EnumAttrCase<"OpTypeInt", 21>;
+def SPV_OC_OpTypeFloat             : I32EnumAttrCase<"OpTypeFloat", 22>;
+def SPV_OC_OpTypeVector            : I32EnumAttrCase<"OpTypeVector", 23>;
+def SPV_OC_OpTypeArray             : I32EnumAttrCase<"OpTypeArray", 28>;
+def SPV_OC_OpTypePointer           : I32EnumAttrCase<"OpTypePointer", 32>;
+def SPV_OC_OpTypeFunction          : I32EnumAttrCase<"OpTypeFunction", 33>;
+def SPV_OC_OpConstantTrue          : I32EnumAttrCase<"OpConstantTrue", 41>;
+def SPV_OC_OpConstantFalse         : I32EnumAttrCase<"OpConstantFalse", 42>;
+def SPV_OC_OpConstant              : I32EnumAttrCase<"OpConstant", 43>;
+def SPV_OC_OpConstantComposite     : I32EnumAttrCase<"OpConstantComposite", 44>;
+def SPV_OC_OpConstantNull          : I32EnumAttrCase<"OpConstantNull", 46>;
+def SPV_OC_OpSpecConstantTrue      : I32EnumAttrCase<"OpSpecConstantTrue", 48>;
+def SPV_OC_OpSpecConstantFalse     : I32EnumAttrCase<"OpSpecConstantFalse", 49>;
+def SPV_OC_OpSpecConstant          : I32EnumAttrCase<"OpSpecConstant", 50>;
+def SPV_OC_OpSpecConstantComposite : I32EnumAttrCase<"OpSpecConstantComposite", 51>;
+def SPV_OC_OpFunction              : I32EnumAttrCase<"OpFunction", 54>;
+def SPV_OC_OpFunctionParameter     : I32EnumAttrCase<"OpFunctionParameter", 55>;
+def SPV_OC_OpFunctionEnd           : I32EnumAttrCase<"OpFunctionEnd", 56>;
+def SPV_OC_OpVariable              : I32EnumAttrCase<"OpVariable", 59>;
+def SPV_OC_OpLoad                  : I32EnumAttrCase<"OpLoad", 61>;
+def SPV_OC_OpStore                 : I32EnumAttrCase<"OpStore", 62>;
+def SPV_OC_OpAccessChain           : I32EnumAttrCase<"OpAccessChain", 65>;
+def SPV_OC_OpDecorate              : I32EnumAttrCase<"OpDecorate", 71>;
+def SPV_OC_OpCompositeExtract      : I32EnumAttrCase<"OpCompositeExtract", 81>;
+def SPV_OC_OpIAdd                  : I32EnumAttrCase<"OpIAdd", 128>;
+def SPV_OC_OpFAdd                  : I32EnumAttrCase<"OpFAdd", 129>;
+def SPV_OC_OpISub                  : I32EnumAttrCase<"OpISub", 130>;
+def SPV_OC_OpFSub                  : I32EnumAttrCase<"OpFSub", 131>;
+def SPV_OC_OpIMul                  : I32EnumAttrCase<"OpIMul", 132>;
+def SPV_OC_OpFMul                  : I32EnumAttrCase<"OpFMul", 133>;
+def SPV_OC_OpUDiv                  : I32EnumAttrCase<"OpUDiv", 134>;
+def SPV_OC_OpSDiv                  : I32EnumAttrCase<"OpSDiv", 135>;
+def SPV_OC_OpFDiv                  : I32EnumAttrCase<"OpFDiv", 136>;
+def SPV_OC_OpUMod                  : I32EnumAttrCase<"OpUMod", 137>;
+def SPV_OC_OpSRem                  : I32EnumAttrCase<"OpSRem", 138>;
+def SPV_OC_OpSMod                  : I32EnumAttrCase<"OpSMod", 139>;
+def SPV_OC_OpFRem                  : I32EnumAttrCase<"OpFRem", 140>;
+def SPV_OC_OpFMod                  : I32EnumAttrCase<"OpFMod", 141>;
+def SPV_OC_OpIEqual                : I32EnumAttrCase<"OpIEqual", 170>;
+def SPV_OC_OpINotEqual             : I32EnumAttrCase<"OpINotEqual", 171>;
+def SPV_OC_OpUGreaterThan          : I32EnumAttrCase<"OpUGreaterThan", 172>;
+def SPV_OC_OpSGreaterThan          : I32EnumAttrCase<"OpSGreaterThan", 173>;
+def SPV_OC_OpUGreaterThanEqual     : I32EnumAttrCase<"OpUGreaterThanEqual", 174>;
+def SPV_OC_OpSGreaterThanEqual     : I32EnumAttrCase<"OpSGreaterThanEqual", 175>;
+def SPV_OC_OpULessThan             : I32EnumAttrCase<"OpULessThan", 176>;
+def SPV_OC_OpSLessThan             : I32EnumAttrCase<"OpSLessThan", 177>;
+def SPV_OC_OpULessThanEqual        : I32EnumAttrCase<"OpULessThanEqual", 178>;
+def SPV_OC_OpSLessThanEqual        : I32EnumAttrCase<"OpSLessThanEqual", 179>;
+def SPV_OC_OpReturn                : I32EnumAttrCase<"OpReturn", 253>;
 
 def SPV_OpcodeAttr :
     I32EnumAttr<"Opcode", "valid SPIR-V instructions", [
@@ -132,14 +136,15 @@ def SPV_OpcodeAttr :
       SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector, SPV_OC_OpTypeArray,
       SPV_OC_OpTypePointer, SPV_OC_OpTypeFunction, SPV_OC_OpConstantTrue,
       SPV_OC_OpConstantFalse, SPV_OC_OpConstant, SPV_OC_OpConstantComposite,
-      SPV_OC_OpConstantNull, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter,
-      SPV_OC_OpFunctionEnd, SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore,
-      SPV_OC_OpAccessChain, SPV_OC_OpDecorate, SPV_OC_OpCompositeExtract,
-      SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul,
-      SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod,
-      SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpIEqual,
-      SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
-      SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
+      SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse,
+      SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction,
+      SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpVariable,
+      SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain, SPV_OC_OpDecorate,
+      SPV_OC_OpCompositeExtract, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub,
+      SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv,
+      SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem,
+      SPV_OC_OpFMod, SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
+      SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
       SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual,
       SPV_OC_OpSLessThanEqual, SPV_OC_OpReturn
       ]> {
index 5cf8e13..054da98 100644 (file)
@@ -152,23 +152,24 @@ def SPV_ConstantOp : SPV_Op<"constant", [NoSideEffect]> {
     ### Custom assembly form
 
     ``` {.ebnf}
-    spv-constant-op ::= ssa-id `=` `spv.constant` attribute-value
+    spv-constant-op ::= ssa-id `=` `spv.constant` (`spec`)? attribute-value
                         (`:` spirv-type)?
     ```
 
     For example:
 
     ```
-    %0 = spv.constant true
-    %1 = spv.constant dense<vector<2xf32>, [2, 3]>
-    %2 = spv.constant [dense<vector<2xf32>, 3.0>] : !spv.array<1xvector<2xf32>>
+    %0 = spv.constant spec true
+    %1 = spv.constant dense<[2, 3]> : vector<2xf32>
+    %2 = spv.constant [dense<3.0> : vector<2xf32>] : !spv.array<1xvector<2xf32>>
     ```
 
     TODO(antiagainst): support constant structs
   }];
 
   let arguments = (ins
-    AnyAttr:$value
+    AnyAttr:$value,
+    UnitAttr:$is_spec_const
   );
 
   let results = (outs
index 5e62db4..2e59809 100644 (file)
@@ -33,6 +33,7 @@ using namespace mlir;
 // TODO(antiagainst): generate these strings using ODS.
 static constexpr const char kAlignmentAttrName[] = "alignment";
 static constexpr const char kIndicesAttrName[] = "indices";
+static constexpr const char kIsSpecConstName[] = "is_spec_const";
 static constexpr const char kValueAttrName[] = "value";
 static constexpr const char kValuesAttrName[] = "values";
 static constexpr const char kFnNameAttrName[] = "fn";
@@ -466,6 +467,9 @@ static LogicalResult verify(spirv::CompositeExtractOp compExOp) {
 //===----------------------------------------------------------------------===//
 
 static ParseResult parseConstantOp(OpAsmParser *parser, OperationState *state) {
+  if (succeeded(parser->parseOptionalKeyword("spec")))
+    state->addAttribute(kIsSpecConstName, parser->getBuilder().getUnitAttr());
+
   Attribute value;
   if (parser->parseAttribute(value, kValueAttrName, state->attributes))
     return failure();
@@ -482,7 +486,8 @@ static ParseResult parseConstantOp(OpAsmParser *parser, OperationState *state) {
 }
 
 static void print(spirv::ConstantOp constOp, OpAsmPrinter *printer) {
-  *printer << spirv::ConstantOp::getOperationName() << " " << constOp.value();
+  *printer << spirv::ConstantOp::getOperationName()
+           << (constOp.is_spec_const() ? " spec " : " ") << constOp.value();
   if (constOp.getType().isa<spirv::ArrayType>()) {
     *printer << " : " << constOp.getType();
   }
index 2ca8f45..2aa3d5e 100644 (file)
@@ -115,16 +115,20 @@ private:
   // Constant
   //===--------------------------------------------------------------------===//
 
-  /// Processes a SPIR-V OpConstant instruction with the given `operands`.
-  LogicalResult processConstant(ArrayRef<uint32_t> operands);
+  /// Processes a SPIR-V Op{|Spec}Constant instruction with the given
+  /// `operands`. `isSpec` indicates whether this is a specialization constant.
+  LogicalResult processConstant(ArrayRef<uint32_t> operands, bool isSpec);
 
-  /// Processes a SPIR-V OpConstant{True|False} instruction with the given
-  /// `operands`.
-  LogicalResult processConstantBool(bool isTrue, ArrayRef<uint32_t> operands);
+  /// Processes a SPIR-V Op{|Spec}Constant{True|False} instruction with the
+  /// given `operands`. `isSpec` indicates whether this is a specialization
+  /// constant.
+  LogicalResult processConstantBool(bool isTrue, ArrayRef<uint32_t> operands,
+                                    bool isSpec);
 
-  /// Processes a SPIR-V OpConstantComposite instruction with the given
-  /// `operands`.
-  LogicalResult processConstantComposite(ArrayRef<uint32_t> operands);
+  /// Processes a SPIR-V Op{|Spec}ConstantComposite instruction with the given
+  /// `operands`. `isSpec` indicates whether this is a specialization constant.
+  LogicalResult processConstantComposite(ArrayRef<uint32_t> operands,
+                                         bool isSpec);
 
   /// Processes a SPIR-V OpConstantNull instruction with the given `operands`.
   LogicalResult processConstantNull(ArrayRef<uint32_t> operands);
@@ -610,14 +614,17 @@ LogicalResult Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
 // Constant
 //===----------------------------------------------------------------------===//
 
-LogicalResult Deserializer::processConstant(ArrayRef<uint32_t> operands) {
+LogicalResult Deserializer::processConstant(ArrayRef<uint32_t> operands,
+                                            bool isSpec) {
+  StringRef opname = isSpec ? "OpSpecConstant" : "OpConstant";
+
   if (operands.size() < 2) {
-    return emitError(unknownLoc,
-                     "OpConstant must have type <id> and result <id>");
+    return emitError(unknownLoc)
+           << opname << " must have type <id> and result <id>";
   }
   if (operands.size() < 3) {
-    return emitError(unknownLoc,
-                     "OpConstant must have at least 1 more parameter");
+    return emitError(unknownLoc)
+           << opname << " must have at least 1 more parameter";
   }
 
   Type resultType = getType(operands[0]);
@@ -631,22 +638,24 @@ LogicalResult Deserializer::processConstant(ArrayRef<uint32_t> operands) {
       if (operands.size() == 4) {
         return success();
       }
-      return emitError(unknownLoc,
-                       "OpConstant should have 2 parameters for 64-bit values");
+      return emitError(unknownLoc)
+             << opname << " should have 2 parameters for 64-bit values";
     }
     if (bitwidth <= 32) {
       if (operands.size() == 3) {
         return success();
       }
 
-      return emitError(unknownLoc, "OpConstant should have 1 parameter for "
-                                   "values with no more than 32 bits");
+      return emitError(unknownLoc)
+             << opname
+             << " should have 1 parameter for values with no more than 32 bits";
     }
     return emitError(unknownLoc, "unsupported OpConstant bitwidth: ")
            << bitwidth;
   };
 
   spirv::ConstantOp op;
+  UnitAttr isSpecConst = isSpec ? opBuilder.getUnitAttr() : UnitAttr();
   if (auto intType = resultType.dyn_cast<IntegerType>()) {
     auto bitwidth = intType.getWidth();
     if (failed(checkOperandSizeForBitwidth(bitwidth))) {
@@ -668,7 +677,8 @@ LogicalResult Deserializer::processConstant(ArrayRef<uint32_t> operands) {
     }
 
     auto attr = opBuilder.getIntegerAttr(intType, value);
-    op = opBuilder.create<spirv::ConstantOp>(unknownLoc, intType, attr);
+    op = opBuilder.create<spirv::ConstantOp>(unknownLoc, intType, attr,
+                                             isSpecConst);
   } else if (auto floatType = resultType.dyn_cast<FloatType>()) {
     auto bitwidth = floatType.getWidth();
     if (failed(checkOperandSizeForBitwidth(bitwidth))) {
@@ -693,7 +703,8 @@ LogicalResult Deserializer::processConstant(ArrayRef<uint32_t> operands) {
     }
 
     auto attr = opBuilder.getFloatAttr(floatType, value);
-    op = opBuilder.create<spirv::ConstantOp>(unknownLoc, floatType, attr);
+    op = opBuilder.create<spirv::ConstantOp>(unknownLoc, floatType, attr,
+                                             isSpecConst);
   } else {
     return emitError(unknownLoc, "OpConstant can only generate values of "
                                  "scalar integer or floating-point type");
@@ -704,23 +715,27 @@ LogicalResult Deserializer::processConstant(ArrayRef<uint32_t> operands) {
 }
 
 LogicalResult Deserializer::processConstantBool(bool isTrue,
-                                                ArrayRef<uint32_t> operands) {
+                                                ArrayRef<uint32_t> operands,
+                                                bool isSpec) {
   if (operands.size() != 2) {
-    return emitError(unknownLoc, "OpConstant")
+    return emitError(unknownLoc, "Op")
+           << (isSpec ? "Spec" : "") << "Constant"
            << (isTrue ? "True" : "False")
            << " must have type <id> and result <id>";
   }
 
   auto attr = opBuilder.getBoolAttr(isTrue);
-  auto op = opBuilder.create<spirv::ConstantOp>(unknownLoc,
-                                                opBuilder.getI1Type(), attr);
+  UnitAttr isSpecConst = isSpec ? opBuilder.getUnitAttr() : UnitAttr();
+  auto op = opBuilder.create<spirv::ConstantOp>(
+      unknownLoc, opBuilder.getI1Type(), attr, isSpecConst);
 
   valueMap[operands[1]] = op.getResult();
   return success();
 }
 
 LogicalResult
-Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
+Deserializer::processConstantComposite(ArrayRef<uint32_t> operands,
+                                       bool isSpec) {
   if (operands.size() < 2) {
     return emitError(unknownLoc,
                      "OpConstantComposite must have type <id> and result <id>");
@@ -757,12 +772,15 @@ Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
   }
 
   spirv::ConstantOp op;
+  UnitAttr isSpecConst = isSpec ? opBuilder.getUnitAttr() : UnitAttr();
   if (auto vectorType = resultType.dyn_cast<VectorType>()) {
     auto attr = opBuilder.getDenseElementsAttr(vectorType, elements);
-    op = opBuilder.create<spirv::ConstantOp>(unknownLoc, resultType, attr);
+    op = opBuilder.create<spirv::ConstantOp>(unknownLoc, resultType, attr,
+                                             isSpecConst);
   } else if (auto arrayType = resultType.dyn_cast<spirv::ArrayType>()) {
     auto attr = opBuilder.getArrayAttr(elements);
-    op = opBuilder.create<spirv::ConstantOp>(unknownLoc, resultType, attr);
+    op = opBuilder.create<spirv::ConstantOp>(unknownLoc, resultType, attr,
+                                             isSpecConst);
   } else {
     return emitError(unknownLoc, "unsupported OpConstantComposite type: ")
            << resultType;
@@ -788,7 +806,9 @@ LogicalResult Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
   if (resultType.isa<IntegerType>() || resultType.isa<FloatType>() ||
       resultType.isa<VectorType>()) {
     auto attr = opBuilder.getZeroAttr(resultType);
-    op = opBuilder.create<spirv::ConstantOp>(unknownLoc, resultType, attr);
+    UnitAttr isSpecConst;
+    op = opBuilder.create<spirv::ConstantOp>(unknownLoc, resultType, attr,
+                                             isSpecConst);
   } else {
     return emitError(unknownLoc, "unsupported OpConstantNull type: ")
            << resultType;
@@ -859,13 +879,21 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
   case spirv::Opcode::OpTypePointer:
     return processType(opcode, operands);
   case spirv::Opcode::OpConstant:
-    return processConstant(operands);
+    return processConstant(operands, /*isSpec=*/false);
+  case spirv::Opcode::OpSpecConstant:
+    return processConstant(operands, /*isSpec=*/true);
   case spirv::Opcode::OpConstantComposite:
-    return processConstantComposite(operands);
+    return processConstantComposite(operands, /*isSpec=*/false);
+  case spirv::Opcode::OpSpecConstantComposite:
+    return processConstantComposite(operands, /*isSpec=*/true);
   case spirv::Opcode::OpConstantTrue:
-    return processConstantBool(true, operands);
+    return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false);
+  case spirv::Opcode::OpSpecConstantTrue:
+    return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/true);
   case spirv::Opcode::OpConstantFalse:
-    return processConstantBool(false, operands);
+    return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/false);
+  case spirv::Opcode::OpSpecConstantFalse:
+    return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true);
   case spirv::Opcode::OpConstantNull:
     return processConstantNull(operands);
   case spirv::Opcode::OpDecorate:
index 35c4088..188b08d 100644 (file)
@@ -168,15 +168,17 @@ private:
   /// and `valueAttr`. `constType` is needed here because we can interpret the
   /// `valueAttr` as a different type than the type of `valueAttr` itself; for
   /// example, ArrayAttr, whose type is NoneType, is used for spirv::ArrayType
-  /// constants.
-  uint32_t prepareConstant(Location loc, Type constType, Attribute valueAttr);
+  /// constants. If `isSpec` is true, then the constant will be serialized as
+  /// a specialization constant.
+  uint32_t prepareConstant(Location loc, Type constType, Attribute valueAttr,
+                           bool isSpec);
 
   /// Prepares bool ElementsAttr serialization. This method updates `opcode`
   /// with a proper OpConstant* instruction and pushes literal values for the
   /// constant to `operands`.
   LogicalResult prepareBoolVectorConstant(Location loc,
                                           DenseIntElementsAttr elementsAttr,
-                                          spirv::Opcode &opcode,
+                                          bool isSpec, spirv::Opcode &opcode,
                                           SmallVectorImpl<uint32_t> &operands);
 
   /// Prepares int ElementsAttr serialization. This method updates `opcode` with
@@ -184,7 +186,7 @@ private:
   /// constant to `operands`.
   LogicalResult prepareIntVectorConstant(Location loc,
                                          DenseIntElementsAttr elementsAttr,
-                                         spirv::Opcode &opcode,
+                                         bool isSpec, spirv::Opcode &opcode,
                                          SmallVectorImpl<uint32_t> &operands);
 
   /// Prepares float ElementsAttr serialization. This method updates `opcode`
@@ -192,14 +194,14 @@ private:
   /// constant to `operands`.
   LogicalResult prepareFloatVectorConstant(Location loc,
                                            DenseFPElementsAttr elementsAttr,
-                                           spirv::Opcode &opcode,
+                                           bool isSpec, spirv::Opcode &opcode,
                                            SmallVectorImpl<uint32_t> &operands);
 
-  uint32_t prepareConstantBool(Location loc, BoolAttr boolAttr);
+  uint32_t prepareConstantBool(Location loc, BoolAttr boolAttr, bool isSpec);
 
-  uint32_t prepareConstantInt(Location loc, IntegerAttr intAttr);
+  uint32_t prepareConstantInt(Location loc, IntegerAttr intAttr, bool isSpec);
 
-  uint32_t prepareConstantFp(Location loc, FloatAttr floatAttr);
+  uint32_t prepareConstantFp(Location loc, FloatAttr floatAttr, bool isSpec);
 
   //===--------------------------------------------------------------------===//
   // Operations
@@ -317,7 +319,8 @@ LogicalResult Serializer::processMemoryModel() {
 }
 
 LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
-  if (auto resultID = prepareConstant(op.getLoc(), op.getType(), op.value())) {
+  if (auto resultID = prepareConstant(op.getLoc(), op.getType(), op.value(),
+                                      op.is_spec_const())) {
     valueIDMap[op.getResult()] = resultID;
     return success();
   }
@@ -484,7 +487,8 @@ Serializer::prepareBasicType(Location loc, Type type, spirv::Opcode &typeEnum,
     }
     operands.push_back(elementTypeID);
     if (auto elementCountID = prepareConstantInt(
-            loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) {
+            loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()),
+            /*isSpec=*/false)) {
       operands.push_back(elementCountID);
       return success();
     }
@@ -535,15 +539,15 @@ Serializer::prepareFunctionType(Location loc, FunctionType type,
 //===----------------------------------------------------------------------===//
 
 uint32_t Serializer::prepareConstant(Location loc, Type constType,
-                                     Attribute valueAttr) {
+                                     Attribute valueAttr, bool isSpec) {
   if (auto floatAttr = valueAttr.dyn_cast<FloatAttr>()) {
-    return prepareConstantFp(loc, floatAttr);
+    return prepareConstantFp(loc, floatAttr, isSpec);
   }
   if (auto intAttr = valueAttr.dyn_cast<IntegerAttr>()) {
-    return prepareConstantInt(loc, intAttr);
+    return prepareConstantInt(loc, intAttr, isSpec);
   }
   if (auto boolAttr = valueAttr.dyn_cast<BoolAttr>()) {
-    return prepareConstantBool(loc, boolAttr);
+    return prepareConstantBool(loc, boolAttr, isSpec);
   }
 
   // This is a composite literal. We need to handle each component separately
@@ -566,21 +570,25 @@ uint32_t Serializer::prepareConstant(Location loc, Type constType,
 
   if (auto vectorAttr = valueAttr.dyn_cast<DenseIntElementsAttr>()) {
     if (vectorAttr.getType().getElementType().isInteger(1)) {
-      if (failed(prepareBoolVectorConstant(loc, vectorAttr, opcode, operands)))
+      if (failed(prepareBoolVectorConstant(loc, vectorAttr, isSpec, opcode,
+                                           operands)))
         return 0;
-    } else if (failed(
-                   prepareIntVectorConstant(loc, vectorAttr, opcode, operands)))
+    } else if (failed(prepareIntVectorConstant(loc, vectorAttr, isSpec, opcode,
+                                               operands)))
       return 0;
   } else if (auto vectorAttr = valueAttr.dyn_cast<DenseFPElementsAttr>()) {
-    if (failed(prepareFloatVectorConstant(loc, vectorAttr, opcode, operands)))
+    if (failed(prepareFloatVectorConstant(loc, vectorAttr, isSpec, opcode,
+                                          operands)))
       return 0;
   } else if (auto arrayAttr = valueAttr.dyn_cast<ArrayAttr>()) {
-    opcode = spirv::Opcode::OpConstantComposite;
+    opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
+                    : spirv::Opcode::OpConstantComposite;
     operands.reserve(arrayAttr.size() + 2);
 
     auto elementType = constType.cast<spirv::ArrayType>().getElementType();
     for (Attribute elementAttr : arrayAttr)
-      if (auto elementID = prepareConstant(loc, elementType, elementAttr)) {
+      if (auto elementID =
+              prepareConstant(loc, elementType, elementAttr, isSpec)) {
         operands.push_back(elementID);
       } else {
         return 0;
@@ -596,8 +604,8 @@ uint32_t Serializer::prepareConstant(Location loc, Type constType,
 }
 
 LogicalResult Serializer::prepareBoolVectorConstant(
-    Location loc, DenseIntElementsAttr elementsAttr, spirv::Opcode &opcode,
-    SmallVectorImpl<uint32_t> &operands) {
+    Location loc, DenseIntElementsAttr elementsAttr, bool isSpec,
+    spirv::Opcode &opcode, SmallVectorImpl<uint32_t> &operands) {
   auto type = elementsAttr.getType();
   assert(type.hasRank() && type.getRank() == 1 &&
          "spv.constant should have verified only vector literal uses "
@@ -612,13 +620,15 @@ LogicalResult Serializer::prepareBoolVectorConstant(
   // the splat value is zero.
   if (Attribute splatAttr = elementsAttr.getSplatValue()) {
     // We can use OpConstantNull if this bool ElementsAttr is splatting false.
-    if (!splatAttr.cast<BoolAttr>().getValue()) {
+    if (!isSpec && !splatAttr.cast<BoolAttr>().getValue()) {
       opcode = spirv::Opcode::OpConstantNull;
       return success();
     }
 
-    if (auto id = prepareConstantBool(loc, splatAttr.cast<BoolAttr>())) {
-      opcode = spirv::Opcode::OpConstantComposite;
+    if (auto id =
+            prepareConstantBool(loc, splatAttr.cast<BoolAttr>(), isSpec)) {
+      opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
+                      : spirv::Opcode::OpConstantComposite;
       operands.append(count, id);
       return success();
     }
@@ -628,13 +638,14 @@ LogicalResult Serializer::prepareBoolVectorConstant(
 
   // Otherwise, we need to process each element and compose them with
   // OpConstantComposite.
-  opcode = spirv::Opcode::OpConstantComposite;
+  opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
+                  : spirv::Opcode::OpConstantComposite;
   for (APInt intValue : elementsAttr) {
     // We are constructing an BoolAttr for each APInt here. But given that
     // we only use ElementsAttr for vectors with no more than 4 elements, it
     // should be fine here.
     auto boolAttr = mlirBuilder.getBoolAttr(intValue.isOneValue());
-    if (auto elementID = prepareConstantBool(loc, boolAttr)) {
+    if (auto elementID = prepareConstantBool(loc, boolAttr, isSpec)) {
       operands.push_back(elementID);
     } else {
       return failure();
@@ -644,8 +655,8 @@ LogicalResult Serializer::prepareBoolVectorConstant(
 }
 
 LogicalResult Serializer::prepareIntVectorConstant(
-    Location loc, DenseIntElementsAttr elementsAttr, spirv::Opcode &opcode,
-    SmallVectorImpl<uint32_t> &operands) {
+    Location loc, DenseIntElementsAttr elementsAttr, bool isSpec,
+    spirv::Opcode &opcode, SmallVectorImpl<uint32_t> &operands) {
   auto type = elementsAttr.getType();
   assert(type.hasRank() && type.getRank() == 1 &&
          "spv.constant should have verified only vector literal uses "
@@ -661,13 +672,15 @@ LogicalResult Serializer::prepareIntVectorConstant(
   // the splat value is zero.
   if (Attribute splatAttr = elementsAttr.getSplatValue()) {
     // We can use OpConstantNull if this int ElementsAttr is splatting 0.
-    if (splatAttr.cast<IntegerAttr>().getValue().isNullValue()) {
+    if (!isSpec && splatAttr.cast<IntegerAttr>().getValue().isNullValue()) {
       opcode = spirv::Opcode::OpConstantNull;
       return success();
     }
 
-    if (auto id = prepareConstantInt(loc, splatAttr.cast<IntegerAttr>())) {
-      opcode = spirv::Opcode::OpConstantComposite;
+    if (auto id =
+            prepareConstantInt(loc, splatAttr.cast<IntegerAttr>(), isSpec)) {
+      opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
+                      : spirv::Opcode::OpConstantComposite;
       operands.append(count, id);
       return success();
     }
@@ -676,7 +689,8 @@ LogicalResult Serializer::prepareIntVectorConstant(
 
   // Otherwise, we need to process each element and compose them with
   // OpConstantComposite.
-  opcode = spirv::Opcode::OpConstantComposite;
+  opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
+                  : spirv::Opcode::OpConstantComposite;
   for (APInt intValue : elementsAttr) {
     // We are constructing an IntegerAttr for each APInt here. But given that
     // we only use ElementsAttr for vectors with no more than 4 elements, it
@@ -684,7 +698,7 @@ LogicalResult Serializer::prepareIntVectorConstant(
     // TODO(antiagainst): revisit this if special extensions enabling large
     // vectors are supported.
     auto intAttr = mlirBuilder.getIntegerAttr(elementType, intValue);
-    if (auto elementID = prepareConstantInt(loc, intAttr)) {
+    if (auto elementID = prepareConstantInt(loc, intAttr, isSpec)) {
       operands.push_back(elementID);
     } else {
       return failure();
@@ -694,8 +708,8 @@ LogicalResult Serializer::prepareIntVectorConstant(
 }
 
 LogicalResult Serializer::prepareFloatVectorConstant(
-    Location loc, DenseFPElementsAttr elementsAttr, spirv::Opcode &opcode,
-    SmallVectorImpl<uint32_t> &operands) {
+    Location loc, DenseFPElementsAttr elementsAttr, bool isSpec,
+    spirv::Opcode &opcode, SmallVectorImpl<uint32_t> &operands) {
   auto type = elementsAttr.getType();
   assert(type.hasRank() && type.getRank() == 1 &&
          "spv.constant should have verified only vector literal uses "
@@ -706,13 +720,14 @@ LogicalResult Serializer::prepareFloatVectorConstant(
   operands.reserve(count + 2);
 
   if (Attribute splatAttr = elementsAttr.getSplatValue()) {
-    if (splatAttr.cast<FloatAttr>().getValue().isZero()) {
+    if (!isSpec && splatAttr.cast<FloatAttr>().getValue().isZero()) {
       opcode = spirv::Opcode::OpConstantNull;
       return success();
     }
 
-    if (auto id = prepareConstantFp(loc, splatAttr.cast<FloatAttr>())) {
-      opcode = spirv::Opcode::OpConstantComposite;
+    if (auto id = prepareConstantFp(loc, splatAttr.cast<FloatAttr>(), isSpec)) {
+      opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
+                      : spirv::Opcode::OpConstantComposite;
       operands.append(count, id);
       return success();
     }
@@ -720,10 +735,11 @@ LogicalResult Serializer::prepareFloatVectorConstant(
     return failure();
   }
 
-  opcode = spirv::Opcode::OpConstantComposite;
+  opcode = isSpec ? spirv::Opcode::OpSpecConstantComposite
+                  : spirv::Opcode::OpConstantComposite;
   for (APFloat floatValue : elementsAttr) {
     auto fpAttr = mlirBuilder.getFloatAttr(elementType, floatValue);
-    if (auto elementID = prepareConstantFp(loc, fpAttr)) {
+    if (auto elementID = prepareConstantFp(loc, fpAttr, isSpec)) {
       operands.push_back(elementID);
     } else {
       return failure();
@@ -732,7 +748,8 @@ LogicalResult Serializer::prepareFloatVectorConstant(
   return success();
 }
 
-uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr) {
+uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr,
+                                         bool isSpec) {
   if (auto id = findConstantID(boolAttr)) {
     return id;
   }
@@ -744,14 +761,18 @@ uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr) {
   }
 
   auto resultID = getNextID();
-  auto opcode = boolAttr.getValue() ? spirv::Opcode::OpConstantTrue
-                                    : spirv::Opcode::OpConstantFalse;
+  auto opcode = boolAttr.getValue()
+                    ? (isSpec ? spirv::Opcode::OpSpecConstantTrue
+                              : spirv::Opcode::OpConstantTrue)
+                    : (isSpec ? spirv::Opcode::OpSpecConstantFalse
+                              : spirv::Opcode::OpConstantFalse);
   encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID});
 
   return constIDMap[boolAttr] = resultID;
 }
 
-uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr) {
+uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr,
+                                        bool isSpec) {
   if (auto id = findConstantID(intAttr)) {
     return id;
   }
@@ -767,6 +788,9 @@ uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr) {
   unsigned bitwidth = value.getBitWidth();
   bool isSigned = value.isSignedIntN(bitwidth);
 
+  auto opcode =
+      isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
+
   // According to SPIR-V spec, "When the type's bit width is less than 32-bits,
   // the literal's value appears in the low-order bits of the word, and the
   // high-order bits must be 0 for a floating-point type, or 0 for an integer
@@ -778,8 +802,7 @@ uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr) {
     } else {
       word = static_cast<uint32_t>(value.getZExtValue());
     }
-    encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstant,
-                          {typeID, resultID, word});
+    encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
   }
   // According to SPIR-V spec: "When the type's bit width is larger than one
   // word, the literal’s low-order words appear first."
@@ -793,7 +816,7 @@ uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr) {
     } else {
       words = llvm::bit_cast<DoubleWord>(value.getZExtValue());
     }
-    encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstant,
+    encodeInstructionInto(typesGlobalValues, opcode,
                           {typeID, resultID, words.word1, words.word2});
   } else {
     std::string valueStr;
@@ -808,7 +831,8 @@ uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr) {
   return constIDMap[intAttr] = resultID;
 }
 
-uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr) {
+uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
+                                       bool isSpec) {
   if (auto id = findConstantID(floatAttr)) {
     return id;
   }
@@ -823,22 +847,23 @@ uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr) {
   APFloat value = floatAttr.getValue();
   APInt intValue = value.bitcastToAPInt();
 
+  auto opcode =
+      isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
+
   if (&value.getSemantics() == &APFloat::IEEEsingle()) {
     uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat());
-    encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstant,
-                          {typeID, resultID, word});
+    encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
   } else if (&value.getSemantics() == &APFloat::IEEEdouble()) {
     struct DoubleWord {
       uint32_t word1;
       uint32_t word2;
     } words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
-    encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstant,
+    encodeInstructionInto(typesGlobalValues, opcode,
                           {typeID, resultID, words.word1, words.word2});
   } else if (&value.getSemantics() == &APFloat::IEEEhalf()) {
     uint32_t word =
         static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue());
-    encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstant,
-                          {typeID, resultID, word});
+    encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
   } else {
     std::string valueStr;
     llvm::raw_string_ostream rss(valueStr);
diff --git a/mlir/test/Dialect/SPIRV/Serialization/spec_constant.mlir b/mlir/test/Dialect/SPIRV/Serialization/spec_constant.mlir
new file mode 100644 (file)
index 0000000..87f1b44
--- /dev/null
@@ -0,0 +1,47 @@
+// RUN: mlir-translate -serialize-spirv %s | mlir-translate -deserialize-spirv | FileCheck %s
+
+func @spirv_module() -> () {
+  spv.module "Logical" "GLSL450" {
+
+    // CHECK: spv.constant spec true
+    %0 = spv.constant spec true
+    // CHECK: spv.constant spec false
+    %1 = spv.constant spec false
+
+    // CHECK: spv.constant spec -5 : i32
+    %2 = spv.constant spec -5 : i32
+
+    // CHECK: spv.constant spec 1.000000e+00 : f32
+    %3 = spv.constant spec 1. : f32
+
+    // Bool vector
+    // CHECK: spv.constant spec dense<false> : vector<2xi1>
+    %4 = spv.constant spec dense<false> : vector<2xi1>
+    // CHECK: spv.constant spec dense<[true, true, true]> : vector<3xi1>
+    %5 = spv.constant spec dense<true> : vector<3xi1>
+    // CHECK: spv.constant spec dense<[false, true]> : vector<2xi1>
+    %6 = spv.constant spec dense<[false, true]> : vector<2xi1>
+
+    // Integer vector
+    // CHECK: spv.constant spec dense<0> : vector<2xi32>
+    %7 = spv.constant spec dense<0> : vector<2xi32>
+    // CHECK: spv.constant spec dense<1> : vector<3xi32>
+    %8 = spv.constant spec dense<1> : vector<3xi32>
+    // CHECK: spv.constant spec dense<[2, -3, 4]> : vector<3xi32>
+    %9 = spv.constant spec dense<[2, -3, 4]> : vector<3xi32>
+
+    // Fp vector
+    // CHECK: spv.constant spec dense<0.000000e+00> : vector<4xf32>
+    %10 = spv.constant spec dense<0.> : vector<4xf32>
+    // CHECK: spv.constant spec dense<-1.500000e+01> : vector<4xf32>
+    %11 = spv.constant spec dense<-15.> : vector<4xf32>
+    // CHECK: spv.constant spec dense<[7.500000e-01, -2.500000e-01, 1.000000e+01, 4.200000e+01]> : vector<4xf32>
+    %12 = spv.constant spec dense<[0.75, -0.25, 10., 42.]> : vector<4xf32>
+
+    // Array
+    // CHECK: spv.constant spec [dense<3.000000e+00> : vector<2xf32>, dense<[4.000000e+00, 5.000000e+00]> : vector<2xf32>] : !spv.array<2 x vector<2xf32>>
+    %13 = spv.constant spec [dense<3.0> : vector<2xf32>, dense<[4., 5.]> : vector<2xf32>] : !spv.array<2 x vector<2xf32>>
+  }
+  return
+}
+