[flang] Fix checking of argument passing for parameterized derived types
authorPeter Steinfeld <psteinfeld@nvidia.com>
Wed, 21 Apr 2021 19:12:26 +0000 (12:12 -0700)
committerPeter Steinfeld <psteinfeld@nvidia.com>
Thu, 22 Apr 2021 15:49:49 +0000 (08:49 -0700)
We were erroneously not taking into account the constant values of LEN type
parameters of parameterized derived types when checking for argument
compatibility.  The required checks are identical to those for assignment
compatibility.  Since argument compatibility is checked in .../lib/Evaluate and
assignment compatibility is checked in .../lib/Semantics, I moved the common
code into .../lib/Evaluate/tools.cpp and changed the assignment compatibility
checking code to call it.

After implementing these new checks, tests in resolve53.f90 were failing
because the tests were erroneous.  I fixed these tests and added new tests
to call03.f90 to test argument passing of parameterized derived types more
completely.

Differential Revision: https://reviews.llvm.org/D100989

flang/include/flang/Evaluate/tools.h
flang/lib/Evaluate/tools.cpp
flang/lib/Evaluate/type.cpp
flang/lib/Semantics/type.cpp
flang/test/Semantics/call03.f90
flang/test/Semantics/resolve53.f90

index bcfb1c5..906acdb 100644 (file)
@@ -965,6 +965,12 @@ const Symbol &GetAssociationRoot(const Symbol &);
 const Symbol *FindCommonBlockContaining(const Symbol &);
 int CountLenParameters(const DerivedTypeSpec &);
 int CountNonConstantLenParameters(const DerivedTypeSpec &);
+
+// 15.5.2.4(4), type compatibility for dummy and actual arguments.
+// Also used for assignment compatibility checking
+bool AreTypeParamCompatible(
+    const semantics::DerivedTypeSpec &, const semantics::DerivedTypeSpec &);
+
 const Symbol &GetUsedModule(const UseDetails &);
 const Symbol *FindFunctionResult(const Symbol &);
 
index 9fbf21e..a0057e8 100644 (file)
@@ -1174,6 +1174,31 @@ int CountNonConstantLenParameters(const DerivedTypeSpec &type) {
       });
 }
 
