From 445cc3f6dd74e86575153a95ecfb8754d6d5b726 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Fri, 1 Nov 2019 14:47:42 -0700 Subject: [PATCH] Add DialectAsmParser/Printer classes to simplify dialect attribute and type parsing. These classes are functionally similar to the OpAsmParser/Printer classes and provide hooks for parsing attributes/tokens/types/etc. This change merely sets up the base infrastructure and updates the parser hooks, followups will add hooks as needed to simplify existing handrolled dialect parsers. This has various different benefits: *) Attribute/Type parsing is much simpler to define. *) Dialect attributes/types that contain other attributes/types can now use aliases. *) It provides a 'spec' with which we may use in the future to auto-generate parsers/printers. *) Error messages emitted by attribute/type parsers can provide character exact locations rather than "beginning of the string" PiperOrigin-RevId: 278005322 --- mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h | 4 +- mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h | 4 +- mlir/include/mlir/Dialect/QuantOps/QuantOps.h | 4 +- mlir/include/mlir/Dialect/SPIRV/SPIRVDialect.h | 4 +- mlir/include/mlir/IR/Dialect.h | 10 +- mlir/include/mlir/IR/DialectImplementation.h | 139 +++++++++++++++ mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 9 +- mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp | 14 +- mlir/lib/Dialect/QuantOps/IR/TypeParser.cpp | 17 +- mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp | 17 +- mlir/lib/IR/AsmPrinter.cpp | 96 ++++++++--- mlir/lib/IR/Dialect.cpp | 9 +- mlir/lib/Parser/Parser.cpp | 197 ++++++++++++++++------ 13 files changed, 413 insertions(+), 111 deletions(-) create mode 100644 mlir/include/mlir/IR/DialectImplementation.h diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h index 67fccec..d09e815 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -173,10 +173,10 @@ public: llvm::Module &getLLVMModule(); /// Parse a type registered to this dialect. - Type parseType(StringRef tyData, Location loc) const override; + Type parseType(DialectAsmParser &parser, Location loc) const override; /// Print a type registered to this dialect. - void printType(Type type, raw_ostream &os) const override; + void printType(Type type, DialectAsmPrinter &os) const override; /// Verify a region argument attribute registered to this dialect. /// Returns failure if the verification failed, success otherwise. diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h index 1835073..8888953 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h @@ -37,10 +37,10 @@ public: static StringRef getDialectNamespace() { return "linalg"; } /// Parse a type registered to this dialect. - Type parseType(llvm::StringRef spec, Location loc) const override; + Type parseType(DialectAsmParser &parser, Location loc) const override; /// Print a type registered to this dialect. - void printType(Type type, llvm::raw_ostream &os) const override; + void printType(Type type, DialectAsmPrinter &os) const override; }; /// A BufferType represents a contiguous block of memory that can be allocated diff --git a/mlir/include/mlir/Dialect/QuantOps/QuantOps.h b/mlir/include/mlir/Dialect/QuantOps/QuantOps.h index 8753cd2..f1ac383 100644 --- a/mlir/include/mlir/Dialect/QuantOps/QuantOps.h +++ b/mlir/include/mlir/Dialect/QuantOps/QuantOps.h @@ -35,10 +35,10 @@ public: QuantizationDialect(MLIRContext *context); /// Parse a type registered to this dialect. - Type parseType(StringRef spec, Location loc) const override; + Type parseType(DialectAsmParser &parser, Location loc) const override; /// Print a type registered to this dialect. - void printType(Type type, raw_ostream &os) const override; + void printType(Type type, DialectAsmPrinter &os) const override; }; #define GET_OP_CLASSES diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVDialect.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVDialect.h index 8e98270..6401eba 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVDialect.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVDialect.h @@ -46,10 +46,10 @@ public: static std::string getAttributeName(Decoration decoration); /// Parses a type registered to this dialect. - Type parseType(llvm::StringRef spec, Location loc) const override; + Type parseType(DialectAsmParser &parser, Location loc) const override; /// Prints a type registered to this dialect. - void printType(Type type, llvm::raw_ostream &os) const override; + void printType(Type type, DialectAsmPrinter &os) const override; /// Provides a hook for materializing a constant to this dialect. Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type, diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h index bf7db91..bd84bee 100644 --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -25,6 +25,8 @@ #include "mlir/IR/OperationSupport.h" namespace mlir { +class DialectAsmParser; +class DialectAsmPrinter; class DialectInterface; class OpBuilder; class Type; @@ -115,21 +117,21 @@ public: /// Parse an attribute registered to this dialect. If 'type' is nonnull, it /// refers to the expected type of the attribute. - virtual Attribute parseAttribute(StringRef attrData, Type type, + virtual Attribute parseAttribute(DialectAsmParser &parser, Type type, Location loc) const; /// Print an attribute registered to this dialect. Note: The type of the /// attribute need not be printed by this method as it is always printed by /// the caller. - virtual void printAttribute(Attribute, raw_ostream &) const { + virtual void printAttribute(Attribute, DialectAsmPrinter &) const { llvm_unreachable("dialect has no registered attribute printing hook"); } /// Parse a type registered to this dialect. - virtual Type parseType(StringRef tyData, Location loc) const; + virtual Type parseType(DialectAsmParser &parser, Location loc) const; /// Print a type registered to this dialect. - virtual void printType(Type, raw_ostream &) const { + virtual void printType(Type, DialectAsmPrinter &) const { llvm_unreachable("dialect has no registered type printing hook"); } diff --git a/mlir/include/mlir/IR/DialectImplementation.h b/mlir/include/mlir/IR/DialectImplementation.h new file mode 100644 index 0000000..c662a4c --- /dev/null +++ b/mlir/include/mlir/IR/DialectImplementation.h @@ -0,0 +1,139 @@ +//===- DialectImplementation.h ----------------------------------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file contains utilities classes for implementing dialect attributes and +// types. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_DIALECTIMPLEMENTATION_H +#define MLIR_IR_DIALECTIMPLEMENTATION_H + +#include "mlir/IR/OpImplementation.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/SMLoc.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir { + +class Builder; + +//===----------------------------------------------------------------------===// +// DialectAsmPrinter +//===----------------------------------------------------------------------===// + +/// This is a pure-virtual base class that exposes the asmprinter hooks +/// necessary to implement a custom printAttribute/printType() method on a +/// dialect. +class DialectAsmPrinter { +public: + DialectAsmPrinter() {} + virtual ~DialectAsmPrinter(); + virtual raw_ostream &getStream() const = 0; + + /// Print the given attribute to the stream. + virtual void printAttribute(Attribute attr) = 0; + + /// Print the given floating point value in a stabilized form that can be + /// roundtripped through the IR. This is the companion to the 'parseFloat' + /// hook on the DialectAsmParser. + virtual void printFloat(const APFloat &value) = 0; + + /// Print the given type to the stream. + virtual void printType(Type type) = 0; + +private: + DialectAsmPrinter(const DialectAsmPrinter &) = delete; + void operator=(const DialectAsmPrinter &) = delete; +}; + +// Make the implementations convenient to use. +inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, Attribute attr) { + p.printAttribute(attr); + return p; +} + +inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, + const APFloat &value) { + p.printFloat(value); + return p; +} +inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, float value) { + return p << APFloat(value); +} +inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, double value) { + return p << APFloat(value); +} + +inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, Type type) { + p.printType(type); + return p; +} + +// Support printing anything that isn't convertible to one of the above types, +// even if it isn't exactly one of them. For example, we want to print +// FunctionType with the Type version above, not have it match this. +template ::value && + !std::is_convertible::value && + !std::is_convertible::value && + !llvm::is_one_of::value, + T>::type * = nullptr> +inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, const T &other) { + p.getStream() << other; + return p; +} + +//===----------------------------------------------------------------------===// +// DialectAsmParser +//===----------------------------------------------------------------------===// + +/// The DialectAsmParser has methods for interacting with the asm parser: +/// parsing things from it, emitting errors etc. It has an intentionally +/// high-level API that is designed to reduce/constrain syntax innovation in +/// individual attributes or types. +class DialectAsmParser { +public: + virtual ~DialectAsmParser(); + + /// Emit a diagnostic at the specified location and return failure. + virtual InFlightDiagnostic emitError(llvm::SMLoc loc, + const Twine &message = {}) = 0; + + /// Return a builder which provides useful access to MLIRContext, global + /// objects like types and attributes. + virtual Builder &getBuilder() const = 0; + + /// Get the location of the next token and store it into the argument. This + /// always succeeds. + virtual llvm::SMLoc getCurrentLocation() = 0; + ParseResult getCurrentLocation(llvm::SMLoc *loc) { + *loc = getCurrentLocation(); + return success(); + } + + /// Return the location of the original name token. + virtual llvm::SMLoc getNameLoc() const = 0; + + /// Returns the full specification of the symbol being parsed. This allows for + /// using a separate parser if necessary. + virtual StringRef getFullSymbolSpec() const = 0; +}; + +} // end namespace mlir + +#endif diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index d7d3307..39decf9 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -21,6 +21,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/IR/StandardTypes.h" @@ -1249,7 +1250,9 @@ llvm::LLVMContext &LLVMDialect::getLLVMContext() { return impl->llvmContext; } llvm::Module &LLVMDialect::getLLVMModule() { return impl->module; } /// Parse a type registered to this dialect. -Type LLVMDialect::parseType(StringRef tyData, Location loc) const { +Type LLVMDialect::parseType(DialectAsmParser &parser, Location loc) const { + StringRef tyData = parser.getFullSymbolSpec(); + // LLVM is not thread-safe, so lock access to it. llvm::sys::SmartScopedLock lock(impl->mutex); @@ -1261,11 +1264,11 @@ Type LLVMDialect::parseType(StringRef tyData, Location loc) const { } /// Print a type registered to this dialect. -void LLVMDialect::printType(Type type, raw_ostream &os) const { +void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const { auto llvmType = type.dyn_cast(); assert(llvmType && "printing wrong type"); assert(llvmType.getUnderlyingType() && "no underlying LLVM type"); - llvmType.getUnderlyingType()->print(os); + llvmType.getUnderlyingType()->print(os.getStream()); } /// Verify LLVMIR function argument attributes. diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp index c09b75e..4a7bcd8 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp @@ -20,9 +20,10 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/StandardTypes.h" -#include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Parser.h" #include "mlir/Support/LLVM.h" @@ -107,8 +108,9 @@ Optional mlir::linalg::BufferType::getBufferSize() { return getImpl()->getBufferSize(); } -Type mlir::linalg::LinalgDialect::parseType(StringRef spec, +Type mlir::linalg::LinalgDialect::parseType(DialectAsmParser &parser, Location loc) const { + StringRef spec = parser.getFullSymbolSpec(); StringRef origSpec = spec; MLIRContext *context = getContext(); if (spec == "range") @@ -146,9 +148,8 @@ Type mlir::linalg::LinalgDialect::parseType(StringRef spec, return (emitError(loc, "unknown Linalg type: " + origSpec), Type()); } - /// BufferType prints as "buffer". -static void print(BufferType bt, raw_ostream &os) { +static void print(BufferType bt, DialectAsmPrinter &os) { os << "buffer<"; auto bs = bt.getBufferSize(); if (bs) { @@ -160,9 +161,10 @@ static void print(BufferType bt, raw_ostream &os) { } /// RangeType prints as just "range". -static void print(RangeType rt, raw_ostream &os) { os << "range"; } +static void print(RangeType rt, DialectAsmPrinter &os) { os << "range"; } -void mlir::linalg::LinalgDialect::printType(Type type, raw_ostream &os) const { +void mlir::linalg::LinalgDialect::printType(Type type, + DialectAsmPrinter &os) const { switch (type.getKind()) { default: llvm_unreachable("Unhandled Linalg type"); diff --git a/mlir/lib/Dialect/QuantOps/IR/TypeParser.cpp b/mlir/lib/Dialect/QuantOps/IR/TypeParser.cpp index 726c20c..360c1b5 100644 --- a/mlir/lib/Dialect/QuantOps/IR/TypeParser.cpp +++ b/mlir/lib/Dialect/QuantOps/IR/TypeParser.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/QuantOps/QuantOps.h" #include "mlir/Dialect/QuantOps/QuantTypes.h" +#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Location.h" #include "mlir/IR/StandardTypes.h" #include "mlir/IR/Types.h" @@ -615,9 +616,10 @@ bool TypeParser::parseQuantParams(double &scale, int64_t &zeroPoint) { } /// Parse a type registered to this dialect. -Type QuantizationDialect::parseType(StringRef spec, Location loc) const { - TypeParser parser(spec, getContext(), loc); - Type parsedType = parser.parseType(); +Type QuantizationDialect::parseType(DialectAsmParser &parser, + Location loc) const { + TypeParser typeParser(parser.getFullSymbolSpec(), getContext(), loc); + Type parsedType = typeParser.parseType(); if (parsedType == nullptr) { // Error. // TODO(laurenzo): Do something? @@ -723,19 +725,20 @@ static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type, } /// Print a type registered to this dialect. -void QuantizationDialect::printType(Type type, raw_ostream &os) const { +void QuantizationDialect::printType(Type type, DialectAsmPrinter &os) const { switch (type.getKind()) { default: llvm_unreachable("Unhandled quantized type"); case QuantizationTypes::Any: - printAnyQuantizedType(type.cast(), os); + printAnyQuantizedType(type.cast(), os.getStream()); break; case QuantizationTypes::UniformQuantized: - printUniformQuantizedType(type.cast(), os); + printUniformQuantizedType(type.cast(), + os.getStream()); break; case QuantizationTypes::UniformQuantizedPerAxis: printUniformQuantizedPerAxisType(type.cast(), - os); + os.getStream()); break; } } diff --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp index 96777b1..26d1ff1 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "mlir/Dialect/SPIRV/SPIRVTypes.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/StandardTypes.h" #include "mlir/Parser.h" @@ -609,7 +610,9 @@ static Type parseStructType(SPIRVDialect const &dialect, StringRef spec, // | pointer-type // | runtime-array-type // | struct-type -Type SPIRVDialect::parseType(StringRef spec, Location loc) const { +Type SPIRVDialect::parseType(DialectAsmParser &parser, Location loc) const { + StringRef spec = parser.getFullSymbolSpec(); + if (spec.startswith("array")) return parseArrayType(*this, spec, loc); if (spec.startswith("image")) @@ -629,7 +632,7 @@ Type SPIRVDialect::parseType(StringRef spec, Location loc) const { // Type Printing //===----------------------------------------------------------------------===// -static void print(ArrayType type, llvm::raw_ostream &os) { +static void print(ArrayType type, DialectAsmPrinter &os) { os << "array<" << type.getNumElements() << " x " << type.getElementType(); if (type.hasLayout()) { os << " [" << type.getArrayStride() << "]"; @@ -637,16 +640,16 @@ static void print(ArrayType type, llvm::raw_ostream &os) { os << ">"; } -static void print(RuntimeArrayType type, llvm::raw_ostream &os) { +static void print(RuntimeArrayType type, DialectAsmPrinter &os) { os << "rtarray<" << type.getElementType() << ">"; } -static void print(PointerType type, llvm::raw_ostream &os) { +static void print(PointerType type, DialectAsmPrinter &os) { os << "ptr<" << type.getPointeeType() << ", " << stringifyStorageClass(type.getStorageClass()) << ">"; } -static void print(ImageType type, llvm::raw_ostream &os) { +static void print(ImageType type, DialectAsmPrinter &os) { os << "image<" << type.getElementType() << ", " << stringifyDim(type.getDim()) << ", " << stringifyImageDepthInfo(type.getDepthInfo()) << ", " << stringifyImageArrayedInfo(type.getArrayedInfo()) << ", " @@ -655,7 +658,7 @@ static void print(ImageType type, llvm::raw_ostream &os) { << stringifyImageFormat(type.getImageFormat()) << ">"; } -static void print(StructType type, llvm::raw_ostream &os) { +static void print(StructType type, DialectAsmPrinter &os) { os << "struct<"; auto printMember = [&](unsigned i) { os << type.getElementType(i); @@ -680,7 +683,7 @@ static void print(StructType type, llvm::raw_ostream &os) { os << ">"; } -void SPIRVDialect::printType(Type type, llvm::raw_ostream &os) const { +void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const { switch (type.getKind()) { case TypeKind::Array: print(type.cast(), os); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 0200e98..0e6b788 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -24,6 +24,7 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Function.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/MLIRContext.h" @@ -53,6 +54,8 @@ void OperationName::print(raw_ostream &os) const { os << getStringRef(); } void OperationName::dump() const { print(llvm::errs()); } +DialectAsmPrinter::~DialectAsmPrinter() {} + OpAsmPrinter::~OpAsmPrinter() {} //===----------------------------------------------------------------------===// @@ -391,6 +394,9 @@ public: : os(printer.os), printerFlags(printer.printerFlags), state(printer.state) {} + /// Returns the output stream of the printer. + raw_ostream &getStream() { return os; } + template inline void interleaveComma(const Container &c, UnaryFunctor each_fn) const { mlir::interleaveComma(c, os, each_fn); @@ -420,6 +426,9 @@ protected: void printLocationInternal(LocationAttr loc, bool pretty = false); void printDenseElementsAttr(DenseElementsAttr attr); + void printDialectAttribute(Attribute attr); + void printDialectType(Type type); + /// This enum is used to represent the binding strength of the enclosing /// context that an AffineExprStorage is being printed in, so we can /// intelligently produce parens. @@ -715,19 +724,9 @@ void ModulePrinter::printAttribute(Attribute attr, bool mayElideType) { } switch (attr.getKind()) { - default: { - auto &dialect = attr.getDialect(); - - // Ask the dialect to serialize the attribute to a string. - std::string attrName; - { - llvm::raw_string_ostream attrNameStr(attrName); - dialect.printAttribute(attr, attrNameStr); - } + default: + return printDialectAttribute(attr); - printDialectSymbol(os, "#", dialect.getNamespace(), attrName); - break; - } case StandardAttributes::Opaque: { auto opaqueAttr = attr.cast(); printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(), @@ -950,19 +949,9 @@ void ModulePrinter::printType(Type type) { } switch (type.getKind()) { - default: { - auto &dialect = type.getDialect(); - - // Ask the dialect to serialize the type to a string. - std::string typeName; - { - llvm::raw_string_ostream typeNameStr(typeName); - dialect.printType(type, typeNameStr); - } + default: + return printDialectType(type); - printDialectSymbol(os, "!", dialect.getNamespace(), typeName); - return; - } case Type::Kind::Opaque: { auto opaqueTy = type.cast(); printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(), @@ -1073,6 +1062,65 @@ void ModulePrinter::printType(Type type) { } //===----------------------------------------------------------------------===// +// CustomDialectAsmPrinter +//===----------------------------------------------------------------------===// + +namespace { +/// This class provides the main specialication of the DialectAsmPrinter that is +/// used to provide support for print attributes and types. This hooks allows +/// for dialects to hook into the main ModulePrinter. +struct CustomDialectAsmPrinter : public DialectAsmPrinter { +public: + CustomDialectAsmPrinter(ModulePrinter &printer) : printer(printer) {} + ~CustomDialectAsmPrinter() override {} + + raw_ostream &getStream() const override { return printer.getStream(); } + + /// Print the given attribute to the stream. + void printAttribute(Attribute attr) override { printer.printAttribute(attr); } + + /// Print the given floating point value in a stablized form. + void printFloat(const APFloat &value) override { + printFloatValue(value, getStream()); + } + + /// Print the given type to the stream. + void printType(Type type) override { printer.printType(type); } + + /// The main module printer. + ModulePrinter &printer; +}; +} // end anonymous namespace + +void ModulePrinter::printDialectAttribute(Attribute attr) { + auto &dialect = attr.getDialect(); + + // Ask the dialect to serialize the attribute to a string. + std::string attrName; + { + llvm::raw_string_ostream attrNameStr(attrName); + ModulePrinter subPrinter(attrNameStr, printerFlags, state); + CustomDialectAsmPrinter printer(subPrinter); + dialect.printAttribute(attr, printer); + } + printDialectSymbol(os, "#", dialect.getNamespace(), attrName); +} + +void ModulePrinter::printDialectType(Type type) { + auto &dialect = type.getDialect(); + + // Ask the dialect to serialize the type to a string. + std::string typeName; + { + llvm::raw_string_ostream typeNameStr(typeName); + ModulePrinter subPrinter(typeNameStr, printerFlags, state); + CustomDialectAsmPrinter printer(subPrinter); + dialect.printType(type, printer); + } + printDialectSymbol(os, "!", dialect.getNamespace(), typeName); +} + +//===----------------------------------------------------------------------===// // Affine expressions and maps //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp index f8539c0..7882e4f 100644 --- a/mlir/lib/IR/Dialect.cpp +++ b/mlir/lib/IR/Dialect.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/DialectHooks.h" +#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/DialectInterface.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" @@ -28,6 +29,8 @@ using namespace mlir; using namespace detail; +DialectAsmParser::~DialectAsmParser() {} + //===----------------------------------------------------------------------===// // Dialect Registration //===----------------------------------------------------------------------===// @@ -99,7 +102,7 @@ LogicalResult Dialect::verifyRegionResultAttribute(Operation *, unsigned, } /// Parse an attribute registered to this dialect. -Attribute Dialect::parseAttribute(StringRef attrData, Type type, +Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type, Location loc) const { emitError(loc) << "dialect '" << getNamespace() << "' provides no attribute parsing hook"; @@ -107,11 +110,11 @@ Attribute Dialect::parseAttribute(StringRef attrData, Type type, } /// Parse a type registered to this dialect. -Type Dialect::parseType(StringRef tyData, Location loc) const { +Type Dialect::parseType(DialectAsmParser &parser, Location loc) const { // If this dialect allows unknown types, then represent this with OpaqueType. if (allowsUnknownTypes()) { auto ns = Identifier::get(getNamespace(), getContext()); - return OpaqueType::get(ns, tyData, getContext()); + return OpaqueType::get(ns, parser.getFullSymbolSpec(), getContext()); } emitError(loc) << "dialect '" << getNamespace() diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index af7e0b6..a6e0227 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -27,6 +27,7 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" @@ -51,37 +52,43 @@ namespace { class Parser; //===----------------------------------------------------------------------===// -// ParserState +// AliasState //===----------------------------------------------------------------------===// -/// This class refers to all of the state maintained globally by the parser, -/// such as the current lexer position etc. The Parser base class provides -/// methods to access this. -class ParserState { -public: - ParserState(const llvm::SourceMgr &sourceMgr, MLIRContext *ctx) - : context(ctx), lex(sourceMgr, ctx), curToken(lex.lexToken()) {} - +/// This class contains record of any parsed top-level aliases. +struct AliasState { // A map from attribute alias identifier to Attribute. llvm::StringMap attributeAliasDefinitions; // A map from type alias identifier to Type. llvm::StringMap typeAliasDefinitions; +}; -private: +//===----------------------------------------------------------------------===// +// ParserState +//===----------------------------------------------------------------------===// + +/// This class refers to all of the state maintained globally by the parser, +/// such as the current lexer position etc. +struct ParserState { + ParserState(const llvm::SourceMgr &sourceMgr, MLIRContext *ctx, + AliasState &aliases) + : context(ctx), lex(sourceMgr, ctx), curToken(lex.lexToken()), + aliases(aliases) {} ParserState(const ParserState &) = delete; void operator=(const ParserState &) = delete; - friend class Parser; - - // The context we're parsing into. + /// The context we're parsing into. MLIRContext *const context; - // The lexer for the source file we're parsing. + /// The lexer for the source file we're parsing. Lexer lex; - // This is the next token that hasn't been consumed yet. + /// This is the next token that hasn't been consumed yet. Token curToken; + + /// Any parsed alias state. + AliasState &aliases; }; //===----------------------------------------------------------------------===// @@ -348,6 +355,55 @@ ParseResult Parser::parseCommaSeparatedListUntil( return success(); } +//===----------------------------------------------------------------------===// +// DialectAsmParser +//===----------------------------------------------------------------------===// + +namespace { +/// This class provides the main implementation of the DialectAsmParser that +/// allows for dialects to parse attributes and types. This allows for dialect +/// hooking into the main MLIR parsing logic. +class CustomDialectAsmParser : public DialectAsmParser { +public: + CustomDialectAsmParser(StringRef fullSpec, Parser &parser) + : fullSpec(fullSpec), nameLoc(parser.getToken().getLoc()), + parser(parser) {} + ~CustomDialectAsmParser() override {} + + /// Emit a diagnostic at the specified location and return failure. + InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override { + return parser.emitError(loc, message); + } + + /// Return a builder which provides useful access to MLIRContext, global + /// objects like types and attributes. + Builder &getBuilder() const override { return parser.builder; } + + /// Get the location of the next token and store it into the argument. This + /// always succeeds. + llvm::SMLoc getCurrentLocation() override { + return parser.getToken().getLoc(); + } + + /// Return the location of the original name token. + llvm::SMLoc getNameLoc() const override { return nameLoc; } + + /// Returns the full specification of the symbol being parsed. This allows + /// for using a separate parser if necessary. + StringRef getFullSymbolSpec() const override { return fullSpec; } + +private: + /// The full symbol specification. + StringRef fullSpec; + + /// The source location of the dialect symbol. + SMLoc nameLoc; + + /// The main parser. + Parser &parser; +}; +} // namespace + /// Parse the body of a pretty dialect symbol, which starts and ends with <>'s, /// and may be recursive. Return with the 'prettyName' StringRef encompassing /// the entire pretty name. @@ -486,8 +542,42 @@ static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok, } // Call into the provided symbol construction function. - auto encodedLoc = p.getEncodedSourceLocation(loc); - return createSymbol(dialectName, symbolData, encodedLoc); + return createSymbol(dialectName, symbolData, loc); +} + +/// Parses a symbol, of type 'T', and returns it if parsing was successful. If +/// parsing failed, nullptr is returned. The number of bytes read from the input +/// string is returned in 'numRead'. +template +static T parseSymbol(llvm::StringRef inputStr, MLIRContext *context, + AliasState &aliasState, ParserFn &&parserFn, + size_t *numRead = nullptr) { + SourceMgr sourceMgr; + auto memBuffer = MemoryBuffer::getMemBuffer( + inputStr, /*BufferName=*/"", + /*RequiresNullTerminator=*/false); + sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); + ParserState state(sourceMgr, context, aliasState); + Parser parser(state); + + Token startTok = parser.getToken(); + T symbol = parserFn(parser); + if (!symbol) + return T(); + + // If 'numRead' is valid, then provide the number of bytes that were read. + Token endTok = parser.getToken(); + if (numRead) { + *numRead = static_cast(endTok.getLoc().getPointer() - + startTok.getLoc().getPointer()); + + // Otherwise, ensure that all of the tokens were parsed. + } else if (startTok.getLoc() != endTok.getLoc() && endTok.isNot(Token::eof)) { + parser.emitError(endTok.getLoc(), + "encountered unexpected tokens after parsing"); + return T(); + } + return symbol; } //===----------------------------------------------------------------------===// @@ -611,16 +701,24 @@ Type Parser::parseComplexType() { /// Type Parser::parseExtendedType() { return parseExtendedSymbol( - *this, Token::exclamation_identifier, state.typeAliasDefinitions, - [&](StringRef dialectName, StringRef symbolData, Location loc) -> Type { + *this, Token::exclamation_identifier, state.aliases.typeAliasDefinitions, + [&](StringRef dialectName, StringRef symbolData, + llvm::SMLoc loc) -> Type { + Location encodedLoc = getEncodedSourceLocation(loc); + // If we found a registered dialect, then ask it to parse the type. - if (auto *dialect = state.context->getRegisteredDialect(dialectName)) - return dialect->parseType(symbolData, loc); + if (auto *dialect = state.context->getRegisteredDialect(dialectName)) { + return parseSymbol( + symbolData, state.context, state.aliases, [&](Parser &parser) { + CustomDialectAsmParser customParser(symbolData, parser); + return dialect->parseType(customParser, encodedLoc); + }); + } // Otherwise, form a new opaque type. return OpaqueType::getChecked( Identifier::get(dialectName, state.context), symbolData, - state.context, loc); + state.context, encodedLoc); }); } @@ -1217,22 +1315,29 @@ Parser::parseAttributeDict(SmallVectorImpl &attributes) { /// Attribute Parser::parseExtendedAttr(Type type) { Attribute attr = parseExtendedSymbol( - *this, Token::hash_identifier, state.attributeAliasDefinitions, + *this, Token::hash_identifier, state.aliases.attributeAliasDefinitions, [&](StringRef dialectName, StringRef symbolData, - Location loc) -> Attribute { + llvm::SMLoc loc) -> Attribute { // Parse an optional trailing colon type. Type attrType = type; if (consumeIf(Token::colon) && !(attrType = parseType())) return Attribute(); // If we found a registered dialect, then ask it to parse the attribute. - if (auto *dialect = state.context->getRegisteredDialect(dialectName)) - return dialect->parseAttribute(symbolData, attrType, loc); + if (auto *dialect = state.context->getRegisteredDialect(dialectName)) { + return parseSymbol( + symbolData, state.context, state.aliases, [&](Parser &parser) { + CustomDialectAsmParser customParser(symbolData, parser); + return dialect->parseAttribute(customParser, attrType, + getEncodedSourceLocation(loc)); + }); + } // Otherwise, form a new opaque attribute. return OpaqueAttr::getChecked( Identifier::get(dialectName, state.context), symbolData, - attrType ? attrType : NoneType::get(state.context), loc); + attrType ? attrType : NoneType::get(state.context), + getEncodedSourceLocation(loc)); }); // Ensure that the attribute has the same type as requested. @@ -4137,7 +4242,7 @@ ParseResult ModuleParser::parseAttributeAliasDef() { StringRef aliasName = getTokenSpelling().drop_front(); // Check for redefinitions. - if (getState().attributeAliasDefinitions.count(aliasName) > 0) + if (getState().aliases.attributeAliasDefinitions.count(aliasName) > 0) return emitError("redefinition of attribute alias id '" + aliasName + "'"); // Make sure this isn't invading the dialect attribute namespace. @@ -4156,7 +4261,7 @@ ParseResult ModuleParser::parseAttributeAliasDef() { if (!attr) return failure(); - getState().attributeAliasDefinitions[aliasName] = attr; + getState().aliases.attributeAliasDefinitions[aliasName] = attr; return success(); } @@ -4169,7 +4274,7 @@ ParseResult ModuleParser::parseTypeAliasDef() { StringRef aliasName = getTokenSpelling().drop_front(); // Check for redefinitions. - if (getState().typeAliasDefinitions.count(aliasName) > 0) + if (getState().aliases.typeAliasDefinitions.count(aliasName) > 0) return emitError("redefinition of type alias id '" + aliasName + "'"); // Make sure this isn't invading the dialect type namespace. @@ -4190,7 +4295,7 @@ ParseResult ModuleParser::parseTypeAliasDef() { return failure(); // Register this alias with the parser state. - getState().typeAliasDefinitions.try_emplace(aliasName, aliasedType); + getState().aliases.typeAliasDefinitions.try_emplace(aliasName, aliasedType); return success(); } @@ -4269,7 +4374,8 @@ OwningModuleRef mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr, OwningModuleRef module(ModuleOp::create(FileLineColLoc::get( sourceBuf->getBufferIdentifier(), /*line=*/0, /*column=*/0, context))); - ParserState state(sourceMgr, context); + AliasState aliasState; + ParserState state(sourceMgr, context, aliasState); if (ModuleParser(state).parseModule(*module)) return nullptr; @@ -4334,23 +4440,16 @@ OwningModuleRef mlir::parseSourceString(StringRef moduleStr, template static T parseSymbol(llvm::StringRef inputStr, MLIRContext *context, size_t &numRead, ParserFn &&parserFn) { - SourceMgr sourceMgr; - auto memBuffer = MemoryBuffer::getMemBuffer( - inputStr, /*BufferName=*/"", - /*RequiresNullTerminator=*/false); - sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); - SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, context); - ParserState state(sourceMgr, context); - Parser parser(state); - - auto start = parser.getToken().getLoc(); - T symbol = parserFn(parser); - if (!symbol) - return T(); - - auto end = parser.getToken().getLoc(); - numRead = static_cast(end.getPointer() - start.getPointer()); - return symbol; + AliasState aliasState; + return parseSymbol( + inputStr, context, aliasState, + [&](Parser &parser) { + SourceMgrDiagnosticHandler handler( + const_cast(parser.getSourceMgr()), + parser.getContext()); + return parserFn(parser); + }, + &numRead); } Attribute mlir::parseAttribute(llvm::StringRef attrStr, MLIRContext *context) { -- 2.7.4