Dialect Conversion: convert regions of operations when cloning them
authorAlex Zinenko <zinenko@google.com>
Thu, 28 Mar 2019 22:58:53 +0000 (15:58 -0700)
committerjpienaar <jpienaar@google.com>
Sat, 30 Mar 2019 00:52:04 +0000 (17:52 -0700)
Dialect conversion currently clones the operations that did not match any
pattern.  This includes cloning any regions that belong to these operations.
Instead, apply conversion recursively to the nested regions.

Note that if an operation matched one of the conversion patterns, it is up to
the pattern rewriter to fill in the regions of the converted operation.  This
may require calling back to the converter and is left for future work.

PiperOrigin-RevId: 240872410

mlir/include/mlir/IR/Builders.h
mlir/include/mlir/IR/Operation.h
mlir/lib/IR/Operation.cpp
mlir/lib/Transforms/DialectConversion.cpp

index 65f986b..4cbf4ee 100644 (file)
@@ -282,6 +282,20 @@ public:
     return cloneOp;
   }
 
+  /// Creates a deep copy of this operation but keep the operation regions
+  /// empty. Operands are remapped using `mapper` (if present), and `mapper` is
+  /// updated to contain the results.
+  Operation *cloneWithoutRegions(Operation &op, BlockAndValueMapping &mapper) {
+    Operation *cloneOp = op.cloneWithoutRegions(mapper, getContext());
+    block->getOperations().insert(insertPoint, cloneOp);
+    return cloneOp;
+  }
+  Operation *cloneWithoutRegions(Operation &op) {
+    Operation *cloneOp = op.cloneWithoutRegions(getContext());
+    block->getOperations().insert(insertPoint, cloneOp);
+    return cloneOp;
+  }
+
 private:
   Function *function;
   Block *block = nullptr;
index 2ffb91a..74a21fd 100644 (file)
@@ -90,6 +90,13 @@ public:
   Operation *clone(BlockAndValueMapping &mapper, MLIRContext *context);
   Operation *clone(MLIRContext *context);
 
+  /// Create a deep copy of this operation but keep the operation regions empty.
+  /// Operands are remapped using `mapper` (if present), and `mapper` is updated
+  /// to contain the results.
+  Operation *cloneWithoutRegions(BlockAndValueMapping &mapper,
+                                 MLIRContext *context);
+  Operation *cloneWithoutRegions(MLIRContext *context);
+
   /// Returns the operation block that contains this operation.
   Block *getBlock() { return block; }
 
