From 673bb7cbbecc0250067d174bfa636617a040c12e Mon Sep 17 00:00:00 2001 From: Mahesh Ravishankar Date: Mon, 29 Jul 2019 10:45:17 -0700 Subject: [PATCH] Enable (de)serialization support for spirv::AccessChainOp Automatic generation of spirv::AccessChainOp (de)serialization needs the (de)serialization emitters to handle argument specified as Variadic<...>. To handle this correctly, this argument can only be the last entry in the arguments list. Add a test to (de)serialize spirv::AccessChainOp PiperOrigin-RevId: 260532598 --- mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td | 2 - .../SPIRV/Serialization/access_chain.mlir | 15 +++++++ mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp | 41 ++++++++++++++----- 3 files changed, 45 insertions(+), 13 deletions(-) create mode 100644 mlir/test/Dialect/SPIRV/Serialization/access_chain.mlir diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td index e2facd3df166..022df0fd9d15 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td @@ -95,8 +95,6 @@ def SPV_AccessChainOp : SPV_Op<"AccessChain", [NoSideEffect]> { let results = (outs SPV_AnyPtr:$component_ptr ); - - let autogenSerialization = 0; } // ----- diff --git a/mlir/test/Dialect/SPIRV/Serialization/access_chain.mlir b/mlir/test/Dialect/SPIRV/Serialization/access_chain.mlir new file mode 100644 index 000000000000..3602b62524db --- /dev/null +++ b/mlir/test/Dialect/SPIRV/Serialization/access_chain.mlir @@ -0,0 +1,15 @@ +// RUN: mlir-translate -serialize-spirv %s | mlir-translate -deserialize-spirv | FileCheck %s + +func @foo() { + spv.module "Logical" "VulkanKHR" { + func @access_chain(%arg0 : !spv.ptr>, Function>, + %arg1 : i32, %arg2 : i32) { + // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr>, Function> + // CHECK-NEXT: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr>, Function> + %1 = spv.AccessChain %arg0[%arg1] : !spv.ptr>, Function> + %2 = spv.AccessChain %arg0[%arg1, %arg2] : !spv.ptr>, Function> + spv.Return + } + } + return +} \ No newline at end of file diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp index 4595f25087a8..80b5499db86e 100644 --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -126,15 +126,13 @@ static void emitSerializationFunction(const Record *record, const Operator &op, auto argument = op.getArg(i); os << " {\n"; if (argument.is()) { - os << " if (" << operandNum - << " < op.getOperation()->getNumOperands()) {\n"; - os << " auto arg = findValueID(op.getOperation()->getOperand(" - << operandNum << "));\n"; - os << " if (!arg) {\n"; + os << " for (auto arg : op.getODSOperands(" << i << ")) {\n"; + os << " auto argID = findValueID(arg);\n"; + os << " if (!argID) {\n"; os << " emitError(op.getLoc(), \"operand " << operandNum << " has a use before def\");\n"; os << " }\n"; - os << " operands.push_back(arg);\n"; + os << " operands.push_back(argID);\n"; os << " }\n"; operandNum++; } else { @@ -243,32 +241,53 @@ static void emitDeserializationFunction(const Record *record, "SPIR-V ops can have only zero or one result"); } - // Process arguments/attributes + // Process operands/attributes os << " SmallVector operands;\n"; os << " SmallVector attributes;\n"; unsigned operandNum = 0; for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) { auto argument = op.getArg(i); - os << " if (wordIndex < words.size()) {\n"; - if (argument.is()) { + if (auto valueArg = argument.dyn_cast()) { + if (valueArg->isVariadic()) { + if (i != e - 1) { + PrintFatalError(record->getLoc(), + "SPIR-V ops can have Variadic<..> argument only if " + "it's the last argument"); + } + os << " for (; wordIndex < words.size(); ++wordIndex)"; + } else { + os << " if (wordIndex < words.size())"; + } + os << " {\n"; os << " auto arg = getValue(words[wordIndex]);\n"; os << " if (!arg) {\n"; os << " return emitError(unknownLoc, \"unknown result : \") << " "words[wordIndex];\n"; os << " }\n"; os << " operands.push_back(arg);\n"; - os << " wordIndex++;\n"; + if (!valueArg->isVariadic()) { + os << " wordIndex++;\n"; + } operandNum++; + os << " }\n"; } else { + os << " if (wordIndex < words.size()) {\n"; auto attr = argument.get(); emitAttributeDeserialization( (attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr), record->getLoc(), "attributes", attr->name, "words", "wordIndex", "words.size()", os); + os << " }\n"; } - os << " }\n"; } + os << " if (wordIndex != words.size()) {\n"; + os << " return emitError(unknownLoc, \"found more operands than expected " + "when deserializing " + << op.getQualCppClassName() + << ", only \") << wordIndex << \" of \" << words.size() << \" " + "processed\";\n"; + os << " }\n"; os << formatv(" auto op = opBuilder.create<{0}>(unknownLoc, resultTypes, " "operands, attributes); (void)op;\n", op.getQualCppClassName()); -- 2.34.1