[mlir][CallOpInterface] Add `setCalleeFromCallable` method
authorWhitney Tsang <whitney.tsang@intel.com>
Mon, 8 May 2023 13:07:10 +0000 (06:07 -0700)
committerWhitney Tsang <whitney.tsang@intel.com>
Mon, 8 May 2023 13:07:10 +0000 (06:07 -0700)
Currently `CallOpInterface` has a method `getCallableForCallee` to have a consistent way to get the callee from an operation with `CallOpInterface`, but missing a consistent way to set a callee for an operation with `CallOpInterface`.

A set callee method is useful for transformations that operate on `CallOpInterface`, and change the callee, e.g., a pass that specialize function, which clone the callee, and change the `CallOpInterface`'s callee to the cloned version. Without such method, transformation would need to understand the implementation for every operations with `CallOpInterface`, and have a type switch to handle them.

This review adds a method to set callee for operation with `CallOpInterface`.

Reviewed By: gysit, zero9178o

Differential Revision: https://reviews.llvm.org/D149763

14 files changed:
flang/include/flang/Optimizer/Dialect/FIROps.td
mlir/docs/Interfaces.md
mlir/docs/Tutorials/Toy/Ch-4.md
mlir/examples/toy/Ch4/mlir/Dialect.cpp
mlir/examples/toy/Ch5/mlir/Dialect.cpp
mlir/examples/toy/Ch6/mlir/Dialect.cpp
mlir/examples/toy/Ch7/mlir/Dialect.cpp
mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
mlir/include/mlir/Dialect/Func/IR/FuncOps.td
mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
mlir/include/mlir/Interfaces/CallInterfaces.td
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
mlir/test/lib/Dialect/Test/TestOps.td

index 2bc4ec0..0e07e6f 100644 (file)
@@ -2357,6 +2357,14 @@ def fir_CallOp : fir_Op<"call",
         return calling;
       return getOperand(0);
     }
+
+    /// Set the callee for this operation.
+    void setCalleeFromCallable(mlir::CallInterfaceCallable callee) {
+      if (auto calling =
+          (*this)->getAttrOfType<mlir::SymbolRefAttr>(getCalleeAttrName()))
+        (*this)->setAttr(getCalleeAttrName(), callee.get<mlir::SymbolRefAttr>());
+      setOperand(0, callee.get<mlir::Value>());
+    }
   }];
 }
 
index b51adec..a299feb 100644 (file)
@@ -728,6 +728,7 @@ interface section goes as follows:
 
 *   `CallOpInterface` - Used to represent operations like 'call'
     -   `CallInterfaceCallable getCallableForCallee()`
+    -   `void setCalleeFromCallable(CallInterfaceCallable)`
 *   `CallableOpInterface` - Used to represent the target callee of call.
     -   `Region * getCallableRegion()`
     -   `ArrayRef<Type> getCallableResults()`
index f462274..9ca9706 100644 (file)
@@ -189,6 +189,12 @@ CallInterfaceCallable GenericCallOp::getCallableForCallee() {
   return getAttrOfType<SymbolRefAttr>("callee");
 }
 
+/// Set the callee for the generic call operation, this is required by the call
+/// interface.
+void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
+  (*this)->setAttr("callee", callee.get<SymbolRefAttr>());
+}
+
 /// Get the argument operands to the called function, this is required by the
 /// call interface.
 Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); }
index 75a5171..d533e58 100644 (file)
@@ -338,6 +338,12 @@ CallInterfaceCallable GenericCallOp::getCallableForCallee() {
   return (*this)->getAttrOfType<SymbolRefAttr>("callee");
 }
 
+/// Set the callee for the generic call operation, this is required by the call
+/// interface.
+void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
+  (*this)->setAttr("callee", callee.get<SymbolRefAttr>());
+}
+
 /// Get the argument operands to the called function, this is required by the
 /// call interface.
 Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); }
index 98c8eb5..4f03266 100644 (file)
@@ -338,6 +338,12 @@ CallInterfaceCallable GenericCallOp::getCallableForCallee() {
   return (*this)->getAttrOfType<SymbolRefAttr>("callee");
 }
 
+/// Set the callee for the generic call operation, this is required by the call
+/// interface.
+void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
+  (*this)->setAttr("callee", callee.get<SymbolRefAttr>());
+}
+
 /// Get the argument operands to the called function, this is required by the
 /// call interface.
 Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); }
index 98c8eb5..4f03266 100644 (file)
@@ -338,6 +338,12 @@ CallInterfaceCallable GenericCallOp::getCallableForCallee() {
   return (*this)->getAttrOfType<SymbolRefAttr>("callee");
 }
 
+/// Set the callee for the generic call operation, this is required by the call
+/// interface.
+void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
+  (*this)->setAttr("callee", callee.get<SymbolRefAttr>());
+}
+
 /// Get the argument operands to the called function, this is required by the
 /// call interface.
 Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); }
index 5fcb0be..6432403 100644 (file)
@@ -367,6 +367,12 @@ CallInterfaceCallable GenericCallOp::getCallableForCallee() {
   return (*this)->getAttrOfType<SymbolRefAttr>("callee");
 }
 
+/// Set the callee for the generic call operation, this is required by the call
+/// interface.
+void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
+  (*this)->setAttr("callee", callee.get<SymbolRefAttr>());
+}
+
 /// Get the argument operands to the called function, this is required by the
 /// call interface.
 Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); }
