[flang] Handle fir.class pointer and allocatable in fir.dispatch code gen
authorValentin Clement <clementval@gmail.com>
Fri, 21 Oct 2022 12:34:37 +0000 (14:34 +0200)
committerValentin Clement <clementval@gmail.com>
Fri, 21 Oct 2022 12:35:26 +0000 (14:35 +0200)
fir.dispatch code generation was not handling fir.class pointer and
allocatable types. Update the code generation part to rertieve correctly the
the type info from those types.

Depends on D136426

Reviewed By: jeanPerier

Differential Revision: https://reviews.llvm.org/D136429

flang/lib/Lower/ConvertExpr.cpp
flang/lib/Optimizer/CodeGen/CodeGen.cpp
flang/test/Lower/allocatable-polymorphic.f90

index 8122ed4..403e81b 100644 (file)
@@ -2754,9 +2754,15 @@ public:
         assert(component && "expect component for type-bound procedure call.");
         fir::ExtendedValue pass =
             symMap.lookupSymbol(component->GetFirstSymbol()).toExtendedValue();
+        mlir::Value passObject = fir::getBase(pass);
+        if (fir::isa_ref_type(passObject.getType()))
+          passObject = builder.create<fir::ConvertOp>(
+              loc,
+              passObject.getType().dyn_cast<fir::ReferenceType>().getEleTy(),
+              passObject);
         dispatch = builder.create<fir::DispatchOp>(
             loc, funcType.getResults(), builder.getStringAttr(procName),
-            fir::getBase(pass), operands, nullptr);
+            passObject, operands, nullptr);
       }
       callResult = dispatch.getResult(0);
       callNumResults = dispatch.getNumResults();
index 9f4d112..a4ec625 100644 (file)
@@ -35,6 +35,7 @@
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/TypeSwitch.h"
 
 namespace fir {
 #define GEN_PASS_DEF_FIRTOLLVMLOWERING
@@ -898,15 +899,13 @@ struct DispatchOpConversion : public FIROpConversion<fir::DispatchOp> {
     if (bindingTables.empty())
       return emitError(loc) << "no binding tables found";
 
-    if (dispatch.getObject()
-            .getType()
-            .getEleTy()
-            .isa<fir::HeapType, fir::PointerType>())
-      TODO(loc,
-           "fir.dispatch with allocatable or pointer polymorphic entities");
-
     // Get derived type information.
-    auto declaredType = dispatch.getObject().getType().getEleTy();
+    auto declaredType = llvm::TypeSwitch<mlir::Type, mlir::Type>(
+                            dispatch.getObject().getType().getEleTy())
+                            .Case<fir::PointerType, fir::HeapType>(
+                                [](auto p) { return p.getEleTy(); })
+                            .Default([](mlir::Type t) { return t; });
+
     assert(declaredType.isa<fir::RecordType>() && "expecting fir.type");
     auto recordType = declaredType.dyn_cast<fir::RecordType>();
     std::string typeDescName =
index 8f04348..3883a35 100644 (file)
@@ -5,11 +5,24 @@ module poly
   type p1
     integer :: a
     integer :: b
+  contains
+    procedure, nopass :: proc1 => proc1_p1
   end type
 
   type, extends(p1) :: p2
     integer :: c
+  contains
+      procedure, nopass :: proc1 => proc1_p2
   end type
+
+contains
+  subroutine proc1_p1()
+    print*, 'call proc1_p1'
+  end subroutine
+
+  subroutine proc1_p2()
+    print*, 'call proc1_p2'
+  end subroutine
 end module
 
 program test_allocatable
@@ -27,6 +40,8 @@ program test_allocatable
   allocate(p1::c3(10))
   allocate(p2::c4(20))
 
+  call c1%proc1()
+  call c2%proc1()
 end
 
 ! CHECK-LABEL: func.func @_QQmain()