Start defining a new operation 'FuncOp' that replicates all of the functionality...
authorRiver Riddle <riverriddle@google.com>
Mon, 3 Jun 2019 19:08:22 +0000 (12:08 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Tue, 4 Jun 2019 02:26:46 +0000 (19:26 -0700)
PiperOrigin-RevId: 251281612

mlir/include/mlir/IR/Function.h
mlir/include/mlir/IR/OpImplementation.h
mlir/lib/IR/Function.cpp
mlir/lib/IR/MLIRContext.cpp
mlir/lib/Parser/Parser.cpp
mlir/test/IR/func-op.mlir [new file with mode: 0644]
mlir/test/IR/invalid-func-op.mlir [new file with mode: 0644]
mlir/test/IR/invalid.mlir

index 81ba845..ed7083e 100644 (file)
 #ifndef MLIR_IR_FUNCTION_H
 #define MLIR_IR_FUNCTION_H
 
-#include "mlir/IR/Attributes.h"
 #include "mlir/IR/Block.h"
-#include "mlir/IR/Identifier.h"
-#include "mlir/IR/Location.h"
+#include "mlir/IR/OpDefinition.h"
 
 namespace mlir {
 class BlockAndValueMapping;
@@ -320,6 +318,85 @@ private:
   friend struct llvm::ilist_traits<Function>;
 };
 
+//===--------------------------------------------------------------------===//
+// Function Operation.
+//===--------------------------------------------------------------------===//
+
+/// FuncOp represents a function, or a named operation containing one region
+/// that forms a CFG(Control Flow Graph). The region of a function is not
+/// allowed to implicitly capture global values, and all external references
+/// must use Function arguments or attributes.
+class FuncOp : public Op<FuncOp, OpTrait::ZeroOperands, OpTrait::ZeroResult,
+                         OpTrait::NthRegionIsIsolatedAbove<0>::Impl> {
+public:
+  using Op::Op;
+  static StringRef getOperationName() { return "func"; }
+
+  static void build(Builder *builder, OperationState *result, StringRef name,
+                    FunctionType type, ArrayRef<NamedAttribute> attrs);
+
+  /// Parsing/Printing methods.
+  static ParseResult parse(OpAsmParser *parser, OperationState *result);
+  void print(OpAsmPrinter *p);
+
+  /// Returns the name of this function.
+  StringRef getName() { return getAttrOfType<StringAttr>("name").getValue(); }
+
+  /// Returns the type of this function.
+  FunctionType getType() {
+    return getAttrOfType<TypeAttr>("type").getValue().cast<FunctionType>();
+  }
+
+  /// Returns true if this function is external, i.e. it has no body.
+  bool isExternal() { return empty(); }
+
+  //===--------------------------------------------------------------------===//
+  // Body Handling
+  //===--------------------------------------------------------------------===//
+
+  Region &getBody() { return getOperation()->getRegion(0); }
+
+  /// This is the list of blocks in the function.
+  using RegionType = Region::RegionType;
+  RegionType &getBlocks() { return getBody().getBlocks(); }
+
+  // Iteration over the block in the function.
+  using iterator = RegionType::iterator;
+  using reverse_iterator = RegionType::reverse_iterator;
+
+  iterator begin() { return getBody().begin(); }
+  iterator end() { return getBody().end(); }
+  reverse_iterator rbegin() { return getBody().rbegin(); }
+  reverse_iterator rend() { return getBody().rend(); }
+
+  bool empty() { return getBody().empty(); }
+  void push_back(Block *block) { getBody().push_back(block); }
+  void push_front(Block *block) { getBody().push_front(block); }
+
+  Block &back() { return getBody().back(); }
+  Block &front() { return getBody().front(); }
+
+  //===--------------------------------------------------------------------===//
+  // Argument Handling
+  //===--------------------------------------------------------------------===//
+
+  /// Returns number of arguments.
+  unsigned getNumArguments() { return getType().getInputs().size(); }
+
+  /// Gets argument.
+  BlockArgument *getArgument(unsigned idx) {
+    return getBlocks().front().getArgument(idx);
+  }
+
+  // Supports non-const operand iteration.
+  using args_iterator = Block::args_iterator;
+  args_iterator args_begin() { return front().args_begin(); }
+  args_iterator args_end() { return front().args_end(); }
+  llvm::iterator_range<args_iterator> getArguments() {
+    return {args_begin(), args_end()};
+  }
+};
+
 } // end namespace mlir
 
 //===----------------------------------------------------------------------===//
