Add support for (de)serialization of SPIR-V Op Decorations
authorMahesh Ravishankar <ravishankarm@google.com>
Tue, 30 Jul 2019 21:14:28 +0000 (14:14 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 30 Jul 2019 21:15:03 +0000 (14:15 -0700)
All non-argument attributes specified for an operation are treated as
decorations on the result value and (de)serialized using OpDecorate
instruction. An error is generated if an attribute is not an argument,
and the name doesn't correspond to a Decoration enum. Name of the
attributes that represent decoerations are to be the snake-case-ified
version of the Decoration name.
Add utility methods to convert to snake-case and camel-case.

PiperOrigin-RevId: 260792638

mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td
mlir/include/mlir/Support/StringExtras.h [new file with mode: 0644]
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp
mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp
mlir/test/Dialect/SPIRV/Serialization/variables.mlir
mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
mlir/unittests/IR/CMakeLists.txt
mlir/unittests/IR/StringExtrasTest.cpp [new file with mode: 0644]
mlir/utils/spirv/gen_spirv_dialect.py

index a12f339..a820c11 100644 (file)
@@ -196,6 +196,97 @@ def SPV_AddressingModelAttr :
   let cppNamespace = "::mlir::spirv";
 }
 
+def SPV_D_RelaxedPrecision            : I32EnumAttrCase<"RelaxedPrecision", 0>;
+def SPV_D_SpecId                      : I32EnumAttrCase<"SpecId", 1>;
+def SPV_D_Block                       : I32EnumAttrCase<"Block", 2>;
+def SPV_D_BufferBlock                 : I32EnumAttrCase<"BufferBlock", 3>;
+def SPV_D_RowMajor                    : I32EnumAttrCase<"RowMajor", 4>;
+def SPV_D_ColMajor                    : I32EnumAttrCase<"ColMajor", 5>;
+def SPV_D_ArrayStride                 : I32EnumAttrCase<"ArrayStride", 6>;
+def SPV_D_MatrixStride                : I32EnumAttrCase<"MatrixStride", 7>;
+def SPV_D_GLSLShared                  : I32EnumAttrCase<"GLSLShared", 8>;
+def SPV_D_GLSLPacked                  : I32EnumAttrCase<"GLSLPacked", 9>;
+def SPV_D_CPacked                     : I32EnumAttrCase<"CPacked", 10>;
+def SPV_D_BuiltIn                     : I32EnumAttrCase<"BuiltIn", 11>;
+def SPV_D_NoPerspective               : I32EnumAttrCase<"NoPerspective", 13>;
+def SPV_D_Flat                        : I32EnumAttrCase<"Flat", 14>;
+def SPV_D_Patch                       : I32EnumAttrCase<"Patch", 15>;
+def SPV_D_Centroid                    : I32EnumAttrCase<"Centroid", 16>;
+def SPV_D_Sample                      : I32EnumAttrCase<"Sample", 17>;
+def SPV_D_Invariant                   : I32EnumAttrCase<"Invariant", 18>;
+def SPV_D_Restrict                    : I32EnumAttrCase<"Restrict", 19>;
+def SPV_D_Aliased                     : I32EnumAttrCase<"Aliased", 20>;
+def SPV_D_Volatile                    : I32EnumAttrCase<"Volatile", 21>;
+def SPV_D_Constant                    : I32EnumAttrCase<"Constant", 22>;
+def SPV_D_Coherent                    : I32EnumAttrCase<"Coherent", 23>;
+def SPV_D_NonWritable                 : I32EnumAttrCase<"NonWritable", 24>;
+def SPV_D_NonReadable                 : I32EnumAttrCase<"NonReadable", 25>;
+def SPV_D_Uniform                     : I32EnumAttrCase<"Uniform", 26>;
+def SPV_D_UniformId                   : I32EnumAttrCase<"UniformId", 27>;
+def SPV_D_SaturatedConversion         : I32EnumAttrCase<"SaturatedConversion", 28>;
+def SPV_D_Stream                      : I32EnumAttrCase<"Stream", 29>;
+def SPV_D_Location                    : I32EnumAttrCase<"Location", 30>;
+def SPV_D_Component                   : I32EnumAttrCase<"Component", 31>;
+def SPV_D_Index                       : I32EnumAttrCase<"Index", 32>;
+def SPV_D_Binding                     : I32EnumAttrCase<"Binding", 33>;
+def SPV_D_DescriptorSet               : I32EnumAttrCase<"DescriptorSet", 34>;
+def SPV_D_Offset                      : I32EnumAttrCase<"Offset", 35>;
+def SPV_D_XfbBuffer                   : I32EnumAttrCase<"XfbBuffer", 36>;
+def SPV_D_XfbStride                   : I32EnumAttrCase<"XfbStride", 37>;
+def SPV_D_FuncParamAttr               : I32EnumAttrCase<"FuncParamAttr", 38>;
+def SPV_D_FPRoundingMode              : I32EnumAttrCase<"FPRoundingMode", 39>;
+def SPV_D_FPFastMathMode              : I32EnumAttrCase<"FPFastMathMode", 40>;
+def SPV_D_LinkageAttributes           : I32EnumAttrCase<"LinkageAttributes", 41>;
+def SPV_D_NoContraction               : I32EnumAttrCase<"NoContraction", 42>;
+def SPV_D_InputAttachmentIndex        : I32EnumAttrCase<"InputAttachmentIndex", 43>;
+def SPV_D_Alignment                   : I32EnumAttrCase<"Alignment", 44>;
+def SPV_D_MaxByteOffset               : I32EnumAttrCase<"MaxByteOffset", 45>;
+def SPV_D_AlignmentId                 : I32EnumAttrCase<"AlignmentId", 46>;
+def SPV_D_MaxByteOffsetId             : I32EnumAttrCase<"MaxByteOffsetId", 47>;
+def SPV_D_NoSignedWrap                : I32EnumAttrCase<"NoSignedWrap", 4469>;
+def SPV_D_NoUnsignedWrap              : I32EnumAttrCase<"NoUnsignedWrap", 4470>;
+def SPV_D_ExplicitInterpAMD           : I32EnumAttrCase<"ExplicitInterpAMD", 4999>;
+def SPV_D_OverrideCoverageNV          : I32EnumAttrCase<"OverrideCoverageNV", 5248>;
+def SPV_D_PassthroughNV               : I32EnumAttrCase<"PassthroughNV", 5250>;
+def SPV_D_ViewportRelativeNV          : I32EnumAttrCase<"ViewportRelativeNV", 5252>;
+def SPV_D_SecondaryViewportRelativeNV : I32EnumAttrCase<"SecondaryViewportRelativeNV", 5256>;
+def SPV_D_PerPrimitiveNV              : I32EnumAttrCase<"PerPrimitiveNV", 5271>;
+def SPV_D_PerViewNV                   : I32EnumAttrCase<"PerViewNV", 5272>;
+def SPV_D_PerTaskNV                   : I32EnumAttrCase<"PerTaskNV", 5273>;
+def SPV_D_PerVertexNV                 : I32EnumAttrCase<"PerVertexNV", 5285>;
+def SPV_D_NonUniformEXT               : I32EnumAttrCase<"NonUniformEXT", 5300>;
+def SPV_D_RestrictPointerEXT          : I32EnumAttrCase<"RestrictPointerEXT", 5355>;
+def SPV_D_AliasedPointerEXT           : I32EnumAttrCase<"AliasedPointerEXT", 5356>;
+def SPV_D_CounterBuffer               : I32EnumAttrCase<"CounterBuffer", 5634>;
+def SPV_D_UserSemantic                : I32EnumAttrCase<"UserSemantic", 5635>;
+def SPV_D_UserTypeGOOGLE              : I32EnumAttrCase<"UserTypeGOOGLE", 5636>;
+
+def SPV_DecorationAttr :
+    I32EnumAttr<"Decoration", "valid SPIR-V Decoration", [
+      SPV_D_RelaxedPrecision, SPV_D_SpecId, SPV_D_Block, SPV_D_BufferBlock,
+      SPV_D_RowMajor, SPV_D_ColMajor, SPV_D_ArrayStride, SPV_D_MatrixStride,
+      SPV_D_GLSLShared, SPV_D_GLSLPacked, SPV_D_CPacked, SPV_D_BuiltIn,
+      SPV_D_NoPerspective, SPV_D_Flat, SPV_D_Patch, SPV_D_Centroid, SPV_D_Sample,
+      SPV_D_Invariant, SPV_D_Restrict, SPV_D_Aliased, SPV_D_Volatile, SPV_D_Constant,
+      SPV_D_Coherent, SPV_D_NonWritable, SPV_D_NonReadable, SPV_D_Uniform,
+      SPV_D_UniformId, SPV_D_SaturatedConversion, SPV_D_Stream, SPV_D_Location,
+      SPV_D_Component, SPV_D_Index, SPV_D_Binding, SPV_D_DescriptorSet, SPV_D_Offset,
+      SPV_D_XfbBuffer, SPV_D_XfbStride, SPV_D_FuncParamAttr, SPV_D_FPRoundingMode,
+      SPV_D_FPFastMathMode, SPV_D_LinkageAttributes, SPV_D_NoContraction,
+      SPV_D_InputAttachmentIndex, SPV_D_Alignment, SPV_D_MaxByteOffset,
+      SPV_D_AlignmentId, SPV_D_MaxByteOffsetId, SPV_D_NoSignedWrap,
+      SPV_D_NoUnsignedWrap, SPV_D_ExplicitInterpAMD, SPV_D_OverrideCoverageNV,
+      SPV_D_PassthroughNV, SPV_D_ViewportRelativeNV,
+      SPV_D_SecondaryViewportRelativeNV, SPV_D_PerPrimitiveNV, SPV_D_PerViewNV,
+      SPV_D_PerTaskNV, SPV_D_PerVertexNV, SPV_D_NonUniformEXT,
+      SPV_D_RestrictPointerEXT, SPV_D_AliasedPointerEXT, SPV_D_CounterBuffer,
+      SPV_D_UserSemantic, SPV_D_UserTypeGOOGLE
+    ]> {
+  let returnType = "::mlir::spirv::Decoration";
+  let convertFromStorage = "static_cast<::mlir::spirv::Decoration>($_self.getInt())";
+  let cppNamespace = "::mlir::spirv";
+}
+
 def SPV_D_1D          : I32EnumAttrCase<"1D", 0>;
 def SPV_D_2D          : I32EnumAttrCase<"2D", 1>;
 def SPV_D_3D          : I32EnumAttrCase<"3D", 2>;
