Enable (de)serialization support for spirv::AccessChainOp
authorMahesh Ravishankar <ravishankarm@google.com>
Mon, 29 Jul 2019 17:45:17 +0000 (10:45 -0700)
committerjpienaar <jpienaar@google.com>
Tue, 30 Jul 2019 13:17:19 +0000 (06:17 -0700)
Automatic generation of spirv::AccessChainOp (de)serialization needs
the (de)serialization emitters to handle argument specified as
Variadic<...>. To handle this correctly, this argument can only be
the last entry in the arguments list.
Add a test to (de)serialize spirv::AccessChainOp

PiperOrigin-RevId: 260532598

mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td
mlir/test/Dialect/SPIRV/Serialization/access_chain.mlir [new file with mode: 0644]
mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp

index e2facd3..022df0f 100644 (file)
@@ -95,8 +95,6 @@ def SPV_AccessChainOp : SPV_Op<"AccessChain", [NoSideEffect]> {
   let results = (outs
     SPV_AnyPtr:$component_ptr
   );
-
-  let autogenSerialization = 0;
 }
 
 // -----
diff --git a/mlir/test/Dialect/SPIRV/Serialization/access_chain.mlir b/mlir/test/Dialect/SPIRV/Serialization/access_chain.mlir
new file mode 100644 (file)
index 0000000..3602b62
--- /dev/null
@@ -0,0 +1,15 @@
+// RUN: mlir-translate -serialize-spirv %s | mlir-translate -deserialize-spirv | FileCheck %s
+
+func @foo() {
+  spv.module "Logical" "VulkanKHR" {
+    func @access_chain(%arg0 : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>,
+                       %arg1 : i32, %arg2 : i32) {
+      // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr<!spv.array<4 x !spv.array<4 x f32>>, Function>
+      // CHECK-NEXT: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr<!spv.array<4 x !spv.array<4 x f32>>, Function>
+      %1 = spv.AccessChain %arg0[%arg1] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
+      %2 = spv.AccessChain %arg0[%arg1, %arg2] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
+      spv.Return
+    }
+  }
+  return
+}
\ No newline at end of file
index 4595f25..80b5499 100644 (file)
@@ -126,15 +126,13 @@ static void emitSerializationFunction(const Record *record, const Operator &op,
     auto argument = op.getArg(i);
     os << "  {\n";
     if (argument.is<NamedTypeConstraint *>()) {
-      os << "    if (" << operandNum
-         << " < op.getOperation()->getNumOperands()) {\n";
-      os << "      auto arg = findValueID(op.getOperation()->getOperand("
-         << operandNum << "));\n";
-      os << "      if (!arg) {\n";
+      os << "    for (auto arg : op.getODSOperands(" << i << ")) {\n";
+      os << "      auto argID = findValueID(arg);\n";
+      os << "      if (!argID) {\n";
       os << "        emitError(op.getLoc(), \"operand " << operandNum
          << " has a use before def\");\n";
       os << "      }\n";
-      os << "      operands.push_back(arg);\n";
+      os << "      operands.push_back(argID);\n";
       os << "    }\n";
       operandNum++;
     } else {
@@ -243,32 +241,53 @@ static void emitDeserializationFunction(const Record *record,
                     "SPIR-V ops can have only zero or one result");
   }
 
-  // Process arguments/attributes
+  // Process operands/attributes
   os << "  SmallVector<Value *, 4> operands;\n";
   os << "  SmallVector<NamedAttribute, 4> attributes;\n";
   unsigned operandNum = 0;
   for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) {
     auto argument = op.getArg(i);
-    os << "  if (wordIndex < words.size()) {\n";
-    if (argument.is<NamedTypeConstraint *>()) {
+    if (auto valueArg = argument.dyn_cast<NamedTypeConstraint *>()) {
+      if (valueArg->isVariadic()) {
+        if (i != e - 1) {
+          PrintFatalError(record->getLoc(),
+                          "SPIR-V ops can have Variadic<..> argument only if "
+                          "it's the last argument");
+        }
+        os << "  for (; wordIndex < words.size(); ++wordIndex)";
+      } else {
+        os << "  if (wordIndex < words.size())";
+      }
+      os << " {\n";
       os << "    auto arg = getValue(words[wordIndex]);\n";
       os << "    if (!arg) {\n";
       os << "      return emitError(unknownLoc, \"unknown result <id> : \") << "
             "words[wordIndex];\n";
       os << "    }\n";
       os << "    operands.push_back(arg);\n";
-      os << "    wordIndex++;\n";
+      if (!valueArg->isVariadic()) {
+        os << "    wordIndex++;\n";
+      }
       operandNum++;
+      os << "  }\n";
     } else {
+      os << "  if (wordIndex < words.size()) {\n";
       auto attr = argument.get<NamedAttribute *>();
       emitAttributeDeserialization(
           (attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr),
           record->getLoc(), "attributes", attr->name, "words", "wordIndex",
           "words.size()", os);
+      os << "  }\n";
     }
-    os << "  }\n";
   }
 
+  os << "  if (wordIndex != words.size()) {\n";
+  os << "    return emitError(unknownLoc, \"found more operands than expected "
+        "when deserializing "
+     << op.getQualCppClassName()
+     << ", only \") << wordIndex << \" of \" << words.size() << \" "
+        "processed\";\n";
+  os << "  }\n";
   os << formatv("  auto op = opBuilder.create<{0}>(unknownLoc, resultTypes, "
                 "operands, attributes); (void)op;\n",
                 op.getQualCppClassName());