Refactor DialectConversion to operate on functions in-place *without* any cloning...
authorRiver Riddle <riverriddle@google.com>
Wed, 22 May 2019 18:49:04 +0000 (11:49 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 2 Jun 2019 02:56:24 +0000 (19:56 -0700)
--

PiperOrigin-RevId: 249490306

mlir/include/mlir/IR/Function.h
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Transforms/DialectConversion.cpp

index 1a7332d..4a9de2b 100644 (file)
@@ -178,6 +178,15 @@ public:
     assert(index < getNumArguments() && "invalid argument number");
     argAttrs[index].setAttrs(attributes);
   }
+  void setArgAttrs(unsigned index, NamedAttributeList attributes) {
+    assert(index < getNumArguments() && "invalid argument number");
+    argAttrs[index] = attributes;
+  }
+  void setAllArgAttrs(ArrayRef<NamedAttributeList> attributes) {
+    assert(attributes.size() == getNumArguments());
+    for (unsigned i = 0, e = attributes.size(); i != e; ++i)
+      argAttrs[i] = attributes[i];
+  }
 
   /// Return all argument attributes of this function.
   MutableArrayRef<NamedAttributeList> getAllArgAttrs() { return argAttrs; }
index 8bbce1d..3facb7a 100644 (file)
@@ -86,10 +86,9 @@ public:
   }
 
   /// Rewrite the IR rooted at the specified operation with the result of
-  /// this pattern, generating any new operations with the specified
-  /// builder. If an unexpected error is encountered (an internal
-  /// compiler error), it is emitted through the normal MLIR diagnostic
-  /// hooks and the IR is left in a valid state.
+  /// this pattern. If an unexpected error is encountered (an internal compiler
+  /// error), it is emitted through the normal MLIR diagnostic hooks and the IR
+  /// is left in a valid state.
   void rewrite(Operation *op, PatternRewriter &rewriter) const final;
 
 private:
@@ -118,7 +117,9 @@ private:
 /// If any error happened during the conversion, the pass fails as soon as
 /// possible.
 ///
