From 20b93abca6516bbb23689c3777536fea04e46e14 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Thu, 20 Jul 2023 22:51:35 -0700 Subject: [PATCH] Update ODS variadic segments "magic" attributes to use native Properties The operand_segment_sizes and result_segment_sizes Attributes are now inlined in the operation as native propertie. We continue to support building an Attribute on the fly for `getAttr("operand_segment_sizes")` and setting the property from an attribute with `setAttr("operand_segment_sizes", attr)`. A new bytecode version is introduced to support backward compatibility and backdeployments. Differential Revision: https://reviews.llvm.org/D155919 --- .../include/mlir/Bytecode/BytecodeImplementation.h | 123 +++++++++ mlir/include/mlir/Bytecode/Encoding.h | 6 +- mlir/include/mlir/IR/ODSSupport.h | 7 + mlir/include/mlir/IR/OpBase.td | 1 + mlir/include/mlir/IR/OpDefinition.h | 1 + mlir/include/mlir/IR/OperationSupport.h | 6 +- mlir/include/mlir/TableGen/Property.h | 49 +++- mlir/lib/Bytecode/Reader/BytecodeReader.cpp | 30 ++- mlir/lib/Bytecode/Writer/IRNumbering.cpp | 6 +- mlir/lib/Bytecode/Writer/IRNumbering.h | 3 + mlir/lib/IR/ODSSupport.cpp | 28 +- mlir/lib/TableGen/Property.cpp | 87 +++---- mlir/test/Dialect/LLVMIR/invalid.mlir | 2 +- mlir/test/IR/traits.mlir | 16 +- mlir/test/lib/Dialect/Test/TestOps.td | 3 +- mlir/test/python/dialects/linalg/ops.py | 2 +- mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 286 +++++++++++++++++---- mlir/tools/mlir-tblgen/OpFormatGen.cpp | 51 ++-- mlir/unittests/IR/AdaptorTest.cpp | 2 +- mlir/unittests/IR/OpPropertiesTest.cpp | 6 +- 20 files changed, 542 insertions(+), 173 deletions(-) diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h index 4e74c12..9c9aa7a 100644 --- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h +++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h @@ -20,6 +20,7 @@ #include "mlir/IR/DialectInterface.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Twine.h" namespace mlir { @@ -39,6 +40,9 @@ public: /// Emit an error to the reader. virtual InFlightDiagnostic emitError(const Twine &msg = {}) = 0; + /// Return the bytecode version being read. + virtual uint64_t getBytecodeVersion() const = 0; + /// Read out a list of elements, invoking the provided callback for each /// element. The callback function may be in any of the following forms: /// * LogicalResult(T &) @@ -148,6 +152,76 @@ public: [this](int64_t &value) { return readSignedVarInt(value); }); } + /// Parse a variable length encoded integer whose low bit is used to encode an + /// unrelated flag, i.e: `(integerValue << 1) | (flag ? 1 : 0)`. + LogicalResult readVarIntWithFlag(uint64_t &result, bool &flag) { + if (failed(readVarInt(result))) + return failure(); + flag = result & 1; + result >>= 1; + return success(); + } + + /// Read a "small" sparse array of integer <= 32 bits elements, where + /// index/value pairs can be compressed when the array is small. + /// Note that only some position of the array will be read and the ones + /// not stored in the bytecode are gonne be left untouched. + /// If the provided array is too small for the stored indices, an error + /// will be returned. + template + LogicalResult readSparseArray(MutableArrayRef array) { + static_assert(sizeof(T) < sizeof(uint64_t), "expect integer < 64 bits"); + static_assert(std::is_integral::value, "expects integer"); + uint64_t nonZeroesCount; + bool useSparseEncoding; + if (failed(readVarIntWithFlag(nonZeroesCount, useSparseEncoding))) + return failure(); + if (nonZeroesCount == 0) + return success(); + if (!useSparseEncoding) { + // This is a simple dense array. + if (nonZeroesCount > array.size()) { + emitError("trying to read an array of ") + << nonZeroesCount << " but only " << array.size() + << " storage available."; + return failure(); + } + for (int64_t index : llvm::seq(0, nonZeroesCount)) { + uint64_t value; + if (failed(readVarInt(value))) + return failure(); + array[index] = value; + } + return success(); + } + // Read sparse encoding + // This is the number of bits used for packing the index with the value. + uint64_t indexBitSize; + if (failed(readVarInt(indexBitSize))) + return failure(); + constexpr uint64_t maxIndexBitSize = 8; + if (indexBitSize > maxIndexBitSize) { + emitError("reading sparse array with indexing above 8 bits: ") + << indexBitSize; + return failure(); + } + for (uint32_t count : llvm::seq(0, nonZeroesCount)) { + (void)count; + uint64_t indexValuePair; + if (failed(readVarInt(indexValuePair))) + return failure(); + uint64_t index = indexValuePair & ~(uint64_t(-1) << (indexBitSize)); + uint64_t value = indexValuePair >> indexBitSize; + if (index >= array.size()) { + emitError("reading a sparse array found index ") + << index << " but only " << array.size() << " storage available."; + return failure(); + } + array[index] = value; + } + return success(); + } + /// Read an APInt that is known to have been encoded with the given width. virtual FailureOr readAPIntWithKnownWidth(unsigned bitWidth) = 0; @@ -230,6 +304,55 @@ public: writeList(value, [this](int64_t value) { writeSignedVarInt(value); }); } + /// Write a VarInt and a flag packed together. + void writeVarIntWithFlag(uint64_t value, bool flag) { + writeVarInt((value << 1) | (flag ? 1 : 0)); + } + + /// Write out a "small" sparse array of integer <= 32 bits elements, where + /// index/value pairs can be compressed when the array is small. This method + /// will scan the array multiple times and should not be used for large + /// arrays. The optional provided "zero" can be used to adjust for the + /// expected repeated value. We assume here that the array size fits in a 32 + /// bits integer. + template + void writeSparseArray(ArrayRef array) { + static_assert(sizeof(T) < sizeof(uint64_t), "expect integer < 64 bits"); + static_assert(std::is_integral::value, "expects integer"); + uint32_t size = array.size(); + uint32_t nonZeroesCount = 0, lastIndex = 0; + for (uint32_t index : llvm::seq(0, size)) { + if (!array[index]) + continue; + nonZeroesCount++; + lastIndex = index; + } + // If the last position is too large, or the array isn't at least 50% + // sparse, emit it with a dense encoding. + if (lastIndex > 256 || nonZeroesCount > size / 2) { + // Emit the array size and a flag which indicates whether it is sparse. + writeVarIntWithFlag(size, false); + for (const T &elt : array) + writeVarInt(elt); + return; + } + // Emit sparse: first the number of elements we'll write and a flag + // indicating it is a sparse encoding. + writeVarIntWithFlag(nonZeroesCount, true); + if (nonZeroesCount == 0) + return; + // This is the number of bits used for packing the index with the value. + int indexBitSize = llvm::Log2_32_Ceil(lastIndex + 1); + writeVarInt(indexBitSize); + for (uint32_t index : llvm::seq(0, lastIndex + 1)) { + T value = array[index]; + if (!value) + continue; + uint64_t indexValuePair = (value << indexBitSize) | (index); + writeVarInt(indexValuePair); + } + } + /// Write an APInt to the bytecode stream whose bitwidth will be known /// externally at read time. This method is useful for encoding APInt values /// when the width is known via external means, such as via a type. This diff --git a/mlir/include/mlir/Bytecode/Encoding.h b/mlir/include/mlir/Bytecode/Encoding.h index 21edc1e..ac38269 100644 --- a/mlir/include/mlir/Bytecode/Encoding.h +++ b/mlir/include/mlir/Bytecode/Encoding.h @@ -45,8 +45,12 @@ enum BytecodeVersion { /// with the discardable attributes. kNativePropertiesEncoding = 5, + /// ODS emits operand/result segment_size as native properties instead of + /// an attribute. + kNativePropertiesODSSegmentSize = 6, + /// The current bytecode version. - kVersion = 5, + kVersion = 6, /// An arbitrary value used to fill alignment padding. kAlignmentByte = 0xCB, diff --git a/mlir/include/mlir/IR/ODSSupport.h b/mlir/include/mlir/IR/ODSSupport.h index 1d3cbbd..687f764 100644 --- a/mlir/include/mlir/IR/ODSSupport.h +++ b/mlir/include/mlir/IR/ODSSupport.h @@ -37,6 +37,13 @@ Attribute convertToAttribute(MLIRContext *ctx, int64_t storage); LogicalResult convertFromAttribute(MutableArrayRef storage, Attribute attr, InFlightDiagnostic *diag); +/// Convert a DenseI32ArrayAttr to the provided storage. It is expected that the +/// storage has the same size as the array. An error is returned if the +/// attribute isn't a DenseI32ArrayAttr or it does not have the same size. If +/// the optional diagnostic is provided an error message is also emitted. +LogicalResult convertFromAttribute(MutableArrayRef storage, + Attribute attr, InFlightDiagnostic *diag); + /// Convert the provided ArrayRef to a DenseI64ArrayAttr attribute. Attribute convertToAttribute(MLIRContext *ctx, ArrayRef storage); diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 940588b..274a531 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1241,6 +1241,7 @@ class ArrayProperty : let interfaceType = "::llvm::ArrayRef<" # storageTypeParam # ">"; let convertFromStorage = "$_storage"; let assignToStorage = "::llvm::copy($_value, $_storage)"; + let hashProperty = "llvm::hash_combine_range(std::begin($_storage), std::end($_storage));"; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index 221c607..d42bffa 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -20,6 +20,7 @@ #define MLIR_IR_OPDEFINITION_H #include "mlir/IR/Dialect.h" +#include "mlir/IR/ODSSupport.h" #include "mlir/IR/Operation.h" #include "llvm/Support/PointerLikeTypeTraits.h" diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h index 303e030..f3a79eb 100644 --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -555,7 +555,8 @@ public: StringRef name) final { if constexpr (hasProperties) { auto concreteOp = cast(op); - return ConcreteOp::getInherentAttr(concreteOp.getProperties(), name); + return ConcreteOp::getInherentAttr(concreteOp.getContext(), + concreteOp.getProperties(), name); } // If the op does not have support for properties, we dispatch back to the // dictionnary of discardable attributes for now. @@ -575,7 +576,8 @@ public: void populateInherentAttrs(Operation *op, NamedAttrList &attrs) final { if constexpr (hasProperties) { auto concreteOp = cast(op); - ConcreteOp::populateInherentAttrs(concreteOp.getProperties(), attrs); + ConcreteOp::populateInherentAttrs(concreteOp.getContext(), + concreteOp.getProperties(), attrs); } } LogicalResult diff --git a/mlir/include/mlir/TableGen/Property.h b/mlir/include/mlir/TableGen/Property.h index 597543d..d0d6f49 100644 --- a/mlir/include/mlir/TableGen/Property.h +++ b/mlir/include/mlir/TableGen/Property.h @@ -35,51 +35,76 @@ class Property { public: explicit Property(const llvm::Record *record); explicit Property(const llvm::DefInit *init); + Property(StringRef storageType, StringRef interfaceType, + StringRef convertFromStorageCall, StringRef assignToStorageCall, + StringRef convertToAttributeCall, StringRef convertFromAttributeCall, + StringRef readFromMlirBytecodeCall, + StringRef writeToMlirBytecodeCall, StringRef hashPropertyCall, + StringRef defaultValue); // Returns the storage type. - StringRef getStorageType() const; + StringRef getStorageType() const { return storageType; } // Returns the interface type for this property. - StringRef getInterfaceType() const; + StringRef getInterfaceType() const { return interfaceType; } // Returns the template getter method call which reads this property's // storage and returns the value as of the desired return type. - StringRef getConvertFromStorageCall() const; + StringRef getConvertFromStorageCall() const { return convertFromStorageCall; } // Returns the template setter method call which reads this property's // in the provided interface type and assign it to the storage. - StringRef getAssignToStorageCall() const; + StringRef getAssignToStorageCall() const { return assignToStorageCall; } // Returns the conversion method call which reads this property's // in the storage type and builds an attribute. - StringRef getConvertToAttributeCall() const; + StringRef getConvertToAttributeCall() const { return convertToAttributeCall; } // Returns the setter method call which reads this property's // in the provided interface type and assign it to the storage. - StringRef getConvertFromAttributeCall() const; + StringRef getConvertFromAttributeCall() const { + return convertFromAttributeCall; + } // Returns the method call which reads this property from // bytecode and assign it to the storage. - StringRef getReadFromMlirBytecodeCall() const; + StringRef getReadFromMlirBytecodeCall() const { + return readFromMlirBytecodeCall; + } // Returns the method call which write this property's // to the the bytecode. - StringRef getWriteToMlirBytecodeCall() const; + StringRef getWriteToMlirBytecodeCall() const { + return writeToMlirBytecodeCall; + } // Returns the code to compute the hash for this property. - StringRef getHashPropertyCall() const; + StringRef getHashPropertyCall() const { return hashPropertyCall; } // Returns whether this Property has a default value. - bool hasDefaultValue() const; + bool hasDefaultValue() const { return !defaultValue.empty(); } + // Returns the default value for this Property. - StringRef getDefaultValue() const; + StringRef getDefaultValue() const { return defaultValue; } // Returns the TableGen definition this Property was constructed from. - const llvm::Record &getDef() const; + const llvm::Record &getDef() const { return *def; } private: // The TableGen definition of this constraint. const llvm::Record *def; + + // Elements describing a Property, in general fetched from the record. + StringRef storageType; + StringRef interfaceType; + StringRef convertFromStorageCall; + StringRef assignToStorageCall; + StringRef convertToAttributeCall; + StringRef convertFromAttributeCall; + StringRef readFromMlirBytecodeCall; + StringRef writeToMlirBytecodeCall; + StringRef hashPropertyCall; + StringRef defaultValue; }; // A struct wrapping an op property and its name together diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp index 8269546..0639baf 100644 --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -796,9 +796,10 @@ class AttrTypeReader { public: AttrTypeReader(StringSectionReader &stringReader, - ResourceSectionReader &resourceReader, Location fileLoc) + ResourceSectionReader &resourceReader, Location fileLoc, + uint64_t &bytecodeVersion) : stringReader(stringReader), resourceReader(resourceReader), - fileLoc(fileLoc) {} + fileLoc(fileLoc), bytecodeVersion(bytecodeVersion) {} /// Initialize the attribute and type information within the reader. LogicalResult initialize(MutableArrayRef dialects, @@ -883,23 +884,30 @@ private: /// A location used for error emission. Location fileLoc; + + /// Current bytecode version being used. + uint64_t &bytecodeVersion; }; class DialectReader : public DialectBytecodeReader { public: DialectReader(AttrTypeReader &attrTypeReader, StringSectionReader &stringReader, - ResourceSectionReader &resourceReader, EncodingReader &reader) + ResourceSectionReader &resourceReader, EncodingReader &reader, + uint64_t &bytecodeVersion) : attrTypeReader(attrTypeReader), stringReader(stringReader), - resourceReader(resourceReader), reader(reader) {} + resourceReader(resourceReader), reader(reader), + bytecodeVersion(bytecodeVersion) {} InFlightDiagnostic emitError(const Twine &msg) override { return reader.emitError(msg); } + uint64_t getBytecodeVersion() const override { return bytecodeVersion; } + DialectReader withEncodingReader(EncodingReader &encReader) { return DialectReader(attrTypeReader, stringReader, resourceReader, - encReader); + encReader, bytecodeVersion); } Location getLoc() const { return reader.getLoc(); } @@ -1003,6 +1011,7 @@ private: StringSectionReader &stringReader; ResourceSectionReader &resourceReader; EncodingReader &reader; + uint64_t &bytecodeVersion; }; /// Wraps the properties section and handles reading properties out of it. @@ -1207,7 +1216,8 @@ template LogicalResult AttrTypeReader::parseCustomEntry(Entry &entry, EncodingReader &reader, StringRef entryType) { - DialectReader dialectReader(*this, stringReader, resourceReader, reader); + DialectReader dialectReader(*this, stringReader, resourceReader, reader, + bytecodeVersion); if (failed(entry.dialect->load(dialectReader, fileLoc.getContext()))) return failure(); // Ensure that the dialect implements the bytecode interface. @@ -1252,7 +1262,7 @@ public: llvm::MemoryBufferRef buffer, const std::shared_ptr &bufferOwnerRef) : config(config), fileLoc(fileLoc), lazyLoading(lazyLoading), - attrTypeReader(stringReader, resourceReader, fileLoc), + attrTypeReader(stringReader, resourceReader, fileLoc, version), // Use the builtin unrealized conversion cast operation to represent // forward references to values that aren't yet defined. forwardRefOpState(UnknownLoc::get(config.getContext()), @@ -1782,7 +1792,7 @@ BytecodeReader::Impl::parseOpName(EncodingReader &reader, if (!opName->opName) { // Load the dialect and its version. DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, - reader); + reader, version); if (failed(opName->dialect->load(dialectReader, getContext()))) return failure(); // If the opName is empty, this is because we use to accept names such as @@ -1825,7 +1835,7 @@ LogicalResult BytecodeReader::Impl::parseResourceSection( // Initialize the resource reader with the resource sections. DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, - reader); + reader, version); return resourceReader.initialize(fileLoc, config, dialects, stringReader, *resourceData, *resourceOffsetData, dialectReader, bufferOwnerRef); @@ -2186,7 +2196,7 @@ BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader, // interface and control the serialization. if (wasRegistered) { DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, - reader); + reader, version); if (failed( propertiesReader.read(fileLoc, dialectReader, &*opName, opState))) return failure(); diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp index 2547d81..284b3c0 100644 --- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp +++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp @@ -48,7 +48,7 @@ struct IRNumberingState::NumberingDialectWriter : public DialectBytecodeWriter { void writeOwnedBool(bool value) override {} int64_t getBytecodeVersion() const override { - llvm_unreachable("unexpected querying of version in IRNumbering"); + return state.getDesiredBytecodeVersion(); } /// The parent numbering state that is populated by this writer. @@ -391,6 +391,10 @@ void IRNumberingState::number(Dialect *dialect, } } +int64_t IRNumberingState::getDesiredBytecodeVersion() const { + return config.getDesiredBytecodeVersion(); +} + namespace { /// A dummy resource builder used to number dialect resources. struct NumberingResourceBuilder : public AsmResourceBuilder { diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.h b/mlir/lib/Bytecode/Writer/IRNumbering.h index c10e09a..ca30078 100644 --- a/mlir/lib/Bytecode/Writer/IRNumbering.h +++ b/mlir/lib/Bytecode/Writer/IRNumbering.h @@ -186,6 +186,9 @@ public: return blockOperationCounts[block]; } + /// Get the set desired bytecode version to emit. + int64_t getDesiredBytecodeVersion() const; + private: /// This class is used to provide a fake dialect writer for numbering nested /// attributes and types. diff --git a/mlir/lib/IR/ODSSupport.cpp b/mlir/lib/IR/ODSSupport.cpp index ffb84f0..f67e7db 100644 --- a/mlir/lib/IR/ODSSupport.cpp +++ b/mlir/lib/IR/ODSSupport.cpp @@ -33,24 +33,40 @@ LogicalResult mlir::convertFromAttribute(int64_t &storage, Attribute mlir::convertToAttribute(MLIRContext *ctx, int64_t storage) { return IntegerAttr::get(IntegerType::get(ctx, 64), storage); } -LogicalResult mlir::convertFromAttribute(MutableArrayRef storage, - ::mlir::Attribute attr, - ::mlir::InFlightDiagnostic *diag) { - auto valueAttr = dyn_cast(attr); + +template +LogicalResult convertDenseArrayFromAttr(MutableArrayRef storage, + ::mlir::Attribute attr, + ::mlir::InFlightDiagnostic *diag, + StringRef denseArrayTyStr) { + auto valueAttr = dyn_cast(attr); if (!valueAttr) { if (diag) - *diag << "expected DenseI64ArrayAttr for key `value`"; + *diag << "expected " << denseArrayTyStr << " for key `value`"; return failure(); } if (valueAttr.size() != static_cast(storage.size())) { if (diag) - *diag << "Size mismatch in attribute conversion: " << valueAttr.size() + *diag << "size mismatch in attribute conversion: " << valueAttr.size() << " vs " << storage.size(); return failure(); } llvm::copy(valueAttr.asArrayRef(), storage.begin()); return success(); } +LogicalResult mlir::convertFromAttribute(MutableArrayRef storage, + ::mlir::Attribute attr, + ::mlir::InFlightDiagnostic *diag) { + return convertDenseArrayFromAttr(storage, attr, diag, + "DenseI64ArrayAttr"); +} +LogicalResult mlir::convertFromAttribute(MutableArrayRef storage, + Attribute attr, + InFlightDiagnostic *diag) { + return convertDenseArrayFromAttr(storage, attr, diag, + "DenseI32ArrayAttr"); +} + Attribute mlir::convertToAttribute(MLIRContext *ctx, ArrayRef storage) { return DenseI64ArrayAttr::get(ctx, storage); diff --git a/mlir/lib/TableGen/Property.cpp b/mlir/lib/TableGen/Property.cpp index b0bea43..e61d2fd 100644 --- a/mlir/lib/TableGen/Property.cpp +++ b/mlir/lib/TableGen/Property.cpp @@ -32,65 +32,40 @@ static StringRef getValueAsString(const Init *init) { return {}; } -Property::Property(const Record *record) : def(record) { - assert((record->isSubClassOf("Property") || record->isSubClassOf("Attr")) && +Property::Property(const Record *def) + : Property(getValueAsString(def->getValueInit("storageType")), + getValueAsString(def->getValueInit("interfaceType")), + getValueAsString(def->getValueInit("convertFromStorage")), + getValueAsString(def->getValueInit("assignToStorage")), + getValueAsString(def->getValueInit("convertToAttribute")), + getValueAsString(def->getValueInit("convertFromAttribute")), + getValueAsString(def->getValueInit("readFromMlirBytecode")), + getValueAsString(def->getValueInit("writeToMlirBytecode")), + getValueAsString(def->getValueInit("hashProperty")), + getValueAsString(def->getValueInit("defaultValue"))) { + this->def = def; + assert((def->isSubClassOf("Property") || def->isSubClassOf("Attr")) && "must be subclass of TableGen 'Property' class"); } Property::Property(const DefInit *init) : Property(init->getDef()) {} -StringRef Property::getStorageType() const { - const auto *init = def->getValueInit("storageType"); - auto type = getValueAsString(init); - if (type.empty()) - return "Property"; - return type; +Property::Property(StringRef storageType, StringRef interfaceType, + StringRef convertFromStorageCall, + StringRef assignToStorageCall, + StringRef convertToAttributeCall, + StringRef convertFromAttributeCall, + StringRef readFromMlirBytecodeCall, + StringRef writeToMlirBytecodeCall, + StringRef hashPropertyCall, StringRef defaultValue) + : storageType(storageType), interfaceType(interfaceType), + convertFromStorageCall(convertFromStorageCall), + assignToStorageCall(assignToStorageCall), + convertToAttributeCall(convertToAttributeCall), + convertFromAttributeCall(convertFromAttributeCall), + readFromMlirBytecodeCall(readFromMlirBytecodeCall), + writeToMlirBytecodeCall(writeToMlirBytecodeCall), + hashPropertyCall(hashPropertyCall), defaultValue(defaultValue) { + if (storageType.empty()) + storageType = "Property"; } - -StringRef Property::getInterfaceType() const { - const auto *init = def->getValueInit("interfaceType"); - return getValueAsString(init); -} - -StringRef Property::getConvertFromStorageCall() const { - const auto *init = def->getValueInit("convertFromStorage"); - return getValueAsString(init); -} - -StringRef Property::getAssignToStorageCall() const { - const auto *init = def->getValueInit("assignToStorage"); - return getValueAsString(init); -} - -StringRef Property::getConvertToAttributeCall() const { - const auto *init = def->getValueInit("convertToAttribute"); - return getValueAsString(init); -} - -StringRef Property::getConvertFromAttributeCall() const { - const auto *init = def->getValueInit("convertFromAttribute"); - return getValueAsString(init); -} - -StringRef Property::getReadFromMlirBytecodeCall() const { - const auto *init = def->getValueInit("readFromMlirBytecode"); - return getValueAsString(init); -} - -StringRef Property::getWriteToMlirBytecodeCall() const { - const auto *init = def->getValueInit("writeToMlirBytecode"); - return getValueAsString(init); -} - -StringRef Property::getHashPropertyCall() const { - return getValueAsString(def->getValueInit("hashProperty")); -} - -bool Property::hasDefaultValue() const { return !getDefaultValue().empty(); } - -StringRef Property::getDefaultValue() const { - const auto *init = def->getValueInit("defaultValue"); - return getValueAsString(init); -} - -const llvm::Record &Property::getDef() const { return *def; } diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index 09bbc5a..14141c4 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -887,7 +887,7 @@ func.func @switch_wrong_number_of_weights(%arg0 : i32) { func.func @switch_case_type_mismatch(%arg0 : i64) { // expected-error@below {{expects case value type to match condition value type}} - "llvm.switch"(%arg0)[^bb1, ^bb2] <{case_operand_segments = array, case_values = dense<42> : vector<1xi32>, operand_segment_sizes = array}> : (i64) -> () + "llvm.switch"(%arg0)[^bb1, ^bb2] <{case_operand_segments = array, case_values = dense<42> : vector<1xi32>, odsOperandSegmentSizes = array}> : (i64) -> () ^bb1: // pred: ^bb0 llvm.return ^bb2: // pred: ^bb0 diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir index 307918d..7d922ec 100644 --- a/mlir/test/IR/traits.mlir +++ b/mlir/test/IR/traits.mlir @@ -383,21 +383,21 @@ func.func private @foo() // ----- func.func @failedMissingOperandSizeAttr(%arg: i32) { - // expected-error @+1 {{requires dense i32 array attribute 'operand_segment_sizes'}} + // expected-error @+1 {{op operand count (4) does not match with the total size (0) specified in attribute 'operand_segment_sizes'}} "test.attr_sized_operands"(%arg, %arg, %arg, %arg) : (i32, i32, i32, i32) -> () } // ----- func.func @failedOperandSizeAttrWrongType(%arg: i32) { - // expected-error @+1 {{attribute 'operand_segment_sizes' failed to satisfy constraint: i32 dense array attribute}} + // expected-error @+1 {{op operand count (4) does not match with the total size (0) specified in attribute 'operand_segment_sizes'}} "test.attr_sized_operands"(%arg, %arg, %arg, %arg) {operand_segment_sizes = 10} : (i32, i32, i32, i32) -> () } // ----- func.func @failedOperandSizeAttrWrongElementType(%arg: i32) { - // expected-error @+1 {{attribute 'operand_segment_sizes' failed to satisfy constraint: i32 dense array attribute}} + // expected-error @+1 {{op operand count (4) does not match with the total size (0) specified in attribute 'operand_segment_sizes'}} "test.attr_sized_operands"(%arg, %arg, %arg, %arg) {operand_segment_sizes = array} : (i32, i32, i32, i32) -> () } @@ -418,7 +418,7 @@ func.func @failedOperandSizeAttrWrongTotalSize(%arg: i32) { // ----- func.func @failedOperandSizeAttrWrongCount(%arg: i32) { - // expected-error @+1 {{'operand_segment_sizes' attribute for specifying operand segments must have 4 elements}} + // expected-error @+1 {{test.attr_sized_operands' op operand count (4) does not match with the total size (0) specified in attribute 'operand_segment_sizes}} "test.attr_sized_operands"(%arg, %arg, %arg, %arg) {operand_segment_sizes = array} : (i32, i32, i32, i32) -> () } @@ -433,14 +433,14 @@ func.func @succeededOperandSizeAttr(%arg: i32) { // ----- func.func @failedMissingResultSizeAttr() { - // expected-error @+1 {{requires dense i32 array attribute 'result_segment_sizes'}} + // expected-error @+1 {{op result count (4) does not match with the total size (0) specified in attribute 'result_segment_sizes'}} %0:4 = "test.attr_sized_results"() : () -> (i32, i32, i32, i32) } // ----- func.func @failedResultSizeAttrWrongType() { - // expected-error @+1 {{requires dense i32 array attribute 'result_segment_sizes'}} + // expected-error @+1 {{ op result count (4) does not match with the total size (0) specified in attribute 'result_segment_sizes'}} %0:4 = "test.attr_sized_results"() {result_segment_sizes = 10} : () -> (i32, i32, i32, i32) } @@ -448,7 +448,7 @@ func.func @failedResultSizeAttrWrongType() { // ----- func.func @failedResultSizeAttrWrongElementType() { - // expected-error @+1 {{requires dense i32 array attribute 'result_segment_sizes'}} + // expected-error @+1 {{ op result count (4) does not match with the total size (0) specified in attribute 'result_segment_sizes'}} %0:4 = "test.attr_sized_results"() {result_segment_sizes = array} : () -> (i32, i32, i32, i32) } @@ -469,7 +469,7 @@ func.func @failedResultSizeAttrWrongTotalSize() { // ----- func.func @failedResultSizeAttrWrongCount() { - // expected-error @+1 {{'result_segment_sizes' attribute for specifying result segments must have 4 elements, but got 3}} + // expected-error @+1 {{ op result count (4) does not match with the total size (0) specified in attribute 'result_segment_sizes'}} %0:4 = "test.attr_sized_results"() {result_segment_sizes = array} : () -> (i32, i32, i32, i32) } diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 8056f6d..966896b 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -839,8 +839,7 @@ def AttrSizedOperandOp : TEST_Op<"attr_sized_operands", Variadic:$a, Variadic:$b, I32:$c, - Variadic:$d, - DenseI32ArrayAttr:$operand_segment_sizes + Variadic:$d ); } diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py index 5e8414a..88f48d0 100644 --- a/mlir/test/python/dialects/linalg/ops.py +++ b/mlir/test/python/dialects/linalg/ops.py @@ -100,7 +100,7 @@ def testNamedStructuredOpGenericForm(): init_result = tensor.EmptyOp([4, 8], f32) # CHECK: "linalg.matmul"(%{{.*}}) # CHECK-SAME: cast = #linalg.type_fn - # CHECK-SAME: operand_segment_sizes = array + # CHECK-SAME: odsOperandSegmentSizes = array # CHECK-NEXT: ^bb0(%{{.*}}: f32, %{{.*}}: f32, %{{.*}}: f32): # CHECK-NEXT: arith.mulf{{.*}} (f32, f32) -> f32 # CHECK-NEXT: arith.addf{{.*}} (f32, f32) -> f32 diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index c069e7c..135308b 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -115,6 +115,10 @@ static const char *const adapterSegmentSizeAttrInitCode = R"( assert({0} && "missing segment size attribute for op"); auto sizeAttr = ::llvm::cast<::mlir::DenseI32ArrayAttr>({0}); )"; +static const char *const adapterSegmentSizeAttrInitCodeProperties = R"( + ::llvm::ArrayRef sizeAttr = {0}; +)"; + /// The code snippet to initialize the sizes for the value range calculation. /// /// {0}: The code to get the attribute. @@ -150,6 +154,29 @@ static const char *const valueRangeReturnCode = R"( std::next({0}, valueRange.first + valueRange.second)}; )"; +/// Read operand/result segment_size from bytecode. +static const char *const readBytecodeSegmentSize = R"( +if ($_reader.getBytecodeVersion() < /*kNativePropertiesODSSegmentSize=*/6) { + DenseI32ArrayAttr attr; + if (::mlir::failed($_reader.readAttribute(attr))) return failure(); + if (attr.size() > static_cast(sizeof($_storage) / sizeof(int32_t))) { + $_reader.emitError("size mismatch for operand/result_segment_size"); + return failure(); + } + llvm::copy(ArrayRef(attr), $_storage); +} else { + return $_reader.readSparseArray(MutableArrayRef($_storage)); +} +)"; + +/// Write operand/result segment_size to bytecode. +static const char *const writeBytecodeSegmentSize = R"( +if ($_writer.getBytecodeVersion() < /*kNativePropertiesODSSegmentSize=*/6) + $_writer.writeAttribute(DenseI32ArrayAttr::get(getContext(), $_storage)); +else + $_writer.writeSparseArray(ArrayRef($_storage)); +)"; + /// A header for indicating code sections. /// /// {0}: Some text, or a class name. @@ -343,6 +370,9 @@ public: return true; if (!op.getDialect().usePropertiesForAttributes()) return false; + if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments") || + op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) + return true; return llvm::any_of(getAttrMetadata(), [](const std::pair &it) { return !it.second.constraint || @@ -350,6 +380,14 @@ public: }); } + std::optional &getOperandSegmentsSize() { + return operandSegmentsSize; + } + + std::optional &getResultSegmentsSize() { + return resultSegmentsSize; + } + private: // Compute the attribute metadata. void computeAttrMetadata(); @@ -361,6 +399,13 @@ private: // The attribute metadata, mapped by name. llvm::MapVector attrMetadata; + + // Property + std::optional operandSegmentsSize; + std::string operandSegmentsSizeStorage; + std::optional resultSegmentsSize; + std::string resultSegmentsSizeStorage; + // The number of required attributes. unsigned numRequired; }; @@ -377,18 +422,50 @@ void OpOrAdaptorHelper::computeAttrMetadata() { attrMetadata.insert( {namedAttr.name, AttributeMetadata{namedAttr.name, !isOptional, attr}}); } + + auto makeProperty = [&](StringRef storageType) { + return Property( + /*storageType=*/storageType, + /*interfaceType=*/"::llvm::ArrayRef", + /*convertFromStorageCall=*/"$_storage", + /*assignToStorageCall=*/"::llvm::copy($_value, $_storage)", + /*convertToAttributeCall=*/ + "DenseI32ArrayAttr::get($_ctxt, $_storage)", + /*convertFromAttributeCall=*/ + "return convertFromAttribute($_storage, $_attr, $_diag);", + /*readFromMlirBytecodeCall=*/readBytecodeSegmentSize, + /*writeToMlirBytecodeCall=*/writeBytecodeSegmentSize, + /*hashPropertyCall=*/ + "llvm::hash_combine_range(std::begin($_storage), " + "std::end($_storage));", + /*StringRef defaultValue=*/""); + }; // Include key attributes from several traits as implicitly registered. if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { - attrMetadata.insert( - {operandSegmentAttrName, - AttributeMetadata{operandSegmentAttrName, /*isRequired=*/true, - /*attr=*/std::nullopt}}); + if (op.getDialect().usePropertiesForAttributes()) { + operandSegmentsSizeStorage = + llvm::formatv("int32_t[{0}]", op.getNumOperands()); + operandSegmentsSize = {"odsOperandSegmentSizes", + makeProperty(operandSegmentsSizeStorage)}; + } else { + attrMetadata.insert( + {operandSegmentAttrName, AttributeMetadata{operandSegmentAttrName, + /*isRequired=*/true, + /*attr=*/std::nullopt}}); + } } if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) { - attrMetadata.insert( - {resultSegmentAttrName, - AttributeMetadata{resultSegmentAttrName, /*isRequired=*/true, - /*attr=*/std::nullopt}}); + if (op.getDialect().usePropertiesForAttributes()) { + resultSegmentsSizeStorage = + llvm::formatv("int32_t[{0}]", op.getNumResults()); + resultSegmentsSize = {"odsResultSegmentSizes", + makeProperty(resultSegmentsSizeStorage)}; + } else { + attrMetadata.insert( + {resultSegmentAttrName, + AttributeMetadata{resultSegmentAttrName, /*isRequired=*/true, + /*attr=*/std::nullopt}}); + } } // Store the metadata in sorted order. @@ -660,14 +737,17 @@ static void genNativeTraitAttrVerifier(MethodBody &body, // Verify a few traits first so that we can use getODSOperands() and // getODSResults() in the rest of the verifier. auto &op = emitHelper.getOp(); - if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { - body << formatv(checkAttrSizedValueSegmentsCode, operandSegmentAttrName, - op.getNumOperands(), "operand", - emitHelper.emitErrorPrefix()); - } - if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) { - body << formatv(checkAttrSizedValueSegmentsCode, resultSegmentAttrName, - op.getNumResults(), "result", emitHelper.emitErrorPrefix()); + if (!op.getDialect().usePropertiesForAttributes()) { + if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { + body << formatv(checkAttrSizedValueSegmentsCode, operandSegmentAttrName, + op.getNumOperands(), "operand", + emitHelper.emitErrorPrefix()); + } + if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) { + body << formatv(checkAttrSizedValueSegmentsCode, resultSegmentAttrName, + op.getNumResults(), "result", + emitHelper.emitErrorPrefix()); + } } } @@ -964,14 +1044,16 @@ static void errorIfPruned(size_t line, Method *m, const Twine &methodName, void OpEmitter::genAttrNameGetters() { const llvm::MapVector &attributes = emitHelper.getAttrMetadata(); - + bool hasOperandSegmentsSize = + op.getDialect().usePropertiesForAttributes() && + op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"); // Emit the getAttributeNames method. { auto *method = opClass.addStaticInlineMethod( "::llvm::ArrayRef<::llvm::StringRef>", "getAttributeNames"); ERROR_IF_PRUNED(method, "getAttributeNames", op); auto &body = method->body(); - if (attributes.empty()) { + if (!hasOperandSegmentsSize && attributes.empty()) { body << " return {};"; // Nothing else to do if there are no registered attributes. Exit early. return; @@ -981,6 +1063,11 @@ void OpEmitter::genAttrNameGetters() { [&](StringRef attrName) { body << "::llvm::StringRef(\"" << attrName << "\")"; }); + if (hasOperandSegmentsSize) { + if (!attributes.empty()) + body << ", "; + body << "::llvm::StringRef(\"" << operandSegmentAttrName << "\")"; + } body << "};\n return ::llvm::ArrayRef(attrNames);"; } @@ -1033,6 +1120,26 @@ void OpEmitter::genAttrNameGetters() { "name, " + Twine(index)); } } + if (hasOperandSegmentsSize) { + std::string name = op.getGetterName(operandSegmentAttrName); + std::string methodName = name + "AttrName"; + // Generate the non-static variant. + { + auto *method = opClass.addInlineMethod("::mlir::StringAttr", methodName); + ERROR_IF_PRUNED(method, methodName, op); + method->body() + << " return (*this)->getName().getAttributeNames().back();"; + } + + // Generate the static variant. + { + auto *method = opClass.addStaticInlineMethod( + "::mlir::StringAttr", methodName, + MethodParameter("::mlir::OperationName", "name")); + ERROR_IF_PRUNED(method, methodName, op); + method->body() << " return name.getAttributeNames().back();"; + } + } } // Emit the getter for an attribute with the return type specified. @@ -1080,6 +1187,10 @@ void OpEmitter::genPropertiesSupport() { } for (const NamedProperty &prop : op.getProperties()) attrOrProperties.push_back(&prop); + if (emitHelper.getOperandSegmentsSize()) + attrOrProperties.push_back(&emitHelper.getOperandSegmentsSize().value()); + if (emitHelper.getResultSegmentsSize()) + attrOrProperties.push_back(&emitHelper.getResultSegmentsSize().value()); if (attrOrProperties.empty()) return; auto &setPropMethod = @@ -1104,6 +1215,7 @@ void OpEmitter::genPropertiesSupport() { auto &getInherentAttrMethod = opClass .addStaticMethod("std::optional", "getInherentAttr", + MethodParameter("::mlir::MLIRContext *", "ctx"), MethodParameter("const Properties &", "prop"), MethodParameter("llvm::StringRef", "name")) ->body(); @@ -1117,6 +1229,7 @@ void OpEmitter::genPropertiesSupport() { auto &populateInherentAttrsMethod = opClass .addStaticMethod("void", "populateInherentAttrs", + MethodParameter("::mlir::MLIRContext *", "ctx"), MethodParameter("const Properties &", "prop"), MethodParameter("::mlir::NamedAttrList &", "attrs")) ->body(); @@ -1318,6 +1431,60 @@ void OpEmitter::genPropertiesSupport() { << formatv(populateInherentAttrsMethodFmt, name); continue; } + // The ODS segment size property is "special": we expose it as an attribute + // even though it is a native property. + const auto *namedProperty = cast(attrOrProp); + StringRef name = namedProperty->name; + if (name != "odsOperandSegmentSizes" && name != "odsResultSegmentSizes") + continue; + auto &prop = namedProperty->prop; + FmtContext fctx; + fctx.addSubst("_ctxt", "ctx"); + fctx.addSubst("_storage", Twine("prop.") + name); + if (name == "odsOperandSegmentSizes") { + getInherentAttrMethod + << formatv(" if (name == \"odsOperandSegmentSizes\" || name == " + "\"{0}\") return ", + operandSegmentAttrName); + } else { + getInherentAttrMethod + << formatv(" if (name == \"odsResultSegmentSizes\" || name == " + "\"{0}\") return ", + resultSegmentAttrName); + } + getInherentAttrMethod << tgfmt(prop.getConvertToAttributeCall(), &fctx) + << ";\n"; + + if (name == "odsOperandSegmentSizes") { + setInherentAttrMethod << formatv( + " if (name == \"odsOperandSegmentSizes\" || name == " + "\"{0}\") {{", + operandSegmentAttrName); + } else { + setInherentAttrMethod + << formatv(" if (name == \"odsResultSegmentSizes\" || name == " + "\"{0}\") {{", + resultSegmentAttrName); + } + setInherentAttrMethod << formatv(R"decl( + auto arrAttr = dyn_cast_or_null(value); + if (!arrAttr) return; + if (arrAttr.size() != sizeof(prop.{0}) / sizeof(int32_t)) + return; + llvm::copy(arrAttr.asArrayRef(), prop.{0}); + return; + } +)decl", + name); + if (name == "odsOperandSegmentSizes") { + populateInherentAttrsMethod + << formatv(" attrs.append(\"{0}\", {1});\n", operandSegmentAttrName, + tgfmt(prop.getConvertToAttributeCall(), &fctx)); + } else { + populateInherentAttrsMethod + << formatv(" attrs.append(\"{0}\", {1});\n", resultSegmentAttrName, + tgfmt(prop.getConvertToAttributeCall(), &fctx)); + } } getInherentAttrMethod << " return std::nullopt;\n"; @@ -1815,8 +1982,13 @@ void OpEmitter::genNamedOperandGetters() { // array. std::string attrSizeInitCode; if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { - attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, - emitHelper.getAttr(operandSegmentAttrName)); + if (op.getDialect().usePropertiesForAttributes()) + attrSizeInitCode = formatv(adapterSegmentSizeAttrInitCodeProperties, + "getProperties().odsOperandSegmentSizes"); + + else + attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, + emitHelper.getAttr(operandSegmentAttrName)); } generateNamedOperandGetters( @@ -1851,10 +2023,11 @@ void OpEmitter::genNamedOperandSetters() { "range.first, range.second"; if (attrSizedOperands) { if (emitHelper.hasProperties()) - body << formatv( - ", ::mlir::MutableOperandRange::OperandSegment({0}u, " - "{getOperandSegmentSizesAttrName(), getProperties().{1}})", - i, operandSegmentAttrName); + body << formatv(", ::mlir::MutableOperandRange::OperandSegment({0}u, " + "{{getOperandSegmentSizesAttrName(), " + "DenseI32ArrayAttr::get(getContext(), " + "getProperties().odsOperandSegmentSizes)})", + i); else body << formatv( ", ::mlir::MutableOperandRange::OperandSegment({0}u, *{1})", i, @@ -1910,8 +2083,13 @@ void OpEmitter::genNamedResultGetters() { // Build the initializer string for the result segment size attribute. std::string attrSizeInitCode; if (attrSizedResults) { - attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, - emitHelper.getAttr(resultSegmentAttrName)); + if (op.getDialect().usePropertiesForAttributes()) + attrSizeInitCode = formatv(adapterSegmentSizeAttrInitCodeProperties, + "getProperties().odsResultSegmentSizes"); + + else + attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, + emitHelper.getAttr(resultSegmentAttrName)); } generateValueRangeStartAndEnd( @@ -2086,10 +2264,7 @@ void OpEmitter::genSeparateArgParamBuilder() { // the length of the type ranges. if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) { if (op.getDialect().usePropertiesForAttributes()) { - body << " (" << builderOpState - << ".getOrAddProperties()." << resultSegmentAttrName - << " = \n" - " odsBuilder.getDenseI32ArrayAttr({"; + body << " llvm::copy(ArrayRef({"; } else { std::string getterName = op.getGetterName(resultSegmentAttrName); body << " " << builderOpState << ".addAttribute(" << getterName @@ -2112,7 +2287,12 @@ void OpEmitter::genSeparateArgParamBuilder() { body << "static_cast(" << resultNames[i] << ".size())"; } }); - body << "}));\n"; + if (op.getDialect().usePropertiesForAttributes()) { + body << "}), " << builderOpState + << ".getOrAddProperties().odsResultSegmentSizes);\n"; + } else { + body << "}));\n"; + } } return; @@ -2706,17 +2886,7 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder( } // If the operation has the operand segment size attribute, add it here. - if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { - std::string sizes = op.getGetterName(operandSegmentAttrName); - if (op.getDialect().usePropertiesForAttributes()) { - body << " (" << builderOpState << ".getOrAddProperties()." - << operandSegmentAttrName << "= " - << "odsBuilder.getDenseI32ArrayAttr({"; - } else { - body << " " << builderOpState << ".addAttribute(" << sizes << "AttrName(" - << builderOpState << ".name), " - << "odsBuilder.getDenseI32ArrayAttr({"; - } + auto emitSegment = [&]() { interleaveComma(llvm::seq(0, op.getNumOperands()), body, [&](int i) { const NamedTypeConstraint &operand = op.getOperand(i); if (!operand.isVariableLength()) { @@ -2737,7 +2907,21 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder( body << "static_cast(" << getArgumentName(op, i) << ".size())"; } }); - body << "}));\n"; + }; + if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { + std::string sizes = op.getGetterName(operandSegmentAttrName); + if (op.getDialect().usePropertiesForAttributes()) { + body << " llvm::copy(ArrayRef({"; + emitSegment(); + body << "}), " << builderOpState + << ".getOrAddProperties().odsOperandSegmentSizes);\n"; + } else { + body << " " << builderOpState << ".addAttribute(" << sizes << "AttrName(" + << builderOpState << ".name), " + << "odsBuilder.getDenseI32ArrayAttr({"; + emitSegment(); + body << "}));\n"; + } } // Push all attributes to the result. @@ -3541,6 +3725,10 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter( } for (const NamedProperty &prop : op.getProperties()) attrOrProperties.push_back(&prop); + if (emitHelper.getOperandSegmentsSize()) + attrOrProperties.push_back(&emitHelper.getOperandSegmentsSize().value()); + if (emitHelper.getResultSegmentsSize()) + attrOrProperties.push_back(&emitHelper.getResultSegmentsSize().value()); assert(!attrOrProperties.empty()); std::string declarations = " struct Properties {\n"; llvm::raw_string_ostream os(declarations); @@ -3598,7 +3786,8 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter( if (attr) { storageType = attr->getStorageType(); } else { - if (name != operandSegmentAttrName && name != resultSegmentAttrName) { + if (name != "odsOperandSegmentSizes" && + name != "odsResultSegmentSizes") { report_fatal_error("unexpected AttributeMetadata"); } // TODO: update to use native integers. @@ -3710,8 +3899,13 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter( std::string sizeAttrInit; if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { - sizeAttrInit = formatv(adapterSegmentSizeAttrInitCode, - emitHelper.getAttr(operandSegmentAttrName)); + if (op.getDialect().usePropertiesForAttributes()) + sizeAttrInit = + formatv(adapterSegmentSizeAttrInitCodeProperties, + llvm::formatv("getProperties().odsOperandSegmentSizes")); + else + sizeAttrInit = formatv(adapterSegmentSizeAttrInitCode, + emitHelper.getAttr(operandSegmentAttrName)); } generateNamedOperandGetters(op, genericAdaptor, /*genericAdaptorBase=*/&genericAdaptorBase, diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp index 9a5e4d5..1e13179 100644 --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -1654,16 +1654,6 @@ void OperationFormat::genParserVariadicSegmentResolution(Operator &op, MethodBody &body) { if (!allOperands) { if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { - if (op.getDialect().usePropertiesForAttributes()) { - body << formatv(" " - "result.getOrAddProperties<{0}::Properties>().operand_" - "segment_sizes = " - "(parser.getBuilder().getDenseI32ArrayAttr({{", - op.getCppClassName()); - } else { - body << " result.addAttribute(\"operand_segment_sizes\", " - << "parser.getBuilder().getDenseI32ArrayAttr({"; - } auto interleaveFn = [&](const NamedTypeConstraint &operand) { // If the operand is variadic emit the parsed size. if (operand.isVariableLength()) @@ -1671,8 +1661,19 @@ void OperationFormat::genParserVariadicSegmentResolution(Operator &op, else body << "1"; }; - llvm::interleaveComma(op.getOperands(), body, interleaveFn); - body << "}));\n"; + if (op.getDialect().usePropertiesForAttributes()) { + body << "llvm::copy(ArrayRef({"; + llvm::interleaveComma(op.getOperands(), body, interleaveFn); + body << formatv("}), " + "result.getOrAddProperties<{0}::Properties>()." + "odsOperandSegmentSizes);\n", + op.getCppClassName()); + } else { + body << " result.addAttribute(\"operand_segment_sizes\", " + << "parser.getBuilder().getDenseI32ArrayAttr({"; + llvm::interleaveComma(op.getOperands(), body, interleaveFn); + body << "}));\n"; + } } for (const NamedTypeConstraint &operand : op.getOperands()) { if (!operand.isVariadicOfVariadic()) @@ -1697,16 +1698,6 @@ void OperationFormat::genParserVariadicSegmentResolution(Operator &op, if (!allResultTypes && op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) { - if (op.getDialect().usePropertiesForAttributes()) { - body << formatv( - " " - "result.getOrAddProperties<{0}::Properties>().result_segment_sizes = " - "(parser.getBuilder().getDenseI32ArrayAttr({{", - op.getCppClassName()); - } else { - body << " result.addAttribute(\"result_segment_sizes\", " - << "parser.getBuilder().getDenseI32ArrayAttr({"; - } auto interleaveFn = [&](const NamedTypeConstraint &result) { // If the result is variadic emit the parsed size. if (result.isVariableLength()) @@ -1714,8 +1705,20 @@ void OperationFormat::genParserVariadicSegmentResolution(Operator &op, else body << "1"; }; - llvm::interleaveComma(op.getResults(), body, interleaveFn); - body << "}));\n"; + if (op.getDialect().usePropertiesForAttributes()) { + body << "llvm::copy(ArrayRef({"; + llvm::interleaveComma(op.getResults(), body, interleaveFn); + body << formatv( + "}), " + "result.getOrAddProperties<{0}::Properties>().odsResultSegmentSizes" + ");\n", + op.getCppClassName()); + } else { + body << " result.addAttribute(\"odsResultSegmentSizes\", " + << "parser.getBuilder().getDenseI32ArrayAttr({"; + llvm::interleaveComma(op.getResults(), body, interleaveFn); + body << "}));\n"; + } } } diff --git a/mlir/unittests/IR/AdaptorTest.cpp b/mlir/unittests/IR/AdaptorTest.cpp index fcd19ad..e4dc4cd 100644 --- a/mlir/unittests/IR/AdaptorTest.cpp +++ b/mlir/unittests/IR/AdaptorTest.cpp @@ -39,7 +39,7 @@ TEST(Adaptor, GenericAdaptorsOperandAccess) { // value from the value 0. SmallVector> v = {0, 4}; OIListSimple::Properties prop; - prop.operand_segment_sizes = builder.getDenseI32ArrayAttr({1, 0, 1}); + llvm::copy(ArrayRef{1, 0, 1}, prop.odsOperandSegmentSizes); OIListSimple::GenericAdaptor>> d(v, {}, prop, {}); EXPECT_EQ(d.getArg0(), 0); diff --git a/mlir/unittests/IR/OpPropertiesTest.cpp b/mlir/unittests/IR/OpPropertiesTest.cpp index eda84a0..21ea448 100644 --- a/mlir/unittests/IR/OpPropertiesTest.cpp +++ b/mlir/unittests/IR/OpPropertiesTest.cpp @@ -115,13 +115,15 @@ public: // This alias is the only definition needed for enabling "properties" for this // operation. using Properties = TestProperties; - static std::optional getInherentAttr(const Properties &prop, + static std::optional getInherentAttr(MLIRContext *context, + const Properties &prop, StringRef name) { return std::nullopt; } static void setInherentAttr(Properties &prop, StringRef name, mlir::Attribute value) {} - static void populateInherentAttrs(const Properties &prop, + static void populateInherentAttrs(MLIRContext *context, + const Properties &prop, NamedAttrList &attrs) {} static LogicalResult verifyInherentAttrs(OperationName opName, NamedAttrList &attrs, -- 2.7.4