From 55cf53fd0f5594eb701b5760729fdc2bd4a70584 Mon Sep 17 00:00:00 2001 From: Rahul Kayaith Date: Thu, 2 Mar 2023 16:06:44 -0500 Subject: [PATCH] [mlir][Parser] Make parse{Attribute,Type} null-terminate input `parseAttribute` and `parseType` require null-terminated strings as input, but this isn't great considering the argument type is `StringRef`. This changes them to copy to a null-terminated buffer by default, with a `isKnownNullTerminated` flag added to disable the copying. closes #58964 Reviewed By: rriddle, kuhar, lattner Differential Revision: https://reviews.llvm.org/D145182 --- mlir/include/mlir/AsmParser/AsmParser.h | 9 +++++++-- mlir/lib/AsmParser/DialectSymbolParser.cpp | 23 +++++++++++++--------- mlir/lib/Bytecode/Reader/BytecodeReader.cpp | 6 ++++-- .../Linalg/TransformOps/LinalgTransformOps.cpp | 3 ++- .../mlir-linalg-ods-yaml-gen.cpp | 6 ++---- mlir/unittests/Parser/ParserTest.cpp | 5 +++++ 6 files changed, 34 insertions(+), 18 deletions(-) diff --git a/mlir/include/mlir/AsmParser/AsmParser.h b/mlir/include/mlir/AsmParser/AsmParser.h index d60df41..3c1bff1 100644 --- a/mlir/include/mlir/AsmParser/AsmParser.h +++ b/mlir/include/mlir/AsmParser/AsmParser.h @@ -49,16 +49,21 @@ parseAsmSourceFile(const llvm::SourceMgr &sourceMgr, Block *block, /// If `numRead` is provided, it is set to the number of consumed characters on /// succesful parse. Otherwise, parsing fails if the entire string is not /// consumed. +/// Some internal copying can be skipped if the source string is known to be +/// null terminated. Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, - Type type = {}, size_t *numRead = nullptr); + Type type = {}, size_t *numRead = nullptr, + bool isKnownNullTerminated = false); /// This parses a single MLIR type to an MLIR context if it was valid. If not, /// an error diagnostic is emitted to the context. /// If `numRead` is provided, it is set to the number of consumed characters on /// succesful parse. Otherwise, parsing fails if the entire string is not /// consumed. +/// Some internal copying can be skipped if the source string is known to be +/// null terminated. Type parseType(llvm::StringRef typeStr, MLIRContext *context, - size_t *numRead = nullptr); + size_t *numRead = nullptr, bool isKnownNullTerminated = false); /// This parses a single IntegerSet/AffineMap to an MLIR context if it was /// valid. If not, an error message is emitted through a new diff --git a/mlir/lib/AsmParser/DialectSymbolParser.cpp b/mlir/lib/AsmParser/DialectSymbolParser.cpp index a3198e0..c98b368 100644 --- a/mlir/lib/AsmParser/DialectSymbolParser.cpp +++ b/mlir/lib/AsmParser/DialectSymbolParser.cpp @@ -306,15 +306,18 @@ Type Parser::parseExtendedType() { //===----------------------------------------------------------------------===// /// Parses a symbol, of type 'T', and returns it if parsing was successful. If -/// parsing failed, nullptr is returned. The number of bytes read from the input -/// string is returned in 'numRead'. +/// parsing failed, nullptr is returned. template static T parseSymbol(StringRef inputStr, MLIRContext *context, - size_t *numReadOut, ParserFn &&parserFn) { + size_t *numReadOut, bool isKnownNullTerminated, + ParserFn &&parserFn) { // Set the buffer name to the string being parsed, so that it appears in error // diagnostics. - auto memBuffer = MemoryBuffer::getMemBuffer(inputStr, /*BufferName=*/inputStr, - /*RequiresNullTerminator=*/true); + auto memBuffer = + isKnownNullTerminated + ? MemoryBuffer::getMemBuffer(inputStr, + /*BufferName=*/inputStr) + : MemoryBuffer::getMemBufferCopy(inputStr, /*BufferName=*/inputStr); SourceMgr sourceMgr; sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); SymbolState aliasState; @@ -343,12 +346,14 @@ static T parseSymbol(StringRef inputStr, MLIRContext *context, } Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context, - Type type, size_t *numRead) { + Type type, size_t *numRead, + bool isKnownNullTerminated) { return parseSymbol( - attrStr, context, numRead, + attrStr, context, numRead, isKnownNullTerminated, [type](Parser &parser) { return parser.parseAttribute(type); }); } -Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t *numRead) { - return parseSymbol(typeStr, context, numRead, +Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t *numRead, + bool isKnownNullTerminated) { + return parseSymbol(typeStr, context, numRead, isKnownNullTerminated, [](Parser &parser) { return parser.parseType(); }); } diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp index f009621..5e71c3a 100644 --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -1031,9 +1031,11 @@ LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader, size_t numRead = 0; MLIRContext *context = fileLoc->getContext(); if constexpr (std::is_same_v) - result = ::parseType(asmStr, context, &numRead); + result = + ::parseType(asmStr, context, &numRead, /*isKnownNullTerminated=*/true); else - result = ::parseAttribute(asmStr, context, Type(), &numRead); + result = ::parseAttribute(asmStr, context, Type(), &numRead, + /*isKnownNullTerminated=*/true); if (!result) return failure(); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index dfcc2bc..600cdde 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1693,7 +1693,8 @@ transform::PadOp::applyToOne(LinalgOp target, // Try to parse string attributes to obtain an attribute of element type. if (auto stringAttr = attr.dyn_cast()) { auto parsedAttr = dyn_cast_if_present( - parseAttribute(stringAttr, getContext(), elementType)); + parseAttribute(stringAttr, getContext(), elementType, + /*numRead=*/nullptr, /*isKnownNullTerminated=*/true)); if (!parsedAttr || parsedAttr.getType() != elementType) { auto diag = this->emitOpError("expects a padding that parses to ") << elementType << ", got " << std::get<0>(it); diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp index f88e0fa..a51bdb5 100644 --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -317,10 +317,8 @@ struct ScalarTraits { SerializedAffineMap &value) { assert(rawYamlContext); auto *yamlContext = static_cast(rawYamlContext); - std::string nullTerminatedScalar(scalar); - if (auto attr = - mlir::parseAttribute(nullTerminatedScalar, yamlContext->mlirContext) - .dyn_cast_or_null()) + if (auto attr = mlir::parseAttribute(scalar, yamlContext->mlirContext) + .dyn_cast_or_null()) value.affineMapAttr = attr; else if (!value.affineMapAttr || !value.affineMapAttr.isa()) return "could not parse as an affine map attribute"; diff --git a/mlir/unittests/Parser/ParserTest.cpp b/mlir/unittests/Parser/ParserTest.cpp index 6b3ac5c..62f609e 100644 --- a/mlir/unittests/Parser/ParserTest.cpp +++ b/mlir/unittests/Parser/ParserTest.cpp @@ -95,5 +95,10 @@ TEST(MLIRParser, ParseAttr) { b.getI64IntegerAttr(10)); EXPECT_EQ(numRead, size_t(4)); // includes trailing whitespace } + { // Parse without null-terminator + StringRef attrAsm("999", 1); + Attribute attr = parseAttribute(attrAsm, &context); + EXPECT_EQ(attr, b.getI64IntegerAttr(9)); + } } } // namespace -- 2.7.4