-/// If the conversion fails, the module is not modified.
+/// If conversion fails for a specific function, that functions remains
+/// unmodified. Otherwise, successfully converted functions will remain
+/// converted.
 class DialectConversion {
 public:
   virtual ~DialectConversion() = default;
index ff53880..de66d01 100644 (file)
@@ -26,71 +26,65 @@ using namespace mlir;
 using namespace mlir::impl;
 
 //===----------------------------------------------------------------------===//
-// ProducerGenerator
+// ArgConverter
 //===----------------------------------------------------------------------===//
 namespace {
-/// This class provides a simple interface for generating fake producers during
-/// the conversion process. These fake producers are used when replacing the
-/// results of an operation with values of a new, legal, type. The producer
-/// provides a definition for the remaining uses of the old value while they
-/// await conversion.
-struct ProducerGenerator {
-  ProducerGenerator(MLIRContext *ctx)
-      : producerOpName(kProducerName, ctx), loc(UnknownLoc::get(ctx)) {}
-
-  /// Cleanup any generated conversion values. Returns failure if there are any
-  /// dangling references to a producer operation, success otherwise.
-  LogicalResult cleanupGeneratedOps() {
-    for (auto *op : producerOps) {
-      if (!op->use_empty()) {
-        auto diag = op->getContext()->emitError(loc)
-                    << "Converter did not convert all uses of replaced value "
-                       "with illegal type";
-        for (auto *user : op->getResult(0)->getUsers())
-          diag.attachNote(user->getLoc())
-              << "user was not converted : " << *user;
-        return diag;
-      }
+/// This class provides a simple interface for converting the types of block
+/// arguments. This is done by inserting fake cast operations for the illegal
+/// type that allow for updating the real type to return the correct type.
+struct ArgConverter {
+  ArgConverter(MLIRContext *ctx)
+      : castOpName(kCastName, ctx), loc(UnknownLoc::get(ctx)) {}
+
+  /// Cleanup and undo any generated conversion values.
+  void discardRewrites() {
+    // On failure drop all uses of the cast operation and destroy it.
+    for (auto *op : castOps) {
+      op->getResult(0)->dropAllUses();
+      op->destroy();
+    }
+    castOps.clear();
+  }
+
+  /// Replace usages of the cast operations with the argument directly.
+  void applyRewrites() {
+    // On success, we update the type of the block argument and replace uses of
+    // the cast.
+    for (auto *op : castOps) {
+      op->getOperand(0)->setType(op->getResult(0)->getType());
+      op->getResult(0)->replaceAllUsesWith(op->getOperand(0));
       op->destroy();
     }
-    return success();
   }
 
-  /// Generate a producer value for 'oldValue'. These new producers replace all
-  /// of the current uses of the original value, and record a mapping between
-  /// for replacement with the 'newValue'.
-  void generateAndReplace(Value *oldValue, Value *newValue,
-                          BlockAndValueMapping &mapping) {
-    if (oldValue->use_empty())
-      return;
-
-    // Otherwise, generate a new producer operation for the given value type.
-    auto *producer = Operation::create(
-        loc, producerOpName, llvm::None, oldValue->getType(), llvm::None,
-        llvm::None, 0, false, oldValue->getContext());
-
-    // Replace the uses of the old value and record the mapping.
-    oldValue->replaceAllUsesWith(producer->getResult(0));
-    mapping.map(producer->getResult(0), newValue);
-    producerOps.push_back(producer);
+  /// Generate a cast operation for 'arg' that produces the new, legal, type.
+  void castArgument(BlockArgument *arg, Type newType,
+                    BlockAndValueMapping &mapping) {
+    // Otherwise, generate a new cast operation for the given value type.
+    auto *cast = Operation::create(loc, castOpName, arg, newType, llvm::None,
+                                   llvm::None, 0, false, arg->getContext());
+
+    // Replace the uses of the argument and record the mapping.
+    mapping.map(arg, cast->getResult(0));
+    castOps.push_back(cast);
   }
 
   /// This is an operation name for a fake operation that is inserted during the
   /// conversion process. Operations of this type are guaranteed to never escape
   /// the converter.
-  static constexpr StringLiteral kProducerName = "__mlir_conversion.producer";
-  OperationName producerOpName;
+  static constexpr StringLiteral kCastName = "__mlir_conversion.cast";
+  OperationName castOpName;
 
-  /// This is a collection of producer values that were generated during the
+  /// This is a collection of cast values that were generated during the
   /// conversion process.
-  std::vector<Operation *> producerOps;
+  std::vector<Operation *> castOps;
 
   /// An instance of the unknown location that is used when generating
   /// producers.
   UnknownLoc loc;
 };
 
-constexpr StringLiteral ProducerGenerator::kProducerName;
+constexpr StringLiteral ArgConverter::kCastName;
 
 //===----------------------------------------------------------------------===//
 // DialectConversionRewriter
@@ -99,43 +93,91 @@ constexpr StringLiteral ProducerGenerator::kProducerName;
 /// This class implements a pattern rewriter for DialectConversionPattern
 /// patterns. It automatically performs remapping of replaced operation values.
 struct DialectConversionRewriter final : public PatternRewriter {
+  /// This class represents one requested operation replacement via 'replaceOp'.
+  struct OpReplacement {
+    OpReplacement() = default;
+    OpReplacement(Operation *op, ArrayRef<Value *> newValues)
+        : op(op), newValues(newValues.begin(), newValues.end()) {}
+
+    Operation *op;
+    SmallVector<Value *, 2> newValues;
+  };
+
   DialectConversionRewriter(Function *fn)
-      : PatternRewriter(fn), tempGenerator(fn->getContext()) {}
+      : PatternRewriter(fn), argConverter(fn->getContext()) {}
   ~DialectConversionRewriter() = default;
 
-  // Implement the hook for replacing an operation with new values.
+  /// Cleanup and destroy any generated rewrite operations. This method is
+  /// invoked when the conversion process fails.
+  void discardRewrites() {
+    argConverter.discardRewrites();
+    for (auto *op : createdOps) {
+      op->dropAllDefinedValueUses();
+      op->erase();
+    }
+  }
+
+  /// Apply all requested operation rewrites. This method is invoked when the
+  /// conversion process succeeds.
+  void applyRewrites() {
+    argConverter.applyRewrites();
+
+    // Apply all of the rewrites replacements requested during conversion.
+    for (auto &repl : replacements) {
+      for (unsigned i = 0, e = repl.newValues.size(); i != e; ++i)
+        repl.op->getResult(i)->replaceAllUsesWith(repl.newValues[i]);
+      repl.op->erase();
+    }
+  }
+
+  /// PatternRewriter hook for replacing the results of an operation.
   void replaceOp(Operation *op, ArrayRef<Value *> newValues,
                  ArrayRef<Value *> valuesToRemoveIfDead) override {
     assert(newValues.size() == op->getNumResults());
-    for (unsigned i = 0, e = newValues.size(); i < e; ++i) {
-      Value *result = op->getResult(i);
-      if (result->getType() != newValues[i]->getType())
-        tempGenerator.generateAndReplace(result, newValues[i], mapping);
-      else
-        result->replaceAllUsesWith(newValues[i]);
-    }
-    op->erase();
+    // Create mappings for any type changes.
+    for (unsigned i = 0, e = newValues.size(); i < e; ++i)
+      if (op->getResult(i)->getType() != newValues[i]->getType())
+        mapping.map(op->getResult(i), newValues[i]);
+
+    // Record the requested operation replacement.
+    replacements.emplace_back(op, newValues);
   }
 
-  // Implement the hook for creating operations, and make sure that newly
-  // created ops are added to the worklist for processing.
+  /// PatternRewriter hook for creating a new operation.
   Operation *createOperation(const OperationState &state) override {
-    return FuncBuilder::createOperation(state);
+    auto *result = FuncBuilder::createOperation(state);
+    createdOps.push_back(result);
+    return result;
+  }
+
+  /// PatternRewriter hook for updating the root operation in-place.
+  void notifyRootUpdated(Operation *op) override {
+    // The rewriter caches changes to the IR to allow for operating in-place and
+    // backtracking. The rewrite is currently not capable of backtracking
+    // in-place modifications.
+    llvm_unreachable("in-place operation updates are not supported");
   }
 
-  void lookupValues(Operation::operand_range operands,
-                    SmallVectorImpl<Value *> &remapped) {
+  /// Remap the given operands to those with potentially different types.
+  void remapValues(Operation::operand_range operands,
+                   SmallVectorImpl<Value *> &remapped) {
     remapped.reserve(llvm::size(operands));
     for (Value *operand : operands)
       remapped.push_back(mapping.lookupOrDefault(operand));
   }
 
-  // Mapping between values(blocks) in the original function and in the new
-  // function.
+  // Mapping between replaced values that differ in type. This happens when
+  // replacing a value with one of a different type.
   BlockAndValueMapping mapping;
 
-  /// Utility used to create temporary producers operations.
-  ProducerGenerator tempGenerator;
+  /// Utility used to convert block arguments.
+  ArgConverter argConverter;
+
+  /// Ordered vector of all of the newly created operations during conversion.
+  SmallVector<Operation *, 4> createdOps;
+
+  /// Ordered vector of any requested operation replacements.
+  SmallVector<OpReplacement, 4> replacements;
 };
 } // end anonymous namespace
 
@@ -143,16 +185,15 @@ struct DialectConversionRewriter final : public PatternRewriter {
 // DialectConversionPattern
 //===----------------------------------------------------------------------===//
 
-/// Rewrite the IR rooted at the specified operation with the result of
-/// this pattern, generating any new operations with the specified
-/// builder.  If an unexpected error is encountered (an internal
-/// compiler error), it is emitted through the normal MLIR diagnostic
-/// hooks and the IR is left in a valid state.
+/// Rewrite the IR rooted at the specified operation with the result of this
+/// pattern.  If an unexpected error is encountered (an internal compiler
+/// error), it is emitted through the normal MLIR diagnostic hooks and the IR is
+/// left in a valid state.
 void DialectConversionPattern::rewrite(Operation *op,
                                        PatternRewriter &rewriter) const {
   SmallVector<Value *, 4> operands;
   auto &dialectRewriter = static_cast<DialectConversionRewriter &>(rewriter);
-  dialectRewriter.lookupValues(op->getOperands(), operands);
+  dialectRewriter.remapValues(op->getOperands(), operands);
 
   // If this operation has no successors, invoke the rewrite directly.
   if (op->getNumSuccessors() == 0)
@@ -185,10 +226,8 @@ void DialectConversionPattern::rewrite(Operation *op,
 // FunctionConverter
 //===----------------------------------------------------------------------===//
 namespace {
-// Implementation detail class of the DialectConversion utility.  Performs
-// function-by-function conversions by creating new functions, filling them in
-// with converted blocks, updating the function attributes, and replacing the
-// old functions with the new ones in the module.
+// This class converts a single function using a given DialectConversion
+// structure.
 class FunctionConverter {
 public:
   // Constructs a FunctionConverter.
@@ -196,31 +235,31 @@ public:
                              RewritePatternMatcher &matcher)
       : dialectConversion(conversion), matcher(matcher) {}
 
-  // Converts the given function to the dialect using hooks defined in
-  // `dialectConversion`.  Returns the converted function or `nullptr` on error.
-  Function *convertFunction(Function *f);
+  /// Converts the given function to the dialect using hooks defined in
+  /// `dialectConversion`. Returns failure on error, success otherwise.
+  LogicalResult convertFunction(Function *f);
 
-  // Converts the given region starting from the entry block and following the
-  // block successors. Returns failure on error, success otherwise.
+  /// Converts the given region starting from the entry block and following the
+  /// block successors. Returns failure on error, success otherwise.
   template <typename RegionParent>
   LogicalResult convertRegion(DialectConversionRewriter &rewriter,
                               Region &region, RegionParent *parent);
 
-  // Converts a block by traversing its operations sequentially, attempting to
-  // match a pattern. If there is no match, recurses the operations regions if
-  // it has any.
+  /// Converts a block by traversing its operations sequentially, attempting to
+  /// match a pattern. If there is no match, recurses the operations regions if
+  /// it has any.
   //
-  // After converting operations, traverses the successor blocks unless they
-  // have been visited already as indicated in `visitedBlocks`.
+  /// After converting operations, traverses the successor blocks unless they
+  /// have been visited already as indicated in `visitedBlocks`.
   LogicalResult convertBlock(DialectConversionRewriter &rewriter, Block *block,
                              DenseSet<Block *> &visitedBlocks);
 
-  // Converts the type of the given block argument. Returns success if the
-  // argument type could be successfully converted, failure otherwise.
+  /// Converts the type of the given block argument. Returns success if the
+  /// argument type could be successfully converted, failure otherwise.
   LogicalResult convertArgument(DialectConversionRewriter &rewriter,
                                 BlockArgument *arg, Location loc);
 
-  // Pointer to a specific dialect conversion info.
+  /// Pointer to a specific dialect conversion info.
   DialectConversion *dialectConversion;
 
   /// The matcher to use when converting operations.
@@ -237,10 +276,8 @@ FunctionConverter::convertArgument(DialectConversionRewriter &rewriter,
            << "could not convert block argument of type : " << arg->getType();
 
   // Generate a replacement value, with the new type, for this argument.
-  if (convertedType != arg->getType()) {
-    rewriter.tempGenerator.generateAndReplace(arg, arg, rewriter.mapping);
-    arg->setType(convertedType);
-  }
+  if (convertedType != arg->getType())
+    rewriter.argConverter.castArgument(arg, convertedType, rewriter.mapping);
   return success();
 }
 
@@ -260,11 +297,6 @@ FunctionConverter::convertBlock(DialectConversionRewriter &rewriter,
     if (matcher.matchAndRewrite(&op, rewriter))
       continue;
 
-    // If a rewrite wasn't matched, update any mapped operands in place.
-    for (auto &operand : op.getOpOperands())
-      if (auto *newOperand = rewriter.mapping.lookupOrNull(operand.get()))
-        operand.set(newOperand);
-
     // Traverse any held regions.
     for (auto &region : op.getRegions())
       if (!region.empty() && failed(convertRegion(rewriter, region, &op)))
@@ -306,44 +338,46 @@ FunctionConverter::convertRegion(DialectConversionRewriter &rewriter,
   return success();
 }
 
-Function *FunctionConverter::convertFunction(Function *f) {
-  // Convert the function type using the dialect converter.
-  SmallVector<NamedAttributeList, 4> newFunctionArgAttrs;
-  Type newFunctionType = dialectConversion->convertFunctionSignatureType(
-      f->getType(), f->getAllArgAttrs(), newFunctionArgAttrs);
-  if (!newFunctionType)
-    return f->emitError("could not convert function type"), nullptr;
-
-  // Create a new function using the mapped function type and arg attributes.
-  auto *newFunc = new Function(f->getLoc(), f->getName().strref(),
-                               newFunctionType.cast<FunctionType>(),
-                               f->getAttrs(), newFunctionArgAttrs);
-  f->getModule()->getFunctions().push_back(newFunc);
-
-  // If this is not an external function, we need to convert the body.
-  if (!f->isExternal()) {
-    DialectConversionRewriter rewriter(f);
-    f->getBody().cloneInto(&newFunc->getBody(), rewriter.mapping,
-                           f->getContext());
-    rewriter.mapping.clear();
-    if (failed(convertRegion(rewriter, newFunc->getBody(), &*newFunc))) {
-      f->getModule()->getFunctions().pop_back();
-      return nullptr;
-    }
+LogicalResult FunctionConverter::convertFunction(Function *f) {
+  // If this is an external function, there is nothing else to do.
+  if (f->isExternal())
+    return success();
 
-    // Cleanup any temp producer operations that were generated by the rewriter.
-    if (failed(rewriter.tempGenerator.cleanupGeneratedOps())) {
-      f->getModule()->getFunctions().pop_back();
-      return nullptr;
-    }
+  // Rewrite the function body.
+  DialectConversionRewriter rewriter(f);
+  if (failed(convertRegion(rewriter, f->getBody(), f))) {
+    // Reset any of the converted arguments.
+    rewriter.argConverter.discardRewrites();
+    return failure();
   }
-  return newFunc;
+
+  // Otherwise the conversion succeeded, so apply all rewrites.
+  rewriter.applyRewrites();
+  return success();
 }
 
 //===----------------------------------------------------------------------===//
 // DialectConversion
 //===----------------------------------------------------------------------===//
 
+namespace {
+/// This class represents a function to be converted. It allows for converting
+/// the body of functions and the signature in two phases.
+struct ConvertedFunction {
+  ConvertedFunction(Function *fn, FunctionType newType,
+                    ArrayRef<NamedAttributeList> newFunctionArgAttrs)
+      : fn(fn), newType(newType),
+        newFunctionArgAttrs(newFunctionArgAttrs.begin(),
+                            newFunctionArgAttrs.end()) {}
+
+  /// The function to convert.
+  Function *fn;
+  /// The new type and argument attributes for the function.
+  FunctionType newType;
+  SmallVector<NamedAttributeList, 4> newFunctionArgAttrs;
+};
+} // end anonymous namespace
+
 // Create a function type with arguments and results converted, and argument
 // attributes passed through.
 FunctionType DialectConversion::convertFunctionSignatureType(
@@ -386,40 +420,35 @@ LogicalResult DialectConversion::convert(Module *module) {
   initConverters(patterns, context);
   RewritePatternMatcher matcher(std::move(patterns));
 
-  SmallVector<Function *, 0> originalFuncs, convertedFuncs;
-  DenseMap<Attribute, FunctionAttr> functionAttrRemapping;
-  originalFuncs.reserve(module->getFunctions().size());
-  for (auto &func : *module)
-    originalFuncs.push_back(&func);
-  convertedFuncs.reserve(originalFuncs.size());
-
-  // Convert each function.
-  FunctionConverter converter(context, this, matcher);
-  for (auto *func : originalFuncs) {
-    Function *converted = converter.convertFunction(func);
-    if (!converted) {
-      // Make sure to erase any previously converted functions.
-      while (!convertedFuncs.empty())
-        convertedFuncs.pop_back_val()->erase();
+  // Try to convert each of the functions within the module. Defer updating the
+  // signatures of the functions until after all of the bodies have been
+  // converted. This allows for the conversion patterns to still rely on the
+  // public signatures of the functions within the module before they are
+  // updated.
+  std::vector<ConvertedFunction> toConvert;
+  toConvert.reserve(module->getFunctions().size());
+  for (auto &func : *module) {
+    // Convert the function type using the dialect converter.
+    SmallVector<NamedAttributeList, 4> newFunctionArgAttrs;
+    FunctionType newType = convertFunctionSignatureType(
+        func.getType(), func.getAllArgAttrs(), newFunctionArgAttrs);
+    if (!newType || !newType.isa<FunctionType>())
+      return func.emitError("could not convert function type");
+
+    // Convert the body of this function.
+    FunctionConverter converter(context, this, matcher);
+    if (failed(converter.convertFunction(&func)))
       return failure();
-    }
 
-    convertedFuncs.push_back(converted);
-    auto origFuncAttr = FunctionAttr::get(func);
-    auto convertedFuncAttr = FunctionAttr::get(converted);
-    functionAttrRemapping.insert({origFuncAttr, convertedFuncAttr});
+    // Add function signature to be updated.
+    toConvert.emplace_back(&func, newType.cast<FunctionType>(),
+                           newFunctionArgAttrs);
   }
 
-  // Remap function attributes in the converted functions. Original functions
-  // will disappear anyway so there is no need to remap attributes in them.
-  for (const auto &funcPair : functionAttrRemapping)
-    remapFunctionAttrs(*funcPair.getSecond().getValue(), functionAttrRemapping);
-
-  // Remove the original functions from the module and update the names of the
-  // converted functions.
-  for (unsigned i = 0, e = originalFuncs.size(); i != e; ++i) {
-    convertedFuncs[i]->takeName(*originalFuncs[i]);
-    originalFuncs[i]->erase();
+  // Finally, update the signatures of all of the converted functions.
+  for (auto &it : toConvert) {
+    it.fn->setType(it.newType);
+    it.fn->setAllArgAttrs(it.newFunctionArgAttrs);
   }
 
   return success();