index 9168d28..67a173f 100644 (file)
@@ -169,6 +169,9 @@ public:
   /// Parse a ')' token.
   virtual ParseResult parseRParen() = 0;
 
+  /// Parses a ')' if present.
+  virtual ParseResult parseOptionalRParen() = 0;
+
   /// This parses an equal(=) token!
   virtual ParseResult parseEqual() = 0;
 
@@ -199,6 +202,10 @@ public:
   /// Parse a colon followed by a type list, which must have at least one type.
   virtual ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) = 0;
 
+  /// Parse an optional arrow followed by a type list.
+  virtual ParseResult
+  parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) = 0;
+
   /// Parse a keyword followed by a type.
   ParseResult parseKeywordType(const char *keyword, Type &result) {
     return failure(parseKeyword(keyword) || parseType(result));
@@ -320,10 +327,18 @@ public:
                                   ArrayRef<OperandType> arguments,
                                   ArrayRef<Type> argTypes) = 0;
 
+  /// Parses an optional region.
+  virtual ParseResult parseOptionalRegion(Region &region,
+                                          ArrayRef<OperandType> arguments,
+                                          ArrayRef<Type> argTypes) = 0;
+
   /// Parse a region argument.  Region arguments define new values, so this also
   /// checks if the values with the same name has not been defined yet.
   virtual ParseResult parseRegionArgument(OperandType &argument) = 0;
 
+  /// Parse an optional region argument.
+  virtual ParseResult parseOptionalRegionArgument(OperandType &argument) = 0;
+
   //===--------------------------------------------------------------------===//
   // Methods for interacting with the parser
   //===--------------------------------------------------------------------===//
index 61e2975..f4ee155 100644 (file)
 
 #include "mlir/IR/Function.h"
 #include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Builders.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Module.h"
+#include "mlir/IR/OpImplementation.h"
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/Twine.h"
@@ -216,3 +218,201 @@ void Function::walk(const std::function<void(Operation *)> &callback) {
   for (auto &block : getBlocks())
     block.walk(callback);
 }
