#define DEBUG_TYPE "flang-abstract-result-opt"
+using namespace mlir;
+
namespace fir {
namespace {
return fir::BoxType::get(type);
return fir::ReferenceType::get(type);
})
- .Case<fir::BoxType>([](mlir::Type type) -> mlir::Type {
+ .Case<fir::BaseBoxType>([](mlir::Type type) -> mlir::Type {
return fir::ReferenceType::get(type);
})
.Default([](mlir::Type) -> mlir::Type {
shouldBoxResult;
}
-class CallOpConversion : public mlir::OpRewritePattern<fir::CallOp> {
+template <typename Op>
+class CallConversion : public mlir::OpRewritePattern<Op> {
public:
- using OpRewritePattern::OpRewritePattern;
- CallOpConversion(mlir::MLIRContext *context, bool shouldBoxResult)
- : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {}
+ using mlir::OpRewritePattern<Op>::OpRewritePattern;
+
+ CallConversion(mlir::MLIRContext *context, bool shouldBoxResult)
+ : OpRewritePattern<Op>(context, 1), shouldBoxResult{shouldBoxResult} {}
+
mlir::LogicalResult
- matchAndRewrite(fir::CallOp callOp,
- mlir::PatternRewriter &rewriter) const override {
- auto loc = callOp.getLoc();
- auto result = callOp->getResult(0);
+ matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto result = op->getResult(0);
if (!result.hasOneUse()) {
mlir::emitError(loc,
"calls with abstract result must have exactly one user");
// TODO: This should be generalized for derived types, and it is
// architecture and OS dependent.
bool isResultBuiltinCPtr = fir::isa_builtin_cptr_type(result.getType());
- fir::CallOp newCallOp;
+ Op newOp;
if (isResultBuiltinCPtr) {
- auto recTy = result.getType().dyn_cast<fir::RecordType>();
+ auto recTy = result.getType().template dyn_cast<fir::RecordType>();
newResultTypes.emplace_back(recTy.getTypeList()[0].second);
}
- if (callOp.getCallee()) {
+
+ // fir::CallOp specific handling.
+ if constexpr (std::is_same_v<Op, fir::CallOp>) {
+ if (op.getCallee()) {
+ llvm::SmallVector<mlir::Value> newOperands;
+ if (!isResultBuiltinCPtr)
+ newOperands.emplace_back(arg);
+ newOperands.append(op.getOperands().begin(), op.getOperands().end());
+ newOp = rewriter.create<fir::CallOp>(loc, *op.getCallee(),
+ newResultTypes, newOperands);
+ } else {
+ // Indirect calls.
+ llvm::SmallVector<mlir::Type> newInputTypes;
+ if (!isResultBuiltinCPtr)
+ newInputTypes.emplace_back(argType);
+ for (auto operand : op.getOperands().drop_front())
+ newInputTypes.push_back(operand.getType());
+ auto newFuncTy = mlir::FunctionType::get(op.getContext(), newInputTypes,
+ newResultTypes);
+
+ llvm::SmallVector<mlir::Value> newOperands;
+ newOperands.push_back(
+ rewriter.create<fir::ConvertOp>(loc, newFuncTy, op.getOperand(0)));
+ if (!isResultBuiltinCPtr)
+ newOperands.push_back(arg);
+ newOperands.append(op.getOperands().begin() + 1,
+ op.getOperands().end());
+ newOp = rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{},
+ newResultTypes, newOperands);
+ }
+ }
+
+ // fir::DispatchOp specific handling.
+ if constexpr (std::is_same_v<Op, fir::DispatchOp>) {
llvm::SmallVector<mlir::Value> newOperands;
if (!isResultBuiltinCPtr)
newOperands.emplace_back(arg);
- newOperands.append(callOp.getOperands().begin(),
- callOp.getOperands().end());
- newCallOp = rewriter.create<fir::CallOp>(loc, *callOp.getCallee(),
- newResultTypes, newOperands);
- } else {
- // Indirect calls.
- llvm::SmallVector<mlir::Type> newInputTypes;
- if (!isResultBuiltinCPtr)
- newInputTypes.emplace_back(argType);
- for (auto operand : callOp.getOperands().drop_front())
- newInputTypes.push_back(operand.getType());
- auto newFuncTy = mlir::FunctionType::get(callOp.getContext(),
- newInputTypes, newResultTypes);
+ unsigned passArgShift = newOperands.size();
+ newOperands.append(op.getOperands().begin() + 1, op.getOperands().end());
- llvm::SmallVector<mlir::Value> newOperands;
- newOperands.push_back(rewriter.create<fir::ConvertOp>(
- loc, newFuncTy, callOp.getOperand(0)));
- if (!isResultBuiltinCPtr)
- newOperands.push_back(arg);
- newOperands.append(callOp.getOperands().begin() + 1,
- callOp.getOperands().end());
- newCallOp = rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{},
- newResultTypes, newOperands);
+ fir::DispatchOp newDispatchOp;
+ if (op.getPassArgPos())
+ newOp = rewriter.create<fir::DispatchOp>(
+ loc, newResultTypes, rewriter.getStringAttr(op.getMethod()),
+ op.getOperands()[0], newOperands,
+ rewriter.getI32IntegerAttr(*op.getPassArgPos() + passArgShift));
+ else
+ newOp = rewriter.create<fir::DispatchOp>(
+ loc, newResultTypes, rewriter.getStringAttr(op.getMethod()),
+ op.getOperands()[0], newOperands, nullptr);
}
+
if (isResultBuiltinCPtr) {
mlir::Value save = saveResult.getMemref();
- auto module = callOp->getParentOfType<mlir::ModuleOp>();
+ auto module = op->template getParentOfType<mlir::ModuleOp>();
fir::KindMapping kindMap = fir::getKindMapping(module);
FirOpBuilder builder(rewriter, kindMap);
mlir::Value saveAddr = fir::factory::genCPtrOrCFunptrAddr(
builder, loc, save, result.getType());
- rewriter.create<fir::StoreOp>(loc, newCallOp->getResult(0), saveAddr);
+ rewriter.create<fir::StoreOp>(loc, newOp->getResult(0), saveAddr);
}
- callOp->dropAllReferences();
- rewriter.eraseOp(callOp);
+ op->dropAllReferences();
+ rewriter.eraseOp(op);
return mlir::success();
}
return true;
});
target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
- if (dispatch->getNumResults() != 1)
- return true;
- auto resultType = dispatch->getResult(0).getType();
- if (resultType.isa<fir::SequenceType, fir::BoxType, fir::RecordType>()) {
- TODO(dispatch.getLoc(), "dispatchOp with abstract results");
- return false;
- }
- return true;
+ return !hasAbstractResult(dispatch.getFunctionType());
});
- patterns.insert<CallOpConversion>(context, shouldBoxResult);
+ patterns.insert<CallConversion<fir::CallOp>>(context, shouldBoxResult);
+ patterns.insert<CallConversion<fir::DispatchOp>>(context, shouldBoxResult);
patterns.insert<SaveResultOpConversion>(context);
patterns.insert<AddrOfOpConversion>(context, shouldBoxResult);
if (mlir::failed(
// FUNC-BOX: fir.store %[[VAL]] to %[[ADDR]] : !fir.ref<i64>
}
+// FUNC-REF-LABEL: func @dispatch(
+// FUNC-REF-SAME: %[[ARG0:.*]]: !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>> {fir.bindc_name = "a"}
+// FUNC-BOX-LABEL: func @dispatch(
+// FUNC-BOX-SAME: %[[ARG0:.*]]: !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>> {fir.bindc_name = "a"}
+func.func @dispatch(%arg0: !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>> {fir.bindc_name = "a"}) {
+ %buffer = fir.alloca !fir.type<t{x:f32}>
+ %res = fir.dispatch "ret_array"(%arg0 : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) (%arg0 : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) -> !fir.type<t{x:f32}> {pass_arg_pos = 0 : i32}
+ fir.save_result %res to %buffer : !fir.type<t{x:f32}>, !fir.ref<!fir.type<t{x:f32}>>
+ return
+ // FUNC-REF: %[[buffer:.*]] = fir.alloca !fir.type<t{x:f32}>
+ // FUNC-REF: fir.dispatch "ret_array"(%[[ARG0]] : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) (%[[buffer]], %[[ARG0]] : !fir.ref<!fir.type<t{x:f32}>>, !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) {pass_arg_pos = 1 : i32}
+ // FUNC-REF-NOT: fir.save_result
+
+ // FUNC-BOX: %[[buffer:.*]] = fir.alloca !fir.type<t{x:f32}>
+ // FUNC-BOX: %[[box:.*]] = fir.embox %[[buffer]] : (!fir.ref<!fir.type<t{x:f32}>>) -> !fir.box<!fir.type<t{x:f32}>>
+ // FUNC-BOX: fir.dispatch "ret_array"(%[[ARG0]] : !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) (%[[box]], %[[ARG0]] : !fir.box<!fir.type<t{x:f32}>>, !fir.class<!fir.type<_QMpolymorphic_testTp1{a:i32,b:i32}>>) {pass_arg_pos = 1 : i32}
+ // FUNC-BOX-NOT: fir.save_result
+}
+
// ------------------------ Test fir.address_of rewrite ------------------------
func.func private @takesfuncarray((i32) -> !fir.array<?xf32>)