let cppNamespace = "::mlir::spirv";
}
+def SPV_D_RelaxedPrecision : I32EnumAttrCase<"RelaxedPrecision", 0>;
+def SPV_D_SpecId : I32EnumAttrCase<"SpecId", 1>;
+def SPV_D_Block : I32EnumAttrCase<"Block", 2>;
+def SPV_D_BufferBlock : I32EnumAttrCase<"BufferBlock", 3>;
+def SPV_D_RowMajor : I32EnumAttrCase<"RowMajor", 4>;
+def SPV_D_ColMajor : I32EnumAttrCase<"ColMajor", 5>;
+def SPV_D_ArrayStride : I32EnumAttrCase<"ArrayStride", 6>;
+def SPV_D_MatrixStride : I32EnumAttrCase<"MatrixStride", 7>;
+def SPV_D_GLSLShared : I32EnumAttrCase<"GLSLShared", 8>;
+def SPV_D_GLSLPacked : I32EnumAttrCase<"GLSLPacked", 9>;
+def SPV_D_CPacked : I32EnumAttrCase<"CPacked", 10>;
+def SPV_D_BuiltIn : I32EnumAttrCase<"BuiltIn", 11>;
+def SPV_D_NoPerspective : I32EnumAttrCase<"NoPerspective", 13>;
+def SPV_D_Flat : I32EnumAttrCase<"Flat", 14>;
+def SPV_D_Patch : I32EnumAttrCase<"Patch", 15>;
+def SPV_D_Centroid : I32EnumAttrCase<"Centroid", 16>;
+def SPV_D_Sample : I32EnumAttrCase<"Sample", 17>;
+def SPV_D_Invariant : I32EnumAttrCase<"Invariant", 18>;
+def SPV_D_Restrict : I32EnumAttrCase<"Restrict", 19>;
+def SPV_D_Aliased : I32EnumAttrCase<"Aliased", 20>;
+def SPV_D_Volatile : I32EnumAttrCase<"Volatile", 21>;
+def SPV_D_Constant : I32EnumAttrCase<"Constant", 22>;
+def SPV_D_Coherent : I32EnumAttrCase<"Coherent", 23>;
+def SPV_D_NonWritable : I32EnumAttrCase<"NonWritable", 24>;
+def SPV_D_NonReadable : I32EnumAttrCase<"NonReadable", 25>;
+def SPV_D_Uniform : I32EnumAttrCase<"Uniform", 26>;
+def SPV_D_UniformId : I32EnumAttrCase<"UniformId", 27>;
+def SPV_D_SaturatedConversion : I32EnumAttrCase<"SaturatedConversion", 28>;
+def SPV_D_Stream : I32EnumAttrCase<"Stream", 29>;
+def SPV_D_Location : I32EnumAttrCase<"Location", 30>;
+def SPV_D_Component : I32EnumAttrCase<"Component", 31>;
+def SPV_D_Index : I32EnumAttrCase<"Index", 32>;
+def SPV_D_Binding : I32EnumAttrCase<"Binding", 33>;
+def SPV_D_DescriptorSet : I32EnumAttrCase<"DescriptorSet", 34>;
+def SPV_D_Offset : I32EnumAttrCase<"Offset", 35>;
+def SPV_D_XfbBuffer : I32EnumAttrCase<"XfbBuffer", 36>;
+def SPV_D_XfbStride : I32EnumAttrCase<"XfbStride", 37>;
+def SPV_D_FuncParamAttr : I32EnumAttrCase<"FuncParamAttr", 38>;
+def SPV_D_FPRoundingMode : I32EnumAttrCase<"FPRoundingMode", 39>;
+def SPV_D_FPFastMathMode : I32EnumAttrCase<"FPFastMathMode", 40>;
+def SPV_D_LinkageAttributes : I32EnumAttrCase<"LinkageAttributes", 41>;
+def SPV_D_NoContraction : I32EnumAttrCase<"NoContraction", 42>;
+def SPV_D_InputAttachmentIndex : I32EnumAttrCase<"InputAttachmentIndex", 43>;
+def SPV_D_Alignment : I32EnumAttrCase<"Alignment", 44>;
+def SPV_D_MaxByteOffset : I32EnumAttrCase<"MaxByteOffset", 45>;
+def SPV_D_AlignmentId : I32EnumAttrCase<"AlignmentId", 46>;
+def SPV_D_MaxByteOffsetId : I32EnumAttrCase<"MaxByteOffsetId", 47>;
+def SPV_D_NoSignedWrap : I32EnumAttrCase<"NoSignedWrap", 4469>;
+def SPV_D_NoUnsignedWrap : I32EnumAttrCase<"NoUnsignedWrap", 4470>;
+def SPV_D_ExplicitInterpAMD : I32EnumAttrCase<"ExplicitInterpAMD", 4999>;
+def SPV_D_OverrideCoverageNV : I32EnumAttrCase<"OverrideCoverageNV", 5248>;
+def SPV_D_PassthroughNV : I32EnumAttrCase<"PassthroughNV", 5250>;
+def SPV_D_ViewportRelativeNV : I32EnumAttrCase<"ViewportRelativeNV", 5252>;
+def SPV_D_SecondaryViewportRelativeNV : I32EnumAttrCase<"SecondaryViewportRelativeNV", 5256>;
+def SPV_D_PerPrimitiveNV : I32EnumAttrCase<"PerPrimitiveNV", 5271>;
+def SPV_D_PerViewNV : I32EnumAttrCase<"PerViewNV", 5272>;
+def SPV_D_PerTaskNV : I32EnumAttrCase<"PerTaskNV", 5273>;
+def SPV_D_PerVertexNV : I32EnumAttrCase<"PerVertexNV", 5285>;
+def SPV_D_NonUniformEXT : I32EnumAttrCase<"NonUniformEXT", 5300>;
+def SPV_D_RestrictPointerEXT : I32EnumAttrCase<"RestrictPointerEXT", 5355>;
+def SPV_D_AliasedPointerEXT : I32EnumAttrCase<"AliasedPointerEXT", 5356>;
+def SPV_D_CounterBuffer : I32EnumAttrCase<"CounterBuffer", 5634>;
+def SPV_D_UserSemantic : I32EnumAttrCase<"UserSemantic", 5635>;
+def SPV_D_UserTypeGOOGLE : I32EnumAttrCase<"UserTypeGOOGLE", 5636>;
+
+def SPV_DecorationAttr :
+ I32EnumAttr<"Decoration", "valid SPIR-V Decoration", [
+ SPV_D_RelaxedPrecision, SPV_D_SpecId, SPV_D_Block, SPV_D_BufferBlock,
+ SPV_D_RowMajor, SPV_D_ColMajor, SPV_D_ArrayStride, SPV_D_MatrixStride,
+ SPV_D_GLSLShared, SPV_D_GLSLPacked, SPV_D_CPacked, SPV_D_BuiltIn,
+ SPV_D_NoPerspective, SPV_D_Flat, SPV_D_Patch, SPV_D_Centroid, SPV_D_Sample,
+ SPV_D_Invariant, SPV_D_Restrict, SPV_D_Aliased, SPV_D_Volatile, SPV_D_Constant,
+ SPV_D_Coherent, SPV_D_NonWritable, SPV_D_NonReadable, SPV_D_Uniform,
+ SPV_D_UniformId, SPV_D_SaturatedConversion, SPV_D_Stream, SPV_D_Location,
+ SPV_D_Component, SPV_D_Index, SPV_D_Binding, SPV_D_DescriptorSet, SPV_D_Offset,
+ SPV_D_XfbBuffer, SPV_D_XfbStride, SPV_D_FuncParamAttr, SPV_D_FPRoundingMode,
+ SPV_D_FPFastMathMode, SPV_D_LinkageAttributes, SPV_D_NoContraction,
+ SPV_D_InputAttachmentIndex, SPV_D_Alignment, SPV_D_MaxByteOffset,
+ SPV_D_AlignmentId, SPV_D_MaxByteOffsetId, SPV_D_NoSignedWrap,
+ SPV_D_NoUnsignedWrap, SPV_D_ExplicitInterpAMD, SPV_D_OverrideCoverageNV,
+ SPV_D_PassthroughNV, SPV_D_ViewportRelativeNV,
+ SPV_D_SecondaryViewportRelativeNV, SPV_D_PerPrimitiveNV, SPV_D_PerViewNV,
+ SPV_D_PerTaskNV, SPV_D_PerVertexNV, SPV_D_NonUniformEXT,
+ SPV_D_RestrictPointerEXT, SPV_D_AliasedPointerEXT, SPV_D_CounterBuffer,
+ SPV_D_UserSemantic, SPV_D_UserTypeGOOGLE
+ ]> {
+ let returnType = "::mlir::spirv::Decoration";
+ let convertFromStorage = "static_cast<::mlir::spirv::Decoration>($_self.getInt())";
+ let cppNamespace = "::mlir::spirv";
+}
+
def SPV_D_1D : I32EnumAttrCase<"1D", 0>;
def SPV_D_2D : I32EnumAttrCase<"2D", 1>;
def SPV_D_3D : I32EnumAttrCase<"3D", 2>;
--- /dev/null
+//===- StringExtras.h - String utilities used by MLIR -----------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file contains string utility functions used within MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_SUPPORT_STRINGEXTRAS_H
+#define MLIR_SUPPORT_STRINGEXTRAS_H
+
+#include "llvm/ADT/StringExtras.h"
+
+namespace mlir {
+/// Converts a string to snake-case from camel-case by replacing all uppercase
+/// letters with '_' followed by the letter in lowercase, except if the
+/// uppercase letter is the first character of the string.
+inline std::string convertToSnakeCase(llvm::StringRef input) {
+ std::string snakeCase;
+ snakeCase.reserve(input.size());
+ for (auto c : input) {
+ if (std::isupper(c)) {
+ if (!snakeCase.empty() && snakeCase.back() != '_') {
+ snakeCase.push_back('_');
+ }
+ snakeCase.push_back(llvm::toLower(c));
+ } else {
+ snakeCase.push_back(c);
+ }
+ }
+ return snakeCase;
+}
+
+/// Converts a string from camel-case to snake_case by replacing all occurences
+/// of '_' followed by a lowercase letter with the letter in
+/// uppercase. Optionally allow capitalization of the first letter (if it is a
+/// lowercase letter)
+inline std::string convertToCamelCase(llvm::StringRef input,
+ bool capitalizeFirst = false) {
+ if (input.empty()) {
+ return "";
+ }
+ std::string output;
+ output.reserve(input.size());
+ size_t pos = 0;
+ if (capitalizeFirst && std::islower(input[pos])) {
+ output.push_back(llvm::toUpper(input[pos]));
+ pos++;
+ }
+ while (pos < input.size()) {
+ auto cur = input[pos];
+ if (cur == '_') {
+ if (pos && (pos + 1 < input.size())) {
+ if (std::islower(input[pos + 1])) {
+ output.push_back(llvm::toUpper(input[pos + 1]));
+ pos += 2;
+ continue;
+ }
+ }
+ }
+ output.push_back(cur);
+ pos++;
+ }
+ return output;
+}
+} // namespace mlir
+
+#endif // MLIR_SUPPORT_STRINGEXTRAS_H
#include "mlir/IR/Function.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
+#include "mlir/Support/StringExtras.h"
using namespace mlir;
// TODO(antiagainst): generate these strings using ODS.
static constexpr const char kAlignmentAttrName[] = "alignment";
-static constexpr const char kBindingAttrName[] = "binding";
-static constexpr const char kDescriptorSetAttrName[] = "descriptor_set";
static constexpr const char kIndicesAttrName[] = "indices";
static constexpr const char kValueAttrName[] = "value";
static constexpr const char kValuesAttrName[] = "values";
}
template <typename EnumClass>
-static ParseResult parseEnumAttribute(EnumClass &value, OpAsmParser *parser,
- OperationState *state) {
+static ParseResult parseEnumAttribute(EnumClass &value, OpAsmParser *parser) {
Attribute attrVal;
SmallVector<NamedAttribute, 1> attr;
auto loc = parser->getCurrentLocation();
<< " attribute specification: " << attrVal;
}
value = attrOptional.getValue();
+ return success();
+}
+
+template <typename EnumClass>
+static ParseResult parseEnumAttribute(EnumClass &value, OpAsmParser *parser,
+ OperationState *state) {
+ if (parseEnumAttribute(value, parser)) {
+ return failure();
+ }
state->addAttribute(
spirv::attributeName<EnumClass>(),
parser->getBuilder().getI32IntegerAttr(bitwiseCast<int32_t>(value)));
spirv::StorageClass storageClass;
OpAsmParser::OperandType ptrInfo;
Type elementType;
- if (parseEnumAttribute(storageClass, parser, state) ||
+ if (parseEnumAttribute(storageClass, parser) ||
parser->parseOperand(ptrInfo) ||
parseMemoryAccessAttributes(parser, state) ||
parser->parseOptionalAttributeDict(state->attributes) ||
SmallVector<OpAsmParser::OperandType, 2> operandInfo;
auto loc = parser->getCurrentLocation();
Type elementType;
- if (parseEnumAttribute(storageClass, parser, state) ||
+ if (parseEnumAttribute(storageClass, parser) ||
parser->parseOperandList(operandInfo, 2) ||
parseMemoryAccessAttributes(parser, state) || parser->parseColon() ||
parser->parseType(elementType)) {
// Parse optional descriptor binding
Attribute set, binding;
+ auto descriptorSetName =
+ convertToSnakeCase(stringifyDecoration(spirv::Decoration::DescriptorSet));
+ auto bindingName =
+ convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding));
if (succeeded(parser->parseOptionalKeyword("bind"))) {
Type i32Type = parser->getBuilder().getIntegerType(32);
if (parser->parseLParen() ||
- parser->parseAttribute(set, i32Type, kDescriptorSetAttrName,
+ parser->parseAttribute(set, i32Type, descriptorSetName,
state->attributes) ||
parser->parseComma() ||
- parser->parseAttribute(binding, i32Type, kBindingAttrName,
+ parser->parseAttribute(binding, i32Type, bindingName,
state->attributes) ||
parser->parseRParen())
return failure();
}
// Print optional descriptor binding
- auto set = varOp.getAttrOfType<IntegerAttr>(kDescriptorSetAttrName);
- auto binding = varOp.getAttrOfType<IntegerAttr>(kBindingAttrName);
- if (set && binding) {
- elidedAttrs.push_back(kDescriptorSetAttrName);
- elidedAttrs.push_back(kBindingAttrName);
- *printer << " bind(" << set.getInt() << ", " << binding.getInt() << ")";
+ auto descriptorSetName =
+ convertToSnakeCase(stringifyDecoration(spirv::Decoration::DescriptorSet));
+ auto bindingName =
+ convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding));
+ auto descriptorSet = varOp.getAttrOfType<IntegerAttr>(descriptorSetName);
+ auto binding = varOp.getAttrOfType<IntegerAttr>(bindingName);
+ if (descriptorSet && binding) {
+ elidedAttrs.push_back(descriptorSetName);
+ elidedAttrs.push_back(bindingName);
+ *printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
+ << ")";
}
printer->printOptionalAttrDict(op->getAttrs(), elidedAttrs);
#include "mlir/IR/Builders.h"
#include "mlir/IR/Location.h"
#include "mlir/Support/LogicalResult.h"
+#include "mlir/Support/StringExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/bit.h"
/// Process SPIR-V OpName with `operands`
LogicalResult processName(ArrayRef<uint32_t> operands);
+ /// Method to process an OpDecorate instruction.
+ LogicalResult processDecoration(ArrayRef<uint32_t> words);
+
/// Processes the SPIR-V function at the current `offset` into `binary`.
/// The operands to the OpFunction instruction is passed in as ``operands`.
/// This method processes each instruction inside the function and dispatches
// Result <id> to name mapping.
DenseMap<uint32_t, StringRef> nameMap;
+ // Result <id> to decorations mapping.
+ DenseMap<uint32_t, NamedAttributeList> decorations;
+
// List of instructions that are processed in a defered fashion (after an
// initial processing of the entire binary). Some operations like
// OpEntryPoint, and OpExecutionMode use forward references to function
return success();
}
+LogicalResult Deserializer::processDecoration(ArrayRef<uint32_t> words) {
+ // TODO : This function should also be auto-generated. For now, since only a
+ // few decorations are processed/handled in a meaningful manner, going with a
+ // manual implementation.
+ if (words.size() < 2) {
+ return emitError(
+ unknownLoc, "OpDecorate must have at least result <id> and Decoration");
+ }
+ auto decorationName =
+ stringifyDecoration(static_cast<spirv::Decoration>(words[1]));
+ if (decorationName.empty()) {
+ return emitError(unknownLoc, "invalid Decoration code : ") << words[1];
+ }
+ auto attrName = convertToSnakeCase(decorationName);
+ switch (static_cast<spirv::Decoration>(words[1])) {
+ case spirv::Decoration::DescriptorSet:
+ case spirv::Decoration::Binding:
+ if (words.size() != 3) {
+ return emitError(unknownLoc, "OpDecorate with ")
+ << decorationName << " needs a single integer literal";
+ }
+ decorations[words[0]].set(
+ opBuilder.getIdentifier(attrName),
+ opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
+ break;
+ default:
+ return emitError(unknownLoc, "unhandled Decoration : '") << decorationName;
+ }
+ return success();
+}
+
LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) {
// Get the result type
if (operands.size() != 4) {
return processConstantBool(false, operands);
case spirv::Opcode::OpConstantNull:
return processConstantNull(operands);
+ case spirv::Opcode::OpDecorate:
+ return processDecoration(operands);
case spirv::Opcode::OpFunction:
return processFunction(operands);
default:
}
namespace {
+
template <>
LogicalResult
Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) {
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/Support/LogicalResult.h"
+#include "mlir/Support/StringExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/bit.h"
#include "llvm/Support/raw_ostream.h"
/// Processes a SPIR-V function op.
LogicalResult processFuncOp(FuncOp op);
+ /// Process attributes that translate to decorations on the result <id>
+ LogicalResult processDecoration(Location loc, uint32_t resultID,
+ NamedAttribute attr);
+
//===--------------------------------------------------------------------===//
// Types
//===--------------------------------------------------------------------===//
return failure();
}
+LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
+ NamedAttribute attr) {
+ auto attrName = attr.first.strref();
+ auto decorationName = mlir::convertToCamelCase(attrName, true);
+ auto decoration = spirv::symbolizeDecoration(decorationName);
+ if (!decoration) {
+ return emitError(
+ loc, "non-argument attributes expected to have snake-case-ified "
+ "decoration name, unhandled attribute with name : ")
+ << attrName;
+ }
+ SmallVector<uint32_t, 1> args;
+ args.push_back(resultID);
+ args.push_back(static_cast<uint32_t>(decoration.getValue()));
+ switch (decoration.getValue()) {
+ case spirv::Decoration::DescriptorSet:
+ case spirv::Decoration::Binding:
+ if (auto intAttr = attr.second.dyn_cast<IntegerAttr>()) {
+ args.push_back(intAttr.getValue().getZExtValue());
+ break;
+ }
+ return emitError(loc, "expected integer attribute for ") << attrName;
+ default:
+ return emitError(loc, "unhandled decoration ") << decorationName;
+ }
+ return encodeInstructionInto(decorations, spirv::Opcode::OpDecorate, args);
+}
+
LogicalResult Serializer::processFuncOp(FuncOp op) {
uint32_t fnTypeID = 0;
// Generate type of the function.
// RUN: mlir-translate -serialize-spirv %s | mlir-translate -deserialize-spirv | FileCheck %s
-// CHECK: {{%.*}} = spv.Variable : !spv.ptr<f32, Input>
-// CHECK-NEXT: {{%.*}} = spv.Variable : !spv.ptr<f32, Output>
+// CHECK: {{%.*}} = spv.Variable bind(1, 0) : !spv.ptr<f32, Input>
+// CHECK-NEXT: {{%.*}} = spv.Variable bind(0, 1) : !spv.ptr<f32, Output>
func @spirv_variables() -> () {
spv.module "Logical" "VulkanKHR" {
- %2 = spv.Variable : !spv.ptr<f32, Input>
- %3 = spv.Variable : !spv.ptr<f32, Output>
+ %2 = spv.Variable bind(1, 0) : !spv.ptr<f32, Input>
+ %3 = spv.Variable bind(0, 1): !spv.ptr<f32, Output>
}
return
}
\ No newline at end of file
//
//===----------------------------------------------------------------------===//
+#include "mlir/Support/StringExtras.h"
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Operator.h"
+#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
using llvm::Record;
using llvm::RecordKeeper;
using llvm::SMLoc;
+using llvm::StringRef;
+using llvm::Twine;
using mlir::tblgen::Attribute;
using mlir::tblgen::EnumAttr;
using mlir::tblgen::NamedAttribute;
os << " }\n";
}
-static void emitSerializationFunction(const Record *record, const Operator &op,
+static void emitSerializationFunction(const Record *attrClass,
+ const Record *record, const Operator &op,
raw_ostream &os) {
// If the record has 'autogenSerialization' set to 0, nothing to do
if (!record->getValueAsBit("autogenSerialization")) {
op.getQualCppClassName())
<< " {\n";
os << " SmallVector<uint32_t, 4> operands;\n";
+ os << " SmallVector<StringRef, 2> elidedAttrs;\n";
// Serialize result information
if (op.getNumResults() == 1) {
- os << " {\n";
- os << " uint32_t typeID = 0;\n";
- os << " if (failed(processType(op.getLoc(), "
- "op.getResult()->getType(), typeID))) {\n";
- os << " return failure();\n";
- os << " }\n";
- os << " operands.push_back(typeID);\n";
- /// Create an SSA result <id> for the op
- os << " auto resultID = getNextID();\n";
- os << " valueIDMap[op.getResult()] = resultID;\n";
- os << " operands.push_back(resultID);\n";
+ os << " uint32_t resultTypeID = 0;\n";
+ os << " if (failed(processType(op.getLoc(), op.getType(), resultTypeID))) "
+ "{\n";
+ os << " return failure();\n";
os << " }\n";
+ os << " operands.push_back(resultTypeID);\n";
+ // Create an SSA result <id> for the op
+ os << " auto resultID = getNextID();\n";
+ os << " valueIDMap[op.getResult()] = resultID;\n";
+ os << " operands.push_back(resultID);\n";
} else if (op.getNumResults() != 0) {
PrintFatalError(record->getLoc(), "SPIR-V ops can only zero or one result");
}
emitAttributeSerialization(
(attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr),
record->getLoc(), "op", "operands", attr->name, os);
+ os << " elidedAttrs.push_back(\"" << attr->name << "\");\n";
}
os << " }\n";
}
os << formatv(" encodeInstructionInto("
"functions, spirv::getOpcode<{0}>(), operands);\n",
op.getQualCppClassName());
+
+ if (op.getNumResults() == 1) {
+ // All non-argument attributes translated into OpDecorate instruction
+ os << " for (auto attr : op.getAttrs()) {\n";
+ os << " if (llvm::any_of(elidedAttrs, [&](StringRef elided) { return "
+ "attr.first.is(elided); })) {\n";
+ os << " continue;\n";
+ os << " }\n";
+ os << " if (failed(processDecoration(op.getLoc(), resultID, attr))) {\n";
+ os << " return failure();";
+ os << " }\n";
+ os << " }\n";
+ }
+
os << " return success();\n";
os << "}\n\n";
}
}
}
-static void emitDeserializationFunction(const Record *record,
+static void emitDeserializationFunction(const Record *attrClass,
+ const Record *record,
const Operator &op, raw_ostream &os) {
// If the record has 'autogenSerialization' set to 0, nothing to do
if (!record->getValueAsBit("autogenSerialization")) {
"operands, attributes); (void)op;\n",
op.getQualCppClassName());
if (hasResult) {
- os << " valueMap[valueID] = op.getResult();\n";
+ os << " valueMap[valueID] = op.getResult();\n\n";
}
+
+ // Import decorations parsed
+ if (op.getNumResults() == 1) {
+ os << " if (decorations.count(valueID)) {\n";
+ os << " auto decorationAttrs = decorations[valueID];\n";
+ os << " for (auto attr : decorationAttrs.getAttrs()) {\n";
+ os << " op.setAttr(attr.first, attr.second);\n";
+ os << " }\n";
+ os << " }\n";
+ }
+
os << " return success();\n";
os << "}\n\n";
}
utilsString;
raw_string_ostream dSerFn(dSerFnString), dDesFn(dDesFnString),
serFn(serFnString), deserFn(deserFnString), utils(utilsString);
+ auto attrClass = recordKeeper.getClass("Attr");
declareOpcodeFn(utils);
initDispatchSerializationFn(dSerFn);
}
Operator op(def);
emitGetOpcodeFunction(def, op, utils);
- emitSerializationFunction(def, op, serFn);
+ emitSerializationFunction(attrClass, def, op, serFn);
emitSerializationDispatch(op, dSerFn);
- emitDeserializationFunction(def, op, deserFn);
+ emitDeserializationFunction(attrClass, def, op, deserFn);
emitDeserializationDispatch(op, def, dDesFn);
}
finalizeDispatchSerializationFn(dSerFn);
"SymbolizeFnTy<EnumClass> symbolizeEnum();\n";
}
-std::string convertSnakeCase(llvm::StringRef inputString) {
- std::string snakeCase;
- for (auto c : inputString) {
- if (c >= 'A' && c <= 'Z') {
- if (!snakeCase.empty()) {
- snakeCase.push_back('_');
- }
- snakeCase.push_back((c - 'A') + 'a');
- } else {
- snakeCase.push_back(c);
- }
- }
- return snakeCase;
-}
-
static void emitEnumGetAttrNameFnDefn(const EnumAttr &enumAttr,
raw_ostream &os) {
auto enumName = enumAttr.getEnumClassName();
<< " {\n";
os << " "
<< formatv("static constexpr const char attrName[] = \"{0}\";\n",
- convertSnakeCase(enumName));
+ mlir::convertToSnakeCase(enumName));
os << " return attrName;\n";
os << "}\n";
}
AttributeTest.cpp
DialectTest.cpp
OperationSupportTest.cpp
+ StringExtrasTest.cpp
)
target_link_libraries(MLIRIRTests
PRIVATE
--- /dev/null
+//===- StringExtras.cpp - Tests for utility methods in StringExtras.h -----===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Support/StringExtras.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+static void testConvertToSnakeCase(llvm::StringRef input,
+ llvm::StringRef expected) {
+ EXPECT_EQ(convertToSnakeCase(input), expected.str());
+}
+
+TEST(StringExtras, ConvertToSnakeCase) {
+ testConvertToSnakeCase("OpName", "op_name");
+ testConvertToSnakeCase("opName", "op_name");
+ testConvertToSnakeCase("_OpName", "_op_name");
+ testConvertToSnakeCase("Op_Name", "op_name");
+ testConvertToSnakeCase("", "");
+ testConvertToSnakeCase("A", "a");
+ testConvertToSnakeCase("_", "_");
+ testConvertToSnakeCase("a", "a");
+ testConvertToSnakeCase("op_name", "op_name");
+ testConvertToSnakeCase("_op_name", "_op_name");
+ testConvertToSnakeCase("__op_name", "__op_name");
+ testConvertToSnakeCase("op__name", "op__name");
+}
+
+template <bool capitalizeFirst>
+static void testConvertToCamelCase(llvm::StringRef input,
+ llvm::StringRef expected) {
+ EXPECT_EQ(convertToCamelCase(input, capitalizeFirst), expected.str());
+}
+
+TEST(StringExtras, ConvertToCamelCase) {
+ testConvertToCamelCase<false>("op_name", "opName");
+ testConvertToCamelCase<false>("_op_name", "_opName");
+ testConvertToCamelCase<false>("__op_name", "_OpName");
+ testConvertToCamelCase<false>("op__name", "op_Name");
+ testConvertToCamelCase<false>("", "");
+ testConvertToCamelCase<false>("A", "A");
+ testConvertToCamelCase<false>("_", "_");
+ testConvertToCamelCase<false>("a", "a");
+ testConvertToCamelCase<false>("OpName", "OpName");
+ testConvertToCamelCase<false>("opName", "opName");
+ testConvertToCamelCase<false>("_OpName", "_OpName");
+ testConvertToCamelCase<false>("Op_Name", "Op_Name");
+ testConvertToCamelCase<true>("op_name", "OpName");
+ testConvertToCamelCase<true>("_op_name", "_opName");
+ testConvertToCamelCase<true>("__op_name", "_OpName");
+ testConvertToCamelCase<true>("op__name", "Op_Name");
+ testConvertToCamelCase<true>("", "");
+ testConvertToCamelCase<true>("A", "A");
+ testConvertToCamelCase<true>("_", "_");
+ testConvertToCamelCase<true>("a", "A");
+ testConvertToCamelCase<true>("OpName", "OpName");
+ testConvertToCamelCase<true>("_OpName", "_OpName");
+ testConvertToCamelCase<true>("Op_Name", "Op_Name");
+ testConvertToCamelCase<true>("opName", "OpName");
+}
return chuncks
+def uniquify(lst, equality_fn):
+ """Returns a list after pruning duplicate elements.
+
+ Arguments:
+ - lst: List whose elements are to be uniqued.
+ - equality_fn: Function used to compare equality between elements of the
+ list.
+
+ Returns:
+ - A list with all duplicated removed. The order of elements is same as the
+ original list, with only the first occurence of duplicates retained.
+ """
+ keys = set()
+ unique_lst = []
+ for elem in lst:
+ key = equality_fn(elem)
+ if equality_fn(key) not in keys:
+ unique_lst.append(elem)
+ keys.add(key)
+ return unique_lst
+
+
def gen_operand_kind_enum_attr(operand_kind):
"""Generates the TableGen I32EnumAttr definition for the given operand kind.
kind_acronym = ''.join([c for c in kind_name if c >= 'A' and c <= 'Z'])
kind_cases = [(case['enumerant'], case['value'])
for case in operand_kind['enumerants']]
+ kind_cases = uniquify(kind_cases, lambda x: x[1])
max_len = max([len(symbol) for (symbol, _) in kind_cases])
# Generate the definition for each enum case