+
+//===----------------------------------------------------------------------===//
+// Function Operation.
+//===----------------------------------------------------------------------===//
+
+void FuncOp::build(Builder *builder, OperationState *result, StringRef name,
+                   FunctionType type, ArrayRef<NamedAttribute> attrs) {
+  result->addAttribute("name", builder->getStringAttr(name));
+  result->addAttribute("type", builder->getTypeAttr(type));
+  result->attributes.append(attrs.begin(), attrs.end());
+  result->addRegion();
+}
+
+/// Parsing/Printing methods.
+static ParseResult
+parseArgumentList(OpAsmParser *parser, SmallVectorImpl<Type> &argTypes,
+                  SmallVectorImpl<OpAsmParser::OperandType> &argNames,
+                  SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs) {
+  if (parser->parseLParen())
+    return failure();
+
+  // The argument list either has to consistently have ssa-id's followed by
+  // types, or just be a type list.  It isn't ok to sometimes have SSA ID's and
+  // sometimes not.
+  auto parseArgument = [&]() -> ParseResult {
+    llvm::SMLoc loc;
+    parser->getCurrentLocation(&loc);
+
+    // Parse argument name if present.
+    OpAsmParser::OperandType argument;
+    Type argumentType;
+    if (succeeded(parser->parseOptionalRegionArgument(argument)) &&
+        !argument.name.empty()) {
+      // Reject this if the preceding argument was missing a name.
+      if (argNames.empty() && !argTypes.empty())
+        return parser->emitError(loc,
+                                 "expected type instead of SSA identifier");
+      argNames.push_back(argument);
+
+      if (parser->parseColonType(argumentType))
+        return failure();
+    } else if (!argNames.empty()) {
+      // Reject this if the preceding argument had a name.
+      return parser->emitError(loc, "expected SSA identifier");
+    } else if (parser->parseType(argumentType)) {
+      return failure();
+    }
+
+    // Add the argument type.
+    argTypes.push_back(argumentType);
+
+    // TODO(riverriddle) Parse argument attributes.
+    // Parse the attribute dict.
+    // SmallVector<NamedAttribute, 2> attrs;
+    // if (parser->parseOptionalAttributeDict(attrs))
+    //  return failure();
+    // argAttrs.push_back(attrs);
+    return success();
+  };
+
+  // Parse the function arguments.
+  if (parser->parseOptionalRParen()) {
+    do {
+      if (parseArgument())
+        return failure();
+    } while (succeeded(parser->parseOptionalComma()));
+    parser->parseRParen();
+  }
+
+  return success();
+}
+
+/// Parse a function signature, starting with a name and including the
+/// parameter list.
+static ParseResult parseFunctionSignature(
+    OpAsmParser *parser, FunctionType &type,
+    SmallVectorImpl<OpAsmParser::OperandType> &argNames,
+    SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs) {
+  SmallVector<Type, 4> argTypes;
+  if (parseArgumentList(parser, argTypes, argNames, argAttrs))
+    return failure();
+
+  // Parse the return types if present.
+  SmallVector<Type, 4> results;
+  if (parser->parseOptionalArrowTypeList(results))
+    return failure();
+  type = parser->getBuilder().getFunctionType(argTypes, results);
+  return success();
+}
+
+ParseResult FuncOp::parse(OpAsmParser *parser, OperationState *result) {
+  FunctionType type;
+  SmallVector<OpAsmParser::OperandType, 4> entryArgs;
+  SmallVector<SmallVector<NamedAttribute, 2>, 4> argAttrs;
+  auto &builder = parser->getBuilder();
+
+  // Parse the name as a function attribute.
+  FunctionAttr nameAttr;
+  if (parser->parseAttribute(nameAttr, "name", result->attributes))
+    return failure();
+  // Convert the parsed function attr into a string attr.
+  result->attributes.back().second = builder.getStringAttr(nameAttr.getValue());
+
+  // Parse the function signature.
+  if (parseFunctionSignature(parser, type, entryArgs, argAttrs))
+    return failure();
+  result->addAttribute("type", builder.getTypeAttr(type));
+
+  // If function attributes are present, parse them.
+  if (succeeded(parser->parseOptionalKeyword("attributes")))
+    if (parser->parseOptionalAttributeDict(result->attributes))
+      return failure();
+
+  // TODO(riverriddle) Parse argument attributes.
+  // Add the attributes to the function arguments.
+  // for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i)
+  //   if (!argAttrs[i].empty())
+  //     result->addAttribute(("arg" + Twine(i)).str(),
+  //                          builder.getDictionaryAttr(argAttrs[i]));
+
+  // Parse the optional function body.
+  auto *body = result->addRegion();
+  if (parser->parseOptionalRegion(
+          *body, entryArgs, entryArgs.empty() ? llvm::None : type.getInputs()))
+    return failure();
+
+  return success();
+}
+
+static void printFunctionSignature(OpAsmPrinter *p, FuncOp op) {
+  *p << '(';
+
+  auto fnType = op.getType();
+  bool isExternal = op.isExternal();
+  for (unsigned i = 0, e = op.getNumArguments(); i != e; ++i) {
+    if (i > 0)
+      *p << ", ";
+
+    // If this is an external function, don't print argument labels.
+    if (!isExternal) {
+      p->printOperand(op.getArgument(i));
+      *p << ": ";
+    }
+
+    p->printType(fnType.getInput(i));
+
+    // TODO(riverriddle) Print argument attributes.
+    // Print the attributes for this argument.
+    // p->printOptionalAttrDict(op.getArgAttrs(i));
+  }
+  *p << ')';
+
+  switch (fnType.getResults().size()) {
+  case 0:
+    break;
+  case 1: {
+    *p << " -> ";
+    auto resultType = fnType.getResults()[0];
+    bool resultIsFunc = resultType.isa<FunctionType>();
+    if (resultIsFunc)
+      *p << '(';
+    p->printType(resultType);
+    if (resultIsFunc)
+      *p << ')';
+    break;
+  }
+  default:
+    *p << " -> (";
+    interleaveComma(fnType.getResults(), *p);
+    *p << ')';
+    break;
+  }
+}
+
+void FuncOp::print(OpAsmPrinter *p) {
+  *p << "func @" << getName();
+
+  // Print the signature.
+  printFunctionSignature(p, *this);
+
+  // Print out function attributes, if present.
+  auto attrs = getAttrs();
+
+  // We must have more attributes than <name, type>.
+  constexpr unsigned kNumHiddenAttrs = 2;
+  if (attrs.size() > kNumHiddenAttrs) {
+    *p << "\n  attributes ";
+    p->printOptionalAttrDict(attrs, {"name", "type"});
+  }
+
+  // Print the body if this is not an external function.
+  if (!isExternal()) {
+    p->printRegion(getBody(), /*printEntryBlockArgs=*/false,
+                   /*printBlockTerminators=*/true);
+    *p << '\n';
+  }
+  *p << '\n';
+}
index bf1ae96..678848b 100644 (file)
@@ -143,6 +143,10 @@ struct BuiltinDialect : public Dialect {
     addTypes<ComplexType, FloatType, FunctionType, IndexType, IntegerType,
              MemRefType, NoneType, OpaqueType, RankedTensorType, TupleType,
              UnrankedTensorType, VectorType>();
+
+    // TODO: FuncOp should be moved to a different dialect when it has been
+    // fully decoupled from the core.
+    addOperations<FuncOp>();
   }
 };
 
