From 62cbdd51fafa9cc6c9f351bdcfc7426c0195c9b8 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Mon, 3 Jun 2019 12:08:22 -0700 Subject: [PATCH] Start defining a new operation 'FuncOp' that replicates all of the functionality of 'Function', but with an operation. The pretty syntax for the operation is exactly the same as that of Function. This operation is currently builtin, but should hopefully be moved to a different dialect when it has been completely decoupled from IR/. This is the first patch in a large series that refactors Functions to be represented as operations. PiperOrigin-RevId: 251281612 --- mlir/include/mlir/IR/Function.h | 83 ++++++++++++- mlir/include/mlir/IR/OpImplementation.h | 15 +++ mlir/lib/IR/Function.cpp | 200 ++++++++++++++++++++++++++++++++ mlir/lib/IR/MLIRContext.cpp | 9 +- mlir/lib/Parser/Parser.cpp | 41 +++++-- mlir/test/IR/func-op.mlir | 29 +++++ mlir/test/IR/invalid-func-op.mlir | 37 ++++++ mlir/test/IR/invalid.mlir | 6 - 8 files changed, 401 insertions(+), 19 deletions(-) create mode 100644 mlir/test/IR/func-op.mlir create mode 100644 mlir/test/IR/invalid-func-op.mlir diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h index 81ba845..ed7083e 100644 --- a/mlir/include/mlir/IR/Function.h +++ b/mlir/include/mlir/IR/Function.h @@ -22,10 +22,8 @@ #ifndef MLIR_IR_FUNCTION_H #define MLIR_IR_FUNCTION_H -#include "mlir/IR/Attributes.h" #include "mlir/IR/Block.h" -#include "mlir/IR/Identifier.h" -#include "mlir/IR/Location.h" +#include "mlir/IR/OpDefinition.h" namespace mlir { class BlockAndValueMapping; @@ -320,6 +318,85 @@ private: friend struct llvm::ilist_traits; }; +//===--------------------------------------------------------------------===// +// Function Operation. +//===--------------------------------------------------------------------===// + +/// FuncOp represents a function, or a named operation containing one region +/// that forms a CFG(Control Flow Graph). The region of a function is not +/// allowed to implicitly capture global values, and all external references +/// must use Function arguments or attributes. +class FuncOp : public Op::Impl> { +public: + using Op::Op; + static StringRef getOperationName() { return "func"; } + + static void build(Builder *builder, OperationState *result, StringRef name, + FunctionType type, ArrayRef attrs); + + /// Parsing/Printing methods. + static ParseResult parse(OpAsmParser *parser, OperationState *result); + void print(OpAsmPrinter *p); + + /// Returns the name of this function. + StringRef getName() { return getAttrOfType("name").getValue(); } + + /// Returns the type of this function. + FunctionType getType() { + return getAttrOfType("type").getValue().cast(); + } + + /// Returns true if this function is external, i.e. it has no body. + bool isExternal() { return empty(); } + + //===--------------------------------------------------------------------===// + // Body Handling + //===--------------------------------------------------------------------===// + + Region &getBody() { return getOperation()->getRegion(0); } + + /// This is the list of blocks in the function. + using RegionType = Region::RegionType; + RegionType &getBlocks() { return getBody().getBlocks(); } + + // Iteration over the block in the function. + using iterator = RegionType::iterator; + using reverse_iterator = RegionType::reverse_iterator; + + iterator begin() { return getBody().begin(); } + iterator end() { return getBody().end(); } + reverse_iterator rbegin() { return getBody().rbegin(); } + reverse_iterator rend() { return getBody().rend(); } + + bool empty() { return getBody().empty(); } + void push_back(Block *block) { getBody().push_back(block); } + void push_front(Block *block) { getBody().push_front(block); } + + Block &back() { return getBody().back(); } + Block &front() { return getBody().front(); } + + //===--------------------------------------------------------------------===// + // Argument Handling + //===--------------------------------------------------------------------===// + + /// Returns number of arguments. + unsigned getNumArguments() { return getType().getInputs().size(); } + + /// Gets argument. + BlockArgument *getArgument(unsigned idx) { + return getBlocks().front().getArgument(idx); + } + + // Supports non-const operand iteration. + using args_iterator = Block::args_iterator; + args_iterator args_begin() { return front().args_begin(); } + args_iterator args_end() { return front().args_end(); } + llvm::iterator_range getArguments() { + return {args_begin(), args_end()}; + } +}; + } // end namespace mlir //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index 9168d28..67a173f 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -169,6 +169,9 @@ public: /// Parse a ')' token. virtual ParseResult parseRParen() = 0; + /// Parses a ')' if present. + virtual ParseResult parseOptionalRParen() = 0; + /// This parses an equal(=) token! virtual ParseResult parseEqual() = 0; @@ -199,6 +202,10 @@ public: /// Parse a colon followed by a type list, which must have at least one type. virtual ParseResult parseColonTypeList(SmallVectorImpl &result) = 0; + /// Parse an optional arrow followed by a type list. + virtual ParseResult + parseOptionalArrowTypeList(SmallVectorImpl &result) = 0; + /// Parse a keyword followed by a type. ParseResult parseKeywordType(const char *keyword, Type &result) { return failure(parseKeyword(keyword) || parseType(result)); @@ -320,10 +327,18 @@ public: ArrayRef arguments, ArrayRef argTypes) = 0; + /// Parses an optional region. + virtual ParseResult parseOptionalRegion(Region ®ion, + ArrayRef arguments, + ArrayRef argTypes) = 0; + /// Parse a region argument. Region arguments define new values, so this also /// checks if the values with the same name has not been defined yet. virtual ParseResult parseRegionArgument(OperandType &argument) = 0; + /// Parse an optional region argument. + virtual ParseResult parseOptionalRegionArgument(OperandType &argument) = 0; + //===--------------------------------------------------------------------===// // Methods for interacting with the parser //===--------------------------------------------------------------------===// diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp index 61e2975..f4ee155 100644 --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -17,9 +17,11 @@ #include "mlir/IR/Function.h" #include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" +#include "mlir/IR/OpImplementation.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/Twine.h" @@ -216,3 +218,201 @@ void Function::walk(const std::function &callback) { for (auto &block : getBlocks()) block.walk(callback); } + +//===----------------------------------------------------------------------===// +// Function Operation. +//===----------------------------------------------------------------------===// + +void FuncOp::build(Builder *builder, OperationState *result, StringRef name, + FunctionType type, ArrayRef attrs) { + result->addAttribute("name", builder->getStringAttr(name)); + result->addAttribute("type", builder->getTypeAttr(type)); + result->attributes.append(attrs.begin(), attrs.end()); + result->addRegion(); +} + +/// Parsing/Printing methods. +static ParseResult +parseArgumentList(OpAsmParser *parser, SmallVectorImpl &argTypes, + SmallVectorImpl &argNames, + SmallVectorImpl> &argAttrs) { + if (parser->parseLParen()) + return failure(); + + // The argument list either has to consistently have ssa-id's followed by + // types, or just be a type list. It isn't ok to sometimes have SSA ID's and + // sometimes not. + auto parseArgument = [&]() -> ParseResult { + llvm::SMLoc loc; + parser->getCurrentLocation(&loc); + + // Parse argument name if present. + OpAsmParser::OperandType argument; + Type argumentType; + if (succeeded(parser->parseOptionalRegionArgument(argument)) && + !argument.name.empty()) { + // Reject this if the preceding argument was missing a name. + if (argNames.empty() && !argTypes.empty()) + return parser->emitError(loc, + "expected type instead of SSA identifier"); + argNames.push_back(argument); + + if (parser->parseColonType(argumentType)) + return failure(); + } else if (!argNames.empty()) { + // Reject this if the preceding argument had a name. + return parser->emitError(loc, "expected SSA identifier"); + } else if (parser->parseType(argumentType)) { + return failure(); + } + + // Add the argument type. + argTypes.push_back(argumentType); + + // TODO(riverriddle) Parse argument attributes. + // Parse the attribute dict. + // SmallVector attrs; + // if (parser->parseOptionalAttributeDict(attrs)) + // return failure(); + // argAttrs.push_back(attrs); + return success(); + }; + + // Parse the function arguments. + if (parser->parseOptionalRParen()) { + do { + if (parseArgument()) + return failure(); + } while (succeeded(parser->parseOptionalComma())); + parser->parseRParen(); + } + + return success(); +} + +/// Parse a function signature, starting with a name and including the +/// parameter list. +static ParseResult parseFunctionSignature( + OpAsmParser *parser, FunctionType &type, + SmallVectorImpl &argNames, + SmallVectorImpl> &argAttrs) { + SmallVector argTypes; + if (parseArgumentList(parser, argTypes, argNames, argAttrs)) + return failure(); + + // Parse the return types if present. + SmallVector results; + if (parser->parseOptionalArrowTypeList(results)) + return failure(); + type = parser->getBuilder().getFunctionType(argTypes, results); + return success(); +} + +ParseResult FuncOp::parse(OpAsmParser *parser, OperationState *result) { + FunctionType type; + SmallVector entryArgs; + SmallVector, 4> argAttrs; + auto &builder = parser->getBuilder(); + + // Parse the name as a function attribute. + FunctionAttr nameAttr; + if (parser->parseAttribute(nameAttr, "name", result->attributes)) + return failure(); + // Convert the parsed function attr into a string attr. + result->attributes.back().second = builder.getStringAttr(nameAttr.getValue()); + + // Parse the function signature. + if (parseFunctionSignature(parser, type, entryArgs, argAttrs)) + return failure(); + result->addAttribute("type", builder.getTypeAttr(type)); + + // If function attributes are present, parse them. + if (succeeded(parser->parseOptionalKeyword("attributes"))) + if (parser->parseOptionalAttributeDict(result->attributes)) + return failure(); + + // TODO(riverriddle) Parse argument attributes. + // Add the attributes to the function arguments. + // for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i) + // if (!argAttrs[i].empty()) + // result->addAttribute(("arg" + Twine(i)).str(), + // builder.getDictionaryAttr(argAttrs[i])); + + // Parse the optional function body. + auto *body = result->addRegion(); + if (parser->parseOptionalRegion( + *body, entryArgs, entryArgs.empty() ? llvm::None : type.getInputs())) + return failure(); + + return success(); +} + +static void printFunctionSignature(OpAsmPrinter *p, FuncOp op) { + *p << '('; + + auto fnType = op.getType(); + bool isExternal = op.isExternal(); + for (unsigned i = 0, e = op.getNumArguments(); i != e; ++i) { + if (i > 0) + *p << ", "; + + // If this is an external function, don't print argument labels. + if (!isExternal) { + p->printOperand(op.getArgument(i)); + *p << ": "; + } + + p->printType(fnType.getInput(i)); + + // TODO(riverriddle) Print argument attributes. + // Print the attributes for this argument. + // p->printOptionalAttrDict(op.getArgAttrs(i)); + } + *p << ')'; + + switch (fnType.getResults().size()) { + case 0: + break; + case 1: { + *p << " -> "; + auto resultType = fnType.getResults()[0]; + bool resultIsFunc = resultType.isa(); + if (resultIsFunc) + *p << '('; + p->printType(resultType); + if (resultIsFunc) + *p << ')'; + break; + } + default: + *p << " -> ("; + interleaveComma(fnType.getResults(), *p); + *p << ')'; + break; + } +} + +void FuncOp::print(OpAsmPrinter *p) { + *p << "func @" << getName(); + + // Print the signature. + printFunctionSignature(p, *this); + + // Print out function attributes, if present. + auto attrs = getAttrs(); + + // We must have more attributes than . + constexpr unsigned kNumHiddenAttrs = 2; + if (attrs.size() > kNumHiddenAttrs) { + *p << "\n attributes "; + p->printOptionalAttrDict(attrs, {"name", "type"}); + } + + // Print the body if this is not an external function. + if (!isExternal()) { + p->printRegion(getBody(), /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/true); + *p << '\n'; + } + *p << '\n'; +} diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index bf1ae96..678848b 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -143,6 +143,10 @@ struct BuiltinDialect : public Dialect { addTypes(); + + // TODO: FuncOp should be moved to a different dialect when it has been + // fully decoupled from the core. + addOperations(); } }; @@ -484,8 +488,9 @@ std::vector MLIRContext::getRegisteredOperations() { } void Dialect::addOperation(AbstractOperation opInfo) { - assert(opInfo.name.split('.').first == getNamespace() && - "op name doesn't start with dialect namespace"); + assert(getNamespace().empty() || + opInfo.name.split('.').first == getNamespace() && + "op name doesn't start with dialect namespace"); assert(&opInfo.dialect == this && "Dialect object mismatch"); auto &impl = context->getImpl(); diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 87f94f7..635adec 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -3219,15 +3219,14 @@ public: ParseResult parseColonTypeList(SmallVectorImpl &result) override { if (parser.parseToken(Token::colon, "expected ':'")) return failure(); + return parser.parseTypeListNoParens(result); + } - do { - if (auto type = parser.parseType()) - result.push_back(type); - else - return failure(); - - } while (parser.consumeIf(Token::comma)); - return success(); + /// Parse an arrow followed by a type list, which must have at least one type. + ParseResult parseOptionalArrowTypeList(SmallVectorImpl &result) { + if (!parser.consumeIf(Token::arrow)) + return success(); + return parser.parseFunctionResultTypes(result); } ParseResult parseTrailingOperandList(SmallVectorImpl &result, @@ -3313,6 +3312,11 @@ public: return parser.parseToken(Token::r_paren, "expected ')'"); } + /// Parses a ')' if present. + ParseResult parseOptionalRParen() override { + return success(parser.consumeIf(Token::r_paren)); + } + ParseResult parseOperandList(SmallVectorImpl &result, int requiredOperandCount = -1, Delimiter delimiter = Delimiter::None) override { @@ -3411,6 +3415,20 @@ public: return parser.parseRegion(region, regionArguments); } + /// Parses an optional region. + ParseResult parseOptionalRegion(Region ®ion, + ArrayRef arguments, + ArrayRef argTypes) override { + if (parser.getToken().isNot(Token::l_brace)) { + if (!arguments.empty()) + return emitError( + parser.getToken().getLoc(), + "optional region with explicit entry arguments must be defined"); + return success(); + } + return parseRegion(region, arguments, argTypes); + } + /// Parse a region argument. Region arguments define new values, so this also /// checks if the values with the same name has not been defined yet. The /// type of the argument will be resolved later by a call to `parseRegion`. @@ -3428,6 +3446,13 @@ public: return success(); } + /// Parse an optional region argument. + ParseResult parseOptionalRegionArgument(OperandType &argument) override { + if (parser.getToken().isNot(Token::percent_identifier)) + return success(); + return parseRegionArgument(argument); + } + //===--------------------------------------------------------------------===// // Methods for interacting with the parser //===--------------------------------------------------------------------===// diff --git a/mlir/test/IR/func-op.mlir b/mlir/test/IR/func-op.mlir new file mode 100644 index 0000000..f264dc2 --- /dev/null +++ b/mlir/test/IR/func-op.mlir @@ -0,0 +1,29 @@ +// RUN: mlir-opt %s | FileCheck %s + +// CHECK-LABEL: func @external_func +func @external_func() { + // CHECK-NEXT: func @external_func(i32, i64) + func @external_func(i32, i64) -> () + + // CHECK: func @external_func_with_result() -> (i1, i32) + func @external_func_with_result() -> (i1, i32) + return +} + +// CHECK-LABEL: func @complex_func +func @complex_func() { + // CHECK-NEXT: func @test_dimop(%i0: tensor<4x4x?xf32>) -> index { + func @test_dimop(%i0: tensor<4x4x?xf32>) -> index { + %0 = dim %i0, 2 : tensor<4x4x?xf32> + "foo.return"(%0) : (index) -> () + } + return +} + +// CHECK-LABEL: func @func_attributes +func @func_attributes() { + // CHECK-NEXT: func @foo() + // CHECK-NEXT: attributes {foo: true} + func @foo() attributes {foo: true} + return +} diff --git a/mlir/test/IR/invalid-func-op.mlir b/mlir/test/IR/invalid-func-op.mlir new file mode 100644 index 0000000..a6e84f8 --- /dev/null +++ b/mlir/test/IR/invalid-func-op.mlir @@ -0,0 +1,37 @@ +// RUN: mlir-opt %s -split-input-file -verify + +// ----- + +func @func_op() { + // expected-error@+1 {{expected non-function type}} + func missingsigil() -> (i1, index, f32) + return +} + +// ----- + +func @func_op() { + // expected-error@+1 {{expected type instead of SSA identifier}} + func @mixed_named_arguments(f32, %a : i32) { + return + } + return +} + +// ----- + +func @func_op() { + // expected-error@+1 {{expected SSA identifier}} + func @mixed_named_arguments(%a : i32, f32) -> () { + return + } + return +} + +// ----- + +func @func_op() { + // expected-error@+2 {{optional region with explicit entry arguments must be defined}} + func @mixed_named_arguments(%a : i32) + return +} diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir index a7e7035..9825648 100644 --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -310,12 +310,6 @@ func @undef() { // ----- -func @missing_rbrace() { - return -func @d() {return} // expected-error {{custom op 'func' is unknown}} - -// ----- - func @malformed_type(%a : intt) { // expected-error {{expected non-function type}} } -- 2.7.4