[mlir][Parser] Make parse{Attribute,Type} null-terminate input
authorRahul Kayaith <rkayaith@gmail.com>
Thu, 2 Mar 2023 21:06:44 +0000 (16:06 -0500)
committerRahul Kayaith <rkayaith@gmail.com>
Fri, 3 Mar 2023 22:03:27 +0000 (17:03 -0500)
`parseAttribute` and `parseType` require null-terminated strings as
input, but this isn't great considering the argument type is
`StringRef`. This changes them to copy to a null-terminated buffer by
default, with a `isKnownNullTerminated` flag added to disable the
copying.

closes #58964

Reviewed By: rriddle, kuhar, lattner

Differential Revision: https://reviews.llvm.org/D145182

mlir/include/mlir/AsmParser/AsmParser.h
mlir/lib/AsmParser/DialectSymbolParser.cpp
mlir/lib/Bytecode/Reader/BytecodeReader.cpp
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
mlir/unittests/Parser/ParserTest.cpp

index d60df41..3c1bff1 100644 (file)
@@ -49,16 +49,21 @@ parseAsmSourceFile(const llvm::SourceMgr &sourceMgr, Block *block,
 /// If `numRead` is provided, it is set to the number of consumed characters on
 /// succesful parse. Otherwise, parsing fails if the entire string is not
 /// consumed.
+/// Some internal copying can be skipped if the source string is known to be
+/// null terminated.
 Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context,
-                         Type type = {}, size_t *numRead = nullptr);
+                         Type type = {}, size_t *numRead = nullptr,
+                         bool isKnownNullTerminated = false);
 
 /// This parses a single MLIR type to an MLIR context if it was valid. If not,
 /// an error diagnostic is emitted to the context.
 /// If `numRead` is provided, it is set to the number of consumed characters on
 /// succesful parse. Otherwise, parsing fails if the entire string is not
 /// consumed.
+/// Some internal copying can be skipped if the source string is known to be
+/// null terminated.
 Type parseType(llvm::StringRef typeStr, MLIRContext *context,
-               size_t *numRead = nullptr);
+               size_t *numRead = nullptr, bool isKnownNullTerminated = false);
 
 /// This parses a single IntegerSet/AffineMap to an MLIR context if it was
 /// valid. If not, an error message is emitted through a new
index a3198e0..c98b368 100644 (file)
@@ -306,15 +306,18 @@ Type Parser::parseExtendedType() {
 //===----------------------------------------------------------------------===//
 
 /// Parses a symbol, of type 'T', and returns it if parsing was successful. If
-/// parsing failed, nullptr is returned. The number of bytes read from the input
-/// string is returned in 'numRead'.
+/// parsing failed, nullptr is returned.
 template <typename T, typename ParserFn>
 static T parseSymbol(StringRef inputStr, MLIRContext *context,
-                     size_t *numReadOut, ParserFn &&parserFn) {
+                     size_t *numReadOut, bool isKnownNullTerminated,
+                     ParserFn &&parserFn) {
   // Set the buffer name to the string being parsed, so that it appears in error
   // diagnostics.
-  auto memBuffer = MemoryBuffer::getMemBuffer(inputStr, /*BufferName=*/inputStr,
-                                              /*RequiresNullTerminator=*/true);
+  auto memBuffer =
+      isKnownNullTerminated
+          ? MemoryBuffer::getMemBuffer(inputStr,
+                                       /*BufferName=*/inputStr)
+          : MemoryBuffer::getMemBufferCopy(inputStr, /*BufferName=*/inputStr);
   SourceMgr sourceMgr;
   sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
   SymbolState aliasState;
@@ -343,12 +346,14 @@ static T parseSymbol(StringRef inputStr, MLIRContext *context,
 }
 
 Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context,
-                               Type type, size_t *numRead) {
+                               Type type, size_t *numRead,
+                               bool isKnownNullTerminated) {
   return parseSymbol<Attribute>(
-      attrStr, context, numRead,
+      attrStr, context, numRead, isKnownNullTerminated,
       [type](Parser &parser) { return parser.parseAttribute(type); });
 }
-Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t *numRead) {
-  return parseSymbol<Type>(typeStr, context, numRead,
+Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t *numRead,
+                     bool isKnownNullTerminated) {
+  return parseSymbol<Type>(typeStr, context, numRead, isKnownNullTerminated,
                            [](Parser &parser) { return parser.parseType(); });
 }
index f009621..5e71c3a 100644 (file)
@@ -1031,9 +1031,11 @@ LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader,
   size_t numRead = 0;
   MLIRContext *context = fileLoc->getContext();
   if constexpr (std::is_same_v<T, Type>)
-    result = ::parseType(asmStr, context, &numRead);
+    result =
+        ::parseType(asmStr, context, &numRead, /*isKnownNullTerminated=*/true);
   else
-    result = ::parseAttribute(asmStr, context, Type(), &numRead);
+    result = ::parseAttribute(asmStr, context, Type(), &numRead,
+                              /*isKnownNullTerminated=*/true);
   if (!result)
     return failure();
 
index dfcc2bc..600cdde 100644 (file)
@@ -1693,7 +1693,8 @@ transform::PadOp::applyToOne(LinalgOp target,
     // Try to parse string attributes to obtain an attribute of element type.
     if (auto stringAttr = attr.dyn_cast<StringAttr>()) {
       auto parsedAttr = dyn_cast_if_present<TypedAttr>(
-          parseAttribute(stringAttr, getContext(), elementType));
+          parseAttribute(stringAttr, getContext(), elementType,
+                         /*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
       if (!parsedAttr || parsedAttr.getType() != elementType) {
         auto diag = this->emitOpError("expects a padding that parses to ")
                     << elementType << ", got " << std::get<0>(it);
index f88e0fa..a51bdb5 100644 (file)
@@ -317,10 +317,8 @@ struct ScalarTraits<SerializedAffineMap> {
                          SerializedAffineMap &value) {
     assert(rawYamlContext);
     auto *yamlContext = static_cast<LinalgYAMLContext *>(rawYamlContext);
-    std::string nullTerminatedScalar(scalar);
-    if (auto attr =
-            mlir::parseAttribute(nullTerminatedScalar, yamlContext->mlirContext)
-                .dyn_cast_or_null<AffineMapAttr>())
+    if (auto attr = mlir::parseAttribute(scalar, yamlContext->mlirContext)
+                        .dyn_cast_or_null<AffineMapAttr>())
       value.affineMapAttr = attr;
     else if (!value.affineMapAttr || !value.affineMapAttr.isa<AffineMapAttr>())
       return "could not parse as an affine map attribute";
index 6b3ac5c..62f609e 100644 (file)
@@ -95,5 +95,10 @@ TEST(MLIRParser, ParseAttr) {
               b.getI64IntegerAttr(10));
     EXPECT_EQ(numRead, size_t(4)); // includes trailing whitespace
   }
+  { // Parse without null-terminator
+    StringRef attrAsm("999", 1);
+    Attribute attr = parseAttribute(attrAsm, &context);
+    EXPECT_EQ(attr, b.getI64IntegerAttr(9));
+  }
 }
 } // namespace