@@ -484,8 +488,9 @@ std::vector<AbstractOperation *> MLIRContext::getRegisteredOperations() {
 }
 
 void Dialect::addOperation(AbstractOperation opInfo) {
-  assert(opInfo.name.split('.').first == getNamespace() &&
-         "op name doesn't start with dialect namespace");
+  assert(getNamespace().empty() ||
+         opInfo.name.split('.').first == getNamespace() &&
+             "op name doesn't start with dialect namespace");
   assert(&opInfo.dialect == this && "Dialect object mismatch");
   auto &impl = context->getImpl();
 
index 87f94f7..635adec 100644 (file)
@@ -3219,15 +3219,14 @@ public:
   ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) override {
     if (parser.parseToken(Token::colon, "expected ':'"))
       return failure();
+    return parser.parseTypeListNoParens(result);
+  }
 
-    do {
-      if (auto type = parser.parseType())
-        result.push_back(type);
-      else
-        return failure();
-
-    } while (parser.consumeIf(Token::comma));
-    return success();
+  /// Parse an arrow followed by a type list, which must have at least one type.
+  ParseResult parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) {
+    if (!parser.consumeIf(Token::arrow))
+      return success();
+    return parser.parseFunctionResultTypes(result);
   }
 
   ParseResult parseTrailingOperandList(SmallVectorImpl<OperandType> &result,
@@ -3313,6 +3312,11 @@ public:
     return parser.parseToken(Token::r_paren, "expected ')'");
   }
 
+  /// Parses a ')' if present.
+  ParseResult parseOptionalRParen() override {
+    return success(parser.consumeIf(Token::r_paren));
+  }
+
   ParseResult parseOperandList(SmallVectorImpl<OperandType> &result,
                                int requiredOperandCount = -1,
                                Delimiter delimiter = Delimiter::None) override {
@@ -3411,6 +3415,20 @@ public:
     return parser.parseRegion(region, regionArguments);
   }
 