diff --git a/mlir/include/mlir/Support/StringExtras.h b/mlir/include/mlir/Support/StringExtras.h
new file mode 100644 (file)
index 0000000..a5ec732
--- /dev/null
@@ -0,0 +1,81 @@
+//===- StringExtras.h - String utilities used by MLIR -----------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file contains string utility functions used within MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_SUPPORT_STRINGEXTRAS_H
+#define MLIR_SUPPORT_STRINGEXTRAS_H
+
+#include "llvm/ADT/StringExtras.h"
+
+namespace mlir {
+/// Converts a string to snake-case from camel-case by replacing all uppercase
+/// letters with '_' followed by the letter in lowercase, except if the
+/// uppercase letter is the first character of the string.
+inline std::string convertToSnakeCase(llvm::StringRef input) {
+  std::string snakeCase;
+  snakeCase.reserve(input.size());
+  for (auto c : input) {
+    if (std::isupper(c)) {
+      if (!snakeCase.empty() && snakeCase.back() != '_') {
+        snakeCase.push_back('_');
+      }
+      snakeCase.push_back(llvm::toLower(c));
+    } else {
+      snakeCase.push_back(c);
+    }
+  }
+  return snakeCase;
+}
+
+/// Converts a string from camel-case to snake_case by replacing all occurences
+/// of '_' followed by a lowercase letter with the letter in
+/// uppercase. Optionally allow capitalization of the first letter (if it is a
+/// lowercase letter)
+inline std::string convertToCamelCase(llvm::StringRef input,
+                                      bool capitalizeFirst = false) {
+  if (input.empty()) {
+    return "";
+  }
+  std::string output;
+  output.reserve(input.size());
+  size_t pos = 0;
+  if (capitalizeFirst && std::islower(input[pos])) {
+    output.push_back(llvm::toUpper(input[pos]));
+    pos++;
+  }
+  while (pos < input.size()) {
+    auto cur = input[pos];
+    if (cur == '_') {
+      if (pos && (pos + 1 < input.size())) {
+        if (std::islower(input[pos + 1])) {
+          output.push_back(llvm::toUpper(input[pos + 1]));
+          pos += 2;
+          continue;
+        }
+      }
+    }
+    output.push_back(cur);
+    pos++;
+  }
+  return output;
+}
+} // namespace mlir
+
+#endif // MLIR_SUPPORT_STRINGEXTRAS_H
index ae5752a..05a1746 100644 (file)
 #include "mlir/IR/Function.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/StandardTypes.h"
