[flang] Represent (parentheses around derived types)
authorpeter klausler <pklausler@nvidia.com>
Wed, 22 Sep 2021 23:49:09 +0000 (16:49 -0700)
committerpeter klausler <pklausler@nvidia.com>
Thu, 23 Sep 2021 20:03:13 +0000 (13:03 -0700)
The strongly typed expression representation classes supported
a representation of parentheses only around intrinsic types
with specific kinds.  Parentheses around derived type variables
must also be preserved so that expressions may be distinguished
from variables; this distinction matters for actual arguments &
construct associations.

Differential Revision: https://reviews.llvm.org/D110355

flang/include/flang/Evaluate/expression.h
flang/lib/Evaluate/expression.cpp
flang/lib/Evaluate/tools.cpp
flang/test/Evaluate/expr01.f90 [new file with mode: 0644]

index 8eacdef..ea68f6e 100644 (file)
@@ -116,8 +116,10 @@ class Operation {
 public:
   using Derived = DERIVED;
   using Result = RESULT;
-  static_assert(IsSpecificIntrinsicType<Result>);
   static constexpr std::size_t operands{sizeof...(OPERANDS)};
+  // Allow specific intrinsic types and Parentheses<SomeDerived>
+  static_assert(IsSpecificIntrinsicType<Result> ||
+      (operands == 1 && std::is_same_v<Result, SomeDerived>));
   template <int J> using Operand = std::tuple_element_t<J, OperandTypes>;
 
   // Unary operations wrap a single Expr with a CopyableIndirection.
@@ -172,7 +174,9 @@ public:
     }
   }
 
-  static constexpr std::optional<DynamicType> GetType() {
+  static constexpr std::conditional_t<Result::category != TypeCategory::Derived,
+      std::optional<DynamicType>, void>
+  GetType() {
     return Result::GetType();
   }
   int Rank() const {
@@ -222,6 +226,17 @@ struct Parentheses : public Operation<Parentheses<A>, A, A> {
   using Base::Base;
 };
 
+template <>
+struct Parentheses<SomeDerived>
+    : public Operation<Parentheses<SomeDerived>, SomeDerived, SomeDerived> {
+public:
+  using Result = SomeDerived;
+  using Operand = SomeDerived;
+  using Base = Operation<Parentheses, SomeDerived, SomeDerived>;
+  using Base::Base;
+  DynamicType GetType() const;
+};
+
 template <typename A> struct Negate : public Operation<Negate<A>, A, A> {
   using Result = A;
   using Operand = A;
@@ -730,7 +745,7 @@ public:
   using Result = SomeDerived;
   EVALUATE_UNION_CLASS_BOILERPLATE(Expr)
   std::variant<Constant<Result>, ArrayConstructor<Result>, StructureConstructor,
-      Designator<Result>, FunctionRef<Result>>
+      Designator<Result>, FunctionRef<Result>, Parentheses<Result>>
       u;
 };
 
index 7f8c9eb..c08e977 100644 (file)
@@ -107,6 +107,10 @@ template <typename A> int ExpressionBase<A>::Rank() const {
       derived().u);
 }
 
+DynamicType Parentheses<SomeDerived>::GetType() const {
+  return left().GetType().value();
+}
+
 // Equality testing
 
 bool ImpliedDoIndex::operator==(const ImpliedDoIndex &that) const {
index dd66259..bf50eb9 100644 (file)
@@ -35,9 +35,10 @@ Expr<SomeType> Parenthesize(Expr<SomeType> &&expr) {
   return std::visit(
       [&](auto &&x) {
         using T = std::decay_t<decltype(x)>;
-        if constexpr (common::HasMember<T, TypelessExpression> ||
-            std::is_same_v<T, Expr<SomeDerived>>) {
-          return expr; // no parentheses around typeless or derived type
+        if constexpr (common::HasMember<T, TypelessExpression>) {
+          return expr; // no parentheses around typeless
+        } else if constexpr (std::is_same_v<T, Expr<SomeDerived>>) {
+          return AsGenericExpr(Parentheses<SomeDerived>{std::move(x)});
         } else {
           return std::visit(
               [](auto &&y) {
diff --git a/flang/test/Evaluate/expr01.f90 b/flang/test/Evaluate/expr01.f90
new file mode 100644 (file)
index 0000000..c0f8437
--- /dev/null
@@ -0,0 +1,34 @@
+! RUN: %flang_fc1 -fdebug-unparse %s 2>&1 | FileCheck %s
+! Ensures that parentheses are preserved with derived types
+module m
+  type :: t
+    integer :: n
+  end type
+ contains
+  subroutine sub(x)
+    type(t), intent(in) :: x
+  end subroutine
+  function f(m)
+    type(t), pointer :: f
+    integer, intent(in) :: m
+    type(t), save, target :: res
+    res%n = m
+    f => res
+  end function
+  subroutine test
+    type(t) :: x
+    x = t(1)
+    !CHECK: CALL sub(t(n=1_4))
+    call sub(t(1))
+    !CHECK: CALL sub((t(n=1_4)))
+    call sub((t(1)))
+    !CHECK: CALL sub(x)
+    call sub(x)
+    !CHECK: CALL sub((x))
+    call sub((x))
+    !CHECK: CALL sub(f(2_4))
+    call sub(f(2))
+    !CHECK: CALL sub((f(2_4)))
+    call sub((f(2)))
+  end subroutine
+end module