[mlir][bufferization] Restrict function boundary buffer. to func.call.
authorIngo Müller <ingomueller@google.com>
Fri, 10 Feb 2023 11:31:35 +0000 (11:31 +0000)
committerIngo Müller <ingomueller@google.com>
Fri, 10 Feb 2023 11:59:06 +0000 (11:59 +0000)
The current bufferization on function boundaries works on `func.func`
and any call op implementing `CallOpInterface`. Then, an error is thrown
if there is a `CallOpInterface` op that is not `func.call`. This is
unnecessary and breaks the pass whenever such an op occurs (such as
`llvm.call`). This PR simply restricts the handling of call ops to
`func.call`.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D143724

mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir
mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir

index 9562ac5..c96a507 100644 (file)
@@ -237,7 +237,7 @@ static void removeBufferizationAttributes(BlockArgument bbArg) {
 }
 
 /// Return the func::FuncOp called by `callOp`.
-static func::FuncOp getCalledFunction(CallOpInterface callOp) {
+static func::FuncOp getCalledFunction(func::CallOp callOp) {
   SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
   if (!sym)
     return nullptr;
@@ -278,15 +278,15 @@ static void equivalenceAnalysis(func::FuncOp funcOp,
 /// callee-caller order (i.e. callees without callers first).
 /// Store the map of FuncOp to all its callers in `callerMap`.
 /// Return `failure()` if a cycle of calls is detected or if we are unable to
-/// retrieve the called FuncOp from any CallOpInterface.
+/// retrieve the called FuncOp from any func::CallOp.
 static LogicalResult
 getFuncOpsOrderedByCalls(ModuleOp moduleOp,
                          SmallVectorImpl<func::FuncOp> &orderedFuncOps,
                          FuncCallerMap &callerMap) {
   // For each FuncOp, the set of functions called by it (i.e. the union of
-  // symbols of all nested CallOpInterfaceOp).
+  // symbols of all nested func::CallOp).
   DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
-  // For each FuncOp, the number of CallOpInterface it contains.
+  // For each FuncOp, the number of func::CallOp it contains.
   DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
   WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult {
     if (!funcOp.getBody().empty()) {
@@ -298,10 +298,7 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
     }
 
     numberCallOpsContainedInFuncOp[funcOp] = 0;
-    return funcOp.walk([&](CallOpInterface callOp) -> WalkResult {
-      // Only support CallOp for now.
-      if (!isa<func::CallOp>(callOp.getOperation()))
-        return callOp->emitError() << "expected a CallOp";
+    return funcOp.walk([&](func::CallOp callOp) -> WalkResult {
       func::FuncOp calledFunction = getCalledFunction(callOp);
       assert(calledFunction && "could not retrieved called func::FuncOp");
       callerMap[calledFunction].insert(callOp);
index 2c0c8d7..759f4f3 100644 (file)
@@ -1,16 +1,5 @@
 // RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="bufferize-function-boundaries=1" -split-input-file -verify-diagnostics
 
-func.func private @foo() -> tensor<?xf32>
-
-func.func @bar() -> tensor<?xf32> {
-  %foo = constant @foo : () -> (tensor<?xf32>)
-// expected-error @+1 {{expected a CallOp}}
-  %res = call_indirect %foo() : () -> (tensor<?xf32>)
-  return %res : tensor<?xf32>
-}
-
-// -----
-
 // expected-error @+2 {{cannot bufferize bodiless function that returns a tensor}}
 // expected-error @+1 {{failed to bufferize op}}
 func.func private @foo() -> tensor<?xf32>
index 1980991..4103a4c 100644 (file)
@@ -625,3 +625,14 @@ func.func @main() {
 
 // This function may write to buffer(%ptr).
 func.func private @maybe_writing_func(%ptr : tensor<*xf32>)
+
+// -----
+
+// Test if other callables are left intact and don't cause trouble.
+
+llvm.func @llvm_func()
+
+func.func @call_llvm_func() {
+  llvm.call @llvm_func() : () -> ()
+  return
+}