[MLIR][SPIRV] Add initial support for OpSpecConstantComposite.
authorergawy <kareem.ergawy@gmail.com>
Fri, 2 Oct 2020 18:56:17 +0000 (14:56 -0400)
committerLei Zhang <antiagainst@google.com>
Fri, 2 Oct 2020 19:18:16 +0000 (15:18 -0400)
This commit adds support to SPIR-V's composite specialization constants.
These are specialization constants which are composed of other spec
constants (whehter scalar or composite), regular constatns, or undef
values.

This commit adds support for parsing, printing, verification, and
(De)serialization.

A few TODOs are still in order:
- Supporting more types of constituents; currently, only scalar spec constatns are supported.
- Extending `spv._reference_of` to support composite spec constatns.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D88568

mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
mlir/test/Dialect/SPIRV/Serialization/spec-constant.mlir
mlir/test/Dialect/SPIRV/structure-ops.mlir

index 2ac28ef..0e866f0 100644 (file)
@@ -491,6 +491,8 @@ def SPV_ReferenceOfOp : SPV_Op<"_reference_of", [NoSideEffect]> {
     ```mlir
     %0 = spv._reference_of @spec_const : f32
     ```
+
+    TODO Add support for composite specialization constants.
   }];
 
   let arguments = (ins
@@ -541,8 +543,6 @@ def SPV_SpecConstantOp : SPV_Op<"specConstant", [InModuleScope, Symbol]> {
     spv.specConstant @spec_const1 = true
     spv.specConstant @spec_const2 spec_id(5) = 42 : i32
     ```
-
-    TODO: support composite spec constants with another op
   }];
 
   let arguments = (ins
@@ -557,6 +557,56 @@ def SPV_SpecConstantOp : SPV_Op<"specConstant", [InModuleScope, Symbol]> {
   let autogenSerialization = 0;
 }
 
+def SPV_SpecConstantCompositeOp : SPV_Op<"specConstantComposite", [InModuleScope, Symbol]> {
+  let summary = "Declare a new composite specialization constant.";
+
+  let description = [{
+    This op declares a SPIR-V composite specialization constant. This covers
+    the `OpSpecConstantComposite` SPIR-V instruction. Scalar constants are
+    covered by `spv.specConstant`.
+
+    A constituent of a spec constant composite can be:
+    - A symbol referring of another spec constant.
+    - The SSA ID of a non-specialization constant (i.e. defined through
+      `spv.specConstant`).
+    - The SSA ID of a `spv.undef`.
+
+    ```
+    spv-spec-constant-composite-op ::= `spv.specConstantComposite` symbol-ref-id ` (`
+                                       symbol-ref-id (`, ` symbol-ref-id)*
+                                       `) :` composite-type
+    ```
+
+     where `composite-type` is some non-scalar type that can be represented in the `spv`
+     dialect: `spv.struct`, `spv.array`, or `vector`.
+
+     #### Example:
+
+     ```mlir
+     spv.specConstant @sc1 = 1   : i32
+     spv.specConstant @sc2 = 2.5 : f32
+     spv.specConstant @sc3 = 3.5 : f32
+     spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.struct<i32, f32, f32>
+     ```
+
+    TODO Add support for constituents that are:
+    - regular constants.
+    - undef.
+    - spec constant composite.
+  }];
+
+  let arguments = (ins
+    TypeAttr:$type,
+    StrAttr:$sym_name,
+    SymbolRefArrayAttr:$constituents
+  );
+
+  let results = (outs);
+
+  let hasOpcode = 0;
+
+  let autogenSerialization = 0;
+}
 // -----
 
 #endif // SPIRV_STRUCTURE_OPS
index a011771..363785e 100644 (file)
@@ -53,6 +53,7 @@ static constexpr const char kTypeAttrName[] = "type";
 static constexpr const char kUnequalSemanticsAttrName[] = "unequal_semantics";
 static constexpr const char kValueAttrName[] = "value";
 static constexpr const char kValuesAttrName[] = "values";
+static constexpr const char kCompositeSpecConstituentsName[] = "constituents";
 
 //===----------------------------------------------------------------------===//
 // Common utility functions
@@ -3287,6 +3288,95 @@ static LogicalResult verifyMatrixTimesMatrix(spirv::MatrixTimesMatrixOp op) {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// spv.specConstantComposite
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseSpecConstantCompositeOp(OpAsmParser &parser,
+                                                OperationState &state) {
+
+  StringAttr compositeName;
+  if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
+                             state.attributes))
+    return failure();
+
+  if (parser.parseLParen())
+    return failure();
+
+  SmallVector<Attribute, 4> constituents;
+
+  do {
+    // The name of the constituent attribute isn't important
+    const char *attrName = "spec_const";
+    FlatSymbolRefAttr specConstRef;
+    NamedAttrList attrs;
+
+    if (parser.parseAttribute(specConstRef, Type(), attrName, attrs))
+      return failure();
+
+    constituents.push_back(specConstRef);
+  } while (!parser.parseOptionalComma());
+
+  if (parser.parseRParen())
+    return failure();
+
+  state.addAttribute(kCompositeSpecConstituentsName,
+                     parser.getBuilder().getArrayAttr(constituents));
+
+  Type type;
+  if (parser.parseColonType(type))
+    return failure();
+
+  state.addAttribute(kTypeAttrName, TypeAttr::get(type));
+
+  return success();
+}
+
+static void print(spirv::SpecConstantCompositeOp op, OpAsmPrinter &printer) {
+  printer << spirv::SpecConstantCompositeOp::getOperationName() << " ";
+  printer.printSymbolName(op.sym_name());
+  printer << " (";
+  auto constituents = op.constituents().getValue();
+
+  if (!constituents.empty())
+    llvm::interleaveComma(constituents, printer);
+
+  printer << ") : " << op.type();
+}
+
+static LogicalResult verify(spirv::SpecConstantCompositeOp constOp) {
+  auto cType = constOp.type().dyn_cast<spirv::CompositeType>();
+  auto constituents = constOp.constituents().getValue();
+
+  if (!cType)
+    return constOp.emitError(
+               "result type must be a composite type, but provided ")
+           << constOp.type();
+
+  if (cType.isa<spirv::CooperativeMatrixNVType>())
+    return constOp.emitError("unsupported composite type  ") << cType;
+  else if (constituents.size() != cType.getNumElements())
+    return constOp.emitError("has incorrect number of operands: expected ")
+           << cType.getNumElements() << ", but provided "
+           << constituents.size();
+
+  for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
+    auto constituent = constituents[index].dyn_cast<FlatSymbolRefAttr>();
+
+    auto constituentSpecConstOp =
+        dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
+            constOp.getParentOp(), constituent.getValue()));
+
+    if (constituentSpecConstOp.default_value().getType() !=
+        cType.getElementType(index))
+      return constOp.emitError("has incorrect types of operands: expected ")
+             << cType.getElementType(index) << ", but provided "
+             << constituentSpecConstOp.default_value().getType();
+  }
+
+  return success();
+}
+
 namespace mlir {
 namespace spirv {
 
index b5eea43..153540d 100644 (file)
@@ -249,6 +249,8 @@ private:
   /// `operands`.
   LogicalResult processConstantComposite(ArrayRef<uint32_t> operands);
 
+  LogicalResult processSpecConstantComposite(ArrayRef<uint32_t> operands);
+
   /// Processes a SPIR-V OpConstantNull instruction with the given `operands`.
   LogicalResult processConstantNull(ArrayRef<uint32_t> operands);
 
@@ -1546,6 +1548,39 @@ Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
   return success();
 }
 
+LogicalResult
+Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
+  if (operands.size() < 2) {
+    return emitError(unknownLoc,
+                     "OpConstantComposite must have type <id> and result <id>");
+  }
+  if (operands.size() < 3) {
+    return emitError(unknownLoc,
+                     "OpConstantComposite must have at least 1 parameter");
+  }
+
+  Type resultType = getType(operands[0]);
+  if (!resultType) {
+    return emitError(unknownLoc, "undefined result type from <id> ")
+           << operands[0];
+  }
+
+  auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(operands[1]));
+
+  SmallVector<Attribute, 4> elements;
+  elements.reserve(operands.size() - 2);
+  for (unsigned i = 2, e = operands.size(); i < e; ++i) {
+    auto elementInfo = getSpecConstant(operands[i]);
+    elements.push_back(opBuilder.getSymbolRefAttr(elementInfo));
+  }
+
+  opBuilder.create<spirv::SpecConstantCompositeOp>(
+      unknownLoc, TypeAttr::get(resultType), symName,
+      opBuilder.getArrayAttr(elements));
+
+  return success();
+}
+
 LogicalResult Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
   if (operands.size() != 2) {
     return emitError(unknownLoc,
@@ -2276,6 +2311,8 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
     return processConstant(operands, /*isSpec=*/true);
   case spirv::Opcode::OpConstantComposite:
     return processConstantComposite(operands);
+  case spirv::Opcode::OpSpecConstantComposite:
+    return processSpecConstantComposite(operands);
   case spirv::Opcode::OpConstantTrue:
     return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false);
   case spirv::Opcode::OpSpecConstantTrue:
index 1eda166..426c838 100644 (file)
@@ -200,6 +200,9 @@ private:
 
   LogicalResult processSpecConstantOp(spirv::SpecConstantOp op);
 
+  LogicalResult
+  processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op);
+
   /// SPIR-V dialect supports OpUndef using spv.UndefOp that produces a SSA
   /// value to use with other operations. The SPIR-V spec recommends that
   /// OpUndef be generated at module level. The serialization generates an
@@ -645,6 +648,42 @@ LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
   return failure();
 }
 
+LogicalResult
+Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
+  uint32_t typeID = 0;
+  if (failed(processType(op.getLoc(), op.type(), typeID))) {
+    return failure();
+  }
+
+  auto resultID = getNextID();
+
+  SmallVector<uint32_t, 8> operands;
+  operands.push_back(typeID);
+  operands.push_back(resultID);
+
+  auto constituents = op.constituents();
+
+  for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
+    auto constituent = constituents[index].dyn_cast<FlatSymbolRefAttr>();
+
+    auto constituentName = constituent.getValue();
+    auto constituentID = getSpecConstID(constituentName);
+
+    if (!constituentID) {
+      return op.emitError("unknown result <id> for specialization constant ")
+             << constituentName;
+    }
+
+    operands.push_back(constituentID);
+  }
+
+  encodeInstructionInto(typesGlobalValues,
+                        spirv::Opcode::OpSpecConstantComposite, operands);
+  specConstIDMap[op.sym_name()] = resultID;
+
+  return processName(resultID, op.sym_name());
+}
+
 LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
   auto undefType = op.getType();
   auto &id = undefValIDMap[undefType];
@@ -1765,6 +1804,9 @@ LogicalResult Serializer::processOperation(Operation *opInst) {
       .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); })
       .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); })
       .Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); })
+      .Case([&](spirv::SpecConstantCompositeOp op) {
+        return processSpecConstantCompositeOp(op);
+      })
       .Case([&](spirv::UndefOp op) { return processUndefOp(op); })
       .Case([&](spirv::VariableOp op) { return processVariableOp(op); })
 
index 03cc85b..0df9301 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-translate -test-spirv-roundtrip %s | FileCheck %s
+// RUN: mlir-translate -test-spirv-roundtrip -split-input-file %s | FileCheck %s
 
 spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
   // CHECK: spv.specConstant @sc_true = true
@@ -25,3 +25,23 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
     spv.ReturnValue %1 : i32
   }
 }
+
+// -----
+
+spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
+
+  spv.specConstant @sc_f32_1 = 1.5 : f32
+  spv.specConstant @sc_f32_2 = 2.5 : f32
+  spv.specConstant @sc_f32_3 = 3.5 : f32
+
+  spv.specConstant @sc_i32_1 = 1   : i32
+
+  // CHECK: spv.specConstantComposite @scc_array (@sc_f32_1, @sc_f32_2, @sc_f32_3) : !spv.array<3 x f32>
+  spv.specConstantComposite @scc_array (@sc_f32_1, @sc_f32_2, @sc_f32_3) : !spv.array<3 x f32>
+
+  // CHECK: spv.specConstantComposite @scc_struct (@sc_i32_1, @sc_f32_2, @sc_f32_3) : !spv.struct<i32, f32, f32>
+  spv.specConstantComposite @scc_struct (@sc_i32_1, @sc_f32_2, @sc_f32_3) : !spv.struct<i32, f32, f32>
+
+  // CHECK: spv.specConstantComposite @scc_vector (@sc_f32_1, @sc_f32_2, @sc_f32_3) : vector<3xf32>
+  spv.specConstantComposite @scc_vector (@sc_f32_1, @sc_f32_2, @sc_f32_3) : vector<3 x f32>
+}
index 98da480..765eba9 100644 (file)
@@ -596,3 +596,130 @@ func @use_in_function() -> () {
   spv.specConstant @sc = false
   return
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spv.specConstantComposite
+//===----------------------------------------------------------------------===//
+
+spv.module Logical GLSL450 {
+  // expected-error @+1 {{result type must be a composite type}}
+  spv.specConstantComposite @scc2 (@sc1, @sc2, @sc3) : i32
+}
+
+//===----------------------------------------------------------------------===//
+// spv.specConstantComposite (spv.array)
+//===----------------------------------------------------------------------===//
+
+// -----
+
+spv.module Logical GLSL450 {
+  spv.specConstant @sc1 = 1.5 : f32
+  spv.specConstant @sc2 = 2.5 : f32
+  spv.specConstant @sc3 = 3.5 : f32
+  // CHECK: spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.array<3 x f32>
+  spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.array<3 x f32>
+}
+
+// -----
+
+spv.module Logical GLSL450 {
+  spv.specConstant @sc1 = false
+  spv.specConstant @sc2 spec_id(5) = 42 : i64
+  spv.specConstant @sc3 = 1.5 : f32
+  // expected-error @+1 {{has incorrect number of operands: expected 4, but provided 3}}
+  spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.array<4 x f32>
+
+}
+
+// -----
+
+spv.module Logical GLSL450 {
+  spv.specConstant @sc1 = 1   : i32
+  spv.specConstant @sc2 = 2.5 : f32
+  spv.specConstant @sc3 = 3.5 : f32
+  // expected-error @+1 {{has incorrect types of operands: expected 'f32', but provided 'i32'}}
+  spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.array<3 x f32>
+}
+
+//===----------------------------------------------------------------------===//
+// spv.specConstantComposite (spv.struct)
+//===----------------------------------------------------------------------===//
+
+// -----
+
+spv.module Logical GLSL450 {
+  spv.specConstant @sc1 = 1   : i32
+  spv.specConstant @sc2 = 2.5 : f32
+  spv.specConstant @sc3 = 3.5 : f32
+  // CHECK: spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.struct<i32, f32, f32>
+  spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.struct<i32, f32, f32>
+}
+
+// -----
+
+spv.module Logical GLSL450 {
+  spv.specConstant @sc1 = 1   : i32
+  spv.specConstant @sc2 = 2.5 : f32
+  spv.specConstant @sc3 = 3.5 : f32
+  // expected-error @+1 {{has incorrect number of operands: expected 2, but provided 3}}
+  spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.struct<i32, f32>
+}
+
+// -----
+
+spv.module Logical GLSL450 {
+  spv.specConstant @sc1 = 1.5 : f32
+  spv.specConstant @sc2 = 2.5 : f32
+  spv.specConstant @sc3 = 3.5 : f32
+  // expected-error @+1 {{has incorrect types of operands: expected 'i32', but provided 'f32'}}
+  spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.struct<i32, f32, f32>
+}
+
+//===----------------------------------------------------------------------===//
+// spv.specConstantComposite (vector)
+//===----------------------------------------------------------------------===//
+
+// -----
+
+spv.module Logical GLSL450 {
+  spv.specConstant @sc1 = 1.5 : f32
+  spv.specConstant @sc2 = 2.5 : f32
+  spv.specConstant @sc3 = 3.5 : f32
+  // CHECK: spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : vector<3xf32>
+  spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : vector<3 x f32>
+}
+
+// -----
+
+spv.module Logical GLSL450 {
+  spv.specConstant @sc1 = false
+  spv.specConstant @sc2 spec_id(5) = 42 : i64
+  spv.specConstant @sc3 = 1.5 : f32
+  // expected-error @+1 {{has incorrect number of operands: expected 4, but provided 3}}
+  spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : vector<4xf32>
+
+}
+
+// -----
+
+spv.module Logical GLSL450 {
+  spv.specConstant @sc1 = 1   : i32
+  spv.specConstant @sc2 = 2.5 : f32
+  spv.specConstant @sc3 = 3.5 : f32
+  // expected-error @+1 {{has incorrect types of operands: expected 'f32', but provided 'i32'}}
+  spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : vector<3xf32>
+}
+
+//===----------------------------------------------------------------------===//
+// spv.specConstantComposite (spv.coopmatrix)
+//===----------------------------------------------------------------------===//
+
+// -----
+
+spv.module Logical GLSL450 {
+  spv.specConstant @sc1 = 1.5 : f32
+  // expected-error @+1 {{unsupported composite type}}
+  spv.specConstantComposite @scc (@sc1) : !spv.coopmatrix<8x16xf32, Device>
+}