+  /// Parses an optional region.
+  ParseResult parseOptionalRegion(Region &region,
+                                  ArrayRef<OperandType> arguments,
+                                  ArrayRef<Type> argTypes) override {
+    if (parser.getToken().isNot(Token::l_brace)) {
+      if (!arguments.empty())
+        return emitError(
+            parser.getToken().getLoc(),
+            "optional region with explicit entry arguments must be defined");
+      return success();
+    }
+    return parseRegion(region, arguments, argTypes);
+  }
+
   /// Parse a region argument.  Region arguments define new values, so this also
   /// checks if the values with the same name has not been defined yet.  The
   /// type of the argument will be resolved later by a call to `parseRegion`.
@@ -3428,6 +3446,13 @@ public:
     return success();
   }
 
+  /// Parse an optional region argument.
+  ParseResult parseOptionalRegionArgument(OperandType &argument) override {
+    if (parser.getToken().isNot(Token::percent_identifier))
+      return success();
+    return parseRegionArgument(argument);
+  }
+
   //===--------------------------------------------------------------------===//
   // Methods for interacting with the parser
   //===--------------------------------------------------------------------===//
diff --git a/mlir/test/IR/func-op.mlir b/mlir/test/IR/func-op.mlir
new file mode 100644 (file)
index 0000000..f264dc2
--- /dev/null
@@ -0,0 +1,29 @@
+// RUN: mlir-opt %s | FileCheck %s
+
+// CHECK-LABEL: func @external_func
+func @external_func() {
+  // CHECK-NEXT: func @external_func(i32, i64)
+  func @external_func(i32, i64) -> ()
+
+  // CHECK: func @external_func_with_result() -> (i1, i32)
+  func @external_func_with_result() -> (i1, i32)
+  return
+}
+
+// CHECK-LABEL: func @complex_func
+func @complex_func() {
+  // CHECK-NEXT: func @test_dimop(%i0: tensor<4x4x?xf32>) -> index {
+  func @test_dimop(%i0: tensor<4x4x?xf32>) -> index {
+    %0 = dim %i0, 2 : tensor<4x4x?xf32>
+    "foo.return"(%0) : (index) -> ()
+  }
+  return
+}
+
+// CHECK-LABEL: func @func_attributes
+func @func_attributes() {
+  // CHECK-NEXT: func @foo()
+  // CHECK-NEXT:   attributes {foo: true}
+  func @foo() attributes {foo: true}
+  return
+}
diff --git a/mlir/test/IR/invalid-func-op.mlir b/mlir/test/IR/invalid-func-op.mlir
new file mode 100644 (file)
index 0000000..a6e84f8
--- /dev/null
@@ -0,0 +1,37 @@
+// RUN: mlir-opt %s -split-input-file -verify
+
+// -----
+
+func @func_op() {
+  // expected-error@+1 {{expected non-function type}}
+  func missingsigil() -> (i1, index, f32)
+  return
+}
+
+// -----
+
+func @func_op() {
+  // expected-error@+1 {{expected type instead of SSA identifier}}
+  func @mixed_named_arguments(f32, %a : i32) {
+    return
+  }
+  return
+}
+
+// -----
+
+func @func_op() {
+  // expected-error@+1 {{expected SSA identifier}}
+  func @mixed_named_arguments(%a : i32, f32) -> () {
+    return
+  }
+  return
+}
+
+// -----
+
+func @func_op() {
+  // expected-error@+2 {{optional region with explicit entry arguments must be defined}}
+  func @mixed_named_arguments(%a : i32)
+  return
+}
index a7e7035..9825648 100644 (file)
@@ -310,12 +310,6 @@ func @undef() {
 
 // -----
 
-func @missing_rbrace() {
-  return
-func @d() {return} // expected-error {{custom op 'func' is unknown}}
-
-// -----
-
 func @malformed_type(%a : intt) { // expected-error {{expected non-function type}}
 }