From a7cd64c9f12b25da7adb36a7e3be88a7ad89d649 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Mon, 24 Jul 2023 12:27:01 -0700 Subject: [PATCH] Revert "Update ODS variadic segments "magic" attributes to use native Properties" This reverts commit 20b93abca6516bbb23689c3777536fea04e46e14. One python test is broken, WIP. --- .../include/mlir/Bytecode/BytecodeImplementation.h | 123 --------- mlir/include/mlir/Bytecode/Encoding.h | 6 +- mlir/include/mlir/IR/ODSSupport.h | 7 - mlir/include/mlir/IR/OpBase.td | 1 - mlir/include/mlir/IR/OpDefinition.h | 1 - mlir/include/mlir/IR/OperationSupport.h | 6 +- mlir/include/mlir/TableGen/Property.h | 49 +--- mlir/lib/Bytecode/Reader/BytecodeReader.cpp | 30 +-- mlir/lib/Bytecode/Writer/IRNumbering.cpp | 6 +- mlir/lib/Bytecode/Writer/IRNumbering.h | 3 - mlir/lib/IR/ODSSupport.cpp | 28 +- mlir/lib/TableGen/Property.cpp | 87 ++++--- mlir/test/Dialect/LLVMIR/invalid.mlir | 2 +- mlir/test/IR/traits.mlir | 16 +- mlir/test/lib/Dialect/Test/TestOps.td | 3 +- mlir/test/python/dialects/linalg/ops.py | 2 +- mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 286 ++++----------------- mlir/tools/mlir-tblgen/OpFormatGen.cpp | 51 ++-- mlir/unittests/IR/AdaptorTest.cpp | 2 +- mlir/unittests/IR/OpPropertiesTest.cpp | 6 +- 20 files changed, 173 insertions(+), 542 deletions(-) diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h index 9c9aa7a..4e74c12 100644 --- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h +++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h @@ -20,7 +20,6 @@ #include "mlir/IR/DialectInterface.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Support/LogicalResult.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Twine.h" namespace mlir { @@ -40,9 +39,6 @@ public: /// Emit an error to the reader. virtual InFlightDiagnostic emitError(const Twine &msg = {}) = 0; - /// Return the bytecode version being read. - virtual uint64_t getBytecodeVersion() const = 0; - /// Read out a list of elements, invoking the provided callback for each /// element. The callback function may be in any of the following forms: /// * LogicalResult(T &) @@ -152,76 +148,6 @@ public: [this](int64_t &value) { return readSignedVarInt(value); }); } - /// Parse a variable length encoded integer whose low bit is used to encode an - /// unrelated flag, i.e: `(integerValue << 1) | (flag ? 1 : 0)`. - LogicalResult readVarIntWithFlag(uint64_t &result, bool &flag) { - if (failed(readVarInt(result))) - return failure(); - flag = result & 1; - result >>= 1; - return success(); - } - - /// Read a "small" sparse array of integer <= 32 bits elements, where - /// index/value pairs can be compressed when the array is small. - /// Note that only some position of the array will be read and the ones - /// not stored in the bytecode are gonne be left untouched. - /// If the provided array is too small for the stored indices, an error - /// will be returned. - template - LogicalResult readSparseArray(MutableArrayRef array) { - static_assert(sizeof(T) < sizeof(uint64_t), "expect integer < 64 bits"); - static_assert(std::is_integral::value, "expects integer"); - uint64_t nonZeroesCount; - bool useSparseEncoding; - if (failed(readVarIntWithFlag(nonZeroesCount, useSparseEncoding))) - return failure(); - if (nonZeroesCount == 0) - return success(); - if (!useSparseEncoding) { - // This is a simple dense array. - if (nonZeroesCount > array.size()) { - emitError("trying to read an array of ") - << nonZeroesCount << " but only " << array.size() - << " storage available."; - return failure(); - } - for (int64_t index : llvm::seq(0, nonZeroesCount)) { - uint64_t value; - if (failed(readVarInt(value))) - return failure(); - array[index] = value; - } - return success(); - } - // Read sparse encoding - // This is the number of bits used for packing the index with the value. - uint64_t indexBitSize; - if (failed(readVarInt(indexBitSize))) - return failure(); - constexpr uint64_t maxIndexBitSize = 8; - if (indexBitSize > maxIndexBitSize) { - emitError("reading sparse array with indexing above 8 bits: ") - << indexBitSize; - return failure(); - } - for (uint32_t count : llvm::seq(0, nonZeroesCount)) { - (void)count; - uint64_t indexValuePair; - if (failed(readVarInt(indexValuePair))) - return failure(); - uint64_t index = indexValuePair & ~(uint64_t(-1) << (indexBitSize)); - uint64_t value = indexValuePair >> indexBitSize; - if (index >= array.size()) { - emitError("reading a sparse array found index ") - << index << " but only " << array.size() << " storage available."; - return failure(); - } - array[index] = value; - } - return success(); - } - /// Read an APInt that is known to have been encoded with the given width. virtual FailureOr readAPIntWithKnownWidth(unsigned bitWidth) = 0; @@ -304,55 +230,6 @@ public: writeList(value, [this](int64_t value) { writeSignedVarInt(value); }); } - /// Write a VarInt and a flag packed together. - void writeVarIntWithFlag(uint64_t value, bool flag) { - writeVarInt((value << 1) | (flag ? 1 : 0)); - } - - /// Write out a "small" sparse array of integer <= 32 bits elements, where - /// index/value pairs can be compressed when the array is small. This method - /// will scan the array multiple times and should not be used for large - /// arrays. The optional provided "zero" can be used to adjust for the - /// expected repeated value. We assume here that the array size fits in a 32 - /// bits integer. - template - void writeSparseArray(ArrayRef array) { - static_assert(sizeof(T) < sizeof(uint64_t), "expect integer < 64 bits"); - static_assert(std::is_integral::value, "expects integer"); - uint32_t size = array.size(); - uint32_t nonZeroesCount = 0, lastIndex = 0; - for (uint32_t index : llvm::seq(0, size)) { - if (!array[index]) - continue; - nonZeroesCount++; - lastIndex = index; - } - // If the last position is too large, or the array isn't at least 50% - // sparse, emit it with a dense encoding. - if (lastIndex > 256 || nonZeroesCount > size / 2) { - // Emit the array size and a flag which indicates whether it is sparse. - writeVarIntWithFlag(size, false); - for (const T &elt : array) - writeVarInt(elt); - return; - } - // Emit sparse: first the number of elements we'll write and a flag - // indicating it is a sparse encoding. - writeVarIntWithFlag(nonZeroesCount, true); - if (nonZeroesCount == 0) - return; - // This is the number of bits used for packing the index with the value. - int indexBitSize = llvm::Log2_32_Ceil(lastIndex + 1); - writeVarInt(indexBitSize); - for (uint32_t index : llvm::seq(0, lastIndex + 1)) { - T value = array[index]; - if (!value) - continue; - uint64_t indexValuePair = (value << indexBitSize) | (index); - writeVarInt(indexValuePair); - } - } - /// Write an APInt to the bytecode stream whose bitwidth will be known /// externally at read time. This method is useful for encoding APInt values /// when the width is known via external means, such as via a type. This diff --git a/mlir/include/mlir/Bytecode/Encoding.h b/mlir/include/mlir/Bytecode/Encoding.h index ac38269..21edc1e 100644 --- a/mlir/include/mlir/Bytecode/Encoding.h +++ b/mlir/include/mlir/Bytecode/Encoding.h @@ -45,12 +45,8 @@ enum BytecodeVersion { /// with the discardable attributes. kNativePropertiesEncoding = 5, - /// ODS emits operand/result segment_size as native properties instead of - /// an attribute. - kNativePropertiesODSSegmentSize = 6, - /// The current bytecode version. - kVersion = 6, + kVersion = 5, /// An arbitrary value used to fill alignment padding. kAlignmentByte = 0xCB, diff --git a/mlir/include/mlir/IR/ODSSupport.h b/mlir/include/mlir/IR/ODSSupport.h index 687f764..1d3cbbd 100644 --- a/mlir/include/mlir/IR/ODSSupport.h +++ b/mlir/include/mlir/IR/ODSSupport.h @@ -37,13 +37,6 @@ Attribute convertToAttribute(MLIRContext *ctx, int64_t storage); LogicalResult convertFromAttribute(MutableArrayRef storage, Attribute attr, InFlightDiagnostic *diag); -/// Convert a DenseI32ArrayAttr to the provided storage. It is expected that the -/// storage has the same size as the array. An error is returned if the -/// attribute isn't a DenseI32ArrayAttr or it does not have the same size. If -/// the optional diagnostic is provided an error message is also emitted. -LogicalResult convertFromAttribute(MutableArrayRef storage, - Attribute attr, InFlightDiagnostic *diag); - /// Convert the provided ArrayRef to a DenseI64ArrayAttr attribute. Attribute convertToAttribute(MLIRContext *ctx, ArrayRef storage); diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 274a531..940588b 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1241,7 +1241,6 @@ class ArrayProperty : let interfaceType = "::llvm::ArrayRef<" # storageTypeParam # ">"; let convertFromStorage = "$_storage"; let assignToStorage = "::llvm::copy($_value, $_storage)"; - let hashProperty = "llvm::hash_combine_range(std::begin($_storage), std::end($_storage));"; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index d42bffa..221c607 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -20,7 +20,6 @@ #define MLIR_IR_OPDEFINITION_H #include "mlir/IR/Dialect.h" -#include "mlir/IR/ODSSupport.h" #include "mlir/IR/Operation.h" #include "llvm/Support/PointerLikeTypeTraits.h" diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h index f3a79eb..303e030 100644 --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -555,8 +555,7 @@ public: StringRef name) final { if constexpr (hasProperties) { auto concreteOp = cast(op); - return ConcreteOp::getInherentAttr(concreteOp.getContext(), - concreteOp.getProperties(), name); + return ConcreteOp::getInherentAttr(concreteOp.getProperties(), name); } // If the op does not have support for properties, we dispatch back to the // dictionnary of discardable attributes for now. @@ -576,8 +575,7 @@ public: void populateInherentAttrs(Operation *op, NamedAttrList &attrs) final { if constexpr (hasProperties) { auto concreteOp = cast(op); - ConcreteOp::populateInherentAttrs(concreteOp.getContext(), - concreteOp.getProperties(), attrs); + ConcreteOp::populateInherentAttrs(concreteOp.getProperties(), attrs); } } LogicalResult diff --git a/mlir/include/mlir/TableGen/Property.h b/mlir/include/mlir/TableGen/Property.h index d0d6f49..597543d 100644 --- a/mlir/include/mlir/TableGen/Property.h +++ b/mlir/include/mlir/TableGen/Property.h @@ -35,76 +35,51 @@ class Property { public: explicit Property(const llvm::Record *record); explicit Property(const llvm::DefInit *init); - Property(StringRef storageType, StringRef interfaceType, - StringRef convertFromStorageCall, StringRef assignToStorageCall, - StringRef convertToAttributeCall, StringRef convertFromAttributeCall, - StringRef readFromMlirBytecodeCall, - StringRef writeToMlirBytecodeCall, StringRef hashPropertyCall, - StringRef defaultValue); // Returns the storage type. - StringRef getStorageType() const { return storageType; } + StringRef getStorageType() const; // Returns the interface type for this property. - StringRef getInterfaceType() const { return interfaceType; } + StringRef getInterfaceType() const; // Returns the template getter method call which reads this property's // storage and returns the value as of the desired return type. - StringRef getConvertFromStorageCall() const { return convertFromStorageCall; } + StringRef getConvertFromStorageCall() const; // Returns the template setter method call which reads this property's // in the provided interface type and assign it to the storage. - StringRef getAssignToStorageCall() const { return assignToStorageCall; } + StringRef getAssignToStorageCall() const; // Returns the conversion method call which reads this property's // in the storage type and builds an attribute. - StringRef getConvertToAttributeCall() const { return convertToAttributeCall; } + StringRef getConvertToAttributeCall() const; // Returns the setter method call which reads this property's // in the provided interface type and assign it to the storage. - StringRef getConvertFromAttributeCall() const { - return convertFromAttributeCall; - } + StringRef getConvertFromAttributeCall() const; // Returns the method call which reads this property from // bytecode and assign it to the storage. - StringRef getReadFromMlirBytecodeCall() const { - return readFromMlirBytecodeCall; - } + StringRef getReadFromMlirBytecodeCall() const; // Returns the method call which write this property's // to the the bytecode. - StringRef getWriteToMlirBytecodeCall() const { - return writeToMlirBytecodeCall; - } + StringRef getWriteToMlirBytecodeCall() const; // Returns the code to compute the hash for this property. - StringRef getHashPropertyCall() const { return hashPropertyCall; } + StringRef getHashPropertyCall() const; // Returns whether this Property has a default value. - bool hasDefaultValue() const { return !defaultValue.empty(); } - + bool hasDefaultValue() const; // Returns the default value for this Property. - StringRef getDefaultValue() const { return defaultValue; } + StringRef getDefaultValue() const; // Returns the TableGen definition this Property was constructed from. - const llvm::Record &getDef() const { return *def; } + const llvm::Record &getDef() const; private: // The TableGen definition of this constraint. const llvm::Record *def; - - // Elements describing a Property, in general fetched from the record. - StringRef storageType; - StringRef interfaceType; - StringRef convertFromStorageCall; - StringRef assignToStorageCall; - StringRef convertToAttributeCall; - StringRef convertFromAttributeCall; - StringRef readFromMlirBytecodeCall; - StringRef writeToMlirBytecodeCall; - StringRef hashPropertyCall; - StringRef defaultValue; }; // A struct wrapping an op property and its name together diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp index 0639baf..8269546 100644 --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -796,10 +796,9 @@ class AttrTypeReader { public: AttrTypeReader(StringSectionReader &stringReader, - ResourceSectionReader &resourceReader, Location fileLoc, - uint64_t &bytecodeVersion) + ResourceSectionReader &resourceReader, Location fileLoc) : stringReader(stringReader), resourceReader(resourceReader), - fileLoc(fileLoc), bytecodeVersion(bytecodeVersion) {} + fileLoc(fileLoc) {} /// Initialize the attribute and type information within the reader. LogicalResult initialize(MutableArrayRef dialects, @@ -884,30 +883,23 @@ private: /// A location used for error emission. Location fileLoc; - - /// Current bytecode version being used. - uint64_t &bytecodeVersion; }; class DialectReader : public DialectBytecodeReader { public: DialectReader(AttrTypeReader &attrTypeReader, StringSectionReader &stringReader, - ResourceSectionReader &resourceReader, EncodingReader &reader, - uint64_t &bytecodeVersion) + ResourceSectionReader &resourceReader, EncodingReader &reader) : attrTypeReader(attrTypeReader), stringReader(stringReader), - resourceReader(resourceReader), reader(reader), - bytecodeVersion(bytecodeVersion) {} + resourceReader(resourceReader), reader(reader) {} InFlightDiagnostic emitError(const Twine &msg) override { return reader.emitError(msg); } - uint64_t getBytecodeVersion() const override { return bytecodeVersion; } - DialectReader withEncodingReader(EncodingReader &encReader) { return DialectReader(attrTypeReader, stringReader, resourceReader, - encReader, bytecodeVersion); + encReader); } Location getLoc() const { return reader.getLoc(); } @@ -1011,7 +1003,6 @@ private: StringSectionReader &stringReader; ResourceSectionReader &resourceReader; EncodingReader &reader; - uint64_t &bytecodeVersion; }; /// Wraps the properties section and handles reading properties out of it. @@ -1216,8 +1207,7 @@ template LogicalResult AttrTypeReader::parseCustomEntry(Entry &entry, EncodingReader &reader, StringRef entryType) { - DialectReader dialectReader(*this, stringReader, resourceReader, reader, - bytecodeVersion); + DialectReader dialectReader(*this, stringReader, resourceReader, reader); if (failed(entry.dialect->load(dialectReader, fileLoc.getContext()))) return failure(); // Ensure that the dialect implements the bytecode interface. @@ -1262,7 +1252,7 @@ public: llvm::MemoryBufferRef buffer, const std::shared_ptr &bufferOwnerRef) : config(config), fileLoc(fileLoc), lazyLoading(lazyLoading), - attrTypeReader(stringReader, resourceReader, fileLoc, version), + attrTypeReader(stringReader, resourceReader, fileLoc), // Use the builtin unrealized conversion cast operation to represent // forward references to values that aren't yet defined. forwardRefOpState(UnknownLoc::get(config.getContext()), @@ -1792,7 +1782,7 @@ BytecodeReader::Impl::parseOpName(EncodingReader &reader, if (!opName->opName) { // Load the dialect and its version. DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, - reader, version); + reader); if (failed(opName->dialect->load(dialectReader, getContext()))) return failure(); // If the opName is empty, this is because we use to accept names such as @@ -1835,7 +1825,7 @@ LogicalResult BytecodeReader::Impl::parseResourceSection( // Initialize the resource reader with the resource sections. DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, - reader, version); + reader); return resourceReader.initialize(fileLoc, config, dialects, stringReader, *resourceData, *resourceOffsetData, dialectReader, bufferOwnerRef); @@ -2196,7 +2186,7 @@ BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader, // interface and control the serialization. if (wasRegistered) { DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, - reader, version); + reader); if (failed( propertiesReader.read(fileLoc, dialectReader, &*opName, opState))) return failure(); diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp index 284b3c0..2547d81 100644 --- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp +++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp @@ -48,7 +48,7 @@ struct IRNumberingState::NumberingDialectWriter : public DialectBytecodeWriter { void writeOwnedBool(bool value) override {} int64_t getBytecodeVersion() const override { - return state.getDesiredBytecodeVersion(); + llvm_unreachable("unexpected querying of version in IRNumbering"); } /// The parent numbering state that is populated by this writer. @@ -391,10 +391,6 @@ void IRNumberingState::number(Dialect *dialect, } } -int64_t IRNumberingState::getDesiredBytecodeVersion() const { - return config.getDesiredBytecodeVersion(); -} - namespace { /// A dummy resource builder used to number dialect resources. struct NumberingResourceBuilder : public AsmResourceBuilder { diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.h b/mlir/lib/Bytecode/Writer/IRNumbering.h index ca30078..c10e09a 100644 --- a/mlir/lib/Bytecode/Writer/IRNumbering.h +++ b/mlir/lib/Bytecode/Writer/IRNumbering.h @@ -186,9 +186,6 @@ public: return blockOperationCounts[block]; } - /// Get the set desired bytecode version to emit. - int64_t getDesiredBytecodeVersion() const; - private: /// This class is used to provide a fake dialect writer for numbering nested /// attributes and types. diff --git a/mlir/lib/IR/ODSSupport.cpp b/mlir/lib/IR/ODSSupport.cpp index f67e7db..ffb84f0 100644 --- a/mlir/lib/IR/ODSSupport.cpp +++ b/mlir/lib/IR/ODSSupport.cpp @@ -33,40 +33,24 @@ LogicalResult mlir::convertFromAttribute(int64_t &storage, Attribute mlir::convertToAttribute(MLIRContext *ctx, int64_t storage) { return IntegerAttr::get(IntegerType::get(ctx, 64), storage); } - -template -LogicalResult convertDenseArrayFromAttr(MutableArrayRef storage, - ::mlir::Attribute attr, - ::mlir::InFlightDiagnostic *diag, - StringRef denseArrayTyStr) { - auto valueAttr = dyn_cast(attr); +LogicalResult mlir::convertFromAttribute(MutableArrayRef storage, + ::mlir::Attribute attr, + ::mlir::InFlightDiagnostic *diag) { + auto valueAttr = dyn_cast(attr); if (!valueAttr) { if (diag) - *diag << "expected " << denseArrayTyStr << " for key `value`"; + *diag << "expected DenseI64ArrayAttr for key `value`"; return failure(); } if (valueAttr.size() != static_cast(storage.size())) { if (diag) - *diag << "size mismatch in attribute conversion: " << valueAttr.size() + *diag << "Size mismatch in attribute conversion: " << valueAttr.size() << " vs " << storage.size(); return failure(); } llvm::copy(valueAttr.asArrayRef(), storage.begin()); return success(); } -LogicalResult mlir::convertFromAttribute(MutableArrayRef storage, - ::mlir::Attribute attr, - ::mlir::InFlightDiagnostic *diag) { - return convertDenseArrayFromAttr(storage, attr, diag, - "DenseI64ArrayAttr"); -} -LogicalResult mlir::convertFromAttribute(MutableArrayRef storage, - Attribute attr, - InFlightDiagnostic *diag) { - return convertDenseArrayFromAttr(storage, attr, diag, - "DenseI32ArrayAttr"); -} - Attribute mlir::convertToAttribute(MLIRContext *ctx, ArrayRef storage) { return DenseI64ArrayAttr::get(ctx, storage); diff --git a/mlir/lib/TableGen/Property.cpp b/mlir/lib/TableGen/Property.cpp index e61d2fd..b0bea43 100644 --- a/mlir/lib/TableGen/Property.cpp +++ b/mlir/lib/TableGen/Property.cpp @@ -32,40 +32,65 @@ static StringRef getValueAsString(const Init *init) { return {}; } -Property::Property(const Record *def) - : Property(getValueAsString(def->getValueInit("storageType")), - getValueAsString(def->getValueInit("interfaceType")), - getValueAsString(def->getValueInit("convertFromStorage")), - getValueAsString(def->getValueInit("assignToStorage")), - getValueAsString(def->getValueInit("convertToAttribute")), - getValueAsString(def->getValueInit("convertFromAttribute")), - getValueAsString(def->getValueInit("readFromMlirBytecode")), - getValueAsString(def->getValueInit("writeToMlirBytecode")), - getValueAsString(def->getValueInit("hashProperty")), - getValueAsString(def->getValueInit("defaultValue"))) { - this->def = def; - assert((def->isSubClassOf("Property") || def->isSubClassOf("Attr")) && +Property::Property(const Record *record) : def(record) { + assert((record->isSubClassOf("Property") || record->isSubClassOf("Attr")) && "must be subclass of TableGen 'Property' class"); } Property::Property(const DefInit *init) : Property(init->getDef()) {} -Property::Property(StringRef storageType, StringRef interfaceType, - StringRef convertFromStorageCall, - StringRef assignToStorageCall, - StringRef convertToAttributeCall, - StringRef convertFromAttributeCall, - StringRef readFromMlirBytecodeCall, - StringRef writeToMlirBytecodeCall, - StringRef hashPropertyCall, StringRef defaultValue) - : storageType(storageType), interfaceType(interfaceType), - convertFromStorageCall(convertFromStorageCall), - assignToStorageCall(assignToStorageCall), - convertToAttributeCall(convertToAttributeCall), - convertFromAttributeCall(convertFromAttributeCall), - readFromMlirBytecodeCall(readFromMlirBytecodeCall), - writeToMlirBytecodeCall(writeToMlirBytecodeCall), - hashPropertyCall(hashPropertyCall), defaultValue(defaultValue) { - if (storageType.empty()) - storageType = "Property"; +StringRef Property::getStorageType() const { + const auto *init = def->getValueInit("storageType"); + auto type = getValueAsString(init); + if (type.empty()) + return "Property"; + return type; } + +StringRef Property::getInterfaceType() const { + const auto *init = def->getValueInit("interfaceType"); + return getValueAsString(init); +} + +StringRef Property::getConvertFromStorageCall() const { + const auto *init = def->getValueInit("convertFromStorage"); + return getValueAsString(init); +} + +StringRef Property::getAssignToStorageCall() const { + const auto *init = def->getValueInit("assignToStorage"); + return getValueAsString(init); +} + +StringRef Property::getConvertToAttributeCall() const { + const auto *init = def->getValueInit("convertToAttribute"); + return getValueAsString(init); +} + +StringRef Property::getConvertFromAttributeCall() const { + const auto *init = def->getValueInit("convertFromAttribute"); + return getValueAsString(init); +} + +StringRef Property::getReadFromMlirBytecodeCall() const { + const auto *init = def->getValueInit("readFromMlirBytecode"); + return getValueAsString(init); +} + +StringRef Property::getWriteToMlirBytecodeCall() const { + const auto *init = def->getValueInit("writeToMlirBytecode"); + return getValueAsString(init); +} + +StringRef Property::getHashPropertyCall() const { + return getValueAsString(def->getValueInit("hashProperty")); +} + +bool Property::hasDefaultValue() const { return !getDefaultValue().empty(); } + +StringRef Property::getDefaultValue() const { + const auto *init = def->getValueInit("defaultValue"); + return getValueAsString(init); +} + +const llvm::Record &Property::getDef() const { return *def; } diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index 14141c4..09bbc5a 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -887,7 +887,7 @@ func.func @switch_wrong_number_of_weights(%arg0 : i32) { func.func @switch_case_type_mismatch(%arg0 : i64) { // expected-error@below {{expects case value type to match condition value type}} - "llvm.switch"(%arg0)[^bb1, ^bb2] <{case_operand_segments = array, case_values = dense<42> : vector<1xi32>, odsOperandSegmentSizes = array}> : (i64) -> () + "llvm.switch"(%arg0)[^bb1, ^bb2] <{case_operand_segments = array, case_values = dense<42> : vector<1xi32>, operand_segment_sizes = array}> : (i64) -> () ^bb1: // pred: ^bb0 llvm.return ^bb2: // pred: ^bb0 diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir index 7d922ec..307918d 100644 --- a/mlir/test/IR/traits.mlir +++ b/mlir/test/IR/traits.mlir @@ -383,21 +383,21 @@ func.func private @foo() // ----- func.func @failedMissingOperandSizeAttr(%arg: i32) { - // expected-error @+1 {{op operand count (4) does not match with the total size (0) specified in attribute 'operand_segment_sizes'}} + // expected-error @+1 {{requires dense i32 array attribute 'operand_segment_sizes'}} "test.attr_sized_operands"(%arg, %arg, %arg, %arg) : (i32, i32, i32, i32) -> () } // ----- func.func @failedOperandSizeAttrWrongType(%arg: i32) { - // expected-error @+1 {{op operand count (4) does not match with the total size (0) specified in attribute 'operand_segment_sizes'}} + // expected-error @+1 {{attribute 'operand_segment_sizes' failed to satisfy constraint: i32 dense array attribute}} "test.attr_sized_operands"(%arg, %arg, %arg, %arg) {operand_segment_sizes = 10} : (i32, i32, i32, i32) -> () } // ----- func.func @failedOperandSizeAttrWrongElementType(%arg: i32) { - // expected-error @+1 {{op operand count (4) does not match with the total size (0) specified in attribute 'operand_segment_sizes'}} + // expected-error @+1 {{attribute 'operand_segment_sizes' failed to satisfy constraint: i32 dense array attribute}} "test.attr_sized_operands"(%arg, %arg, %arg, %arg) {operand_segment_sizes = array} : (i32, i32, i32, i32) -> () } @@ -418,7 +418,7 @@ func.func @failedOperandSizeAttrWrongTotalSize(%arg: i32) { // ----- func.func @failedOperandSizeAttrWrongCount(%arg: i32) { - // expected-error @+1 {{test.attr_sized_operands' op operand count (4) does not match with the total size (0) specified in attribute 'operand_segment_sizes}} + // expected-error @+1 {{'operand_segment_sizes' attribute for specifying operand segments must have 4 elements}} "test.attr_sized_operands"(%arg, %arg, %arg, %arg) {operand_segment_sizes = array} : (i32, i32, i32, i32) -> () } @@ -433,14 +433,14 @@ func.func @succeededOperandSizeAttr(%arg: i32) { // ----- func.func @failedMissingResultSizeAttr() { - // expected-error @+1 {{op result count (4) does not match with the total size (0) specified in attribute 'result_segment_sizes'}} + // expected-error @+1 {{requires dense i32 array attribute 'result_segment_sizes'}} %0:4 = "test.attr_sized_results"() : () -> (i32, i32, i32, i32) } // ----- func.func @failedResultSizeAttrWrongType() { - // expected-error @+1 {{ op result count (4) does not match with the total size (0) specified in attribute 'result_segment_sizes'}} + // expected-error @+1 {{requires dense i32 array attribute 'result_segment_sizes'}} %0:4 = "test.attr_sized_results"() {result_segment_sizes = 10} : () -> (i32, i32, i32, i32) } @@ -448,7 +448,7 @@ func.func @failedResultSizeAttrWrongType() { // ----- func.func @failedResultSizeAttrWrongElementType() { - // expected-error @+1 {{ op result count (4) does not match with the total size (0) specified in attribute 'result_segment_sizes'}} + // expected-error @+1 {{requires dense i32 array attribute 'result_segment_sizes'}} %0:4 = "test.attr_sized_results"() {result_segment_sizes = array} : () -> (i32, i32, i32, i32) } @@ -469,7 +469,7 @@ func.func @failedResultSizeAttrWrongTotalSize() { // ----- func.func @failedResultSizeAttrWrongCount() { - // expected-error @+1 {{ op result count (4) does not match with the total size (0) specified in attribute 'result_segment_sizes'}} + // expected-error @+1 {{'result_segment_sizes' attribute for specifying result segments must have 4 elements, but got 3}} %0:4 = "test.attr_sized_results"() {result_segment_sizes = array} : () -> (i32, i32, i32, i32) } diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 966896b..8056f6d 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -839,7 +839,8 @@ def AttrSizedOperandOp : TEST_Op<"attr_sized_operands", Variadic:$a, Variadic:$b, I32:$c, - Variadic:$d + Variadic:$d, + DenseI32ArrayAttr:$operand_segment_sizes ); } diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py index 88f48d0..5e8414a 100644 --- a/mlir/test/python/dialects/linalg/ops.py +++ b/mlir/test/python/dialects/linalg/ops.py @@ -100,7 +100,7 @@ def testNamedStructuredOpGenericForm(): init_result = tensor.EmptyOp([4, 8], f32) # CHECK: "linalg.matmul"(%{{.*}}) # CHECK-SAME: cast = #linalg.type_fn - # CHECK-SAME: odsOperandSegmentSizes = array + # CHECK-SAME: operand_segment_sizes = array # CHECK-NEXT: ^bb0(%{{.*}}: f32, %{{.*}}: f32, %{{.*}}: f32): # CHECK-NEXT: arith.mulf{{.*}} (f32, f32) -> f32 # CHECK-NEXT: arith.addf{{.*}} (f32, f32) -> f32 diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 135308b..c069e7c 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -115,10 +115,6 @@ static const char *const adapterSegmentSizeAttrInitCode = R"( assert({0} && "missing segment size attribute for op"); auto sizeAttr = ::llvm::cast<::mlir::DenseI32ArrayAttr>({0}); )"; -static const char *const adapterSegmentSizeAttrInitCodeProperties = R"( - ::llvm::ArrayRef sizeAttr = {0}; -)"; - /// The code snippet to initialize the sizes for the value range calculation. /// /// {0}: The code to get the attribute. @@ -154,29 +150,6 @@ static const char *const valueRangeReturnCode = R"( std::next({0}, valueRange.first + valueRange.second)}; )"; -/// Read operand/result segment_size from bytecode. -static const char *const readBytecodeSegmentSize = R"( -if ($_reader.getBytecodeVersion() < /*kNativePropertiesODSSegmentSize=*/6) { - DenseI32ArrayAttr attr; - if (::mlir::failed($_reader.readAttribute(attr))) return failure(); - if (attr.size() > static_cast(sizeof($_storage) / sizeof(int32_t))) { - $_reader.emitError("size mismatch for operand/result_segment_size"); - return failure(); - } - llvm::copy(ArrayRef(attr), $_storage); -} else { - return $_reader.readSparseArray(MutableArrayRef($_storage)); -} -)"; - -/// Write operand/result segment_size to bytecode. -static const char *const writeBytecodeSegmentSize = R"( -if ($_writer.getBytecodeVersion() < /*kNativePropertiesODSSegmentSize=*/6) - $_writer.writeAttribute(DenseI32ArrayAttr::get(getContext(), $_storage)); -else - $_writer.writeSparseArray(ArrayRef($_storage)); -)"; - /// A header for indicating code sections. /// /// {0}: Some text, or a class name. @@ -370,9 +343,6 @@ public: return true; if (!op.getDialect().usePropertiesForAttributes()) return false; - if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments") || - op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) - return true; return llvm::any_of(getAttrMetadata(), [](const std::pair &it) { return !it.second.constraint || @@ -380,14 +350,6 @@ public: }); } - std::optional &getOperandSegmentsSize() { - return operandSegmentsSize; - } - - std::optional &getResultSegmentsSize() { - return resultSegmentsSize; - } - private: // Compute the attribute metadata. void computeAttrMetadata(); @@ -399,13 +361,6 @@ private: // The attribute metadata, mapped by name. llvm::MapVector attrMetadata; - - // Property - std::optional operandSegmentsSize; - std::string operandSegmentsSizeStorage; - std::optional resultSegmentsSize; - std::string resultSegmentsSizeStorage; - // The number of required attributes. unsigned numRequired; }; @@ -422,50 +377,18 @@ void OpOrAdaptorHelper::computeAttrMetadata() { attrMetadata.insert( {namedAttr.name, AttributeMetadata{namedAttr.name, !isOptional, attr}}); } - - auto makeProperty = [&](StringRef storageType) { - return Property( - /*storageType=*/storageType, - /*interfaceType=*/"::llvm::ArrayRef", - /*convertFromStorageCall=*/"$_storage", - /*assignToStorageCall=*/"::llvm::copy($_value, $_storage)", - /*convertToAttributeCall=*/ - "DenseI32ArrayAttr::get($_ctxt, $_storage)", - /*convertFromAttributeCall=*/ - "return convertFromAttribute($_storage, $_attr, $_diag);", - /*readFromMlirBytecodeCall=*/readBytecodeSegmentSize, - /*writeToMlirBytecodeCall=*/writeBytecodeSegmentSize, - /*hashPropertyCall=*/ - "llvm::hash_combine_range(std::begin($_storage), " - "std::end($_storage));", - /*StringRef defaultValue=*/""); - }; // Include key attributes from several traits as implicitly registered. if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { - if (op.getDialect().usePropertiesForAttributes()) { - operandSegmentsSizeStorage = - llvm::formatv("int32_t[{0}]", op.getNumOperands()); - operandSegmentsSize = {"odsOperandSegmentSizes", - makeProperty(operandSegmentsSizeStorage)}; - } else { - attrMetadata.insert( - {operandSegmentAttrName, AttributeMetadata{operandSegmentAttrName, - /*isRequired=*/true, - /*attr=*/std::nullopt}}); - } + attrMetadata.insert( + {operandSegmentAttrName, + AttributeMetadata{operandSegmentAttrName, /*isRequired=*/true, + /*attr=*/std::nullopt}}); } if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) { - if (op.getDialect().usePropertiesForAttributes()) { - resultSegmentsSizeStorage = - llvm::formatv("int32_t[{0}]", op.getNumResults()); - resultSegmentsSize = {"odsResultSegmentSizes", - makeProperty(resultSegmentsSizeStorage)}; - } else { - attrMetadata.insert( - {resultSegmentAttrName, - AttributeMetadata{resultSegmentAttrName, /*isRequired=*/true, - /*attr=*/std::nullopt}}); - } + attrMetadata.insert( + {resultSegmentAttrName, + AttributeMetadata{resultSegmentAttrName, /*isRequired=*/true, + /*attr=*/std::nullopt}}); } // Store the metadata in sorted order. @@ -737,17 +660,14 @@ static void genNativeTraitAttrVerifier(MethodBody &body, // Verify a few traits first so that we can use getODSOperands() and // getODSResults() in the rest of the verifier. auto &op = emitHelper.getOp(); - if (!op.getDialect().usePropertiesForAttributes()) { - if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { - body << formatv(checkAttrSizedValueSegmentsCode, operandSegmentAttrName, - op.getNumOperands(), "operand", - emitHelper.emitErrorPrefix()); - } - if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) { - body << formatv(checkAttrSizedValueSegmentsCode, resultSegmentAttrName, - op.getNumResults(), "result", - emitHelper.emitErrorPrefix()); - } + if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { + body << formatv(checkAttrSizedValueSegmentsCode, operandSegmentAttrName, + op.getNumOperands(), "operand", + emitHelper.emitErrorPrefix()); + } + if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) { + body << formatv(checkAttrSizedValueSegmentsCode, resultSegmentAttrName, + op.getNumResults(), "result", emitHelper.emitErrorPrefix()); } } @@ -1044,16 +964,14 @@ static void errorIfPruned(size_t line, Method *m, const Twine &methodName, void OpEmitter::genAttrNameGetters() { const llvm::MapVector &attributes = emitHelper.getAttrMetadata(); - bool hasOperandSegmentsSize = - op.getDialect().usePropertiesForAttributes() && - op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"); + // Emit the getAttributeNames method. { auto *method = opClass.addStaticInlineMethod( "::llvm::ArrayRef<::llvm::StringRef>", "getAttributeNames"); ERROR_IF_PRUNED(method, "getAttributeNames", op); auto &body = method->body(); - if (!hasOperandSegmentsSize && attributes.empty()) { + if (attributes.empty()) { body << " return {};"; // Nothing else to do if there are no registered attributes. Exit early. return; @@ -1063,11 +981,6 @@ void OpEmitter::genAttrNameGetters() { [&](StringRef attrName) { body << "::llvm::StringRef(\"" << attrName << "\")"; }); - if (hasOperandSegmentsSize) { - if (!attributes.empty()) - body << ", "; - body << "::llvm::StringRef(\"" << operandSegmentAttrName << "\")"; - } body << "};\n return ::llvm::ArrayRef(attrNames);"; } @@ -1120,26 +1033,6 @@ void OpEmitter::genAttrNameGetters() { "name, " + Twine(index)); } } - if (hasOperandSegmentsSize) { - std::string name = op.getGetterName(operandSegmentAttrName); - std::string methodName = name + "AttrName"; - // Generate the non-static variant. - { - auto *method = opClass.addInlineMethod("::mlir::StringAttr", methodName); - ERROR_IF_PRUNED(method, methodName, op); - method->body() - << " return (*this)->getName().getAttributeNames().back();"; - } - - // Generate the static variant. - { - auto *method = opClass.addStaticInlineMethod( - "::mlir::StringAttr", methodName, - MethodParameter("::mlir::OperationName", "name")); - ERROR_IF_PRUNED(method, methodName, op); - method->body() << " return name.getAttributeNames().back();"; - } - } } // Emit the getter for an attribute with the return type specified. @@ -1187,10 +1080,6 @@ void OpEmitter::genPropertiesSupport() { } for (const NamedProperty &prop : op.getProperties()) attrOrProperties.push_back(&prop); - if (emitHelper.getOperandSegmentsSize()) - attrOrProperties.push_back(&emitHelper.getOperandSegmentsSize().value()); - if (emitHelper.getResultSegmentsSize()) - attrOrProperties.push_back(&emitHelper.getResultSegmentsSize().value()); if (attrOrProperties.empty()) return; auto &setPropMethod = @@ -1215,7 +1104,6 @@ void OpEmitter::genPropertiesSupport() { auto &getInherentAttrMethod = opClass .addStaticMethod("std::optional", "getInherentAttr", - MethodParameter("::mlir::MLIRContext *", "ctx"), MethodParameter("const Properties &", "prop"), MethodParameter("llvm::StringRef", "name")) ->body(); @@ -1229,7 +1117,6 @@ void OpEmitter::genPropertiesSupport() { auto &populateInherentAttrsMethod = opClass .addStaticMethod("void", "populateInherentAttrs", - MethodParameter("::mlir::MLIRContext *", "ctx"), MethodParameter("const Properties &", "prop"), MethodParameter("::mlir::NamedAttrList &", "attrs")) ->body(); @@ -1431,60 +1318,6 @@ void OpEmitter::genPropertiesSupport() { << formatv(populateInherentAttrsMethodFmt, name); continue; } - // The ODS segment size property is "special": we expose it as an attribute - // even though it is a native property. - const auto *namedProperty = cast(attrOrProp); - StringRef name = namedProperty->name; - if (name != "odsOperandSegmentSizes" && name != "odsResultSegmentSizes") - continue; - auto &prop = namedProperty->prop; - FmtContext fctx; - fctx.addSubst("_ctxt", "ctx"); - fctx.addSubst("_storage", Twine("prop.") + name); - if (name == "odsOperandSegmentSizes") { - getInherentAttrMethod - << formatv(" if (name == \"odsOperandSegmentSizes\" || name == " - "\"{0}\") return ", - operandSegmentAttrName); - } else { - getInherentAttrMethod - << formatv(" if (name == \"odsResultSegmentSizes\" || name == " - "\"{0}\") return ", - resultSegmentAttrName); - } - getInherentAttrMethod << tgfmt(prop.getConvertToAttributeCall(), &fctx) - << ";\n"; - - if (name == "odsOperandSegmentSizes") { - setInherentAttrMethod << formatv( - " if (name == \"odsOperandSegmentSizes\" || name == " - "\"{0}\") {{", - operandSegmentAttrName); - } else { - setInherentAttrMethod - << formatv(" if (name == \"odsResultSegmentSizes\" || name == " - "\"{0}\") {{", - resultSegmentAttrName); - } - setInherentAttrMethod << formatv(R"decl( - auto arrAttr = dyn_cast_or_null(value); - if (!arrAttr) return; - if (arrAttr.size() != sizeof(prop.{0}) / sizeof(int32_t)) - return; - llvm::copy(arrAttr.asArrayRef(), prop.{0}); - return; - } -)decl", - name); - if (name == "odsOperandSegmentSizes") { - populateInherentAttrsMethod - << formatv(" attrs.append(\"{0}\", {1});\n", operandSegmentAttrName, - tgfmt(prop.getConvertToAttributeCall(), &fctx)); - } else { - populateInherentAttrsMethod - << formatv(" attrs.append(\"{0}\", {1});\n", resultSegmentAttrName, - tgfmt(prop.getConvertToAttributeCall(), &fctx)); - } } getInherentAttrMethod << " return std::nullopt;\n"; @@ -1982,13 +1815,8 @@ void OpEmitter::genNamedOperandGetters() { // array. std::string attrSizeInitCode; if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { - if (op.getDialect().usePropertiesForAttributes()) - attrSizeInitCode = formatv(adapterSegmentSizeAttrInitCodeProperties, - "getProperties().odsOperandSegmentSizes"); - - else - attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, - emitHelper.getAttr(operandSegmentAttrName)); + attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, + emitHelper.getAttr(operandSegmentAttrName)); } generateNamedOperandGetters( @@ -2023,11 +1851,10 @@ void OpEmitter::genNamedOperandSetters() { "range.first, range.second"; if (attrSizedOperands) { if (emitHelper.hasProperties()) - body << formatv(", ::mlir::MutableOperandRange::OperandSegment({0}u, " - "{{getOperandSegmentSizesAttrName(), " - "DenseI32ArrayAttr::get(getContext(), " - "getProperties().odsOperandSegmentSizes)})", - i); + body << formatv( + ", ::mlir::MutableOperandRange::OperandSegment({0}u, " + "{getOperandSegmentSizesAttrName(), getProperties().{1}})", + i, operandSegmentAttrName); else body << formatv( ", ::mlir::MutableOperandRange::OperandSegment({0}u, *{1})", i, @@ -2083,13 +1910,8 @@ void OpEmitter::genNamedResultGetters() { // Build the initializer string for the result segment size attribute. std::string attrSizeInitCode; if (attrSizedResults) { - if (op.getDialect().usePropertiesForAttributes()) - attrSizeInitCode = formatv(adapterSegmentSizeAttrInitCodeProperties, - "getProperties().odsResultSegmentSizes"); - - else - attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, - emitHelper.getAttr(resultSegmentAttrName)); + attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, + emitHelper.getAttr(resultSegmentAttrName)); } generateValueRangeStartAndEnd( @@ -2264,7 +2086,10 @@ void OpEmitter::genSeparateArgParamBuilder() { // the length of the type ranges. if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) { if (op.getDialect().usePropertiesForAttributes()) { - body << " llvm::copy(ArrayRef({"; + body << " (" << builderOpState + << ".getOrAddProperties()." << resultSegmentAttrName + << " = \n" + " odsBuilder.getDenseI32ArrayAttr({"; } else { std::string getterName = op.getGetterName(resultSegmentAttrName); body << " " << builderOpState << ".addAttribute(" << getterName @@ -2287,12 +2112,7 @@ void OpEmitter::genSeparateArgParamBuilder() { body << "static_cast(" << resultNames[i] << ".size())"; } }); - if (op.getDialect().usePropertiesForAttributes()) { - body << "}), " << builderOpState - << ".getOrAddProperties().odsResultSegmentSizes);\n"; - } else { - body << "}));\n"; - } + body << "}));\n"; } return; @@ -2886,7 +2706,17 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder( } // If the operation has the operand segment size attribute, add it here. - auto emitSegment = [&]() { + if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { + std::string sizes = op.getGetterName(operandSegmentAttrName); + if (op.getDialect().usePropertiesForAttributes()) { + body << " (" << builderOpState << ".getOrAddProperties()." + << operandSegmentAttrName << "= " + << "odsBuilder.getDenseI32ArrayAttr({"; + } else { + body << " " << builderOpState << ".addAttribute(" << sizes << "AttrName(" + << builderOpState << ".name), " + << "odsBuilder.getDenseI32ArrayAttr({"; + } interleaveComma(llvm::seq(0, op.getNumOperands()), body, [&](int i) { const NamedTypeConstraint &operand = op.getOperand(i); if (!operand.isVariableLength()) { @@ -2907,21 +2737,7 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder( body << "static_cast(" << getArgumentName(op, i) << ".size())"; } }); - }; - if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { - std::string sizes = op.getGetterName(operandSegmentAttrName); - if (op.getDialect().usePropertiesForAttributes()) { - body << " llvm::copy(ArrayRef({"; - emitSegment(); - body << "}), " << builderOpState - << ".getOrAddProperties().odsOperandSegmentSizes);\n"; - } else { - body << " " << builderOpState << ".addAttribute(" << sizes << "AttrName(" - << builderOpState << ".name), " - << "odsBuilder.getDenseI32ArrayAttr({"; - emitSegment(); - body << "}));\n"; - } + body << "}));\n"; } // Push all attributes to the result. @@ -3725,10 +3541,6 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter( } for (const NamedProperty &prop : op.getProperties()) attrOrProperties.push_back(&prop); - if (emitHelper.getOperandSegmentsSize()) - attrOrProperties.push_back(&emitHelper.getOperandSegmentsSize().value()); - if (emitHelper.getResultSegmentsSize()) - attrOrProperties.push_back(&emitHelper.getResultSegmentsSize().value()); assert(!attrOrProperties.empty()); std::string declarations = " struct Properties {\n"; llvm::raw_string_ostream os(declarations); @@ -3786,8 +3598,7 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter( if (attr) { storageType = attr->getStorageType(); } else { - if (name != "odsOperandSegmentSizes" && - name != "odsResultSegmentSizes") { + if (name != operandSegmentAttrName && name != resultSegmentAttrName) { report_fatal_error("unexpected AttributeMetadata"); } // TODO: update to use native integers. @@ -3899,13 +3710,8 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter( std::string sizeAttrInit; if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { - if (op.getDialect().usePropertiesForAttributes()) - sizeAttrInit = - formatv(adapterSegmentSizeAttrInitCodeProperties, - llvm::formatv("getProperties().odsOperandSegmentSizes")); - else - sizeAttrInit = formatv(adapterSegmentSizeAttrInitCode, - emitHelper.getAttr(operandSegmentAttrName)); + sizeAttrInit = formatv(adapterSegmentSizeAttrInitCode, + emitHelper.getAttr(operandSegmentAttrName)); } generateNamedOperandGetters(op, genericAdaptor, /*genericAdaptorBase=*/&genericAdaptorBase, diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp index 1e13179..9a5e4d5 100644 --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -1654,6 +1654,16 @@ void OperationFormat::genParserVariadicSegmentResolution(Operator &op, MethodBody &body) { if (!allOperands) { if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { + if (op.getDialect().usePropertiesForAttributes()) { + body << formatv(" " + "result.getOrAddProperties<{0}::Properties>().operand_" + "segment_sizes = " + "(parser.getBuilder().getDenseI32ArrayAttr({{", + op.getCppClassName()); + } else { + body << " result.addAttribute(\"operand_segment_sizes\", " + << "parser.getBuilder().getDenseI32ArrayAttr({"; + } auto interleaveFn = [&](const NamedTypeConstraint &operand) { // If the operand is variadic emit the parsed size. if (operand.isVariableLength()) @@ -1661,19 +1671,8 @@ void OperationFormat::genParserVariadicSegmentResolution(Operator &op, else body << "1"; }; - if (op.getDialect().usePropertiesForAttributes()) { - body << "llvm::copy(ArrayRef({"; - llvm::interleaveComma(op.getOperands(), body, interleaveFn); - body << formatv("}), " - "result.getOrAddProperties<{0}::Properties>()." - "odsOperandSegmentSizes);\n", - op.getCppClassName()); - } else { - body << " result.addAttribute(\"operand_segment_sizes\", " - << "parser.getBuilder().getDenseI32ArrayAttr({"; - llvm::interleaveComma(op.getOperands(), body, interleaveFn); - body << "}));\n"; - } + llvm::interleaveComma(op.getOperands(), body, interleaveFn); + body << "}));\n"; } for (const NamedTypeConstraint &operand : op.getOperands()) { if (!operand.isVariadicOfVariadic()) @@ -1698,6 +1697,16 @@ void OperationFormat::genParserVariadicSegmentResolution(Operator &op, if (!allResultTypes && op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) { + if (op.getDialect().usePropertiesForAttributes()) { + body << formatv( + " " + "result.getOrAddProperties<{0}::Properties>().result_segment_sizes = " + "(parser.getBuilder().getDenseI32ArrayAttr({{", + op.getCppClassName()); + } else { + body << " result.addAttribute(\"result_segment_sizes\", " + << "parser.getBuilder().getDenseI32ArrayAttr({"; + } auto interleaveFn = [&](const NamedTypeConstraint &result) { // If the result is variadic emit the parsed size. if (result.isVariableLength()) @@ -1705,20 +1714,8 @@ void OperationFormat::genParserVariadicSegmentResolution(Operator &op, else body << "1"; }; - if (op.getDialect().usePropertiesForAttributes()) { - body << "llvm::copy(ArrayRef({"; - llvm::interleaveComma(op.getResults(), body, interleaveFn); - body << formatv( - "}), " - "result.getOrAddProperties<{0}::Properties>().odsResultSegmentSizes" - ");\n", - op.getCppClassName()); - } else { - body << " result.addAttribute(\"odsResultSegmentSizes\", " - << "parser.getBuilder().getDenseI32ArrayAttr({"; - llvm::interleaveComma(op.getResults(), body, interleaveFn); - body << "}));\n"; - } + llvm::interleaveComma(op.getResults(), body, interleaveFn); + body << "}));\n"; } } diff --git a/mlir/unittests/IR/AdaptorTest.cpp b/mlir/unittests/IR/AdaptorTest.cpp index e4dc4cd..fcd19ad 100644 --- a/mlir/unittests/IR/AdaptorTest.cpp +++ b/mlir/unittests/IR/AdaptorTest.cpp @@ -39,7 +39,7 @@ TEST(Adaptor, GenericAdaptorsOperandAccess) { // value from the value 0. SmallVector> v = {0, 4}; OIListSimple::Properties prop; - llvm::copy(ArrayRef{1, 0, 1}, prop.odsOperandSegmentSizes); + prop.operand_segment_sizes = builder.getDenseI32ArrayAttr({1, 0, 1}); OIListSimple::GenericAdaptor>> d(v, {}, prop, {}); EXPECT_EQ(d.getArg0(), 0); diff --git a/mlir/unittests/IR/OpPropertiesTest.cpp b/mlir/unittests/IR/OpPropertiesTest.cpp index 21ea448..eda84a0 100644 --- a/mlir/unittests/IR/OpPropertiesTest.cpp +++ b/mlir/unittests/IR/OpPropertiesTest.cpp @@ -115,15 +115,13 @@ public: // This alias is the only definition needed for enabling "properties" for this // operation. using Properties = TestProperties; - static std::optional getInherentAttr(MLIRContext *context, - const Properties &prop, + static std::optional getInherentAttr(const Properties &prop, StringRef name) { return std::nullopt; } static void setInherentAttr(Properties &prop, StringRef name, mlir::Attribute value) {} - static void populateInherentAttrs(MLIRContext *context, - const Properties &prop, + static void populateInherentAttrs(const Properties &prop, NamedAttrList &attrs) {} static LogicalResult verifyInherentAttrs(OperationName opName, NamedAttrList &attrs, -- 2.7.4