+// Are the type parameters of type1 compile-time compatible with the
+// corresponding kind type parameters of type2?  Return true if all constant
+// valued parameters are equal.
+// Used to check assignment statements and argument passing.  See 15.5.2.4(4)
+bool AreTypeParamCompatible(const semantics::DerivedTypeSpec &type1,
+    const semantics::DerivedTypeSpec &type2) {
+  for (const auto &[name, param1] : type1.parameters()) {
+    if (semantics::MaybeIntExpr paramExpr1{param1.GetExplicit()}) {
+      if (IsConstantExpr(*paramExpr1)) {
+        const semantics::ParamValue *param2{type2.FindParameter(name)};
+        if (param2) {
+          if (semantics::MaybeIntExpr paramExpr2{param2->GetExplicit()}) {
+            if (IsConstantExpr(*paramExpr2)) {
+              if (ToInt64(*paramExpr1) != ToInt64(*paramExpr2)) {
+                return false;
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+  return true;
+}
+
 const Symbol &GetUsedModule(const UseDetails &details) {
   return DEREF(details.symbol().owner().symbol());
 }
index 1d5f720..0d2004d 100644 (file)
@@ -316,21 +316,6 @@ static bool AreCompatibleDerivedTypes(const semantics::DerivedTypeSpec *x,
   }
 }
 
-// Do the kind type parameters of type1 have the same values as the
-// corresponding kind type parameters of type2?
-static bool AreKindCompatible(const semantics::DerivedTypeSpec &type1,
-    const semantics::DerivedTypeSpec &type2) {
-  for (const auto &[name, param1] : type1.parameters()) {
-    if (param1.isKind()) {
-      const semantics::ParamValue *param2{type2.FindParameter(name)};
-      if (!PointeeComparison(&param1, param2)) {
-        return false;
-      }
-    }
-  }
-  return true;
-}
-
 // See 7.3.2.3 (5) & 15.5.2.4
 bool DynamicType::IsTkCompatibleWith(const DynamicType &that) const {
   if (IsUnlimitedPolymorphic()) {
@@ -342,7 +327,7 @@ bool DynamicType::IsTkCompatibleWith(const DynamicType &that) const {
   } else if (derived_) {
     return that.derived_ &&
         AreCompatibleDerivedTypes(derived_, that.derived_, IsPolymorphic()) &&
-        AreKindCompatible(*derived_, *that.derived_);
+        AreTypeParamCompatible(*derived_, *that.derived_);
   } else {
     return kind_ == that.kind_;
   }
index 16625c0..6f76f0b 100644 (file)
@@ -10,6 +10,7 @@
 #include "check-declarations.h"
 #include "compute-offsets.h"
 #include "flang/Evaluate/fold.h"
+#include "flang/Evaluate/tools.h"
 #include "flang/Parser/characters.h"
 #include "flang/Semantics/scope.h"
 #include "flang/Semantics/symbol.h"
@@ -197,26 +198,7 @@ bool DerivedTypeSpec::MightBeAssignmentCompatibleWith(
   if (!RawEquals(that)) {
     return false;
   }
-  const std::map<SourceName, ParamValue> &theseParams{this->parameters()};
-  const std::map<SourceName, ParamValue> &thoseParams{that.parameters()};
-  auto thatIter{thoseParams.begin()};
-  for (const auto &[thisName, thisValue] : theseParams) {
-    CHECK(thatIter != thoseParams.end());
-    const ParamValue &thatValue{thatIter->second};
-    if (MaybeIntExpr thisExpr{thisValue.GetExplicit()}) {
-      if (evaluate::IsConstantExpr(*thisExpr)) {
-        if (MaybeIntExpr thatExpr{thatValue.GetExplicit()}) {
-          if (evaluate::IsConstantExpr(*thatExpr)) {
-            if (evaluate::ToInt64(*thisExpr) != evaluate::ToInt64(*thatExpr)) {
-              return false;
-            }
-          }
-        }
-      }
-    }
-    thatIter++;
-  }
-  return true;
+  return AreTypeParamCompatible(*this, that);
 }
 
 class InstantiateHelper {
index f76e421..23e8dc8 100644 (file)
@@ -8,6 +8,9 @@ module m01
   type :: pdt(n)
     integer, len :: n
   end type
+  type :: pdtWithDefault(n)
+    integer, len :: n = 3
+  end type
   type :: tbp
    contains
     procedure :: binding => subr01
@@ -120,11 +123,59 @@ module m01
   subroutine ch2(x)
     character(2), intent(in out) :: x
   end subroutine
+  subroutine pdtdefault (derivedArg)
+    !ERROR: Type parameter 'n' lacks a value and has no default
+    type(pdt) :: derivedArg
+  end subroutine pdtdefault
+  subroutine pdt3 (derivedArg)
+    type(pdt(4)) :: derivedArg
+  end subroutine pdt3
+  subroutine pdt4 (derivedArg)
+    type(pdt(*)) :: derivedArg
+  end subroutine pdt4
+  subroutine pdtWithDefaultDefault (derivedArg)
+    type(pdtWithDefault) :: derivedArg
+  end subroutine pdtWithDefaultdefault
+  subroutine pdtWithDefault3 (derivedArg)
+    type(pdtWithDefault(4)) :: derivedArg
+  end subroutine pdtWithDefault3
+  subroutine pdtWithDefault4 (derivedArg)
+    type(pdtWithDefault(*)) :: derivedArg
+  end subroutine pdtWithDefault4
   subroutine test06 ! 15.5.2.4(4)
+    !ERROR: Type parameter 'n' lacks a value and has no default
+    type(pdt) :: vardefault
+    type(pdt(3)) :: var3
+    type(pdt(4)) :: var4
+    type(pdtWithDefault) :: defaultVardefault
+    type(pdtWithDefault(3)) :: defaultVar3
+    type(pdtWithDefault(4)) :: defaultVar4
     character :: ch1
     ! The actual argument is converted to a padded expression.
     !ERROR: Actual argument associated with INTENT(IN OUT) dummy argument 'x=' must be definable
     call ch2(ch1)
+    call pdtdefault(vardefault)
+    call pdtdefault(var3)
+    call pdtdefault(var4) ! error
+    call pdt3(vardefault) ! error
+    !ERROR: Actual argument type 'pdt(n=3_4)' is not compatible with dummy argument type 'pdt(n=4_4)'
+    call pdt3(var3) ! error
+    call pdt3(var4)
+    call pdt4(vardefault)
+    call pdt4(var3)
+    call pdt4(var4)
+    call pdtWithDefaultdefault(defaultVardefault)
+    call pdtWithDefaultdefault(defaultVar3)
+    !ERROR: Actual argument type 'pdtwithdefault(n=4_4)' is not compatible with dummy argument type 'pdtwithdefault(n=3_4)'
+    call pdtWithDefaultdefault(defaultVar4) ! error
+    !ERROR: Actual argument type 'pdtwithdefault(n=3_4)' is not compatible with dummy argument type 'pdtwithdefault(n=4_4)'
+    call pdtWithDefault3(defaultVardefault) ! error
+    !ERROR: Actual argument type 'pdtwithdefault(n=3_4)' is not compatible with dummy argument type 'pdtwithdefault(n=4_4)'
+    call pdtWithDefault3(defaultVar3) ! error
+    call pdtWithDefault3(defaultVar4)
+    call pdtWithDefault4(defaultVardefault)
+    call pdtWithDefault4(defaultVar3)
+    call pdtWithDefault4(defaultVar4)
   end subroutine
 
   subroutine out01(x)
index c2cbe38..89d07cb 100644 (file)
@@ -304,7 +304,7 @@ module m15
 
 contains
   subroutine s1(x)
-    type(t1(1, 4)) :: x
+    type(t1(1, 5)) :: x
   end
   subroutine s2(x)
     type(t1(2, 4)) :: x
@@ -319,7 +319,7 @@ contains
     type(t3) :: x
   end subroutine
   subroutine s6(x)
-    type(t3(1, 99, k2b=2, k2a=3, l2=*, l3=97, k3=4)) :: x
+    type(t3(1, 99, k2b=2, k2a=3, l2=*, l3=103, k3=4)) :: x
   end subroutine
   subroutine s7(x)
     type(t3(k1=1, l1=99, k2a=3, k2b=2, k3=4)) :: x