+#include "mlir/Support/StringExtras.h"
 
 using namespace mlir;
 
 // TODO(antiagainst): generate these strings using ODS.
 static constexpr const char kAlignmentAttrName[] = "alignment";
-static constexpr const char kBindingAttrName[] = "binding";
-static constexpr const char kDescriptorSetAttrName[] = "descriptor_set";
 static constexpr const char kIndicesAttrName[] = "indices";
 static constexpr const char kValueAttrName[] = "value";
 static constexpr const char kValuesAttrName[] = "values";
@@ -67,8 +66,7 @@ static LogicalResult extractValueFromConstOp(Operation *op,
 }
 
 template <typename EnumClass>
-static ParseResult parseEnumAttribute(EnumClass &value, OpAsmParser *parser,
-                                      OperationState *state) {
+static ParseResult parseEnumAttribute(EnumClass &value, OpAsmParser *parser) {
   Attribute attrVal;
   SmallVector<NamedAttribute, 1> attr;
   auto loc = parser->getCurrentLocation();
@@ -89,6 +87,15 @@ static ParseResult parseEnumAttribute(EnumClass &value, OpAsmParser *parser,
            << " attribute specification: " << attrVal;
   }
   value = attrOptional.getValue();
+  return success();
+}
+
+template <typename EnumClass>
+static ParseResult parseEnumAttribute(EnumClass &value, OpAsmParser *parser,
+                                      OperationState *state) {
+  if (parseEnumAttribute(value, parser)) {
+    return failure();
+  }
   state->addAttribute(
       spirv::attributeName<EnumClass>(),
       parser->getBuilder().getI32IntegerAttr(bitwiseCast<int32_t>(value)));
@@ -601,7 +608,7 @@ static ParseResult parseLoadOp(OpAsmParser *parser, OperationState *state) {
   spirv::StorageClass storageClass;
   OpAsmParser::OperandType ptrInfo;
   Type elementType;
-  if (parseEnumAttribute(storageClass, parser, state) ||
+  if (parseEnumAttribute(storageClass, parser) ||
       parser->parseOperand(ptrInfo) ||
       parseMemoryAccessAttributes(parser, state) ||
       parser->parseOptionalAttributeDict(state->attributes) ||
@@ -813,7 +820,7 @@ static ParseResult parseStoreOp(OpAsmParser *parser, OperationState *state) {
   SmallVector<OpAsmParser::OperandType, 2> operandInfo;
   auto loc = parser->getCurrentLocation();
   Type elementType;
-  if (parseEnumAttribute(storageClass, parser, state) ||
+  if (parseEnumAttribute(storageClass, parser) ||
       parser->parseOperandList(operandInfo, 2) ||
       parseMemoryAccessAttributes(parser, state) || parser->parseColon() ||
       parser->parseType(elementType)) {
@@ -873,13 +880,17 @@ static ParseResult parseVariableOp(OpAsmParser *parser, OperationState *state) {
 
   // Parse optional descriptor binding
   Attribute set, binding;
+  auto descriptorSetName =
+      convertToSnakeCase(stringifyDecoration(spirv::Decoration::DescriptorSet));
+  auto bindingName =
+      convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding));
   if (succeeded(parser->parseOptionalKeyword("bind"))) {
     Type i32Type = parser->getBuilder().getIntegerType(32);
     if (parser->parseLParen() ||
-        parser->parseAttribute(set, i32Type, kDescriptorSetAttrName,
+        parser->parseAttribute(set, i32Type, descriptorSetName,
                                state->attributes) ||
         parser->parseComma() ||
-        parser->parseAttribute(binding, i32Type, kBindingAttrName,
+        parser->parseAttribute(binding, i32Type, bindingName,
                                state->attributes) ||
         parser->parseRParen())
       return failure();
@@ -931,12 +942,17 @@ static void print(spirv::VariableOp varOp, OpAsmPrinter *printer) {
   }
 
   // Print optional descriptor binding
-  auto set = varOp.getAttrOfType<IntegerAttr>(kDescriptorSetAttrName);
-  auto binding = varOp.getAttrOfType<IntegerAttr>(kBindingAttrName);
-  if (set && binding) {
-    elidedAttrs.push_back(kDescriptorSetAttrName);
-    elidedAttrs.push_back(kBindingAttrName);
-    *printer << " bind(" << set.getInt() << ", " << binding.getInt() << ")";
+  auto descriptorSetName =
+      convertToSnakeCase(stringifyDecoration(spirv::Decoration::DescriptorSet));
+  auto bindingName =
+      convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding));
+  auto descriptorSet = varOp.getAttrOfType<IntegerAttr>(descriptorSetName);
+  auto binding = varOp.getAttrOfType<IntegerAttr>(bindingName);
+  if (descriptorSet && binding) {
+    elidedAttrs.push_back(descriptorSetName);
+    elidedAttrs.push_back(bindingName);
+    *printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
+             << ")";
   }
 
   printer->printOptionalAttrDict(op->getAttrs(), elidedAttrs);
index 4c35f6a..2ca8f45 100644 (file)
@@ -27,6 +27,7 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Location.h"
 #include "mlir/Support/LogicalResult.h"
+#include "mlir/Support/StringExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/bit.h"
 
@@ -80,6 +81,9 @@ private:
   /// Process SPIR-V OpName with `operands`
   LogicalResult processName(ArrayRef<uint32_t> operands);
 
+  /// Method to process an OpDecorate instruction.
+  LogicalResult processDecoration(ArrayRef<uint32_t> words);
+
   /// Processes the SPIR-V function at the current `offset` into `binary`.
   /// The operands to the OpFunction instruction is passed in as ``operands`.
   /// This method processes each instruction inside the function and dispatches
@@ -196,6 +200,9 @@ private:
   // Result <id> to name mapping.
   DenseMap<uint32_t, StringRef> nameMap;
 
+  // Result <id> to decorations mapping.
+  DenseMap<uint32_t, NamedAttributeList> decorations;
+
   // List of instructions that are processed in a defered fashion (after an
   // initial processing of the entire binary). Some operations like
   // OpEntryPoint, and OpExecutionMode use forward references to function
@@ -285,6 +292,37 @@ LogicalResult Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
   return success();
 }
 
+LogicalResult Deserializer::processDecoration(ArrayRef<uint32_t> words) {
+  // TODO : This function should also be auto-generated. For now, since only a
+  // few decorations are processed/handled in a meaningful manner, going with a
+  // manual implementation.
+  if (words.size() < 2) {
+    return emitError(
+        unknownLoc, "OpDecorate must have at least result <id> and Decoration");
+  }
+  auto decorationName =
+      stringifyDecoration(static_cast<spirv::Decoration>(words[1]));
+  if (decorationName.empty()) {
+    return emitError(unknownLoc, "invalid Decoration code : ") << words[1];
+  }
+  auto attrName = convertToSnakeCase(decorationName);
+  switch (static_cast<spirv::Decoration>(words[1])) {
+  case spirv::Decoration::DescriptorSet:
+  case spirv::Decoration::Binding:
+    if (words.size() != 3) {
+      return emitError(unknownLoc, "OpDecorate with ")
+             << decorationName << " needs a single integer literal";
+    }
+    decorations[words[0]].set(
+        opBuilder.getIdentifier(attrName),
+        opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
+    break;
+  default:
+    return emitError(unknownLoc, "unhandled Decoration : '") << decorationName;
+  }
+  return success();
+}
+
 LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) {
   // Get the result type
   if (operands.size() != 4) {
@@ -830,6 +868,8 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
     return processConstantBool(false, operands);
   case spirv::Opcode::OpConstantNull:
     return processConstantNull(operands);
+  case spirv::Opcode::OpDecorate:
+    return processDecoration(operands);
   case spirv::Opcode::OpFunction:
     return processFunction(operands);
   default:
@@ -839,6 +879,7 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
 }
 
 namespace {
+
 template <>
 LogicalResult
 Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) {
index 7030bd9..35c4088 100644 (file)
@@ -27,6 +27,7 @@
 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/Support/LogicalResult.h"
+#include "mlir/Support/StringExtras.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/bit.h"
 #include "llvm/Support/raw_ostream.h"
@@ -127,6 +128,10 @@ private:
   /// Processes a SPIR-V function op.
   LogicalResult processFuncOp(FuncOp op);
 
+  /// Process attributes that translate to decorations on the result <id>
+  LogicalResult processDecoration(Location loc, uint32_t resultID,
+                                  NamedAttribute attr);
+
   //===--------------------------------------------------------------------===//
   // Types
   //===--------------------------------------------------------------------===//
@@ -319,6 +324,34 @@ LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
   return failure();
 }
 
+LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
+                                            NamedAttribute attr) {
+  auto attrName = attr.first.strref();
+  auto decorationName = mlir::convertToCamelCase(attrName, true);
+  auto decoration = spirv::symbolizeDecoration(decorationName);
+  if (!decoration) {
+    return emitError(
+               loc, "non-argument attributes expected to have snake-case-ified "
+                    "decoration name, unhandled attribute with name : ")
+           << attrName;
+  }
+  SmallVector<uint32_t, 1> args;
+  args.push_back(resultID);
+  args.push_back(static_cast<uint32_t>(decoration.getValue()));
+  switch (decoration.getValue()) {
+  case spirv::Decoration::DescriptorSet:
+  case spirv::Decoration::Binding:
+    if (auto intAttr = attr.second.dyn_cast<IntegerAttr>()) {
+      args.push_back(intAttr.getValue().getZExtValue());
+      break;
+    }
+    return emitError(loc, "expected integer attribute for ") << attrName;
+  default:
+    return emitError(loc, "unhandled decoration ") << decorationName;
+  }
+  return encodeInstructionInto(decorations, spirv::Opcode::OpDecorate, args);
+}
+
 LogicalResult Serializer::processFuncOp(FuncOp op) {
   uint32_t fnTypeID = 0;
   // Generate type of the function.
index dbb1f7f..e0620f1 100644 (file)
@@ -1,11 +1,11 @@
 // RUN: mlir-translate -serialize-spirv %s | mlir-translate -deserialize-spirv | FileCheck %s
 
-// CHECK:           {{%.*}} = spv.Variable : !spv.ptr<f32, Input>
-// CHECK-NEXT:      {{%.*}} = spv.Variable : !spv.ptr<f32, Output>
+// CHECK:           {{%.*}} = spv.Variable bind(1, 0) : !spv.ptr<f32, Input>
+// CHECK-NEXT:      {{%.*}} = spv.Variable bind(0, 1) : !spv.ptr<f32, Output>
 func @spirv_variables() -> () {
   spv.module "Logical" "VulkanKHR" {
-    %2 = spv.Variable : !spv.ptr<f32, Input>
-    %3 = spv.Variable : !spv.ptr<f32, Output>
+    %2 = spv.Variable bind(1, 0) : !spv.ptr<f32, Input>
+    %3 = spv.Variable bind(0, 1): !spv.ptr<f32, Output>
   }
   return
 }
\ No newline at end of file
index 0c17720..75da5e7 100644 (file)
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Support/StringExtras.h"
 #include "mlir/TableGen/Attribute.h"
 #include "mlir/TableGen/GenInfo.h"
 #include "mlir/TableGen/Operator.h"
+#include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringRef.h"
@@ -39,6 +41,8 @@ using llvm::raw_string_ostream;
 using llvm::Record;
 using llvm::RecordKeeper;
 using llvm::SMLoc;
+using llvm::StringRef;
+using llvm::Twine;
 using mlir::tblgen::Attribute;
 using mlir::tblgen::EnumAttr;
 using mlir::tblgen::NamedAttribute;
@@ -90,7 +94,8 @@ static void emitAttributeSerialization(const Attribute &attr,
   os << "    }\n";
 }
 
-static void emitSerializationFunction(const Record *record, const Operator &op,
+static void emitSerializationFunction(const Record *attrClass,
+                                      const Record *record, const Operator &op,
                                       raw_ostream &os) {
   // If the record has 'autogenSerialization' set to 0, nothing to do
   if (!record->getValueAsBit("autogenSerialization")) {
@@ -101,21 +106,20 @@ static void emitSerializationFunction(const Record *record, const Operator &op,
                 op.getQualCppClassName())
      << " {\n";
   os << "  SmallVector<uint32_t, 4> operands;\n";
+  os << "  SmallVector<StringRef, 2> elidedAttrs;\n";
 
   // Serialize result information
   if (op.getNumResults() == 1) {
-    os << "  {\n";
-    os << "    uint32_t typeID = 0;\n";
-    os << "    if (failed(processType(op.getLoc(), "
-          "op.getResult()->getType(), typeID))) {\n";
-    os << "      return failure();\n";
-    os << "    }\n";
-    os << "    operands.push_back(typeID);\n";
-    /// Create an SSA result <id> for the op
-    os << "    auto resultID = getNextID();\n";
-    os << "    valueIDMap[op.getResult()] = resultID;\n";
-    os << "    operands.push_back(resultID);\n";
+    os << "  uint32_t resultTypeID = 0;\n";
+    os << "  if (failed(processType(op.getLoc(), op.getType(), resultTypeID))) "
+          "{\n";
+    os << "    return failure();\n";
     os << "  }\n";
+    os << "  operands.push_back(resultTypeID);\n";
+    // Create an SSA result <id> for the op
+    os << "  auto resultID = getNextID();\n";
+    os << "  valueIDMap[op.getResult()] = resultID;\n";
+    os << "  operands.push_back(resultID);\n";
   } else if (op.getNumResults() != 0) {
     PrintFatalError(record->getLoc(), "SPIR-V ops can only zero or one result");
   }
@@ -140,6 +144,7 @@ static void emitSerializationFunction(const Record *record, const Operator &op,
       emitAttributeSerialization(
           (attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr),
           record->getLoc(), "op", "operands", attr->name, os);
+      os << "    elidedAttrs.push_back(\"" << attr->name << "\");\n";
     }
     os << "  }\n";
   }
@@ -147,6 +152,20 @@ static void emitSerializationFunction(const Record *record, const Operator &op,
   os << formatv("  encodeInstructionInto("
                 "functions, spirv::getOpcode<{0}>(), operands);\n",
                 op.getQualCppClassName());
+
+  if (op.getNumResults() == 1) {
+    // All non-argument attributes translated into OpDecorate instruction
+    os << "  for (auto attr : op.getAttrs()) {\n";
+    os << "    if (llvm::any_of(elidedAttrs, [&](StringRef elided) { return "
+          "attr.first.is(elided); })) {\n";
+    os << "      continue;\n";
+    os << "    }\n";
+    os << "    if (failed(processDecoration(op.getLoc(), resultID, attr))) {\n";
+    os << "      return failure();";
+    os << "    }\n";
+    os << "  }\n";
+  }
+
   os << "  return success();\n";
   os << "}\n\n";
 }
@@ -196,7 +215,8 @@ static void emitAttributeDeserialization(
   }
 }
 
-static void emitDeserializationFunction(const Record *record,
+static void emitDeserializationFunction(const Record *attrClass,
+                                        const Record *record,
                                         const Operator &op, raw_ostream &os) {
   // If the record has 'autogenSerialization' set to 0, nothing to do
   if (!record->getValueAsBit("autogenSerialization")) {
@@ -292,8 +312,19 @@ static void emitDeserializationFunction(const Record *record,
                 "operands, attributes); (void)op;\n",
                 op.getQualCppClassName());
   if (hasResult) {
-    os << "  valueMap[valueID] = op.getResult();\n";
+    os << "  valueMap[valueID] = op.getResult();\n\n";
   }
+
+  // Import decorations parsed
+  if (op.getNumResults() == 1) {
+    os << "  if (decorations.count(valueID)) {\n";
+    os << "    auto decorationAttrs = decorations[valueID];\n";
+    os << "    for (auto attr : decorationAttrs.getAttrs()) {\n";
+    os << "      op.setAttr(attr.first, attr.second);\n";
+    os << "    }\n";
+    os << "  }\n";
+  }
+
   os << "  return success();\n";
   os << "}\n\n";
 }
@@ -330,6 +361,7 @@ static bool emitSerializationFns(const RecordKeeper &recordKeeper,
       utilsString;
   raw_string_ostream dSerFn(dSerFnString), dDesFn(dDesFnString),
       serFn(serFnString), deserFn(deserFnString), utils(utilsString);
+  auto attrClass = recordKeeper.getClass("Attr");
 
   declareOpcodeFn(utils);
   initDispatchSerializationFn(dSerFn);
@@ -341,9 +373,9 @@ static bool emitSerializationFns(const RecordKeeper &recordKeeper,
     }
     Operator op(def);
     emitGetOpcodeFunction(def, op, utils);
-    emitSerializationFunction(def, op, serFn);
+    emitSerializationFunction(attrClass, def, op, serFn);
     emitSerializationDispatch(op, dSerFn);
-    emitDeserializationFunction(def, op, deserFn);
+    emitDeserializationFunction(attrClass, def, op, deserFn);
     emitDeserializationDispatch(op, def, dDesFn);
   }
   finalizeDispatchSerializationFn(dSerFn);
@@ -378,21 +410,6 @@ static void emitEnumGetSymbolizeFnDecl(raw_ostream &os) {
         "SymbolizeFnTy<EnumClass> symbolizeEnum();\n";
 }
 
-std::string convertSnakeCase(llvm::StringRef inputString) {
-  std::string snakeCase;
-  for (auto c : inputString) {
-    if (c >= 'A' && c <= 'Z') {
-      if (!snakeCase.empty()) {
-        snakeCase.push_back('_');
-      }
-      snakeCase.push_back((c - 'A') + 'a');
-    } else {
-      snakeCase.push_back(c);
-    }
-  }
-  return snakeCase;
-}
-
 static void emitEnumGetAttrNameFnDefn(const EnumAttr &enumAttr,
                                       raw_ostream &os) {
   auto enumName = enumAttr.getEnumClassName();
@@ -400,7 +417,7 @@ static void emitEnumGetAttrNameFnDefn(const EnumAttr &enumAttr,
      << " {\n";
   os << "  "
      << formatv("static constexpr const char attrName[] = \"{0}\";\n",
-                convertSnakeCase(enumName));
+                mlir::convertToSnakeCase(enumName));
   os << "  return attrName;\n";
   os << "}\n";
 }
index 0b80f11..1539358 100644 (file)
@@ -2,6 +2,7 @@ add_mlir_unittest(MLIRIRTests
   AttributeTest.cpp
   DialectTest.cpp
   OperationSupportTest.cpp
+  StringExtrasTest.cpp
 )
 target_link_libraries(MLIRIRTests
   PRIVATE
diff --git a/mlir/unittests/IR/StringExtrasTest.cpp b/mlir/unittests/IR/StringExtrasTest.cpp
new file mode 100644 (file)
index 0000000..6d18633
--- /dev/null
@@ -0,0 +1,74 @@
+//===- StringExtras.cpp - Tests for utility methods in StringExtras.h -----===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Support/StringExtras.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+static void testConvertToSnakeCase(llvm::StringRef input,
+                                   llvm::StringRef expected) {
+  EXPECT_EQ(convertToSnakeCase(input), expected.str());
+}
+
+TEST(StringExtras, ConvertToSnakeCase) {
+  testConvertToSnakeCase("OpName", "op_name");
+  testConvertToSnakeCase("opName", "op_name");
+  testConvertToSnakeCase("_OpName", "_op_name");
+  testConvertToSnakeCase("Op_Name", "op_name");
+  testConvertToSnakeCase("", "");
+  testConvertToSnakeCase("A", "a");
+  testConvertToSnakeCase("_", "_");
+  testConvertToSnakeCase("a", "a");
+  testConvertToSnakeCase("op_name", "op_name");
+  testConvertToSnakeCase("_op_name", "_op_name");
+  testConvertToSnakeCase("__op_name", "__op_name");
+  testConvertToSnakeCase("op__name", "op__name");
+}
+
+template <bool capitalizeFirst>
+static void testConvertToCamelCase(llvm::StringRef input,
+                                   llvm::StringRef expected) {
+  EXPECT_EQ(convertToCamelCase(input, capitalizeFirst), expected.str());
+}
+
+TEST(StringExtras, ConvertToCamelCase) {
+  testConvertToCamelCase<false>("op_name", "opName");
+  testConvertToCamelCase<false>("_op_name", "_opName");
+  testConvertToCamelCase<false>("__op_name", "_OpName");
+  testConvertToCamelCase<false>("op__name", "op_Name");
+  testConvertToCamelCase<false>("", "");
+  testConvertToCamelCase<false>("A", "A");
+  testConvertToCamelCase<false>("_", "_");
+  testConvertToCamelCase<false>("a", "a");
+  testConvertToCamelCase<false>("OpName", "OpName");
+  testConvertToCamelCase<false>("opName", "opName");
+  testConvertToCamelCase<false>("_OpName", "_OpName");
+  testConvertToCamelCase<false>("Op_Name", "Op_Name");
+  testConvertToCamelCase<true>("op_name", "OpName");
+  testConvertToCamelCase<true>("_op_name", "_opName");
+  testConvertToCamelCase<true>("__op_name", "_OpName");
+  testConvertToCamelCase<true>("op__name", "Op_Name");
+  testConvertToCamelCase<true>("", "");
+  testConvertToCamelCase<true>("A", "A");
+  testConvertToCamelCase<true>("_", "_");
+  testConvertToCamelCase<true>("a", "A");
+  testConvertToCamelCase<true>("OpName", "OpName");
+  testConvertToCamelCase<true>("_OpName", "_OpName");
+  testConvertToCamelCase<true>("Op_Name", "Op_Name");
+  testConvertToCamelCase<true>("opName", "OpName");
+}
index de17756..ac00179 100755 (executable)
@@ -109,6 +109,28 @@ def split_list_into_sublists(items, offset):
   return chuncks
 
 
+def uniquify(lst, equality_fn):
+  """Returns a list after pruning duplicate elements.
+
+  Arguments:
+   - lst: List whose elements are to be uniqued.
+   - equality_fn: Function used to compare equality between elements of the
+     list.
+
+  Returns:
+   - A list with all duplicated removed. The order of elements is same as the
+     original list, with only the first occurence of duplicates retained.
+  """
+  keys = set()
+  unique_lst = []
+  for elem in lst:
+    key = equality_fn(elem)
+    if equality_fn(key) not in keys:
+      unique_lst.append(elem)
+      keys.add(key)
+  return unique_lst
+
+
 def gen_operand_kind_enum_attr(operand_kind):
   """Generates the TableGen I32EnumAttr definition for the given operand kind.
 
@@ -123,6 +145,7 @@ def gen_operand_kind_enum_attr(operand_kind):
   kind_acronym = ''.join([c for c in kind_name if c >= 'A' and c <= 'Z'])
   kind_cases = [(case['enumerant'], case['value'])
                 for case in operand_kind['enumerants']]
+  kind_cases = uniquify(kind_cases, lambda x: x[1])
   max_len = max([len(symbol) for (symbol, _) in kind_cases])
 
   # Generate the definition for each enum case