[flang] Hanlde disptach op in abstract result pass
authorValentin Clement <clementval@gmail.com>
Mon, 28 Nov 2022 15:49:25 +0000 (16:49 +0100)
committerValentin Clement <clementval@gmail.com>
Mon, 28 Nov 2022 15:49:51 +0000 (16:49 +0100)
Update the call conversion pattern to support fir.dispatch
operation as well. The first operand of fir.dispatch op is always the
polymoprhic object. The pass_arg_pos attribute needs to be shifted when
the result is added as argument.

Reviewed By: jeanPerier

Differential Revision: https://reviews.llvm.org/D138799

flang/lib/Optimizer/Dialect/FIRType.cpp
flang/lib/Optimizer/Transforms/AbstractResult.cpp
flang/test/Fir/abstract-results.fir

index eb9b9af..89a806c 100644 (file)
@@ -960,7 +960,7 @@ bool fir::hasAbstractResult(mlir::FunctionType ty) {
   if (ty.getNumResults() == 0)
     return false;
   auto resultType = ty.getResult(0);
-  return resultType.isa<fir::SequenceType, fir::BoxType, fir::RecordType>();
+  return resultType.isa<fir::SequenceType, fir::BaseBoxType, fir::RecordType>();
 }
 
 /// Convert llvm::Type::TypeID to mlir::Type. \p kind is provided for error
