def SPV_OC_OpNop : I32EnumAttrCase<"OpNop", 0>;
def SPV_OC_OpName : I32EnumAttrCase<"OpName", 5>;
def SPV_OC_OpExtension : I32EnumAttrCase<"OpExtension", 10>;
+def SPV_OC_OpExtInstImport : I32EnumAttrCase<"OpExtInstImport", 11>;
+def SPV_OC_OpExtInst : I32EnumAttrCase<"OpExtInst", 12>;
def SPV_OC_OpMemoryModel : I32EnumAttrCase<"OpMemoryModel", 14>;
def SPV_OC_OpEntryPoint : I32EnumAttrCase<"OpEntryPoint", 15>;
def SPV_OC_OpExecutionMode : I32EnumAttrCase<"OpExecutionMode", 16>;
def SPV_OpcodeAttr :
I32EnumAttr<"Opcode", "valid SPIR-V instructions", [
- SPV_OC_OpNop, SPV_OC_OpName, SPV_OC_OpExtension, SPV_OC_OpMemoryModel,
- SPV_OC_OpEntryPoint, SPV_OC_OpExecutionMode, SPV_OC_OpCapability,
- SPV_OC_OpTypeVoid, SPV_OC_OpTypeBool, SPV_OC_OpTypeInt, SPV_OC_OpTypeFloat,
- SPV_OC_OpTypeVector, SPV_OC_OpTypeArray, SPV_OC_OpTypeStruct,
- SPV_OC_OpTypePointer, SPV_OC_OpTypeFunction, SPV_OC_OpConstantTrue,
- SPV_OC_OpConstantFalse, SPV_OC_OpConstant, SPV_OC_OpConstantComposite,
- SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse,
- SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction,
- SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall,
- SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain,
- SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract,
- SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul,
- SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod,
- SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpSelect,
- SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
+ SPV_OC_OpNop, SPV_OC_OpName, SPV_OC_OpExtension, SPV_OC_OpExtInstImport,
+ SPV_OC_OpExtInst, SPV_OC_OpMemoryModel, SPV_OC_OpEntryPoint,
+ SPV_OC_OpExecutionMode, SPV_OC_OpCapability, SPV_OC_OpTypeVoid,
+ SPV_OC_OpTypeBool, SPV_OC_OpTypeInt, SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector,
+ SPV_OC_OpTypeArray, SPV_OC_OpTypeStruct, SPV_OC_OpTypePointer,
+ SPV_OC_OpTypeFunction, SPV_OC_OpConstantTrue, SPV_OC_OpConstantFalse,
+ SPV_OC_OpConstant, SPV_OC_OpConstantComposite, SPV_OC_OpConstantNull,
+ SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse, SPV_OC_OpSpecConstant,
+ SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter,
+ SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad,
+ SPV_OC_OpStore, SPV_OC_OpAccessChain, SPV_OC_OpDecorate,
+ SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract, SPV_OC_OpIAdd,
+ SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul,
+ SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem,
+ SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpSelect, SPV_OC_OpIEqual,
+ SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan,
SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual,
SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual,
// Specifies whether this op has a direct corresponding SPIR-V binary
// instruction opcode. The (de)serializer use this field to determine whether
// to auto-generate an entry in the (de)serialization dispatch table for this
- // op. If set, this field also futher enables `autogenSerialization` (see
- // below for details).
+ // op.
bit hasOpcode = 1;
// Name of the corresponding SPIR-V op. Only valid to use when hasOpcode is 1.
// these methods is required.
//
// Note:
- //
// 1) If hasOpcode is set but autogenSerialization is not set, the
// (de)serializer dispatch method still calls the above method for
// (de)serializing this op.
- //
- // 2) If hasOpcode is not set, then this field is not interpreted; this op's
- // (de)serialization method will not be auto-generated regardless. Neither
- // does the handling in the (de)serialization dispatch table. Both
- // (de)serializing this op and its dispatch should be handled manually.
+ // 2) If hasOpcode is not set, but autogenSerialization is set, the
+ // above methods for (de)serialization are generated, but there is no
+ // entry added in the dispatch tables to invoke these methods. The
+ // dispatch needs to be handled manually. SPV_ExtInstOps are an
+ // example of this.
bit autogenSerialization = 1;
}
let verifier = [{ return success(); }];
}
+class SPV_ExtInstOp<string mnemonic, string setPrefix, string setName,
+ int opcode, list<OpTrait> traits = []> :
+ SPV_Op<setPrefix # "." # mnemonic, traits> {
+
+ // Extended instruction sets have no direct opcode (they share the
+ // same `OpExtInst` instruction). So the hasOpcode field is set to
+ // false. So no entry corresponding to these ops are added in the
+ // dispatch functions for (de)serialization. The methods for
+ // (de)serialization are still automatically generated (since
+ // autogenSerialization remains 1). A separate method is generated
+ // for dispatching extended instruction set ops.
+ let hasOpcode = 0;
+
+ // Opcode within extended instruction set.
+ int extendedInstOpcode = opcode;
+
+ // Name used to import the extended instruction set.
+ string extendedInstSetName = setName;
+}
+
#endif // SPIRV_BASE
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
using llvm::raw_string_ostream;
using llvm::Record;
using llvm::RecordKeeper;
+using llvm::SmallVector;
using llvm::SMLoc;
+using llvm::StringMap;
using llvm::StringRef;
using llvm::Twine;
using mlir::tblgen::Attribute;
static void emitGetOpcodeFunction(const Record *record, Operator const &op,
raw_ostream &os) {
os << formatv("template <> constexpr inline ::mlir::spirv::Opcode "
- "getOpcode<{0}>()",
- op.getQualCppClassName())
- << " {\n "
- << formatv("return ::mlir::spirv::Opcode::{0};\n}\n",
+ "getOpcode<{0}>() {{\n",
+ op.getQualCppClassName());
+ os << formatv(" return ::mlir::spirv::Opcode::{0};\n",
record->getValueAsString("spirvOpName"));
+ os << "}\n";
}
+/// Forward declaration of function to return the SPIR-V opcode corresponding to
+/// an operation. This function will be generated for all SPV_Op instances that
+/// have hasOpcode = 1.
static void declareOpcodeFn(raw_ostream &os) {
os << "template <typename OpClass> inline constexpr ::mlir::spirv::Opcode "
"getOpcode();\n";
}
+/// Generates code to serialize attributes of a SPV_Op `op` into `os`. The
+/// generates code extracts the attribute with name `attrName` from
+/// `operandList` of `op`.
static void emitAttributeSerialization(const Attribute &attr,
- ArrayRef<SMLoc> loc, llvm::StringRef op,
- llvm::StringRef operandList,
- llvm::StringRef attrName,
- raw_ostream &os) {
- os << " auto attr = " << op << ".getAttr(\"" << attrName << "\");\n";
- os << " if (attr) {\n";
+ ArrayRef<SMLoc> loc, StringRef tabs,
+ StringRef opVar, StringRef operandList,
+ StringRef attrName, raw_ostream &os) {
+ os << tabs << formatv("auto attr = {0}.getAttr(\"{1}\");\n", opVar, attrName);
+ os << tabs << "if (attr) {\n";
if (attr.getAttrDefName() == "I32ArrayAttr") {
// Serialize all the elements of the array
- os << " for (auto attrElem : attr.cast<ArrayAttr>()) {\n";
- os << " " << operandList
- << ".push_back(static_cast<uint32_t>(attrElem.cast<IntegerAttr>()."
- "getValue().getZExtValue()));\n";
- os << " }\n";
+ os << tabs << " for (auto attrElem : attr.cast<ArrayAttr>()) {\n";
+ os << tabs
+ << formatv(" {0}.push_back(static_cast<uint32_t>("
+ "attrElem.cast<IntegerAttr>().getValue().getZExtValue()));\n",
+ operandList);
+ os << tabs << " }\n";
} else if (attr.isEnumAttr() || attr.getAttrDefName() == "I32Attr") {
- os << " " << operandList
- << ".push_back(static_cast<uint32_t>(attr.cast<IntegerAttr>().getValue()"
- ".getZExtValue()));\n";
+ os << tabs
+ << formatv(" {0}.push_back(static_cast<uint32_t>("
+ "attr.cast<IntegerAttr>().getValue().getZExtValue()));\n",
+ operandList);
} else {
PrintFatalError(
loc,
"unhandled attribute type in SPIR-V serialization generation : '") +
attr.getAttrDefName() + llvm::Twine("'"));
}
- os << " }\n";
+ os << tabs << "}\n";
}
-static void emitSerializationFunction(const Record *attrClass,
- const Record *record, const Operator &op,
- raw_ostream &os) {
- // If the record has 'autogenSerialization' set to 0, nothing to do
- if (!record->getValueAsBit("autogenSerialization")) {
- return;
- }
- os << formatv("template <> LogicalResult\nSerializer::processOp<{0}>(\n"
- " {0} op)",
- op.getQualCppClassName())
- << " {\n";
- os << " SmallVector<uint32_t, 4> operands;\n";
- os << " SmallVector<StringRef, 2> elidedAttrs;\n";
-
- // Serialize result information
- if (op.getNumResults() == 1) {
- os << " uint32_t resultTypeID = 0;\n";
- os << " if (failed(processType(op.getLoc(), op.getType(), resultTypeID))) "
- "{\n";
- os << " return failure();\n";
- os << " }\n";
- os << " operands.push_back(resultTypeID);\n";
- // Create an SSA result <id> for the op
- os << " auto resultID = getNextID();\n";
- os << " valueIDMap[op.getResult()] = resultID;\n";
- os << " operands.push_back(resultID);\n";
- } else if (op.getNumResults() != 0) {
- PrintFatalError(record->getLoc(), "SPIR-V ops can only zero or one result");
- }
-
- // Process arguments
+/// Generates code to serialize the operands of a SPV_Op `op` into `os`. The
+/// generated querries the SSA-ID if operand is a SSA-Value, or serializes the
+/// attributes. The `operands` vector is updated appropriately. `elidedAttrs`
+/// updated as well to include the serialized attributes.
+static void emitOperandSerialization(const Operator &op, ArrayRef<SMLoc> loc,
+ StringRef tabs, StringRef opVar,
+ StringRef operands, StringRef elidedAttrs,
+ raw_ostream &os) {
auto operandNum = 0;
for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) {
auto argument = op.getArg(i);
- os << " {\n";
+ os << tabs << "{\n";
if (argument.is<NamedTypeConstraint *>()) {
- os << " for (auto arg : op.getODSOperands(" << operandNum << ")) {\n";
- os << " auto argID = findValueID(arg);\n";
- os << " if (!argID) {\n";
- os << " emitError(op.getLoc(), \"operand " << operandNum
- << " has a use before def\");\n";
- os << " }\n";
- os << " operands.push_back(argID);\n";
+ os << tabs
+ << formatv(" for (auto arg : {0}.getODSOperands({1})) {{\n", opVar,
+ operandNum);
+ os << tabs << " auto argID = findValueID(arg);\n";
+ os << tabs << " if (!argID) {\n";
+ os << tabs
+ << formatv(
+ " emitError({0}.getLoc(), \"operand {1} has a use before "
+ "def\");\n",
+ opVar, operandNum);
+ os << tabs << " }\n";
+ os << tabs << formatv(" {0}.push_back(argID);\n", operands);
os << " }\n";
operandNum++;
} else {
auto attr = argument.get<NamedAttribute *>();
+ auto newtabs = tabs.str() + " ";
emitAttributeSerialization(
(attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr),
- record->getLoc(), "op", "operands", attr->name, os);
- os << " elidedAttrs.push_back(\"" << attr->name << "\");\n";
+ loc, newtabs, opVar, operands, attr->name, os);
+ os << newtabs
+ << formatv("{0}.push_back(\"{1}\");\n", elidedAttrs, attr->name);
}
- os << " }\n";
+ os << tabs << "}\n";
}
+}
- os << formatv(" encodeInstructionInto("
- "functions, spirv::getOpcode<{0}>(), operands);\n",
- op.getQualCppClassName());
+/// Generates code to serializes the result of SPV_Op `op` into `os`. The
+/// generated gets the ID for the type of the result (if any), the SSA-ID of
+/// the result and updates `resultID` with the SSA-ID.
+static void emitResultSerialization(const Operator &op, ArrayRef<SMLoc> loc,
+ StringRef tabs, StringRef opVar,
+ StringRef operands, StringRef resultID,
+ raw_ostream &os) {
+ if (op.getNumResults() == 1) {
+ StringRef resultTypeID("resultTypeID");
+ os << tabs << formatv("uint32_t {0} = 0;\n", resultTypeID);
+ os << tabs
+ << formatv(
+ "if (failed(processType({0}.getLoc(), {0}.getType(), {1}))) {{\n",
+ opVar, resultTypeID);
+ os << tabs << " return failure();\n";
+ os << tabs << "}\n";
+ os << tabs << formatv("{0}.push_back({1});\n", operands, resultTypeID);
+ // Create an SSA result <id> for the op
+ os << tabs << formatv("{0} = getNextID();\n", resultID);
+ os << tabs
+ << formatv("valueIDMap[{0}.getResult()] = {1};\n", opVar, resultID);
+ os << tabs << formatv("{0}.push_back({1});\n", operands, resultID);
+ } else if (op.getNumResults() != 0) {
+ PrintFatalError(loc, "SPIR-V ops can only have zero or one result");
+ }
+}
+/// Generates code to serialize attributes of SPV_Op `op` that become
+/// decorations on the `resultID` of the serialized operation `opVar` in the
+/// SPIR-V binary.
+static void emitDecorationSerialization(const Operator &op, StringRef tabs,
+ StringRef opVar, StringRef elidedAttrs,
+ StringRef resultID, raw_ostream &os) {
if (op.getNumResults() == 1) {
// All non-argument attributes translated into OpDecorate instruction
- os << " for (auto attr : op.getAttrs()) {\n";
- os << " if (llvm::any_of(elidedAttrs, [&](StringRef elided) { return "
- "attr.first.is(elided); })) {\n";
- os << " continue;\n";
- os << " }\n";
- os << " if (failed(processDecoration(op.getLoc(), resultID, attr))) {\n";
- os << " return failure();";
- os << " }\n";
- os << " }\n";
+ os << tabs << formatv("for (auto attr : {0}.getAttrs()) {{\n", opVar);
+ os << tabs
+ << formatv(" if (llvm::any_of({0}, [&](StringRef elided)", elidedAttrs);
+ os << " {return attr.first.is(elided);})) {\n";
+ os << tabs << " continue;\n";
+ os << tabs << " }\n";
+ os << tabs
+ << formatv(
+ " if (failed(processDecoration({0}.getLoc(), {1}, attr))) {{\n",
+ opVar, resultID);
+ os << tabs << " return failure();\n";
+ os << tabs << " }\n";
+ os << tabs << "}\n";
+ }
+}
+
+/// Generates code to serialize an SPV_Op `op` into `os`.
+static void emitSerializationFunction(const Record *attrClass,
+ const Record *record, const Operator &op,
+ raw_ostream &os) {
+ // If the record has 'autogenSerialization' set to 0, nothing to do
+ if (!record->getValueAsBit("autogenSerialization")) {
+ return;
+ }
+ StringRef opVar("op"), operands("operands"), elidedAttrs("elidedAttrs"),
+ resultID("resultID");
+ os << formatv(
+ "template <> LogicalResult\nSerializer::processOp<{0}>({0} {1}) {{\n",
+ op.getQualCppClassName(), opVar);
+ os << formatv(" SmallVector<uint32_t, 4> {0};\n", operands);
+ os << formatv(" SmallVector<StringRef, 2> {0};\n", elidedAttrs);
+
+ // Serialize result information.
+ if (op.getNumResults() == 1) {
+ os << formatv(" uint32_t {0} = 0;\n", resultID);
+ emitResultSerialization(op, record->getLoc(), " ", opVar, operands,
+ resultID, os);
+ }
+
+ // Process arguments.
+ emitOperandSerialization(op, record->getLoc(), " ", opVar, operands,
+ elidedAttrs, os);
+
+ if (record->isSubClassOf("SPV_ExtInstOp")) {
+ os << formatv(" encodeExtensionInstruction({0}, \"{1}\", {2}, {3});\n",
+ opVar, record->getValueAsString("extendedInstSetName"),
+ record->getValueAsInt("extendedInstOpcode"), operands);
+ } else {
+ os << formatv(" encodeInstructionInto("
+ "functions, spirv::getOpcode<{0}>(), {1});\n",
+ op.getQualCppClassName(), operands);
}
+ // Process decorations.
+ emitDecorationSerialization(op, " ", opVar, elidedAttrs, resultID, os);
+
os << " return success();\n";
os << "}\n\n";
}
-static void initDispatchSerializationFn(raw_ostream &os) {
- os << "LogicalResult Serializer::dispatchToAutogenSerialization(Operation "
- "*op) {\n ";
+/// Generates the prologue for the function that dispatches the serialization of
+/// the operation `opVar` based on its opcode.
+static void initDispatchSerializationFn(StringRef opVar, raw_ostream &os) {
+ os << formatv(
+ "LogicalResult Serializer::dispatchToAutogenSerialization(Operation "
+ "*{0}) {{\n ",
+ opVar);
}
-static void emitSerializationDispatch(const Operator &op, raw_ostream &os) {
- os << formatv(" if (isa<{0}>(op)) ", op.getQualCppClassName()) << "{\n";
- os << " ";
- os << formatv("return processOp<{0}>(cast<{0}>(op));\n",
- op.getQualCppClassName());
- os << " } else";
+/// Generates the body of the dispatch function. This function generates the
+/// check that if satisfied, will call the serialization function generated for
+/// the `op`.
+static void emitSerializationDispatch(const Operator &op, StringRef tabs,
+ StringRef opVar, raw_ostream &os) {
+ os << tabs
+ << formatv("if (isa<{0}>({1})) {{\n", op.getQualCppClassName(), opVar);
+ os << tabs
+ << formatv(" return processOp(cast<{0}>({1}));\n",
+ op.getQualCppClassName(), opVar);
+ os << tabs << "} else";
}
-static void finalizeDispatchSerializationFn(raw_ostream &os) {
+/// Generates the epilogue for the function that dispatches the serialization of
+/// the operation.
+static void finalizeDispatchSerializationFn(StringRef opVar, raw_ostream &os) {
os << " {\n";
- os << " return op->emitError(\"unhandled operation serialization\");\n";
+ os << formatv(
+ " return {0}->emitError(\"unhandled operation serialization\");\n",
+ opVar);
os << " }\n";
os << " return success();\n";
os << "}\n\n";
}
-static void emitAttributeDeserialization(
- const Attribute &attr, ArrayRef<SMLoc> loc, llvm::StringRef attrList,
- llvm::StringRef attrName, llvm::StringRef operandsList,
- llvm::StringRef wordIndex, llvm::StringRef wordCount, raw_ostream &os) {
+/// Generates code to deserialize the attribute of a SPV_Op into `os`. The
+/// generated code reads the `words` of the serialized instruction at
+/// position `wordIndex` and adds the deserialized attribute into `attrList`.
+static void emitAttributeDeserialization(const Attribute &attr,
+ ArrayRef<SMLoc> loc, StringRef tabs,
+ StringRef attrList, StringRef attrName,
+ StringRef words, StringRef wordIndex,
+ raw_ostream &os) {
if (attr.getAttrDefName() == "I32ArrayAttr") {
- os << " SmallVector<Attribute, 4> attrListElems;\n";
- os << " while (" << wordIndex << " < " << wordCount << ") {\n";
- os << " attrListElems.push_back(opBuilder.getI32IntegerAttr("
- << operandsList << "[" << wordIndex << "++]));\n";
- os << " }\n";
- os << " " << attrList << ".push_back(opBuilder.getNamedAttr(\""
- << attrName << "\", opBuilder.getArrayAttr(attrListElems)));\n";
+ os << tabs << "SmallVector<Attribute, 4> attrListElems;\n";
+ os << tabs << formatv("while ({0} < {1}.size()) {{\n", wordIndex, words);
+ os << tabs
+ << formatv(
+ " "
+ "attrListElems.push_back(opBuilder.getI32IntegerAttr({0}[{1}++]))"
+ ";\n",
+ words, wordIndex);
+ os << tabs << "}\n";
+ os << tabs
+ << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
+ "opBuilder.getArrayAttr(attrListElems)));\n",
+ attrList, attrName);
} else if (attr.isEnumAttr() || attr.getAttrDefName() == "I32Attr") {
- os << " " << attrList << ".push_back(opBuilder.getNamedAttr(\""
- << attrName << "\", opBuilder.getI32IntegerAttr(" << operandsList << "["
- << wordIndex << "++])));\n";
+ os << tabs
+ << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
+ "opBuilder.getI32IntegerAttr({2}[{3}++])));\n",
+ attrList, attrName, words, wordIndex);
} else {
PrintFatalError(
loc, llvm::Twine(
}
}
-static void emitDeserializationFunction(const Record *attrClass,
- const Record *record,
- const Operator &op, raw_ostream &os) {
- // If the record has 'autogenSerialization' set to 0, nothing to do
- if (!record->getValueAsBit("autogenSerialization")) {
- return;
- }
- os << formatv("template <> "
- "LogicalResult\nDeserializer::processOp<{0}>(ArrayRef<"
- "uint32_t> words)",
- op.getQualCppClassName());
- os << " {\n";
- os << " SmallVector<Type, 1> resultTypes;\n";
- os << " size_t wordIndex = 0; (void)wordIndex;\n";
-
+/// Generates the code to deserialize the result of an SPV_Op `op` into
+/// `os`. The generated code gets the type of the result specified at
+/// `words`[`wordIndex`], the SSA ID for the result at position `wordIndex` + 1
+/// and updates the `resultType` and `valueID` with the parsed type and SSA ID,
+/// respectively.
+static void emitResultDeserialization(const Operator &op, ArrayRef<SMLoc> loc,
+ StringRef tabs, StringRef words,
+ StringRef wordIndex,
+ StringRef resultTypes, StringRef valueID,
+ raw_ostream &os) {
// Deserialize result information if it exists
- bool hasResult = false;
if (op.getNumResults() == 1) {
- os << " {\n";
- os << " if (wordIndex >= words.size()) {\n";
- os << " "
- << formatv("return emitError(unknownLoc, \"expected result type <id> "
- "while deserializing {0}\");\n",
- op.getQualCppClassName());
- os << " }\n";
- os << " auto ty = getType(words[wordIndex]);\n";
- os << " if (!ty) {\n";
- os << " return emitError(unknownLoc, \"unknown type result <id> : "
- "\") << words[wordIndex];\n";
- os << " }\n";
- os << " resultTypes.push_back(ty);\n";
- os << " wordIndex++;\n";
- os << " }\n";
- os << " if (wordIndex >= words.size()) {\n";
- os << " "
- << formatv("return emitError(unknownLoc, \"expected result <id> while "
- "deserializing {0}\");\n",
- op.getQualCppClassName());
- os << " }\n";
- os << " uint32_t valueID = words[wordIndex++];\n";
- hasResult = true;
+ os << tabs << "{\n";
+ os << tabs << formatv(" if ({0} >= {1}.size()) {{\n", wordIndex, words);
+ os << tabs
+ << formatv(
+ " return emitError(unknownLoc, \"expected result type <id> "
+ "while deserializing {0}\");\n",
+ op.getQualCppClassName());
+ os << tabs << " }\n";
+ os << tabs << formatv(" auto ty = getType({0}[{1}]);\n", words, wordIndex);
+ os << tabs << " if (!ty) {\n";
+ os << tabs
+ << formatv(
+ " return emitError(unknownLoc, \"unknown type result <id> : "
+ "\") << {0}[{1}];\n",
+ words, wordIndex);
+ os << tabs << " }\n";
+ os << tabs << formatv(" {0}.push_back(ty);\n", resultTypes);
+ os << tabs << formatv(" {0}++;\n", wordIndex);
+ os << tabs << formatv(" if ({0} >= {1}.size()) {{\n", wordIndex, words);
+ os << tabs
+ << formatv(
+ " return emitError(unknownLoc, \"expected result <id> while "
+ "deserializing {0}\");\n",
+ op.getQualCppClassName());
+ os << tabs << " }\n";
+ os << tabs << "}\n";
+ os << tabs << formatv("{0} = {1}[{2}++];\n", valueID, words, wordIndex);
} else if (op.getNumResults() != 0) {
- PrintFatalError(record->getLoc(),
- "SPIR-V ops can have only zero or one result");
+ PrintFatalError(loc, "SPIR-V ops can have only zero or one result");
}
+}
+/// Generates the code to deserialize the operands of an SPV_Op `op` into
+/// `os`. The generated code reads the `words` of the binary instruction, from
+/// position `wordIndex` to the end, and either gets the Value corresponding to
+/// the ID encoded, or deserializes the attributes encoded. The parsed
+/// operand(attribute) is added to the `operands` list or `attributes` list.
+static void emitOperandDeserialization(const Operator &op, ArrayRef<SMLoc> loc,
+ StringRef tabs, StringRef words,
+ StringRef wordIndex, StringRef operands,
+ StringRef attributes, raw_ostream &os) {
// Process operands/attributes
- os << " SmallVector<Value *, 4> operands;\n";
- os << " SmallVector<NamedAttribute, 4> attributes;\n";
unsigned operandNum = 0;
for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) {
auto argument = op.getArg(i);
if (auto valueArg = argument.dyn_cast<NamedTypeConstraint *>()) {
if (valueArg->isVariadic()) {
if (i != e - 1) {
- PrintFatalError(record->getLoc(),
+ PrintFatalError(loc,
"SPIR-V ops can have Variadic<..> argument only if "
"it's the last argument");
}
- os << " for (; wordIndex < words.size(); ++wordIndex)";
+ os << tabs
+ << formatv("for (; {0} < {1}.size(); ++{0})", wordIndex, words);
} else {
- os << " if (wordIndex < words.size())";
+ os << tabs << formatv("if ({0} < {1}.size())", wordIndex, words);
}
os << " {\n";
- os << " auto arg = getValue(words[wordIndex]);\n";
- os << " if (!arg) {\n";
- os << " return emitError(unknownLoc, \"unknown result <id> : \") << "
- "words[wordIndex];\n";
- os << " }\n";
- os << " operands.push_back(arg);\n";
+ os << tabs
+ << formatv(" auto arg = getValue({0}[{1}]);\n", words, wordIndex);
+ os << tabs << " if (!arg) {\n";
+ os << tabs
+ << formatv(
+ " return emitError(unknownLoc, \"unknown result <id> : \") "
+ "<< {0}[{1}];\n",
+ words, wordIndex);
+ os << tabs << " }\n";
+ os << tabs << formatv(" {0}.push_back(arg);\n", operands);
if (!valueArg->isVariadic()) {
- os << " wordIndex++;\n";
+ os << tabs << formatv(" {0}++;\n", wordIndex);
}
operandNum++;
- os << " }\n";
+ os << tabs << "}\n";
} else {
- os << " if (wordIndex < words.size()) {\n";
+ os << tabs << formatv("if ({0} < {1}.size()) {{\n", wordIndex, words);
auto attr = argument.get<NamedAttribute *>();
+ auto newtabs = tabs.str() + " ";
emitAttributeDeserialization(
(attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr),
- record->getLoc(), "attributes", attr->name, "words", "wordIndex",
- "words.size()", os);
+ loc, newtabs, attributes, attr->name, words, wordIndex, os);
os << " }\n";
}
}
- os << " if (wordIndex != words.size()) {\n";
- os << " return emitError(unknownLoc, \"found more operands than expected "
- "when deserializing "
- << op.getQualCppClassName()
- << ", only \") << wordIndex << \" of \" << words.size() << \" "
- "processed\";\n";
- os << " }\n\n";
+ os << tabs << formatv("if ({0} != {1}.size()) {{\n", wordIndex, words);
+ os << tabs
+ << formatv(
+ " return emitError(unknownLoc, \"found more operands than "
+ "expected when deserializing {0}, only \") << {1} << \" of \" << "
+ "{2}.size() << \" processed\";\n",
+ op.getQualCppClassName(), wordIndex, words);
+ os << tabs << "}\n\n";
+}
+/// Generates code to update the `attributes` vector with the attributes
+/// obtained from parsing the decorations in the SPIR-V binary associated with
+/// an <id> `valueID`
+static void emitDecorationDeserialization(const Operator &op, StringRef tabs,
+ StringRef valueID,
+ StringRef attributes,
+ raw_ostream &os) {
// Import decorations parsed
if (op.getNumResults() == 1) {
- os << " if (decorations.count(valueID)) {\n"
- << " auto attrs = decorations[valueID].getAttrs();\n"
- << " attributes.append(attrs.begin(), attrs.end());\n"
- << " }\n";
+ os << tabs << formatv("if (decorations.count({0})) {{\n", valueID);
+ os << tabs
+ << formatv(" auto attrs = decorations[{0}].getAttrs();\n", valueID);
+ os << tabs
+ << formatv(" {0}.append(attrs.begin(), attrs.end());\n", attributes);
+ os << tabs << "}\n";
}
+}
- os << formatv(" auto op = opBuilder.create<{0}>(unknownLoc, resultTypes, "
- "operands, attributes); (void)op;\n",
- op.getQualCppClassName());
- if (hasResult) {
- os << " valueMap[valueID] = op.getResult();\n\n";
+/// Generates code to deserialize an SPV_Op `op` into `os`.
+static void emitDeserializationFunction(const Record *attrClass,
+ const Record *record,
+ const Operator &op, raw_ostream &os) {
+ // If the record has 'autogenSerialization' set to 0, nothing to do
+ if (!record->getValueAsBit("autogenSerialization")) {
+ return;
+ }
+ StringRef resultTypes("resultTypes"), valueID("valueID"), words("words"),
+ wordIndex("wordIndex"), opVar("op"), operands("operands"),
+ attributes("attributes");
+ os << formatv("template <> "
+ "LogicalResult\nDeserializer::processOp<{0}>(ArrayRef<"
+ "uint32_t> {1}) {{\n",
+ op.getQualCppClassName(), words);
+ os << formatv(" SmallVector<Type, 1> {0};\n", resultTypes);
+ os << formatv(" size_t {0} = 0; (void){0};\n", wordIndex);
+ os << formatv(" uint32_t {0} = 0; (void){0};\n", valueID);
+
+ // Deserialize result information
+ emitResultDeserialization(op, record->getLoc(), " ", words, wordIndex,
+ resultTypes, valueID, os);
+
+ os << formatv(" SmallVector<Value *, 4> {0};\n", operands);
+ os << formatv(" SmallVector<NamedAttribute, 4> {0};\n", attributes);
+ // Operand deserialization
+ emitOperandDeserialization(op, record->getLoc(), " ", words, wordIndex,
+ operands, attributes, os);
+
+ os << formatv(
+ " auto {1} = opBuilder.create<{0}>(unknownLoc, {2}, {3}, {4}); "
+ "(void){1};\n",
+ op.getQualCppClassName(), opVar, resultTypes, operands, attributes);
+ if (op.getNumResults() == 1) {
+ os << formatv(" valueMap[{0}] = {1}.getResult();\n\n", valueID, opVar);
}
+ // Decorations
+ emitDecorationDeserialization(op, " ", valueID, attributes, os);
os << " return success();\n";
os << "}\n\n";
}
-static void initDispatchDeserializationFn(raw_ostream &os) {
- os << "LogicalResult "
- "Deserializer::dispatchToAutogenDeserialization(spirv::Opcode "
- "opcode, ArrayRef<uint32_t> words) {\n";
- os << " switch (opcode) {\n";
+/// Generates the prologue for the function that dispatches the deserialization
+/// based on the `opcode`.
+static void initDispatchDeserializationFn(StringRef opcode, StringRef words,
+ raw_ostream &os) {
+ os << formatv(
+ "LogicalResult "
+ "Deserializer::dispatchToAutogenDeserialization(spirv::Opcode {0}, "
+ "ArrayRef<uint32_t> {1}) {{\n",
+ opcode, words);
+ os << formatv(" switch ({0}) {{\n", opcode);
}
+/// Generates the body of the dispatch function, by generating the case label
+/// for an opcode and the call to the method to perform the deserialization.
static void emitDeserializationDispatch(const Operator &op, const Record *def,
+ StringRef tabs, StringRef words,
raw_ostream &os) {
- os << formatv(" case spirv::Opcode::{0}:\n",
+ os << tabs
+ << formatv("case spirv::Opcode::{0}:\n",
def->getValueAsString("spirvOpName"));
- os << formatv(" return processOp<{0}>(words);\n",
- op.getQualCppClassName());
+ os << tabs
+ << formatv(" return processOp<{0}>({1});\n", op.getQualCppClassName(),
+ words);
}
-static void finalizeDispatchDeserializationFn(raw_ostream &os) {
+/// Generates the epilogue for the function that dispatches the deserialization
+/// of the operation.
+static void finalizeDispatchDeserializationFn(StringRef opcode,
+ raw_ostream &os) {
os << " default:\n";
os << " ;\n";
os << " }\n";
- os << " return emitError(unknownLoc, \"unhandled deserialization of \") << "
- "spirv::stringifyOpcode(opcode);\n";
+ os << formatv(
+ " return emitError(unknownLoc, \"unhandled deserialization of \") << "
+ "spirv::stringifyOpcode({0});\n",
+ opcode);
os << "}\n";
}
+static void initExtendedSetDeserializationDispatch(StringRef extensionSetName,
+ StringRef instructionID,
+ StringRef words,
+ raw_ostream &os) {
+ os << formatv("LogicalResult "
+ "Deserializer::dispatchToExtensionSetAutogenDeserialization("
+ "StringRef {0}, uint32_t {1}, ArrayRef<uint32_t> {2}) {{\n",
+ extensionSetName, instructionID, words);
+}
+
+static void
+emitExtendedSetDeserializationDispatch(const RecordKeeper &recordKeeper,
+ raw_ostream &os) {
+ StringRef extensionSetName("extensionSetName"),
+ instructionID("instructionID"), words("words");
+
+ // First iterate over all ops derived from SPV_ExtensionSetOps to get all
+ // extensionSets.
+
+ // For each of the extensions a separate raw_string_ostream is used to
+ // generate code into. These are then concatenated at the end. Since
+ // raw_string_ostream needs a string&, use a vector to store all the string
+ // that are captured by reference within raw_string_ostream.
+ StringMap<raw_string_ostream> extensionSets;
+ SmallVector<std::string, 1> extensionSetNames;
+
+ initExtendedSetDeserializationDispatch(extensionSetName, instructionID, words,
+ os);
+ auto defs = recordKeeper.getAllDerivedDefinitions("SPV_ExtInstOp");
+ for (const auto *def : defs) {
+ if (!def->getValueAsBit("autogenSerialization")) {
+ continue;
+ }
+ Operator op(def);
+ auto setName = def->getValueAsString("extendedInstSetName");
+ if (!extensionSets.count(setName)) {
+ extensionSetNames.push_back("");
+ extensionSets.try_emplace(setName, extensionSetNames.back());
+ auto &setos = extensionSets.find(setName)->second;
+ setos << formatv(" if ({0} == \"{1}\") {{\n", extensionSetName, setName);
+ setos << formatv(" switch ({0}) {{\n", instructionID);
+ }
+ auto &setos = extensionSets.find(setName)->second;
+ setos << formatv(" case {0}:\n",
+ def->getValueAsInt("extendedInstOpcode"));
+ setos << formatv(" return processOp<{0}>({1});\n",
+ op.getQualCppClassName(), words);
+ }
+
+ // Append the dispatch code for all the extended sets.
+ for (auto &extensionSet : extensionSets) {
+ os << extensionSet.second.str();
+ os << " default:\n";
+ os << formatv(
+ " return emitError(unknownLoc, \"unhandled deserializations of "
+ "\") << {0} << \" from extension set \" << {1};\n",
+ instructionID, extensionSetName);
+ os << " }\n";
+ os << " }\n";
+ }
+
+ os << formatv(" return emitError(unknownLoc, \"unhandled deserialization of "
+ "extended instruction set {0}\");\n",
+ extensionSetName);
+ os << "}\n";
+}
+
+/// Emits all the autogenerated serialization/deserializations functions for the
+/// SPV_Ops.
static bool emitSerializationFns(const RecordKeeper &recordKeeper,
raw_ostream &os) {
llvm::emitSourceFileHeader("SPIR-V Serialization Utilities/Functions", os);
serFn(serFnString), deserFn(deserFnString), utils(utilsString);
auto attrClass = recordKeeper.getClass("Attr");
+ // Emit the serialization and deserialization functions simulataneously.
declareOpcodeFn(utils);
- initDispatchSerializationFn(dSerFn);
- initDispatchDeserializationFn(dDesFn);
+ StringRef opVar("op");
+ StringRef opcode("opcode"), words("words");
+
+ // Handle the SPIR-V ops.
+ initDispatchSerializationFn(opVar, dSerFn);
+ initDispatchDeserializationFn(opcode, words, dDesFn);
auto defs = recordKeeper.getAllDerivedDefinitions("SPV_Op");
for (const auto *def : defs) {
- if (!def->getValueAsBit("hasOpcode")) {
- continue;
- }
Operator op(def);
- emitGetOpcodeFunction(def, op, utils);
emitSerializationFunction(attrClass, def, op, serFn);
- emitSerializationDispatch(op, dSerFn);
emitDeserializationFunction(attrClass, def, op, deserFn);
- emitDeserializationDispatch(op, def, dDesFn);
+ if (def->getValueAsBit("hasOpcode") || def->isSubClassOf("SPV_ExtInstOp")) {
+ emitSerializationDispatch(op, " ", opVar, dSerFn);
+ }
+ if (def->getValueAsBit("hasOpcode")) {
+ emitGetOpcodeFunction(def, op, utils);
+ emitDeserializationDispatch(op, def, " ", words, dDesFn);
+ }
}
- finalizeDispatchSerializationFn(dSerFn);
- finalizeDispatchDeserializationFn(dDesFn);
+ finalizeDispatchSerializationFn(opVar, dSerFn);
+ finalizeDispatchDeserializationFn(opcode, dDesFn);
+
+ emitExtendedSetDeserializationDispatch(recordKeeper, dDesFn);
os << "#ifdef GET_SPIRV_SERIALIZATION_UTILS\n";
os << utils.str();
static void emitEnumGetAttrNameFnDefn(const EnumAttr &enumAttr,
raw_ostream &os) {
auto enumName = enumAttr.getEnumClassName();
- os << formatv("template <> inline StringRef attributeName<{0}>()", enumName)
- << " {\n";
+ os << formatv("template <> inline StringRef attributeName<{0}>() {{\n",
+ enumName);
os << " "
<< formatv("static constexpr const char attrName[] = \"{0}\";\n",
mlir::convertToSnakeCase(enumName));
raw_ostream &os) {
auto enumName = enumAttr.getEnumClassName();
auto strToSymFnName = enumAttr.getStringToSymbolFnName();
- os << formatv("template <> inline SymbolizeFnTy<{0}> symbolizeEnum<{0}>()",
- enumName)
- << " {\n";
+ os << formatv(
+ "template <> inline SymbolizeFnTy<{0}> symbolizeEnum<{0}>() {{\n",
+ enumName);
os << " return " << strToSymFnName << ";\n";
os << "}\n";
}