} // end namespace llvm
namespace mlir {
+class Location;
class Module;
class MLIRContext;
+class Type;
/// This parses the file specified by the indicated SourceMgr and returns an
/// MLIR module if it was valid. If not, the error message is emitted through
Module *parseSourceFile(llvm::StringRef filename, MLIRContext *context);
/// This parses the module string to a MLIR module if it was valid. If not, the
-/// error message is emitted through / the error handler registered in the
+/// error message is emitted through the error handler registered in the
/// context, and a null pointer is returned.
Module *parseSourceString(llvm::StringRef moduleStr, MLIRContext *context);
+/// This parses a single MLIR type to an MLIR context if it was valid. If not,
+/// an error message is emitted through a new SourceMgrDiagnosticHandler
+/// constructed from a new SourceMgr with a single a MemoryBuffer wrapping
+/// `typeStr`.
+// TODO(ntv) Improve diagnostic reporting.
+Type parseType(llvm::StringRef typeStr, MLIRContext *context);
+
} // end namespace mlir
#endif // MLIR_PARSER_H
auto indexTy = parser->getBuilder().getIndexType();
if (parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType))
return failure();
- if (bufferType.getElementType() != parser->getBuilder().getF32Type())
- return parser->emitError(parser->getNameLoc(),
- "Only buffer<f32> supported until "
- "mlir::linalg::Parser pieces are exposed");
return failure(parser->resolveOperands(sizeInfo, indexTy, result->operands) ||
parser->addTypeToList(bufferType, result->types));
}
#include "mlir/IR/Dialect.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Linalg/IR/LinalgOps.h"
+#include "mlir/Parser.h"
#include "mlir/Support/LLVM.h"
+#include "llvm/Support/raw_ostream.h"
+
using namespace mlir;
using namespace mlir::linalg;
Type mlir::linalg::LinalgDialect::parseType(StringRef spec,
Location loc) const {
+ StringRef origSpec = spec;
MLIRContext *context = getContext();
if (spec == "range")
return RangeType::get(getContext());
- // TODO(ntv): reuse mlir Parser once exposed.
- if (spec == "buffer<f32>")
- return BufferType::get(getContext(), FloatType::getF32(getContext()));
- // TODO(ntv): reuse mlir Parser once exposed.
- if (spec.startswith("view")) {
- spec.consume_front("view");
- // Just count the number of ? to get the rank, the type must be f32 for now.
- unsigned rank = 0;
- for (unsigned i = 0, e = spec.size(); i < e; ++i)
- if (spec[i] == '?')
- ++rank;
- return ViewType::get(context, FloatType::getF32(context), rank);
+ else if (spec.consume_front("buffer")) {
+ if (spec.consume_front("<") && spec.consume_back(">")) {
+ if (auto t = mlir::parseType(spec, context))
+ return BufferType::get(getContext(), t);
+ }
+ } else if (spec.consume_front("view")) {
+ if (spec.consume_front("<") && spec.consume_back(">")) {
+ // Just count the number of ? to get the rank.
+ unsigned rank = 0;
+ for (unsigned i = 0, e = spec.size(); i < e; ++i) {
+ if (spec.consume_front("?")) {
+ ++rank;
+ if (!spec.consume_front("x")) {
+ context->emitError(loc,
+ "expected a list of '?x' dimension specifiers: ")
+ << spec;
+ return Type();
+ }
+ }
+ }
+ if (auto t = mlir::parseType(spec, context))
+ return ViewType::get(context, t, rank);
+ }
}
- return (context->emitError(loc, "unknown Linalg type: " + spec), Type());
+ return (context->emitError(loc, "unknown Linalg type: " + origSpec), Type());
}
struct mlir::linalg::ViewTypeStorage : public TypeStorage {
/// methods to access this.
class ParserState {
public:
- ParserState(const llvm::SourceMgr &sourceMgr, Module *module)
- : context(module->getContext()), module(module), lex(sourceMgr, context),
- curToken(lex.lexToken()) {}
+ ParserState(const llvm::SourceMgr &sourceMgr, MLIRContext *ctx)
+ : context(ctx), lex(sourceMgr, ctx), curToken(lex.lexToken()) {}
// A map from attribute alias identifier to Attribute.
llvm::StringMap<Attribute> attributeAliasDefinitions;
// The context we're parsing into.
MLIRContext *const context;
- // This is the module we are parsing into.
- Module *const module;
-
// The lexer for the source file we're parsing.
Lexer lex;
// Helper methods to get stuff from the parser-global state.
ParserState &getState() const { return state; }
MLIRContext *getContext() const { return state.context; }
- Module *getModule() { return state.module; }
const llvm::SourceMgr &getSourceMgr() { return state.lex.getSourceMgr(); }
/// Parse a comma-separated list of elements up until the specified end token.
public:
explicit ModuleParser(ParserState &state) : Parser(state) {}
- ParseResult parseModule();
+ ParseResult parseModule(Module *module);
private:
/// Parse an attribute alias declaration.
StringRef &name, FunctionType &type,
SmallVectorImpl<std::pair<SMLoc, StringRef>> &argNames,
SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs);
- ParseResult parseFunc();
+ ParseResult parseFunc(Module *module);
};
} // end anonymous namespace
/// function-body ::= `{` block+ `}`
/// function-attributes ::= `attributes` attribute-dict
///
-ParseResult ModuleParser::parseFunc() {
+ParseResult ModuleParser::parseFunc(Module *module) {
consumeToken();
StringRef name;
// Okay, the function signature was parsed correctly, create the function now.
auto *function =
new Function(getEncodedSourceLocation(loc), name, type, attrs);
- getModule()->getFunctions().push_back(function);
+ module->getFunctions().push_back(function);
// Verify no name collision / redefinition.
if (function->getName() != name)
}
/// This is the top-level module parser.
-ParseResult ModuleParser::parseModule() {
+ParseResult ModuleParser::parseModule(Module *module) {
while (1) {
switch (getToken().getKind()) {
default:
break;
case Token::kw_func:
- if (parseFunc())
+ if (parseFunc(module))
return failure();
break;
}
// This is the result module we are parsing into.
std::unique_ptr<Module> module(new Module(context));
- ParserState state(sourceMgr, module.get());
- if (ModuleParser(state).parseModule()) {
+ ParserState state(sourceMgr, context);
+ if (ModuleParser(state).parseModule(module.get())) {
return nullptr;
}
sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
return parseSourceFile(sourceMgr, context);
}
+
+Type mlir::parseType(llvm::StringRef typeStr, MLIRContext *context) {
+ SourceMgr sourceMgr;
+ auto memBuffer = MemoryBuffer::getMemBuffer(typeStr, /*BufferName=*/"",
+ /*RequiresNullTerminator=*/false);
+ sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
+ SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, context);
+ ParserState state(sourceMgr, context);
+ return Parser(state).parseType();
+}
func @buffer(%arg0: index, %arg1: index) {
%0 = muli %arg0, %arg0 : index
- %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
- linalg.buffer_dealloc %1 : !linalg.buffer<f32>
+ %1 = linalg.buffer_alloc %0 : !linalg.buffer<vector<4xi8>>
+ linalg.buffer_dealloc %1 : !linalg.buffer<vector<4xi8>>
return
}
// CHECK-LABEL: func @buffer(%arg0: index, %arg1: index) {
// CHECK-NEXT: %0 = muli %arg0, %arg0 : index
-// CHECK-NEXT: %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
-// CHECK-NEXT: linalg.buffer_dealloc %1 : !linalg.buffer<f32>
+// CHECK-NEXT: %1 = linalg.buffer_alloc %0 : !linalg.buffer<vector<4xi8>>
+// CHECK-NEXT: linalg.buffer_dealloc %1 : !linalg.buffer<vector<4xi8>>
+
+func @view_fun(%arg0: !linalg.view<?x?xvector<3x4xi4>>) {
+ return
+}
+// CHECK-LABEL: func @view_fun(%arg0: !linalg.view<?x?xvector<3x4xi4>>) {
func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) {
%0 = muli %arg0, %arg0 : index