index dcc6e90..df00c17 100644 (file)
@@ -28,6 +28,8 @@ namespace fir {
 
 #define DEBUG_TYPE "flang-abstract-result-opt"
 
+using namespace mlir;
+
 namespace fir {
 namespace {
 
@@ -40,7 +42,7 @@ static mlir::Type getResultArgumentType(mlir::Type resultType,
               return fir::BoxType::get(type);
             return fir::ReferenceType::get(type);
           })
-      .Case<fir::BoxType>([](mlir::Type type) -> mlir::Type {
+      .Case<fir::BaseBoxType>([](mlir::Type type) -> mlir::Type {
         return fir::ReferenceType::get(type);
       })
       .Default([](mlir::Type) -> mlir::Type {
@@ -75,16 +77,18 @@ static bool mustEmboxResult(mlir::Type resultType, bool shouldBoxResult) {
          shouldBoxResult;
 }
 
-class CallOpConversion : public mlir::OpRewritePattern<fir::CallOp> {
+template <typename Op>
+class CallConversion : public mlir::OpRewritePattern<Op> {
 public:
-  using OpRewritePattern::OpRewritePattern;
-  CallOpConversion(mlir::MLIRContext *context, bool shouldBoxResult)
-      : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {}
+  using mlir::OpRewritePattern<Op>::OpRewritePattern;
+
+  CallConversion(mlir::MLIRContext *context, bool shouldBoxResult)
+      : OpRewritePattern<Op>(context, 1), shouldBoxResult{shouldBoxResult} {}
+
   mlir::LogicalResult
-  matchAndRewrite(fir::CallOp callOp,
-                  mlir::PatternRewriter &rewriter) const override {
-    auto loc = callOp.getLoc();
-    auto result = callOp->getResult(0);
+  matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto result = op->getResult(0);
     if (!result.hasOneUse()) {
       mlir::emitError(loc,
                       "calls with abstract result must have exactly one user");
@@ -109,50 +113,74 @@ public:
     // TODO: This should be generalized for derived types, and it is
     // architecture and OS dependent.
     bool isResultBuiltinCPtr = fir::isa_builtin_cptr_type(result.getType());
-    fir::CallOp newCallOp;
+    Op newOp;
     if (isResultBuiltinCPtr) {
-      auto recTy = result.getType().dyn_cast<fir::RecordType>();
+      auto recTy = result.getType().template dyn_cast<fir::RecordType>();
       newResultTypes.emplace_back(recTy.getTypeList()[0].second);
     }
-    if (callOp.getCallee()) {
+
+    // fir::CallOp specific handling.
+    if constexpr (std::is_same_v<Op, fir::CallOp>) {
+      if (op.getCallee()) {
+        llvm::SmallVector<mlir::Value> newOperands;
+        if (!isResultBuiltinCPtr)
+          newOperands.emplace_back(arg);
+        newOperands.append(op.getOperands().begin(), op.getOperands().end());
+        newOp = rewriter.create<fir::CallOp>(loc, *op.getCallee(),
+                                             newResultTypes, newOperands);
+      } else {
+        // Indirect calls.
+        llvm::SmallVector<mlir::Type> newInputTypes;
+        if (!isResultBuiltinCPtr)
+          newInputTypes.emplace_back(argType);
+        for (auto operand : op.getOperands().drop_front())
+          newInputTypes.push_back(operand.getType());
+        auto newFuncTy = mlir::FunctionType::get(op.getContext(), newInputTypes,
+                                                 newResultTypes);
+
+        llvm::SmallVector<mlir::Value> newOperands;
+        newOperands.push_back(
+            rewriter.create<fir::ConvertOp>(loc, newFuncTy, op.getOperand(0)));
+        if (!isResultBuiltinCPtr)
+          newOperands.push_back(arg);
+        newOperands.append(op.getOperands().begin() + 1,
+                           op.getOperands().end());
+        newOp = rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{},
+                                             newResultTypes, newOperands);
+      }
+    }
+
+    // fir::DispatchOp specific handling.
+    if constexpr (std::is_same_v<Op, fir::DispatchOp>) {
       llvm::SmallVector<mlir::Value> newOperands;
       if (!isResultBuiltinCPtr)
         newOperands.emplace_back(arg);
-      newOperands.append(callOp.getOperands().begin(),
-                         callOp.getOperands().end());
-      newCallOp = rewriter.create<fir::CallOp>(loc, *callOp.getCallee(),
-                                               newResultTypes, newOperands);
-    } else {
-      // Indirect calls.
-      llvm::SmallVector<mlir::Type> newInputTypes;
-      if (!isResultBuiltinCPtr)
-        newInputTypes.emplace_back(argType);
-      for (auto operand : callOp.getOperands().drop_front())
-        newInputTypes.push_back(operand.getType());
-      auto newFuncTy = mlir::FunctionType::get(callOp.getContext(),
-                                               newInputTypes, newResultTypes);
+      unsigned passArgShift = newOperands.size();
+      newOperands.append(op.getOperands().begin() + 1, op.getOperands().end());
 
-      llvm::SmallVector<mlir::Value> newOperands;
-      newOperands.push_back(rewriter.create<fir::ConvertOp>(
-          loc, newFuncTy, callOp.getOperand(0)));
-      if (!isResultBuiltinCPtr)
-        newOperands.push_back(arg);
-      newOperands.append(callOp.getOperands().begin() + 1,
-                         callOp.getOperands().end());
-      newCallOp = rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{},
-                                               newResultTypes, newOperands);
+      fir::DispatchOp newDispatchOp;
+      if (op.getPassArgPos())
+        newOp = rewriter.create<fir::DispatchOp>(
+            loc, newResultTypes, rewriter.getStringAttr(op.getMethod()),
+            op.getOperands()[0], newOperands,
+            rewriter.getI32IntegerAttr(*op.getPassArgPos() + passArgShift));
+      else
+        newOp = rewriter.create<fir::DispatchOp>(
+            loc, newResultTypes, rewriter.getStringAttr(op.getMethod()),
+            op.getOperands()[0], newOperands, nullptr);
     }
+
     if (isResultBuiltinCPtr) {
       mlir::Value save = saveResult.getMemref();
-      auto module = callOp->getParentOfType<mlir::ModuleOp>();
+      auto module = op->template getParentOfType<mlir::ModuleOp>();
       fir::KindMapping kindMap = fir::getKindMapping(module);
       FirOpBuilder builder(rewriter, kindMap);
       mlir::Value saveAddr = fir::factory::genCPtrOrCFunptrAddr(
           builder, loc, save, result.getType());
-      rewriter.create<fir::StoreOp>(loc, newCallOp->getResult(0), saveAddr);
+      rewriter.create<fir::StoreOp>(loc, newOp->getResult(0), saveAddr);
     }
-    callOp->dropAllReferences();
-    rewriter.eraseOp(callOp);
+    op->dropAllReferences();
+    rewriter.eraseOp(op);
     return mlir::success();
   }
 
@@ -289,17 +317,11 @@ public:
       return true;
     });
     target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
-      if (dispatch->getNumResults() != 1)
-        return true;
-      auto resultType = dispatch->getResult(0).getType();
-      if (resultType.isa<fir::SequenceType, fir::BoxType, fir::RecordType>()) {
-        TODO(dispatch.getLoc(), "dispatchOp with abstract results");
-        return false;
-      }
-      return true;
+      return !hasAbstractResult(dispatch.getFunctionType());
     });
 
-    patterns.insert<CallOpConversion>(context, shouldBoxResult);
+    patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult);
+    patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult);
     patterns.insert<SaveResultOpConversion>(context);
     patterns.insert<AddrOfOpConversion>(context, shouldBoxResult);
     if (mlir::failed(
index 14c59a6..374c0d1 100644 (file)
@@ -244,6 +244,25 @@ func.func @_QPtest_return_cptr() {
   // FUNC-BOX: fir.store %[[VAL]] to %[[ADDR]] : !fir.ref<i64>
 }
 
+// FUNC-REF-LABEL: func @dispatch(
+// FUNC-REF-SAME:    %[[ARG0:.*]]: !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>> {fir.bindc_name = "a"}
+// FUNC-BOX-LABEL: func @dispatch(
+// FUNC-BOX-SAME:    %[[ARG0:.*]]: !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>> {fir.bindc_name = "a"}
+func.func @dispatch(%arg0: !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>> {fir.bindc_name = "a"}) {
+  %buffer = fir.alloca !fir.type<t{x:f32}>
+  %res = fir.dispatch "ret_array"(%arg0 : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) (%arg0 : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) -> !fir.type<t{x:f32}> {pass_arg_pos = 0 : i32}
+  fir.save_result %res to %buffer : !fir.type<t{x:f32}>, !fir.ref<!fir.type<t{x:f32}>>
+  return
+  // FUNC-REF: %[[buffer:.*]] = fir.alloca !fir.type<t{x:f32}>
+  // FUNC-REF: fir.dispatch "ret_array"(%[[ARG0]] : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) (%[[buffer]], %[[ARG0]] : !fir.ref<!fir.type<t{x:f32}>>, !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) {pass_arg_pos = 1 : i32}
+  // FUNC-REF-NOT: fir.save_result
+
+  // FUNC-BOX: %[[buffer:.*]] = fir.alloca !fir.type<t{x:f32}>
+  // FUNC-BOX: %[[box:.*]] = fir.embox %[[buffer]] : (!fir.ref<!fir.type<t{x:f32}>>) -> !fir.box<!fir.type<t{x:f32}>>
+  // FUNC-BOX: fir.dispatch "ret_array"(%[[ARG0]] : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) (%[[box]], %[[ARG0]] : !fir.box<!fir.type<t{x:f32}>>, !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) {pass_arg_pos = 1 : i32}
+  // FUNC-BOX-NOT: fir.save_result
+}
+
 // ------------------------ Test fir.address_of rewrite ------------------------
 
 func.func private @takesfuncarray((i32) -> !fir.array<?xf32>)