Adding additional dialect parsing utilities, conversion wrappers, and traversal helpers.
authorBen Vanik <benvanik@google.com>
Tue, 4 Jun 2019 18:35:21 +0000 (11:35 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 9 Jun 2019 23:16:59 +0000 (16:16 -0700)
- added a typed walk to Block (matching the equivalent on Function)
- added token parsers (incl optional variants) for : and (
- added applyConversionPatterns that takes a list of functions to apply patterns to

PiperOrigin-RevId: 251481608

mlir/include/mlir/IR/Block.h
mlir/include/mlir/IR/OpImplementation.h
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Parser/Parser.cpp
mlir/lib/Transforms/DialectConversion.cpp

index 93d4aff..3c627b4 100644 (file)
@@ -249,6 +249,14 @@ public:
   /// each operation.
   void walk(const std::function<void(Operation *)> &callback);
 
+  /// Specialization of walk to only visit operations of 'OpTy'.
+  template <typename OpTy> void walk(std::function<void(OpTy)> callback) {
+    walk([&](Operation *opInst) {
+      if (auto op = dyn_cast<OpTy>(opInst))
+        callback(op);
+    });
+  }
+
   /// Walk the operations in the specified [begin, end) range of this block in
   /// postorder, calling the callback for each operation.
   void walk(Block::iterator begin, Block::iterator end,
index 67a173f..1ebed7c 100644 (file)
@@ -163,16 +163,22 @@ public:
   /// Parse a `:` token.
   virtual ParseResult parseColon() = 0;
 
+  /// Parse a `:` token if present.
+  virtual ParseResult parseOptionalColon() = 0;
+
   /// Parse a '(' token.
   virtual ParseResult parseLParen() = 0;
 
+  /// Parse a '(' token if present.
+  virtual ParseResult parseOptionalLParen() = 0;
+
   /// Parse a ')' token.
   virtual ParseResult parseRParen() = 0;
 
-  /// Parses a ')' if present.
+  /// Parse a ')' token if present.
   virtual ParseResult parseOptionalRParen() = 0;
 
-  /// This parses an equal(=) token!
+  /// Parse a '=' token.
   virtual ParseResult parseEqual() = 0;
 
   /// Parse a type.
index ac24252..3886f0c 100644 (file)
@@ -234,6 +234,16 @@ LLVM_NODISCARD LogicalResult applyConversionPatterns(
     Module &module, ConversionTarget &target, TypeConverter &converter,
     OwningRewritePatternList &&patterns);
 
+/// Convert the given functions with the provided conversion patterns. This will
+/// convert as many of the operations within each function as possible given the
+/// set of patterns. If conversion fails for specific functions, those functions
+/// remains unmodified.
+LLVM_NODISCARD
+LogicalResult applyConversionPatterns(ArrayRef<Function *> fns,
+                                      ConversionTarget &target,
+                                      TypeConverter &converter,
+                                      OwningRewritePatternList &&patterns);
+
 /// Convert the given function with the provided conversion patterns. This will
 /// convert as many of the operations within 'fn' as possible given the set of
 /// patterns.
index 635adec..6cc933a 100644 (file)
@@ -3197,12 +3197,23 @@ public:
     *loc = parser.getToken().getLoc();
     return success();
   }
+
   ParseResult parseComma() override {
     return parser.parseToken(Token::comma, "expected ','");
   }
+
+  ParseResult parseOptionalComma() override {
+    return success(parser.consumeIf(Token::comma));
+  }
+
   ParseResult parseColon() override {
     return parser.parseToken(Token::colon, "expected ':'");
   }
+
+  ParseResult parseOptionalColon() override {
+    return success(parser.consumeIf(Token::colon));
+  }
+
   ParseResult parseEqual() override {
     return parser.parseToken(Token::equal, "expected '='");
   }
@@ -3242,10 +3253,6 @@ public:
     return success();
   }
 
-  ParseResult parseOptionalComma() override {
-    return success(parser.consumeIf(Token::comma));
-  }
-
   /// Parse an optional keyword.
   ParseResult parseOptionalKeyword(const char *keyword) override {
     // Check that the current token is a bare identifier or keyword.
@@ -3308,6 +3315,11 @@ public:
     return parser.parseToken(Token::l_paren, "expected '('");
   }
 
+  /// Parses a '(' if present.
+  ParseResult parseOptionalLParen() override {
+    return success(parser.consumeIf(Token::l_paren));
+  }
+
   ParseResult parseRParen() override {
     return parser.parseToken(Token::r_paren, "expected ')'");
   }
index 4c110d1..6002cad 100644 (file)
@@ -642,8 +642,26 @@ LogicalResult
 mlir::applyConversionPatterns(Module &module, ConversionTarget &target,
                               TypeConverter &converter,
                               OwningRewritePatternList &&patterns) {
+  std::vector<Function *> allFunctions;
+  allFunctions.reserve(module.getFunctions().size());
+  for (auto &func : module)
+    allFunctions.push_back(&func);
+  return applyConversionPatterns(allFunctions, target, converter,
+                                 std::move(patterns));
+}
+
+/// Convert the given functions with the provided conversion patterns. This will
+/// convert as many of the operations within each function as possible given the
+/// set of patterns. If conversion fails for specific functions, those functions
+// remains unmodified.
+LogicalResult mlir::applyConversionPatterns(
+    ArrayRef<Function *> fns, ConversionTarget &target,
+    TypeConverter &converter, OwningRewritePatternList &&patterns) {
+  if (fns.empty())
+    return success();
+
   // Build the function converter.
-  FunctionConverter funcConverter(module.getContext(), target, patterns,
+  FunctionConverter funcConverter(fns.front()->getContext(), target, patterns,
                                   &converter);
 
   // Try to convert each of the functions within the module. Defer updating the
@@ -652,21 +670,21 @@ mlir::applyConversionPatterns(Module &module, ConversionTarget &target,
   // public signatures of the functions within the module before they are
   // updated.
   std::vector<ConvertedFunction> toConvert;
-  toConvert.reserve(module.getFunctions().size());
-  for (auto &func : module) {
+  toConvert.reserve(fns.size());
+  for (auto *func : fns) {
     // Convert the function type using the dialect converter.
     SmallVector<NamedAttributeList, 4> newFunctionArgAttrs;
     FunctionType newType = converter.convertFunctionSignatureType(
-        func.getType(), func.getAllArgAttrs(), newFunctionArgAttrs);
+        func->getType(), func->getAllArgAttrs(), newFunctionArgAttrs);
     if (!newType || !newType.isa<FunctionType>())
-      return func.emitError("could not convert function type");
+      return func->emitError("could not convert function type");
 
     // Convert the body of this function.
-    if (failed(funcConverter.convertFunction(&func)))
+    if (failed(funcConverter.convertFunction(func)))
       return failure();
 
     // Add function signature to be updated.
-    toConvert.emplace_back(&func, newType.cast<FunctionType>(),
+    toConvert.emplace_back(func, newType.cast<FunctionType>(),
                            newFunctionArgAttrs);
   }