[flang] Allow scalar boxed record type in intrinsic elemental lowering
authorValentin Clement <clementval@gmail.com>
Wed, 1 Mar 2023 14:41:56 +0000 (15:41 +0100)
committerValentin Clement <clementval@gmail.com>
Wed, 1 Mar 2023 14:42:25 +0000 (15:42 +0100)
Relax a bit the condition added in D144417 and allow scalar polymorphic entities
and boxed scalar record type.

Reviewed By: jeanPerier

Differential Revision: https://reviews.llvm.org/D145058

flang/include/flang/Optimizer/Dialect/FIRType.h
flang/lib/Optimizer/Builder/IntrinsicCall.cpp
flang/lib/Optimizer/Dialect/FIRType.cpp
flang/test/Lower/polymorphic-temp.f90
flang/unittests/Optimizer/FIRTypesTest.cpp

index 9d88445..3a0254a 100644 (file)
@@ -283,6 +283,12 @@ bool isBoxNone(mlir::Type ty);
 /// e.g. !fir.box<!fir.type<derived>>
 bool isBoxedRecordType(mlir::Type ty);
 
+/// Return true iff `ty` is a scalar boxed record type.
+/// e.g. !fir.box<!fir.type<derived>>
+///      !fir.box<!fir.heap<!fir.type<derived>>>
+///      !fir.class<!fir.type<derived>>
+bool isScalarBoxedRecordType(mlir::Type ty);
+
 /// Return the nested RecordType if one if found. Return ty otherwise.
 mlir::Type getDerivedType(mlir::Type ty);
 
index 0da45e8..5227fdf 100644 (file)
@@ -1708,7 +1708,7 @@ IntrinsicLibrary::genElementalCall<IntrinsicLibrary::ExtendedGenerator>(
   for (const fir::ExtendedValue &arg : args) {
     auto *box = arg.getBoxOf<fir::BoxValue>();
     if (!arg.getUnboxed() && !arg.getCharBox() &&
-        !(box && fir::isPolymorphicType(fir::getBase(*box).getType())))
+        !(box && fir::isScalarBoxedRecordType(fir::getBase(*box).getType())))
       fir::emitFatalError(loc, "nonscalar intrinsic argument");
   }
   if (outline)
index dac0e65..decb93f 100644 (file)
@@ -290,6 +290,20 @@ bool isBoxedRecordType(mlir::Type ty) {
   return false;
 }
 