index 3581038..6d727fe 100644 (file)
@@ -558,13 +558,11 @@ bool Operation::emitOpError(const Twine &message) {
 // Operation Cloning
 //===----------------------------------------------------------------------===//
 
-/// Create a deep copy of this operation, remapping any operands that use
-/// values outside of the operation using the map that is provided (leaving
-/// them alone if no entry is present).  Replaces references to cloned
-/// sub-operations to the corresponding operation that is copied, and adds
-/// those mappings to the map.
-Operation *Operation::clone(BlockAndValueMapping &mapper,
-                            MLIRContext *context) {
+/// Create a deep copy of this operation but keep the operation regions empty.
+/// Operands are remapped using `mapper` (if present), and `mapper` is updated
+/// to contain the results.
+Operation *Operation::cloneWithoutRegions(BlockAndValueMapping &mapper,
+                                          MLIRContext *context) {
   SmallVector<Value *, 8> operands;
   SmallVector<Block *, 2> successors;
 
@@ -607,13 +605,31 @@ Operation *Operation::clone(BlockAndValueMapping &mapper,
                                   attrs, successors, numRegions,
                                   hasResizableOperandsList(), context);
 
+  // Remember the mapping of any results.
+  for (unsigned i = 0, e = getNumResults(); i != e; ++i)
+    mapper.map(getResult(i), newOp->getResult(i));
+
+  return newOp;
+}
+
+Operation *Operation::cloneWithoutRegions(MLIRContext *context) {
+  BlockAndValueMapping mapper;
+  return cloneWithoutRegions(mapper, context);
+}
+
+/// Create a deep copy of this operation, remapping any operands that use
+/// values outside of the operation using the map that is provided (leaving
+/// them alone if no entry is present).  Replaces references to cloned
+/// sub-operations to the corresponding operation that is copied, and adds
+/// those mappings to the map.
+Operation *Operation::clone(BlockAndValueMapping &mapper,
+                            MLIRContext *context) {
+  auto *newOp = cloneWithoutRegions(mapper, context);
+
   // Clone the regions.
   for (unsigned i = 0; i != numRegions; ++i)
     getRegion(i).cloneInto(&newOp->getRegion(i), mapper, context);
 
-  // Remember the mapping of any results.
-  for (unsigned i = 0, e = getNumResults(); i != e; ++i)
-    mapper.map(getResult(i), newOp->getResult(i));
   return newOp;
 }
 
index 2d16f23..49fb937 100644 (file)
@@ -56,6 +56,12 @@ private:
   // `dialectConversion`.  Returns the converted function or `nullptr` on error.
   Function *convertFunction(Function *f);
 
+  // Converts the given region starting from the entry block and following the
+  // block successors.  Returns the converted region or `nullptr` on error.
+  template <typename RegionParent>
+  std::unique_ptr<Region> convertRegion(MLIRContext *context, Region *region,
+                                        RegionParent *parent);
+
   // Converts an operation with successors.  Extracts the converted operands
   // from `valueRemapping` and the converted blocks from `blockRemapping`, and
   // passes them to `converter->rewriteTerminator` function defined in the
@@ -171,11 +177,6 @@ impl::FunctionConversion::convertBlock(Block *block, FuncBuilder &builder,
 
   // Iterate over ops and convert them.
   for (Operation &op : *block) {
-    if (op.getNumRegions() != 0) {
-      op.emitError("unsupported region operation");
-      return failure();
-    }
-
     // Find the first matching conversion and apply it.
     bool converted = false;
     for (auto *conversion : conversions) {
@@ -191,9 +192,15 @@ impl::FunctionConversion::convertBlock(Block *block, FuncBuilder &builder,
       converted = true;
       break;
     }
-    // If there is no conversion provided for the op, clone the op as is.
-    if (!converted)
-      builder.clone(op, mapping);
+    // If there is no conversion provided for the op, clone the op and convert
+    // its regions, if any.
+    if (!converted) {
+      auto *newOp = builder.cloneWithoutRegions(op, mapping);
+      for (int i = 0, e = op.getNumRegions(); i < e; ++i) {
+        auto newRegion = convertRegion(op.getContext(), &op.getRegion(i), &op);
+        newOp->getRegion(i).takeBody(*newRegion);
+      }
+    }
   }
 
   // Recurse to children unless they have been already visited.
@@ -206,6 +213,49 @@ impl::FunctionConversion::convertBlock(Block *block, FuncBuilder &builder,
   return success();
 }
 
+template <typename RegionParent>
+std::unique_ptr<Region>
+impl::FunctionConversion::convertRegion(MLIRContext *context, Region *region,
+                                        RegionParent *parent) {
+  assert(region && "expected a region");
+  auto newRegion = llvm::make_unique<Region>(parent);
+  if (region->empty())
+    return newRegion;
+
+  auto emitError = [context](llvm::Twine f) -> std::unique_ptr<Region> {
+    context->emitError(UnknownLoc::get(context), f.str());
+    return nullptr;
+  };
+
+  // Create new blocks and convert their arguments.
+  for (Block &block : *region) {
+    auto *newBlock = new Block;
+    newRegion->push_back(newBlock);
+    mapping.map(&block, newBlock);
+    for (auto *arg : block.getArguments()) {
+      auto convertedType = dialectConversion->convertType(arg->getType());
+      if (!convertedType)
+        return emitError("could not convert block argument type");
+      newBlock->addArgument(convertedType);
+      mapping.map(arg, *newBlock->args_rbegin());
+    }
+  }
+
+  // Start a DFS-order traversal of the CFG to make sure defs are converted
+  // before uses in dominated blocks.
+  llvm::DenseSet<Block *> visitedBlocks;
+  FuncBuilder builder(&newRegion->front());
+  if (failed(convertBlock(&region->front(), builder, visitedBlocks)))
+    return nullptr;
+
+  // If some blocks are not reachable through successor chains, they should have
+  // been removed by the DCE before this.
+
+  if (visitedBlocks.size() != std::distance(region->begin(), region->end()))
+    return emitError("unreachable blocks were not converted");
+  return newRegion;
+}
+
 Function *impl::FunctionConversion::convertFunction(Function *f) {
   assert(f && "expected function");
   MLIRContext *context = f->getContext();
@@ -229,30 +279,10 @@ Function *impl::FunctionConversion::convertFunction(Function *f) {
   if (f->getBlocks().empty())
     return newFunction.release();
 
-  // Create blocks in the new function and convert types of their arguments.
-  FuncBuilder builder(newFunction.get());
-  for (Block &block : *f) {
-    auto *newBlock = builder.createBlock();
-    mapping.map(&block, newBlock);
-    for (auto *arg : block.getArguments()) {
-      auto convertedType = dialectConversion->convertType(arg->getType());
-      if (!convertedType)
-        return emitError("could not convert block argument type");
-      newBlock->addArgument(convertedType);
-      mapping.map(arg, *newBlock->args_rbegin());
-    }
-  }
-
-  // Start a DFS-order traversal of the CFG to make sure defs are converted
-  // before uses in dominated blocks.
-  llvm::DenseSet<Block *> visitedBlocks;
-  if (failed(convertBlock(&f->front(), builder, visitedBlocks)))
-    return nullptr;
-
-  // If some blocks are not reachable through successor chains, they should have
-  // been removed by the DCE before this.
-  if (visitedBlocks.size() != f->getBlocks().size())
-    return emitError("unreachable blocks were not converted");
+  auto newBody = convertRegion(context, &f->getBody(), f);
+  if (!newBody)
+    return emitError("could not convert function body");
+  newFunction->getBody().takeBody(*newBody);
 
   return newFunction.release();
 }