From a0557ea9d6543a1be8451a59bd697cf01523607f Mon Sep 17 00:00:00 2001 From: Mahesh Ravishankar Date: Mon, 16 Dec 2019 14:21:44 -0800 Subject: [PATCH] Fix (de)serialization generation for SPV_ScopeAttr, SPV_MemorySemanticsAttr. Scope and Memory Semantics attributes need to be serialized as a constant integer value and the needs to be used to specify the value. Fix the auto-generated SPIR-V (de)serialization to handle this. PiperOrigin-RevId: 285849431 --- mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp index 422183e..f1712ef 100644 --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -85,7 +85,13 @@ static void emitAttributeSerialization(const Attribute &attr, StringRef attrName, raw_ostream &os) { os << tabs << formatv("auto attr = {0}.getAttr(\"{1}\");\n", opVar, attrName); os << tabs << "if (attr) {\n"; - if (attr.getAttrDefName() == "I32ArrayAttr") { + if (attr.getAttrDefName() == "SPV_ScopeAttr" || + attr.getAttrDefName() == "SPV_MemorySemanticsAttr") { + os << tabs + << formatv(" {0}.push_back(prepareConstantInt({1}.getLoc(), " + "attr.cast()));\n", + operandList, opVar); + } else if (attr.getAttrDefName() == "I32ArrayAttr") { // Serialize all the elements of the array os << tabs << " for (auto attrElem : attr.cast()) {\n"; os << tabs @@ -284,7 +290,13 @@ static void emitAttributeDeserialization(const Attribute &attr, StringRef attrList, StringRef attrName, StringRef words, StringRef wordIndex, raw_ostream &os) { - if (attr.getAttrDefName() == "I32ArrayAttr") { + if (attr.getAttrDefName() == "SPV_ScopeAttr" || + attr.getAttrDefName() == "SPV_MemorySemanticsAttr") { + os << tabs + << formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", " + "getConstantInt({2}[{3}++])));\n", + attrList, attrName, words, wordIndex); + } else if (attr.getAttrDefName() == "I32ArrayAttr") { os << tabs << "SmallVector attrListElems;\n"; os << tabs << formatv("while ({0} < {1}.size()) {{\n", wordIndex, words); os << tabs -- 2.7.4