index 30147b8..9824238 100644 (file)
@@ -271,6 +271,11 @@ def Async_CallOp : Async_Op<"call",
     CallInterfaceCallable getCallableForCallee() {
       return (*this)->getAttrOfType<SymbolRefAttr>("callee");
     }
+
+    /// Set the callee for this operation.
+    void setCalleeFromCallable(CallInterfaceCallable callee) {
+      (*this)->setAttr("callee", callee.get<SymbolRefAttr>());
+    }
   }];
 
   let assemblyFormat = [{
index 4204bc5..fb206f1 100644 (file)
@@ -91,6 +91,11 @@ def CallOp : Func_Op<"call",
     CallInterfaceCallable getCallableForCallee() {
       return (*this)->getAttrOfType<SymbolRefAttr>("callee");
     }
+
+    /// Set the callee for this operation.
+    void setCalleeFromCallable(CallInterfaceCallable callee) {
+      (*this)->setAttr("callee", callee.get<SymbolRefAttr>());
+    }
   }];
 
   let assemblyFormat = [{
@@ -153,6 +158,11 @@ def CallIndirectOp : Func_Op<"call_indirect", [
 
     /// Return the callee of this operation.
     CallInterfaceCallable getCallableForCallee() { return getCallee(); }
+
+    /// Set the callee for this operation.
+    void setCalleeFromCallable(CallInterfaceCallable callee) {
+      setOperand(0, callee.get<Value>());
+    }
   }];
 
   let hasCanonicalizeMethod = 1;
index 53c1c0a..8154835 100644 (file)
@@ -372,6 +372,10 @@ def IncludeOp : TransformDialectOp<"include",
       return getTarget();
     }
 
+    void setCalleeFromCallable(::mlir::CallInterfaceCallable callee) {
+      setTargetAttr(callee.get<SymbolRefAttr>());
+    }
+
     ::mlir::Operation::operand_range getArgOperands() {
       return getOperands();
     }
index cd37222..328b3d5 100644 (file)
@@ -41,6 +41,15 @@ def CallOpInterface : OpInterface<"CallOpInterface"> {
       "::mlir::CallInterfaceCallable", "getCallableForCallee"
     >,
     InterfaceMethod<[{
+        Sets the callee of this call-like operation. A `callee` is either a
+        reference to a symbol, via SymbolRefAttr, or a reference to a defined
+        SSA value. The type of the `callee` is expected to be the same as the
+        return type of `getCallableForCallee`, e.g., `callee` should be
+        SymbolRefAttr for `func.call`.
+      }],
+      "void", "setCalleeFromCallable", (ins "::mlir::CallInterfaceCallable":$callee)
+    >,
+    InterfaceMethod<[{
         Returns the operands within this call that are used as arguments to the
         callee.
       }],
index 9595c18..5380ba0 100644 (file)
@@ -933,6 +933,16 @@ CallInterfaceCallable CallOp::getCallableForCallee() {
   return getOperand(0);
 }
 
+void CallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
+  // Direct call.
+  if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) {
+    auto symRef = callee.get<SymbolRefAttr>();
+    return setCalleeAttr(cast<FlatSymbolRefAttr>(symRef));
+  }
+  // Indirect call, callee Value is the first operand.
+  return setOperand(0, callee.get<Value>());
+}
+
 Operation::operand_range CallOp::getArgOperands() {
   return getOperands().drop_front(getCallee().has_value() ? 0 : 1);
 }
@@ -1157,6 +1167,16 @@ CallInterfaceCallable InvokeOp::getCallableForCallee() {
   return getOperand(0);
 }
 
+void InvokeOp::setCalleeFromCallable(CallInterfaceCallable callee) {
+  // Direct call.
+  if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) {
+    auto symRef = callee.get<SymbolRefAttr>();
+    return setCalleeAttr(cast<FlatSymbolRefAttr>(symRef));
+  }
+  // Indirect call, callee Value is the first operand.
+  return setOperand(0, callee.get<Value>());
+}
+
 Operation::operand_range InvokeOp::getArgOperands() {
   return getOperands().drop_front(getCallee().has_value() ? 0 : 1);
 }
index 181c9e0..2ad2497 100644 (file)
@@ -2576,6 +2576,11 @@ CallInterfaceCallable spirv::FunctionCallOp::getCallableForCallee() {
   return (*this)->getAttrOfType<SymbolRefAttr>(kCallee);
 }
 
+void spirv::FunctionCallOp::setCalleeFromCallable(
+    CallInterfaceCallable callee) {
+  (*this)->setAttr(kCallee, callee.get<SymbolRefAttr>());
+}
+
 Operation::operand_range spirv::FunctionCallOp::getArgOperands() {
   return getArguments();
 }
index 60faf6d..507f4aa 100644 (file)
@@ -495,11 +495,18 @@ def ConversionCallOp : TEST_Op<"conversion_call_op",
   let extraClassDeclaration = [{
     /// Return the callee of this operation.
     ::mlir::CallInterfaceCallable getCallableForCallee();
+
+    /// Set the callee for this operation.
+    void setCalleeFromCallable(::mlir::CallInterfaceCallable);
   }];
   let extraClassDefinition = [{
     ::mlir::CallInterfaceCallable $cppClass::getCallableForCallee() {
       return (*this)->getAttrOfType<::mlir::SymbolRefAttr>("callee");
     }
+
+    void $cppClass::setCalleeFromCallable(::mlir::CallInterfaceCallable callee) {
+      (*this)->setAttr("callee", callee.get<SymbolRefAttr>());
+    }
   }];
 }