#ifndef MLIR_IR_FUNCTION_H
#define MLIR_IR_FUNCTION_H
-#include "mlir/IR/Attributes.h"
#include "mlir/IR/Block.h"
-#include "mlir/IR/Identifier.h"
-#include "mlir/IR/Location.h"
+#include "mlir/IR/OpDefinition.h"
namespace mlir {
class BlockAndValueMapping;
friend struct llvm::ilist_traits<Function>;
};
+//===--------------------------------------------------------------------===//
+// Function Operation.
+//===--------------------------------------------------------------------===//
+
+/// FuncOp represents a function, or a named operation containing one region
+/// that forms a CFG(Control Flow Graph). The region of a function is not
+/// allowed to implicitly capture global values, and all external references
+/// must use Function arguments or attributes.
+class FuncOp : public Op<FuncOp, OpTrait::ZeroOperands, OpTrait::ZeroResult,
+ OpTrait::NthRegionIsIsolatedAbove<0>::Impl> {
+public:
+ using Op::Op;
+ static StringRef getOperationName() { return "func"; }
+
+ static void build(Builder *builder, OperationState *result, StringRef name,
+ FunctionType type, ArrayRef<NamedAttribute> attrs);
+
+ /// Parsing/Printing methods.
+ static ParseResult parse(OpAsmParser *parser, OperationState *result);
+ void print(OpAsmPrinter *p);
+
+ /// Returns the name of this function.
+ StringRef getName() { return getAttrOfType<StringAttr>("name").getValue(); }
+
+ /// Returns the type of this function.
+ FunctionType getType() {
+ return getAttrOfType<TypeAttr>("type").getValue().cast<FunctionType>();
+ }
+
+ /// Returns true if this function is external, i.e. it has no body.
+ bool isExternal() { return empty(); }
+
+ //===--------------------------------------------------------------------===//
+ // Body Handling
+ //===--------------------------------------------------------------------===//
+
+ Region &getBody() { return getOperation()->getRegion(0); }
+
+ /// This is the list of blocks in the function.
+ using RegionType = Region::RegionType;
+ RegionType &getBlocks() { return getBody().getBlocks(); }
+
+ // Iteration over the block in the function.
+ using iterator = RegionType::iterator;
+ using reverse_iterator = RegionType::reverse_iterator;
+
+ iterator begin() { return getBody().begin(); }
+ iterator end() { return getBody().end(); }
+ reverse_iterator rbegin() { return getBody().rbegin(); }
+ reverse_iterator rend() { return getBody().rend(); }
+
+ bool empty() { return getBody().empty(); }
+ void push_back(Block *block) { getBody().push_back(block); }
+ void push_front(Block *block) { getBody().push_front(block); }
+
+ Block &back() { return getBody().back(); }
+ Block &front() { return getBody().front(); }
+
+ //===--------------------------------------------------------------------===//
+ // Argument Handling
+ //===--------------------------------------------------------------------===//
+
+ /// Returns number of arguments.
+ unsigned getNumArguments() { return getType().getInputs().size(); }
+
+ /// Gets argument.
+ BlockArgument *getArgument(unsigned idx) {
+ return getBlocks().front().getArgument(idx);
+ }
+
+ // Supports non-const operand iteration.
+ using args_iterator = Block::args_iterator;
+ args_iterator args_begin() { return front().args_begin(); }
+ args_iterator args_end() { return front().args_end(); }
+ llvm::iterator_range<args_iterator> getArguments() {
+ return {args_begin(), args_end()};
+ }
+};
+
} // end namespace mlir
//===----------------------------------------------------------------------===//
/// Parse a ')' token.
virtual ParseResult parseRParen() = 0;
+ /// Parses a ')' if present.
+ virtual ParseResult parseOptionalRParen() = 0;
+
/// This parses an equal(=) token!
virtual ParseResult parseEqual() = 0;
/// Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) = 0;
+ /// Parse an optional arrow followed by a type list.
+ virtual ParseResult
+ parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) = 0;
+
/// Parse a keyword followed by a type.
ParseResult parseKeywordType(const char *keyword, Type &result) {
return failure(parseKeyword(keyword) || parseType(result));
ArrayRef<OperandType> arguments,
ArrayRef<Type> argTypes) = 0;
+ /// Parses an optional region.
+ virtual ParseResult parseOptionalRegion(Region ®ion,
+ ArrayRef<OperandType> arguments,
+ ArrayRef<Type> argTypes) = 0;
+
/// Parse a region argument. Region arguments define new values, so this also
/// checks if the values with the same name has not been defined yet.
virtual ParseResult parseRegionArgument(OperandType &argument) = 0;
+ /// Parse an optional region argument.
+ virtual ParseResult parseOptionalRegionArgument(OperandType &argument) = 0;
+
//===--------------------------------------------------------------------===//
// Methods for interacting with the parser
//===--------------------------------------------------------------------===//
#include "mlir/IR/Function.h"
#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
+#include "mlir/IR/OpImplementation.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/Twine.h"
for (auto &block : getBlocks())
block.walk(callback);
}
+
+//===----------------------------------------------------------------------===//
+// Function Operation.
+//===----------------------------------------------------------------------===//
+
+void FuncOp::build(Builder *builder, OperationState *result, StringRef name,
+ FunctionType type, ArrayRef<NamedAttribute> attrs) {
+ result->addAttribute("name", builder->getStringAttr(name));
+ result->addAttribute("type", builder->getTypeAttr(type));
+ result->attributes.append(attrs.begin(), attrs.end());
+ result->addRegion();
+}
+
+/// Parsing/Printing methods.
+static ParseResult
+parseArgumentList(OpAsmParser *parser, SmallVectorImpl<Type> &argTypes,
+ SmallVectorImpl<OpAsmParser::OperandType> &argNames,
+ SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs) {
+ if (parser->parseLParen())
+ return failure();
+
+ // The argument list either has to consistently have ssa-id's followed by
+ // types, or just be a type list. It isn't ok to sometimes have SSA ID's and
+ // sometimes not.
+ auto parseArgument = [&]() -> ParseResult {
+ llvm::SMLoc loc;
+ parser->getCurrentLocation(&loc);
+
+ // Parse argument name if present.
+ OpAsmParser::OperandType argument;
+ Type argumentType;
+ if (succeeded(parser->parseOptionalRegionArgument(argument)) &&
+ !argument.name.empty()) {
+ // Reject this if the preceding argument was missing a name.
+ if (argNames.empty() && !argTypes.empty())
+ return parser->emitError(loc,
+ "expected type instead of SSA identifier");
+ argNames.push_back(argument);
+
+ if (parser->parseColonType(argumentType))
+ return failure();
+ } else if (!argNames.empty()) {
+ // Reject this if the preceding argument had a name.
+ return parser->emitError(loc, "expected SSA identifier");
+ } else if (parser->parseType(argumentType)) {
+ return failure();
+ }
+
+ // Add the argument type.
+ argTypes.push_back(argumentType);
+
+ // TODO(riverriddle) Parse argument attributes.
+ // Parse the attribute dict.
+ // SmallVector<NamedAttribute, 2> attrs;
+ // if (parser->parseOptionalAttributeDict(attrs))
+ // return failure();
+ // argAttrs.push_back(attrs);
+ return success();
+ };
+
+ // Parse the function arguments.
+ if (parser->parseOptionalRParen()) {
+ do {
+ if (parseArgument())
+ return failure();
+ } while (succeeded(parser->parseOptionalComma()));
+ parser->parseRParen();
+ }
+
+ return success();
+}
+
+/// Parse a function signature, starting with a name and including the
+/// parameter list.
+static ParseResult parseFunctionSignature(
+ OpAsmParser *parser, FunctionType &type,
+ SmallVectorImpl<OpAsmParser::OperandType> &argNames,
+ SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs) {
+ SmallVector<Type, 4> argTypes;
+ if (parseArgumentList(parser, argTypes, argNames, argAttrs))
+ return failure();
+
+ // Parse the return types if present.
+ SmallVector<Type, 4> results;
+ if (parser->parseOptionalArrowTypeList(results))
+ return failure();
+ type = parser->getBuilder().getFunctionType(argTypes, results);
+ return success();
+}
+
+ParseResult FuncOp::parse(OpAsmParser *parser, OperationState *result) {
+ FunctionType type;
+ SmallVector<OpAsmParser::OperandType, 4> entryArgs;
+ SmallVector<SmallVector<NamedAttribute, 2>, 4> argAttrs;
+ auto &builder = parser->getBuilder();
+
+ // Parse the name as a function attribute.
+ FunctionAttr nameAttr;
+ if (parser->parseAttribute(nameAttr, "name", result->attributes))
+ return failure();
+ // Convert the parsed function attr into a string attr.
+ result->attributes.back().second = builder.getStringAttr(nameAttr.getValue());
+
+ // Parse the function signature.
+ if (parseFunctionSignature(parser, type, entryArgs, argAttrs))
+ return failure();
+ result->addAttribute("type", builder.getTypeAttr(type));
+
+ // If function attributes are present, parse them.
+ if (succeeded(parser->parseOptionalKeyword("attributes")))
+ if (parser->parseOptionalAttributeDict(result->attributes))
+ return failure();
+
+ // TODO(riverriddle) Parse argument attributes.
+ // Add the attributes to the function arguments.
+ // for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i)
+ // if (!argAttrs[i].empty())
+ // result->addAttribute(("arg" + Twine(i)).str(),
+ // builder.getDictionaryAttr(argAttrs[i]));
+
+ // Parse the optional function body.
+ auto *body = result->addRegion();
+ if (parser->parseOptionalRegion(
+ *body, entryArgs, entryArgs.empty() ? llvm::None : type.getInputs()))
+ return failure();
+
+ return success();
+}
+
+static void printFunctionSignature(OpAsmPrinter *p, FuncOp op) {
+ *p << '(';
+
+ auto fnType = op.getType();
+ bool isExternal = op.isExternal();
+ for (unsigned i = 0, e = op.getNumArguments(); i != e; ++i) {
+ if (i > 0)
+ *p << ", ";
+
+ // If this is an external function, don't print argument labels.
+ if (!isExternal) {
+ p->printOperand(op.getArgument(i));
+ *p << ": ";
+ }
+
+ p->printType(fnType.getInput(i));
+
+ // TODO(riverriddle) Print argument attributes.
+ // Print the attributes for this argument.
+ // p->printOptionalAttrDict(op.getArgAttrs(i));
+ }
+ *p << ')';
+
+ switch (fnType.getResults().size()) {
+ case 0:
+ break;
+ case 1: {
+ *p << " -> ";
+ auto resultType = fnType.getResults()[0];
+ bool resultIsFunc = resultType.isa<FunctionType>();
+ if (resultIsFunc)
+ *p << '(';
+ p->printType(resultType);
+ if (resultIsFunc)
+ *p << ')';
+ break;
+ }
+ default:
+ *p << " -> (";
+ interleaveComma(fnType.getResults(), *p);
+ *p << ')';
+ break;
+ }
+}
+
+void FuncOp::print(OpAsmPrinter *p) {
+ *p << "func @" << getName();
+
+ // Print the signature.
+ printFunctionSignature(p, *this);
+
+ // Print out function attributes, if present.
+ auto attrs = getAttrs();
+
+ // We must have more attributes than <name, type>.
+ constexpr unsigned kNumHiddenAttrs = 2;
+ if (attrs.size() > kNumHiddenAttrs) {
+ *p << "\n attributes ";
+ p->printOptionalAttrDict(attrs, {"name", "type"});
+ }
+
+ // Print the body if this is not an external function.
+ if (!isExternal()) {
+ p->printRegion(getBody(), /*printEntryBlockArgs=*/false,
+ /*printBlockTerminators=*/true);
+ *p << '\n';
+ }
+ *p << '\n';
+}
addTypes<ComplexType, FloatType, FunctionType, IndexType, IntegerType,
MemRefType, NoneType, OpaqueType, RankedTensorType, TupleType,
UnrankedTensorType, VectorType>();
+
+ // TODO: FuncOp should be moved to a different dialect when it has been
+ // fully decoupled from the core.
+ addOperations<FuncOp>();
}
};
}
void Dialect::addOperation(AbstractOperation opInfo) {
- assert(opInfo.name.split('.').first == getNamespace() &&
- "op name doesn't start with dialect namespace");
+ assert(getNamespace().empty() ||
+ opInfo.name.split('.').first == getNamespace() &&
+ "op name doesn't start with dialect namespace");
assert(&opInfo.dialect == this && "Dialect object mismatch");
auto &impl = context->getImpl();
ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) override {
if (parser.parseToken(Token::colon, "expected ':'"))
return failure();
+ return parser.parseTypeListNoParens(result);
+ }
- do {
- if (auto type = parser.parseType())
- result.push_back(type);
- else
- return failure();
-
- } while (parser.consumeIf(Token::comma));
- return success();
+ /// Parse an arrow followed by a type list, which must have at least one type.
+ ParseResult parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) {
+ if (!parser.consumeIf(Token::arrow))
+ return success();
+ return parser.parseFunctionResultTypes(result);
}
ParseResult parseTrailingOperandList(SmallVectorImpl<OperandType> &result,
return parser.parseToken(Token::r_paren, "expected ')'");
}
+ /// Parses a ')' if present.
+ ParseResult parseOptionalRParen() override {
+ return success(parser.consumeIf(Token::r_paren));
+ }
+
ParseResult parseOperandList(SmallVectorImpl<OperandType> &result,
int requiredOperandCount = -1,
Delimiter delimiter = Delimiter::None) override {
return parser.parseRegion(region, regionArguments);
}
+ /// Parses an optional region.
+ ParseResult parseOptionalRegion(Region ®ion,
+ ArrayRef<OperandType> arguments,
+ ArrayRef<Type> argTypes) override {
+ if (parser.getToken().isNot(Token::l_brace)) {
+ if (!arguments.empty())
+ return emitError(
+ parser.getToken().getLoc(),
+ "optional region with explicit entry arguments must be defined");
+ return success();
+ }
+ return parseRegion(region, arguments, argTypes);
+ }
+
/// Parse a region argument. Region arguments define new values, so this also
/// checks if the values with the same name has not been defined yet. The
/// type of the argument will be resolved later by a call to `parseRegion`.
return success();
}
+ /// Parse an optional region argument.
+ ParseResult parseOptionalRegionArgument(OperandType &argument) override {
+ if (parser.getToken().isNot(Token::percent_identifier))
+ return success();
+ return parseRegionArgument(argument);
+ }
+
//===--------------------------------------------------------------------===//
// Methods for interacting with the parser
//===--------------------------------------------------------------------===//
--- /dev/null
+// RUN: mlir-opt %s | FileCheck %s
+
+// CHECK-LABEL: func @external_func
+func @external_func() {
+ // CHECK-NEXT: func @external_func(i32, i64)
+ func @external_func(i32, i64) -> ()
+
+ // CHECK: func @external_func_with_result() -> (i1, i32)
+ func @external_func_with_result() -> (i1, i32)
+ return
+}
+
+// CHECK-LABEL: func @complex_func
+func @complex_func() {
+ // CHECK-NEXT: func @test_dimop(%i0: tensor<4x4x?xf32>) -> index {
+ func @test_dimop(%i0: tensor<4x4x?xf32>) -> index {
+ %0 = dim %i0, 2 : tensor<4x4x?xf32>
+ "foo.return"(%0) : (index) -> ()
+ }
+ return
+}
+
+// CHECK-LABEL: func @func_attributes
+func @func_attributes() {
+ // CHECK-NEXT: func @foo()
+ // CHECK-NEXT: attributes {foo: true}
+ func @foo() attributes {foo: true}
+ return
+}
--- /dev/null
+// RUN: mlir-opt %s -split-input-file -verify
+
+// -----
+
+func @func_op() {
+ // expected-error@+1 {{expected non-function type}}
+ func missingsigil() -> (i1, index, f32)
+ return
+}
+
+// -----
+
+func @func_op() {
+ // expected-error@+1 {{expected type instead of SSA identifier}}
+ func @mixed_named_arguments(f32, %a : i32) {
+ return
+ }
+ return
+}
+
+// -----
+
+func @func_op() {
+ // expected-error@+1 {{expected SSA identifier}}
+ func @mixed_named_arguments(%a : i32, f32) -> () {
+ return
+ }
+ return
+}
+
+// -----
+
+func @func_op() {
+ // expected-error@+2 {{optional region with explicit entry arguments must be defined}}
+ func @mixed_named_arguments(%a : i32)
+ return
+}
// -----
-func @missing_rbrace() {
- return
-func @d() {return} // expected-error {{custom op 'func' is unknown}}
-
-// -----
-
func @malformed_type(%a : intt) { // expected-error {{expected non-function type}}
}