[mlir][bufferization][NFC] Make getEnclosingRepetitiveRegion public
authorMatthias Springer <springerm@google.com>
Fri, 13 Jan 2023 15:31:01 +0000 (16:31 +0100)
committerMatthias Springer <springerm@google.com>
Fri, 13 Jan 2023 15:39:41 +0000 (16:39 +0100)
These functions are generally useful and not specific to One-Shot Analysis. Move them to `BufferizableOpInterface.h` and make them public.

Differential Revision: https://reviews.llvm.org/D141685

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp

index 24754d5..799aff9 100644 (file)
@@ -522,6 +522,19 @@ getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
 /// owner of the block. In case of an OpResult that is the defining op.
 Operation *getOwnerOfValue(Value value);
 
+/// Return the closest enclosing repetitive region around the given op.
+Region *getEnclosingRepetitiveRegion(Operation *op,
+                                     const BufferizationOptions &options);
+
+/// Return the closest enclosing repetitive region around the place where the
+/// given value is defined.
+Region *getEnclosingRepetitiveRegion(Value value,
+                                     const BufferizationOptions &options);
+
+/// Return the closest enclosing repetitive region around the given block.
+Region *getEnclosingRepetitiveRegion(Block *block,
+                                     const BufferizationOptions &options);
+
 namespace detail {
 /// This is the default implementation of
 /// BufferizableOpInterface::getBufferType. Should not be called from other
index af0d48a..9e7dbf5 100644 (file)
@@ -41,6 +41,39 @@ MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::bufferization::AnalysisState)
 using namespace mlir;
 using namespace bufferization;
 
+Region *bufferization::getEnclosingRepetitiveRegion(
+    Operation *op, const BufferizationOptions &options) {
+  if (!op->getBlock())
+    return nullptr;
+  return getEnclosingRepetitiveRegion(op->getBlock(), options);
+}
+
+Region *bufferization::getEnclosingRepetitiveRegion(
+    Value value, const BufferizationOptions &options) {
+  Region *region = value.getParentRegion();
+  while (region) {
+    Operation *op = region->getParentOp();
+    if (auto bufferizableOp = options.dynCastBufferizableOp(op))
+      if (bufferizableOp.isRepetitiveRegion(region->getRegionNumber()))
+        return region;
+    region = op->getParentRegion();
+  }
+  return nullptr;
+}
+
+Region *bufferization::getEnclosingRepetitiveRegion(
+    Block *block, const BufferizationOptions &options) {
+  Region *region = block->getParent();
+  Operation *op = nullptr;
+  do {
+    op = region->getParentOp();
+    if (auto bufferizableOp = options.dynCastBufferizableOp(op))
+      if (bufferizableOp.isRepetitiveRegion(region->getRegionNumber()))
+        return region;
+  } while ((region = op->getParentRegion()));
+  return nullptr;
+}
+
 Operation *bufferization::getOwnerOfValue(Value value) {
   if (auto opResult = value.dyn_cast<OpResult>())
     return opResult.getDefiningOp();
index cd06899..7dfd626 100644 (file)
@@ -355,31 +355,6 @@ static bool happensBefore(Operation *a, Operation *b,
   return false;
 }
 
-static Region *
-getEnclosingRepetitiveRegion(Operation *op,
-                             const BufferizationOptions &options) {
-  while (Region *region = op->getParentRegion()) {
-    op = region->getParentOp();
-    if (auto bufferizableOp = options.dynCastBufferizableOp(op))
-      if (bufferizableOp.isRepetitiveRegion(region->getRegionNumber()))
-        return region;
-  }
-  return nullptr;
-}
-
-static Region *
-getEnclosingRepetitiveRegion(Value value, const BufferizationOptions &options) {
-  Region *region = value.getParentRegion();
-  while (region) {
-    Operation *op = region->getParentOp();
-    if (auto bufferizableOp = options.dynCastBufferizableOp(op))
-      if (bufferizableOp.isRepetitiveRegion(region->getRegionNumber()))
-        return region;
-    region = op->getParentRegion();
-  }
-  return nullptr;
-}
-
 /// Return `true` if the given tensor value is a memory write. Most values are
 /// tensor writes, but ops that define a tensor SSA value without specifying its
 /// contents (e.g., alloc_tensor) are not.