+bool isScalarBoxedRecordType(mlir::Type ty) {
+  if (auto refTy = fir::dyn_cast_ptrEleTy(ty))
+    ty = refTy;
+  if (auto boxTy = ty.dyn_cast<fir::BaseBoxType>()) {
+    if (boxTy.getEleTy().isa<fir::RecordType>())
+      return true;
+    if (auto heapTy = boxTy.getEleTy().dyn_cast<fir::HeapType>())
+      return heapTy.getEleTy().isa<fir::RecordType>();
+    if (auto ptrTy = boxTy.getEleTy().dyn_cast<fir::PointerType>())
+      return ptrTy.getEleTy().isa<fir::RecordType>();
+  }
+  return false;
+}
+
 static bool isAssumedType(mlir::Type ty) {
   if (auto boxTy = ty.dyn_cast<fir::BoxType>()) {
     if (boxTy.getEleTy().isa<mlir::NoneType>())
index f8627ef..5dfd36a 100644 (file)
@@ -207,4 +207,23 @@ contains
 ! CHECK: %[[SELECT:.*]] = arith.select %[[CMPI]], %[[ARG0]], %[[ARG1]] : !fir.class<!fir.type<_QMpoly_tmpTp1{a:i32}>>
 ! CHECK: fir.call @_QMpoly_tmpPcheck_scalar(%[[SELECT]]) {{.*}} : (!fir.class<!fir.type<_QMpoly_tmpTp1{a:i32}>>) -> ()
 
+  subroutine test_merge_intrinsic2(a, b, i)
+    class(p1), allocatable, intent(in) :: a
+    type(p1), allocatable :: b
+    integer, intent(in) :: i
+
+    call check_scalar(merge(a, b, i==1))
+  end subroutine
+
+
+! CHECK-LABEL: func.func @_QMpoly_tmpPtest_merge_intrinsic2(
+! CHECK-SAME: %[[A:.*]]: !fir.ref<!fir.class<!fir.heap<!fir.type<_QMpoly_tmpTp1{a:i32}>>>> {fir.bindc_name = "a"}, %[[B:.*]]: !fir.ref<!fir.box<!fir.heap<!fir.type<_QMpoly_tmpTp1{a:i32}>>>> {fir.bindc_name = "b"}, %[[I:.*]]: !fir.ref<i32> {fir.bindc_name = "i"}) {
+! CHECK: %[[LOAD_A:.*]] = fir.load %[[A]] : !fir.ref<!fir.class<!fir.heap<!fir.type<_QMpoly_tmpTp1{a:i32}>>>>
+! CHECK: %[[LOAD_B:.*]] = fir.load %[[B]] : !fir.ref<!fir.box<!fir.heap<!fir.type<_QMpoly_tmpTp1{a:i32}>>>>
+! CHECK: %[[LOAD_I:.*]] = fir.load %[[I]] : !fir.ref<i32>
+! CHECK: %[[C1:.*]] = arith.constant 1 : i32
+! CHECK: %[[CMPI:.*]] = arith.cmpi eq, %[[LOAD_I]], %[[C1]] : i32
+! CHECK: %[[B_CONV:.*]] = fir.convert %[[LOAD_B]] : (!fir.box<!fir.heap<!fir.type<_QMpoly_tmpTp1{a:i32}>>>) -> !fir.class<!fir.heap<!fir.type<_QMpoly_tmpTp1{a:i32}>>>
+! CHECK: %{{.*}} = arith.select %[[CMPI]], %[[LOAD_A]], %[[B_CONV]] : !fir.class<!fir.heap<!fir.type<_QMpoly_tmpTp1{a:i32}>>>
+
 end module
index e30800a..41588e2 100644 (file)
@@ -147,6 +147,43 @@ TEST_F(FIRTypesTest, isBoxedRecordType) {
       fir::ReferenceType::get(mlir::IntegerType::get(&context, 32)))));
 }
 
+// Test fir::isScalarBoxedRecordType from flang/Optimizer/Dialect/FIRType.h.
+TEST_F(FIRTypesTest, isScalarBoxedRecordType) {
+  mlir::Type recTy = fir::RecordType::get(&context, "dt");
+  mlir::Type seqRecTy =
+      fir::SequenceType::get({fir::SequenceType::getUnknownExtent()}, recTy);
+  mlir::Type ty = fir::BoxType::get(recTy);
+  EXPECT_TRUE(fir::isScalarBoxedRecordType(ty));
+  EXPECT_TRUE(fir::isScalarBoxedRecordType(fir::ReferenceType::get(ty)));
+
+  // CLASS(T), ALLOCATABLE
+  ty = fir::ClassType::get(fir::HeapType::get(recTy));
+  EXPECT_TRUE(fir::isScalarBoxedRecordType(ty));
+
+  // TYPE(T), ALLOCATABLE
+  ty = fir::BoxType::get(fir::HeapType::get(recTy));
+  EXPECT_TRUE(fir::isScalarBoxedRecordType(ty));
+
+  // TYPE(T), POINTER
+  ty = fir::BoxType::get(fir::PointerType::get(recTy));
+  EXPECT_TRUE(fir::isScalarBoxedRecordType(ty));
+
+  // CLASS(T), POINTER
+  ty = fir::ClassType::get(fir::PointerType::get(recTy));
+  EXPECT_TRUE(fir::isScalarBoxedRecordType(ty));
+
+  // TYPE(T), DIMENSION(10)
+  ty = fir::BoxType::get(fir::SequenceType::get({10}, recTy));
+  EXPECT_FALSE(fir::isScalarBoxedRecordType(ty));
+
+  // TYPE(T), DIMENSION(:)
+  ty = fir::BoxType::get(seqRecTy);
+  EXPECT_FALSE(fir::isScalarBoxedRecordType(ty));
+
+  EXPECT_FALSE(fir::isScalarBoxedRecordType(fir::BoxType::get(
+      fir::ReferenceType::get(mlir::IntegerType::get(&context, 32)))));
+}
+
 TEST_F(FIRTypesTest, updateTypeForUnlimitedPolymorphic) {
   // RecordType are not changed.