From: peter klausler Date: Tue, 21 May 2019 23:58:46 +0000 (-0700) Subject: [flang] Defer conversions to objects; fix some intrinsic table entries X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=d29530e1c419d8347cbd3c345fef43a6322491e4;p=platform%2Fupstream%2Fllvm.git [flang] Defer conversions to objects; fix some intrinsic table entries more fixes Access components of constant structures Apply implicit typing to dummy args used in automatic array dimensions SELECTED_INT_KIND and SELECTED_REAL_KIND Finish SELECTED_{INT,REAL}_KIND and common cases of ALL()/ANY() Original-commit: flang-compiler/f18@e9f8e53e55867863ca06461fb6edd965c86352c7 Reviewed-on: https://github.com/flang-compiler/f18/pull/472 Tree-same-pre-rewrite: false --- diff --git a/flang/lib/common/template.h b/flang/lib/common/template.h index 48f31f2..471a15b 100644 --- a/flang/lib/common/template.h +++ b/flang/lib/common/template.h @@ -295,25 +295,28 @@ std::optional MapOptional(R (*f)(A &&...), std::optional &&... x) { // SearchTypes will traverse the element types in the tuple in order // and invoke VISITOR::Test() on each until it returns a value that // casts to true. If no invocation of Test succeeds, SearchTypes will -// return a default-constructed value VISITOR::Result{}. +// return a default value. template common::IfNoLvalue SearchTypesHelper( - VISITOR &&visitor) { + VISITOR &&visitor, typename VISITOR::Result &&defaultResult) { using Tuple = typename VISITOR::Types; if constexpr (J < std::tuple_size_v) { if (auto result{visitor.template Test>()}) { return result; } - return SearchTypesHelper(std::move(visitor)); + return SearchTypesHelper(std::move(visitor), + std::move(defaultResult)); } else { - return typename VISITOR::Result{}; + return std::move(defaultResult); } } template common::IfNoLvalue SearchTypes( - VISITOR &&visitor) { - return SearchTypesHelper<0, VISITOR>(std::move(visitor)); + VISITOR &&visitor, + typename VISITOR::Result defaultResult = typename VISITOR::Result{}) { + return SearchTypesHelper<0, VISITOR>( + std::move(visitor), std::move(defaultResult)); } } #endif // FORTRAN_COMMON_TEMPLATE_H_ diff --git a/flang/lib/evaluate/call.cc b/flang/lib/evaluate/call.cc index 873131a..1e7dd1a 100644 --- a/flang/lib/evaluate/call.cc +++ b/flang/lib/evaluate/call.cc @@ -42,7 +42,7 @@ ActualArgument &ActualArgument::operator=(Expr &&expr) { } std::optional ActualArgument::GetType() const { - if (const auto *expr{GetExpr()}) { + if (const auto *expr{UnwrapExpr()}) { return expr->GetType(); } else { return std::nullopt; @@ -50,7 +50,7 @@ std::optional ActualArgument::GetType() const { } int ActualArgument::Rank() const { - if (const auto *expr{GetExpr()}) { + if (const auto *expr{UnwrapExpr()}) { return expr->Rank(); } else { return std::get(u_).Rank(); diff --git a/flang/lib/evaluate/call.h b/flang/lib/evaluate/call.h index 5bd38e9..74785da 100644 --- a/flang/lib/evaluate/call.h +++ b/flang/lib/evaluate/call.h @@ -73,7 +73,7 @@ public: ~ActualArgument(); ActualArgument &operator=(Expr &&); - Expr *GetExpr() { + Expr *UnwrapExpr() { if (auto *p{ std::get_if>>(&u_)}) { return &p->value(); @@ -81,7 +81,7 @@ public: return nullptr; } } - const Expr *GetExpr() const { + const Expr *UnwrapExpr() const { if (const auto *p{ std::get_if>>(&u_)}) { return &p->value(); diff --git a/flang/lib/evaluate/characteristics.h b/flang/lib/evaluate/characteristics.h index 47ab784..7d8757c 100644 --- a/flang/lib/evaluate/characteristics.h +++ b/flang/lib/evaluate/characteristics.h @@ -56,7 +56,7 @@ public: bool operator==(const TypeAndShape &) const; bool IsAssumedRank() const { return isAssumedRank_; } - int Rank() const { return static_cast(shape().size()); } + int Rank() const { return GetRank(shape_); } bool IsCompatibleWith( parser::ContextualMessages &, const TypeAndShape &) const; diff --git a/flang/lib/evaluate/constant.cc b/flang/lib/evaluate/constant.cc index e1db2df..8be8c3c 100644 --- a/flang/lib/evaluate/constant.cc +++ b/flang/lib/evaluate/constant.cc @@ -14,6 +14,7 @@ #include "constant.h" #include "expression.h" +#include "shape.h" #include "type.h" #include @@ -30,9 +31,9 @@ std::size_t TotalElementCount(const ConstantSubscripts &shape) { bool IncrementSubscripts( ConstantSubscripts &indices, const ConstantSubscripts &shape) { - auto rank{shape.size()}; - CHECK(indices.size() == rank); - for (std::size_t j{0}; j < rank; ++j) { + int rank{GetRank(shape)}; + CHECK(GetRank(indices) == rank); + for (int j{0}; j < rank; ++j) { CHECK(indices[j] >= 1); if (++indices[j] <= shape[j]) { return true; @@ -45,6 +46,13 @@ bool IncrementSubscripts( } template +ConstantBase::ConstantBase( + std::vector &&x, ConstantSubscripts &&dims, Result res) + : result_{res}, values_(std::move(x)), shape_(std::move(dims)) { + CHECK(size() == TotalElementCount(shape_)); +} + +template ConstantBase::~ConstantBase() {} template @@ -54,7 +62,7 @@ bool ConstantBase::operator==(const ConstantBase &that) const { static ConstantSubscript SubscriptsToOffset( const ConstantSubscripts &index, const ConstantSubscripts &shape) { - CHECK(index.size() == shape.size()); + CHECK(GetRank(index) == GetRank(shape)); ConstantSubscript stride{1}, offset{0}; int dim{0}; for (auto j : index) { @@ -66,20 +74,25 @@ static ConstantSubscript SubscriptsToOffset( return offset; } -static Constant ShapeAsConstant( - const ConstantSubscripts &shape) { - using IntType = Scalar; - std::vector result; - for (auto dim : shape) { - result.emplace_back(dim); - } - return {std::move(result), - ConstantSubscripts{static_cast(shape.size())}}; +template +Constant ConstantBase::SHAPE() const { + return AsConstantShape(shape_); } template -Constant ConstantBase::SHAPE() const { - return ShapeAsConstant(shape_); +auto ConstantBase::Reshape( + const ConstantSubscripts &dims) const -> std::vector { + std::size_t n{TotalElementCount(dims)}; + CHECK(!empty() || n == 0); + std::vector elements; + auto iter{values().cbegin()}; + while (n-- > 0) { + elements.push_back(*iter); + if (++iter == values().cend()) { + iter = values().cbegin(); + } + } + return elements; } template @@ -87,6 +100,11 @@ auto Constant::At(const ConstantSubscripts &index) const -> Element { return Base::values_.at(SubscriptsToOffset(index, Base::shape_)); } +template +auto Constant::Reshape(ConstantSubscripts &&dims) const -> Constant { + return {Base::Reshape(dims), std::move(dims)}; +} + // Constant specializations template Constant>::Constant( @@ -102,6 +120,7 @@ template Constant>::Constant(std::int64_t len, std::vector> &&strings, ConstantSubscripts &&dims) : length_{len}, shape_{std::move(dims)} { + CHECK(strings.size() == TotalElementCount(shape_)); values_.assign(strings.size() * length_, static_cast::value_type>(' ')); std::int64_t at{0}; @@ -119,14 +138,6 @@ Constant>::Constant(std::int64_t len, template Constant>::~Constant() {} -static ConstantSubscript ShapeElements(const ConstantSubscripts &shape) { - ConstantSubscript elements{1}; - for (auto dim : shape) { - elements *= dim; - } - return elements; -} - template bool Constant>::empty() const { return size() == 0; @@ -135,7 +146,7 @@ bool Constant>::empty() const { template std::size_t Constant>::size() const { if (length_ == 0) { - return ShapeElements(shape_); + return TotalElementCount(shape_); } else { return static_cast(values_.size()) / length_; } @@ -149,9 +160,26 @@ auto Constant>::At( } template +auto Constant>::Reshape( + ConstantSubscripts &&dims) const -> Constant { + std::size_t n{TotalElementCount(dims)}; + CHECK(!empty() || n == 0); + std::vector elements; + std::int64_t at{0}, limit{static_cast(values_.size())}; + while (n-- > 0) { + elements.push_back(values_.substr(at, length_)); + at += length_; + if (at == limit) { // subtle: at > limit somehow? substr() will catch it + at = 0; + } + } + return {length_, std::move(elements), std::move(dims)}; +} + +template Constant Constant>::SHAPE() const { - return ShapeAsConstant(shape_); + return AsConstantShape(shape_); } // Constant specialization @@ -193,5 +221,10 @@ StructureConstructor Constant::At( values_.at(SubscriptsToOffset(index, shape_))}; } +auto Constant::Reshape(ConstantSubscripts &&dims) const + -> Constant { + return {result().derivedTypeSpec(), Base::Reshape(dims), std::move(dims)}; +} + INSTANTIATE_CONSTANT_TEMPLATES } diff --git a/flang/lib/evaluate/constant.h b/flang/lib/evaluate/constant.h index 064a774..8efa9d5 100644 --- a/flang/lib/evaluate/constant.h +++ b/flang/lib/evaluate/constant.h @@ -19,6 +19,7 @@ #include "type.h" #include #include +#include namespace Fortran::evaluate { @@ -36,6 +37,9 @@ template class Constant; // values as indices into constants, use a vector of integers. using ConstantSubscript = std::int64_t; using ConstantSubscripts = std::vector; +inline int GetRank(const ConstantSubscripts &s) { + return static_cast(s.size()); +} std::size_t TotalElementCount(const ConstantSubscripts &); @@ -43,7 +47,7 @@ inline ConstantSubscripts InitialSubscripts(int rank) { return ConstantSubscripts(rank, 1); // parens, not braces: "rank" copies of 1 } inline ConstantSubscripts InitialSubscripts(const ConstantSubscripts &shape) { - return InitialSubscripts(static_cast(shape.size())); + return InitialSubscripts(GetRank(shape)); } // Increments a vector of subscripts in Fortran array order (first dimension @@ -66,20 +70,18 @@ public: template> ConstantBase(A &&x, Result res = Result{}) : result_{res}, values_{std::move(x)} {} - ConstantBase(std::vector &&x, ConstantSubscripts &&dims, - Result res = Result{}) - : result_{res}, values_(std::move(x)), shape_(std::move(dims)) {} + ConstantBase( + std::vector &&, ConstantSubscripts &&, Result = Result{}); DEFAULT_CONSTRUCTORS_AND_ASSIGNMENTS(ConstantBase) ~ConstantBase(); - int Rank() const { return static_cast(shape_.size()); } + int Rank() const { return GetRank(shape_); } bool operator==(const ConstantBase &) const; bool empty() const { return values_.empty(); } std::size_t size() const { return values_.size(); } const std::vector &values() const { return values_; } const ConstantSubscripts &shape() const { return shape_; } - ConstantSubscripts &shape() { return shape_; } constexpr Result result() const { return result_; } constexpr DynamicType GetType() const { return result_.GetType(); } @@ -87,6 +89,8 @@ public: std::ostream &AsFortran(std::ostream &) const; protected: + std::vector Reshape(const ConstantSubscripts &) const; + Result result_; std::vector values_; ConstantSubscripts shape_; @@ -111,6 +115,7 @@ public: // Apply 1-based subscripts Element At(const ConstantSubscripts &) const; + Constant Reshape(ConstantSubscripts &&) const; }; template class Constant> { @@ -124,14 +129,13 @@ public: Constant(std::int64_t, std::vector &&, ConstantSubscripts &&); ~Constant(); - int Rank() const { return static_cast(shape_.size()); } + int Rank() const { return GetRank(shape_); } bool operator==(const Constant &that) const { return shape_ == that.shape_ && values_ == that.values_; } bool empty() const; std::size_t size() const; const ConstantSubscripts &shape() const { return shape_; } - ConstantSubscripts &shape() { return shape_; } std::int64_t LEN() const { return length_; } @@ -145,6 +149,7 @@ public: // Apply 1-based subscripts Scalar At(const ConstantSubscripts &) const; + Constant Reshape(ConstantSubscripts &&) const; Constant SHAPE() const; std::ostream &AsFortran(std::ostream &) const; @@ -180,6 +185,7 @@ public: std::optional GetScalarValue() const; StructureConstructor At(const ConstantSubscripts &) const; + Constant Reshape(ConstantSubscripts &&) const; }; FOR_EACH_LENGTHLESS_INTRINSIC_KIND(extern template class ConstantBase, ) diff --git a/flang/lib/evaluate/descender.h b/flang/lib/evaluate/descender.h index f43fee6..40f45ba 100644 --- a/flang/lib/evaluate/descender.h +++ b/flang/lib/evaluate/descender.h @@ -309,7 +309,7 @@ public: template void Descend(Variable &var) { Visit(var.u); } void Descend(const ActualArgument &arg) { - if (const auto *expr{arg.GetExpr()}) { + if (const auto *expr{arg.UnwrapExpr()}) { Visit(*expr); } else { const semantics::Symbol *aType{arg.GetAssumedTypeDummy()}; @@ -317,7 +317,7 @@ public: } } void Descend(ActualArgument &arg) { - if (auto *expr{arg.GetExpr()}) { + if (auto *expr{arg.UnwrapExpr()}) { Visit(*expr); } else { const semantics::Symbol *aType{arg.GetAssumedTypeDummy()}; diff --git a/flang/lib/evaluate/expression.cc b/flang/lib/evaluate/expression.cc index c1662d2..944f460 100644 --- a/flang/lib/evaluate/expression.cc +++ b/flang/lib/evaluate/expression.cc @@ -142,6 +142,15 @@ bool StructureConstructor::operator==(const StructureConstructor &that) const { DynamicType StructureConstructor::GetType() const { return result_.GetType(); } +const Expr *StructureConstructor::Find( + const Symbol *component) const { + if (auto iter{values_.find(component)}; iter != values_.end()) { + return &iter->second.value(); + } else { + return nullptr; + } +} + StructureConstructor &StructureConstructor::Add( const Symbol &symbol, Expr &&expr) { values_.emplace(&symbol, std::move(expr)); diff --git a/flang/lib/evaluate/expression.h b/flang/lib/evaluate/expression.h index 7139341..808a894 100644 --- a/flang/lib/evaluate/expression.h +++ b/flang/lib/evaluate/expression.h @@ -752,6 +752,8 @@ public: return values_.end(); } + const Expr *Find(const Symbol *) const; // can return null + StructureConstructor &Add(const semantics::Symbol &, Expr &&); int Rank() const { return 0; } DynamicType GetType() const; diff --git a/flang/lib/evaluate/fold.cc b/flang/lib/evaluate/fold.cc index 44ac74c..d60f5e0 100644 --- a/flang/lib/evaluate/fold.cc +++ b/flang/lib/evaluate/fold.cc @@ -40,24 +40,33 @@ namespace Fortran::evaluate { +// FoldOperation() rewrites expression tree nodes. +// If there is any possibility that the rewritten node will +// not have the same representation type, the result of +// FoldOperation() will be packaged in an Expr<> of the same +// specific type. + // no-op base case template common::IfNoLvalue>, A> FoldOperation( FoldingContext &, A &&x) { + static_assert(!std::is_same_v>> && + "call Fold() instead for Expr<>"); return Expr>{std::move(x)}; } // Forward declarations of overloads, template instantiations, and template // specializations of FoldOperation() to enable mutual recursion between them. -BaseObject FoldOperation(FoldingContext &, BaseObject &&); -Component FoldOperation(FoldingContext &, Component &&); -Triplet FoldOperation(FoldingContext &, Triplet &&); -Subscript FoldOperation(FoldingContext &, Subscript &&); -ArrayRef FoldOperation(FoldingContext &, ArrayRef &&); -CoarrayRef FoldOperation(FoldingContext &, CoarrayRef &&); -DataRef FoldOperation(FoldingContext &, DataRef &&); -Substring FoldOperation(FoldingContext &, Substring &&); -ComplexPart FoldOperation(FoldingContext &, ComplexPart &&); +static Component FoldOperation(FoldingContext &, Component &&); +static Triplet FoldOperation( + FoldingContext &, Triplet &&, const Symbol &, int dim); +static Subscript FoldOperation( + FoldingContext &, Subscript &&, const Symbol &, int dim); +static ArrayRef FoldOperation(FoldingContext &, ArrayRef &&); +static CoarrayRef FoldOperation(FoldingContext &, CoarrayRef &&); +static DataRef FoldOperation(FoldingContext &, DataRef &&); +static Substring FoldOperation(FoldingContext &, Substring &&); +static ComplexPart FoldOperation(FoldingContext &, ComplexPart &&); template Expr> FoldOperation( FoldingContext &context, FunctionRef> &&); @@ -75,27 +84,37 @@ template Expr FoldOperation(FoldingContext &, Designator &&); template Expr> FoldOperation( FoldingContext &, TypeParamInquiry &&); +static Expr FoldOperation( + FoldingContext &context, ImpliedDoIndex &&); template Expr FoldOperation(FoldingContext &, ArrayConstructor &&); -Expr FoldOperation(FoldingContext &, StructureConstructor &&); +static Expr FoldOperation( + FoldingContext &, StructureConstructor &&); // Overloads, instantiations, and specializations of FoldOperation(). -BaseObject FoldOperation(FoldingContext &, BaseObject &&object) { - return std::move(object); -} - Component FoldOperation(FoldingContext &context, Component &&component) { return {FoldOperation(context, std::move(component.base())), component.GetLastSymbol()}; } -Triplet FoldOperation(FoldingContext &context, Triplet &&triplet) { - return {Fold(context, triplet.lower()), Fold(context, triplet.upper()), - Fold(context, common::Clone(triplet.stride()))}; +Triplet FoldOperation( + FoldingContext &context, Triplet &&triplet, const Symbol &symbol, int dim) { + MaybeExtentExpr lower{triplet.lower()}; + if (!lower.has_value()) { + lower = GetLowerBound(context, symbol, dim); + } + MaybeExtentExpr upper{triplet.upper()}; + if (!upper.has_value()) { + upper = GetUpperBound( + context, common::Clone(lower), GetExtent(context, symbol, dim)); + } + return {Fold(context, std::move(lower)), Fold(context, std::move(upper)), + Fold(context, triplet.stride())}; } -Subscript FoldOperation(FoldingContext &context, Subscript &&subscript) { +Subscript FoldOperation(FoldingContext &context, Subscript &&subscript, + const Symbol &symbol, int dim) { return std::visit( common::visitors{ [&](IndirectSubscriptIntegerExpr &&expr) { @@ -103,15 +122,18 @@ Subscript FoldOperation(FoldingContext &context, Subscript &&subscript) { return Subscript(std::move(expr)); }, [&](Triplet &&triplet) { - return Subscript(FoldOperation(context, std::move(triplet))); + return Subscript( + FoldOperation(context, std::move(triplet), symbol, dim)); }, }, std::move(subscript.u)); } ArrayRef FoldOperation(FoldingContext &context, ArrayRef &&arrayRef) { + const Symbol &symbol{arrayRef.GetLastSymbol()}; + int dim{0}; for (Subscript &subscript : arrayRef.subscript()) { - subscript = FoldOperation(context, std::move(subscript)); + subscript = FoldOperation(context, std::move(subscript), symbol, dim++); } return std::visit( common::visitors{ @@ -127,9 +149,11 @@ ArrayRef FoldOperation(FoldingContext &context, ArrayRef &&arrayRef) { } CoarrayRef FoldOperation(FoldingContext &context, CoarrayRef &&coarrayRef) { + const Symbol &symbol{coarrayRef.GetLastSymbol()}; std::vector subscript; + int dim{0}; for (Subscript x : coarrayRef.subscript()) { - subscript.emplace_back(FoldOperation(context, std::move(x))); + subscript.emplace_back(FoldOperation(context, std::move(x), symbol, dim++)); } std::vector> cosubscript; for (Expr x : coarrayRef.cosubscript()) { @@ -194,8 +218,8 @@ static inline Expr FoldElementalIntrinsicHelper(FoldingContext &context, static_assert( (... && IsSpecificIntrinsicType)); // TODO derived types for MERGE? static_assert(sizeof...(TA) > 0); - std::tuple *...> args{ - GetConstantValue(*funcRef.arguments()[I].value().GetExpr())...}; + std::tuple *...> args{UnwrapExpr>( + *funcRef.arguments()[I].value().UnwrapExpr())...}; if ((... && (std::get(args) != nullptr))) { // Compute the shape of the result based on shapes of arguments ConstantSubscripts shape; @@ -215,13 +239,13 @@ static inline Expr FoldElementalIntrinsicHelper(FoldingContext &context, // same. Shouldn't this be checked elsewhere so that this is also // checked for non constexpr call to elemental intrinsics function? context.messages().Say( - "arguments in elemental intrinsic function are not conformable"_err_en_US); + "Arguments in elemental intrinsic function are not conformable"_err_en_US); return Expr{std::move(funcRef)}; } } } } - CHECK(rank == static_cast(shape.size())); + CHECK(rank == GetRank(shape)); // Compute all the scalar values of the results std::vector> results; @@ -254,34 +278,35 @@ static inline Expr FoldElementalIntrinsicHelper(FoldingContext &context, } template -static Expr FoldElementalIntrinsic(FoldingContext &context, +Expr FoldElementalIntrinsic(FoldingContext &context, FunctionRef &&funcRef, ScalarFunc func) { return FoldElementalIntrinsicHelper( context, std::move(funcRef), func, std::index_sequence_for{}); } template -static Expr FoldElementalIntrinsic(FoldingContext &context, +Expr FoldElementalIntrinsic(FoldingContext &context, FunctionRef &&funcRef, ScalarFuncWithContext func) { return FoldElementalIntrinsicHelper( context, std::move(funcRef), func, std::index_sequence_for{}); } -template -static Expr *UnwrapArgument(std::optional &arg) { - if (arg.has_value()) { - if (Expr * expr{arg->GetExpr()}) { - return UnwrapExpr>(*expr); - } +static std::optional GetInt64Arg( + const std::optional &arg) { + if (const auto *intExpr{UnwrapExpr>(arg)}) { + return ToInt64(*intExpr); + } else { + return std::nullopt; } - return nullptr; } -static BOZLiteralConstant *UnwrapBozArgument( - std::optional &arg) { - if (auto *expr{UnwrapArgument(arg)}) { - return std::get_if(&expr->u); +static std::optional GetInt64ArgOr( + const std::optional &arg, std::int64_t defaultValue) { + if (!arg.has_value()) { + return defaultValue; + } else if (const auto *intExpr{UnwrapExpr>(arg)}) { + return ToInt64(*intExpr); } else { - return nullptr; + return std::nullopt; } } @@ -291,8 +316,8 @@ Expr> FoldOperation(FoldingContext &context, using T = Type; ActualArguments &args{funcRef.arguments()}; for (std::optional &arg : args) { - if (auto *expr{UnwrapArgument(arg)}) { - *expr = FoldOperation(context, std::move(*expr)); + if (auto *expr{UnwrapExpr>(arg)}) { + *expr = Fold(context, std::move(*expr)); } } if (auto *intrinsic{std::get_if(&funcRef.proc().u)}) { @@ -313,7 +338,7 @@ Expr> FoldOperation(FoldingContext &context, } else if (name == "dshiftl" || name == "dshiftr") { // convert boz for (int i{0}; i <= 1; ++i) { - if (auto *x{UnwrapBozArgument(args[i])}) { + if (auto *x{UnwrapExpr(args[i])}) { *args[i] = AsGenericExpr(Fold(context, ConvertToType(std::move(*x)))); } @@ -321,7 +346,7 @@ Expr> FoldOperation(FoldingContext &context, // Third argument can be of any kind. However, it must be smaller or equal // than BIT_SIZE. It can be converted to Int4 to simplify. using Int4 = Type; - if (auto *n{UnwrapArgument(args[2])}) { + if (auto *n{UnwrapExpr>(args[2])}) { *args[2] = AsGenericExpr(Fold(context, ConvertToType(std::move(*n)))); } @@ -335,7 +360,7 @@ Expr> FoldOperation(FoldingContext &context, fptr, i, j, static_cast(shift.ToInt64())); })); } else if (name == "exponent") { - if (auto *sx{UnwrapArgument(args[0])}) { + if (auto *sx{UnwrapExpr>(args[0])}) { return std::visit( [&funcRef, &context](const auto &x) -> Expr { using TR = typename std::decay_t::Result; @@ -349,7 +374,7 @@ Expr> FoldOperation(FoldingContext &context, } else if (name == "iand" || name == "ior" || name == "ieor") { // convert boz for (int i{0}; i <= 1; ++i) { - if (auto *x{UnwrapBozArgument(args[i])}) { + if (auto *x{UnwrapExpr(args[i])}) { *args[i] = AsGenericExpr(Fold(context, ConvertToType(std::move(*x)))); } @@ -370,7 +395,7 @@ Expr> FoldOperation(FoldingContext &context, // Second argument can be of any kind. However, it must be smaller or // equal than BIT_SIZE. It can be converted to Int4 to simplify. using Int4 = Type; - if (auto *n{UnwrapArgument(args[1])}) { + if (auto *n{UnwrapExpr>(args[1])}) { *args[1] = AsGenericExpr(Fold(context, ConvertToType(std::move(*n)))); } @@ -395,7 +420,7 @@ Expr> FoldOperation(FoldingContext &context, return std::invoke(fptr, i, static_cast(pos.ToInt64())); })); } else if (name == "int") { - if (auto *expr{args[0].value().GetExpr()}) { + if (auto *expr{args[0].value().UnwrapExpr()}) { return std::visit( [&](auto &&x) -> Expr { using From = std::decay_t; @@ -415,7 +440,7 @@ Expr> FoldOperation(FoldingContext &context, } } else if (name == "leadz" || name == "trailz" || name == "poppar" || name == "popcnt") { - if (auto *sn{UnwrapArgument(args[0])}) { + if (auto *sn{UnwrapExpr>(args[0])}) { return std::visit( [&funcRef, &context, &name](const auto &n) -> Expr { using TI = typename std::decay_t::Result; @@ -446,7 +471,7 @@ Expr> FoldOperation(FoldingContext &context, common::die("leadz argument must be integer"); } } else if (name == "len") { - if (auto *charExpr{UnwrapArgument(args[0])}) { + if (auto *charExpr{UnwrapExpr>(args[0])}) { return std::visit( [&](auto &kx) { return Fold(context, ConvertToType(kx.LEN())); }, charExpr->u); @@ -457,7 +482,7 @@ Expr> FoldOperation(FoldingContext &context, // Argument can be of any kind but value has to be smaller than bit_size. // It can be safely converted to Int4 to simplify. using Int4 = Type; - if (auto *n{UnwrapArgument(args[0])}) { + if (auto *n{UnwrapExpr>(args[0])}) { *args[0] = AsGenericExpr(Fold(context, ConvertToType(std::move(*n)))); } @@ -469,7 +494,7 @@ Expr> FoldOperation(FoldingContext &context, } else if (name == "merge_bits") { // convert boz for (int i{0}; i <= 2; ++i) { - if (auto *x{UnwrapBozArgument(args[i])}) { + if (auto *x{UnwrapExpr(args[i])}) { *args[i] = AsGenericExpr(Fold(context, ConvertToType(std::move(*x)))); } @@ -479,6 +504,18 @@ Expr> FoldOperation(FoldingContext &context, } else if (name == "rank") { // TODO assumed-rank dummy argument return Expr{args[0].value().Rank()}; + } else if (name == "selected_int_kind") { + if (auto p{GetInt64Arg(args[0])}) { + return Expr{SelectedIntKind(*p)}; + } + } else if (name == "selected_real_kind") { + if (auto p{GetInt64ArgOr(args[0], 0)}) { + if (auto r{GetInt64ArgOr(args[1], 0)}) { + if (auto radix{GetInt64ArgOr(args[2], 2)}) { + return Expr{SelectedRealKind(*p, *r, *radix)}; + } + } + } } else if (name == "shape") { if (auto shape{GetShape(context, args[0].value())}) { if (auto shapeExpr{AsExtentArrayExpr(*shape)}) { @@ -488,18 +525,16 @@ Expr> FoldOperation(FoldingContext &context, } else if (name == "size") { if (auto shape{GetShape(context, args[0].value())}) { if (auto &dimArg{args[1]}) { // DIM= is present, get one extent - if (auto *expr{dimArg->GetExpr()}) { - if (auto dim{ToInt64(*expr)}) { - std::int64_t rank = shape->size(); - if (*dim >= 1 && *dim <= rank) { - if (auto &extent{shape->at(*dim - 1)}) { - return Fold(context, ConvertToType(std::move(*extent))); - } - } else { - context.messages().Say( - "size(array,dim=%jd) dimension is out of range for rank-%d array"_en_US, - static_cast(*dim), static_cast(rank)); + if (auto dim{GetInt64Arg(args[1])}) { + int rank = GetRank(*shape); + if (*dim >= 1 && *dim <= rank) { + if (auto &extent{shape->at(*dim - 1)}) { + return Fold(context, ConvertToType(std::move(*extent))); } + } else { + context.messages().Say( + "size(array,dim=%jd) dimension is out of range for rank-%d array"_en_US, + static_cast(*dim), static_cast(rank)); } } } else if (auto extents{ @@ -518,7 +553,7 @@ Expr> FoldOperation(FoldingContext &context, // findloc, floor, iachar, iall, iany, iparity, ibits, ichar, image_status, // index, ishftc, lbound, len_trim, matmul, max, maxloc, maxval, merge, min, // minloc, minval, mod, modulo, nint, not, pack, product, reduce, reshape, - // scan, selected_char_kind, selected_int_kind, selected_real_kind, + // scan, selected_char_kind, // sign, spread, sum, transfer, transpose, ubound, unpack, verify } return Expr{std::move(funcRef)}; @@ -562,8 +597,8 @@ Expr> FoldOperation(FoldingContext &context, ActualArguments &args{funcRef.arguments()}; for (std::optional &arg : args) { if (arg.has_value()) { - if (auto *expr{arg->GetExpr()}) { - *expr = FoldOperation(context, std::move(*expr)); + if (auto *expr{arg->UnwrapExpr()}) { + *expr = Fold(context, std::move(*expr)); } } } @@ -605,7 +640,7 @@ Expr> FoldOperation(FoldingContext &context, if (args.size() == 2) { // elemental // runtime functions use int arg using Int4 = Type; - if (auto *n{UnwrapArgument(args[0])}) { + if (auto *n{UnwrapExpr>(args[0])}) { *args[0] = AsGenericExpr(Fold(context, ConvertToType(std::move(*n)))); } @@ -622,10 +657,10 @@ Expr> FoldOperation(FoldingContext &context, } } else if (name == "abs") { // Argument can be complex or real - if (auto *x{UnwrapArgument(args[0])}) { + if (auto *x{UnwrapExpr>(args[0])}) { return FoldElementalIntrinsic( context, std::move(funcRef), &Scalar::ABS); - } else if (auto *z{UnwrapArgument(args[0])}) { + } else if (auto *z{UnwrapExpr>(args[0])}) { if (auto callable{ context.hostIntrinsicsLibrary() .GetHostProcedureWrapper("abs")}) { @@ -643,7 +678,7 @@ Expr> FoldOperation(FoldingContext &context, context, std::move(funcRef), &Scalar::AIMAG); } else if (name == "aint") { // Convert argument to the requested kind before calling aint - if (auto *x{UnwrapArgument(args[0])}) { + if (auto *x{UnwrapExpr>(args[0])}) { *args[0] = AsGenericExpr(Fold(context, ConvertToType(std::move(*x)))); } @@ -657,8 +692,8 @@ Expr> FoldOperation(FoldingContext &context, return y.value; })); } else if (name == "dprod") { - if (auto *x{UnwrapArgument(args[0])}) { - if (auto *y{UnwrapArgument(args[1])}) { + if (auto *x{UnwrapExpr>(args[0])}) { + if (auto *y{UnwrapExpr>(args[1])}) { return Fold(context, Expr{Multiply{ConvertToType(std::move(*x)), ConvertToType(std::move(*y))}}); @@ -668,7 +703,7 @@ Expr> FoldOperation(FoldingContext &context, } else if (name == "epsilon") { return Expr{Constant{Scalar::EPSILON()}}; } else if (name == "real") { - if (auto *expr{args[0].value().GetExpr()}) { + if (auto *expr{args[0].value().UnwrapExpr()}) { return ToReal(context, std::move(*expr)); } } @@ -688,8 +723,8 @@ Expr> FoldOperation(FoldingContext &context, ActualArguments &args{funcRef.arguments()}; for (std::optional &arg : args) { if (arg.has_value()) { - if (auto *expr{arg->GetExpr()}) { - *expr = FoldOperation(context, std::move(*expr)); + if (auto *expr{arg->UnwrapExpr()}) { + *expr = Fold(context, std::move(*expr)); } } } @@ -712,7 +747,7 @@ Expr> FoldOperation(FoldingContext &context, context, std::move(funcRef), &Scalar::CONJG); } else if (name == "cmplx") { if (args.size() == 2) { - if (auto *x{UnwrapArgument(args[0])}) { + if (auto *x{UnwrapExpr>(args[0])}) { return Fold(context, ConvertToType(std::move(*x))); } else { common::die("x must be complex in cmplx(x[, kind])"); @@ -720,9 +755,9 @@ Expr> FoldOperation(FoldingContext &context, } else { CHECK(args.size() == 3); using Part = typename T::Part; - Expr re{std::move(*args[0].value().GetExpr())}; + Expr re{std::move(*args[0].value().UnwrapExpr())}; Expr im{args[1].has_value() - ? std::move(*args[1].value().GetExpr()) + ? std::move(*args[1].value().UnwrapExpr()) : AsGenericExpr(Constant{Scalar{}})}; return Fold(context, Expr{ @@ -743,24 +778,51 @@ Expr> FoldOperation(FoldingContext &context, ActualArguments &args{funcRef.arguments()}; for (std::optional &arg : args) { if (arg.has_value()) { - if (auto *expr{arg->GetExpr()}) { - *expr = FoldOperation(context, std::move(*expr)); + if (auto *expr{arg->UnwrapExpr()}) { + *expr = Fold(context, std::move(*expr)); } } } if (auto *intrinsic{std::get_if(&funcRef.proc().u)}) { std::string name{intrinsic->name}; - if (name == "bge" || name == "bgt" || name == "ble" || name == "blt") { + if (name == "all") { + if (!args[1].has_value()) { // TODO: ALL(x,DIM=d) + if (const auto *constant{UnwrapConstantValue(args[0])}) { + bool result{true}; + for (const auto &element : constant->values()) { + if (!element.IsTrue()) { + result = false; + break; + } + } + return Expr{result}; + } + } + } else if (name == "any") { + if (!args[1].has_value()) { // TODO: ANY(x,DIM=d) + if (const auto *constant{UnwrapConstantValue(args[0])}) { + bool result{false}; + for (const auto &element : constant->values()) { + if (element.IsTrue()) { + result = true; + break; + } + } + return Expr{result}; + } + } + } else if (name == "bge" || name == "bgt" || name == "ble" || + name == "blt") { using LargestInt = Type; static_assert(std::is_same_v, BOZLiteralConstant>); // Arguments do not have to be of the same integer type. Convert all // arguments to the biggest integer type before comparing them to // simplify. for (int i{0}; i <= 1; ++i) { - if (auto *x{UnwrapArgument(args[i])}) { + if (auto *x{UnwrapExpr>(args[i])}) { *args[i] = AsGenericExpr( Fold(context, ConvertToType(std::move(*x)))); - } else if (auto *x{UnwrapBozArgument(args[i])}) { + } else if (auto *x{UnwrapExpr(args[i])}) { *args[i] = AsGenericExpr(Constant{std::move(*x)}); } } @@ -783,7 +845,7 @@ Expr> FoldOperation(FoldingContext &context, return Scalar{std::invoke(fptr, i, j)}; })); } - // TODO: all, any, btest, cshift, dot_product, eoshift, is_iostat_end, + // TODO: btest, cshift, dot_product, eoshift, is_iostat_end, // is_iostat_eor, lge, lgt, lle, llt, logical, matmul, merge, out_of_range, // pack, parity, reduce, reshape, spread, transfer, transpose, unpack } @@ -792,20 +854,23 @@ Expr> FoldOperation(FoldingContext &context, // Get the value of a PARAMETER template -static std::optional> GetParameterValue( +std::optional> GetParameterValue( FoldingContext &context, const Symbol *symbol) { CHECK(symbol != nullptr); if (symbol->attrs().test(semantics::Attr::PARAMETER)) { if (const auto *object{ symbol->detailsIf()}) { + if (object->initWasValidated()) { + const auto *constant{UnwrapConstantValue(object->init())}; + CHECK(constant != nullptr); + return Expr{*constant}; + } if (const auto &init{object->init()}) { - if (const auto *constant{UnwrapExpr>(*init)}) { - return Expr{*constant}; - } if (auto dyType{DynamicType::From(*symbol)}) { - auto converted{ConvertToType(*dyType, common::Clone(*init))}; semantics::ObjectEntityDetails *mutableObject{ const_cast(object)}; + auto converted{ + ConvertToType(*dyType, std::move(mutableObject->init().value()))}; // Reset expression now to prevent infinite loops if the init // expression depends on symbol itself. mutableObject->set_init(std::nullopt); @@ -813,9 +878,34 @@ static std::optional> GetParameterValue( *converted = Fold(context, std::move(*converted)); auto *unwrapped{UnwrapExpr>(*converted)}; CHECK(unwrapped != nullptr); - if (auto constant{GetScalarConstantValue(*unwrapped)}) { + if (auto *constant{UnwrapConstantValue(*unwrapped)}) { + if (constant->Rank() == 0 && symbol->Rank() > 0) { + // scalar expansion + if (auto symShape{GetShape(context, *symbol)}) { + if (auto extents{AsConstantExtents(*symShape)}) { + *constant = constant->Reshape(std::move(*extents)); + CHECK(constant->Rank() == symbol->Rank()); + } + } + } mutableObject->set_init(AsGenericExpr(Expr{*constant})); - return std::move(*unwrapped); + if (auto constShape{GetShape(context, *constant)}) { + if (auto symShape{GetShape(context, *symbol)}) { + if (CheckConformance(context.messages(), *constShape, + *symShape, "initialization expression", + "PARAMETER")) { + mutableObject->set_initWasValidated(); + return std::move(*unwrapped); + } + } else { + context.messages().Say(symbol->name(), + "Could not determine the shape of the PARAMETER"_err_en_US); + } + } else { + context.messages().Say(symbol->name(), + "Could not determine the shape of the initialization expression"_err_en_US); + } + mutableObject->set_init(std::nullopt); } else { std::stringstream ss; unwrapped->AsFortran(ss); @@ -839,9 +929,9 @@ static std::optional> GetParameterValue( // Apply subscripts to a constant array -std::optional> GetConstantSubscript( - FoldingContext &context, Subscript &ss) { - ss = FoldOperation(context, std::move(ss)); +static std::optional> GetConstantSubscript( + FoldingContext &context, Subscript &ss, const Symbol &symbol, int dim) { + ss = FoldOperation(context, std::move(ss), symbol, dim); return std::visit( common::visitors{ [](IndirectSubscriptIntegerExpr &expr) @@ -886,8 +976,8 @@ std::optional> ApplySubscripts(parser::ContextualMessages &messages, const Constant &array, const std::vector> &subscripts) { const auto &shape{array.shape()}; - std::size_t rank{shape.size()}; - CHECK(rank == subscripts.size()); + int rank{GetRank(shape)}; + CHECK(rank == static_cast(subscripts.size())); std::size_t elements{1}; ConstantSubscripts resultShape; for (const auto &ss : subscripts) { @@ -901,12 +991,12 @@ std::optional> ApplySubscripts(parser::ContextualMessages &messages, std::vector> values; while (elements-- > 0) { bool increment{true}; - std::size_t k{0}; - for (std::size_t j{0}; j < rank; ++j) { + int k{0}; + for (int j{0}; j < rank; ++j) { if (subscripts[j].Rank() == 0) { at[j] = subscripts[j].GetScalarValue().value().ToInt64(); } else { - CHECK(k < resultShape.size()); + CHECK(k < GetRank(resultShape)); tmp[0] = ssAt[j] + 1; at[j] = subscripts[j].At(tmp).ToInt64(); if (increment) { @@ -927,7 +1017,7 @@ std::optional> ApplySubscripts(parser::ContextualMessages &messages, } values.emplace_back(array.At(at)); CHECK(!increment || elements == 0); - CHECK(k == resultShape.size()); + CHECK(k == GetRank(resultShape)); } if constexpr (T::category == TypeCategory::Character) { return Constant{array.LEN(), std::move(values), std::move(resultShape)}; @@ -940,20 +1030,23 @@ std::optional> ApplySubscripts(parser::ContextualMessages &messages, } template -static std::optional> ApplyConstantSubscripts( +std::optional> ApplyConstantSubscripts( FoldingContext &context, ArrayRef &aRef) { + const Symbol &symbol{aRef.GetLastSymbol()}; std::vector> subscripts; + int dim{0}; for (Subscript &ss : aRef.subscript()) { - if (auto constant{GetConstantSubscript(context, ss)}) { + if (auto constant{GetConstantSubscript(context, ss, symbol, dim++)}) { subscripts.emplace_back(std::move(*constant)); } else { return std::nullopt; } } + // TODO pmk generalize to component base too if (const Symbol *const *symbol{std::get_if(&aRef.base())}) { if (auto value{GetParameterValue(context, *symbol)}) { Expr folded{Fold(context, std::move(*value))}; - if (const auto *array{GetConstantValue(folded)}) { + if (const auto *array{UnwrapConstantValue(folded)}) { if (auto result{ ApplySubscripts(context.messages(), *array, subscripts)}) { return result; @@ -965,6 +1058,28 @@ static std::optional> ApplyConstantSubscripts( } template +std::optional> GetConstantComponent( + FoldingContext &context, Component &component) { + // TODO pmk generalize to array ref and component bases too + if (const Symbol *const *symbol{ + std::get_if(&component.base().u)}) { + if (auto value{GetParameterValue(context, *symbol)}) { + Expr folded{Fold(context, std::move(*value))}; + if (const auto *structure{UnwrapConstantValue(folded)}) { + if (auto scalar{structure->GetScalarValue()}) { + if (auto *expr{scalar->Find(&component.GetLastSymbol())}) { + if (const auto *value{UnwrapConstantValue(*expr)}) { + return *value; + } + } + } + } + } + } + return std::nullopt; +} + +template Expr FoldOperation(FoldingContext &context, Designator &&designator) { if constexpr (T::category == TypeCategory::Character) { if (auto *substring{common::Unwrap(designator.u)}) { @@ -990,11 +1105,19 @@ Expr FoldOperation(FoldingContext &context, Designator &&designator) { } }, [&](ArrayRef &&aRef) { + aRef = FoldOperation(context, std::move(aRef)); if (auto c{ApplyConstantSubscripts(context, aRef)}) { return Expr{std::move(*c)}; } else { - return Expr{ - Designator{FoldOperation(context, std::move(aRef))}}; + return Expr{Designator{std::move(aRef)}}; + } + }, + [&](Component &&component) { + component = FoldOperation(context, std::move(component)); + if (auto c{GetConstantComponent(context, component)}) { + return Expr{std::move(*c)}; + } else { + return Expr{Designator{std::move(component)}}; } }, [&](auto &&x) { @@ -1020,6 +1143,7 @@ public: explicit ArrayConstructorFolder(const FoldingContext &c) : context_{c} {} Expr FoldArray(ArrayConstructor &&array) { + // Calls FoldArray(const ArrayConstructorValues &) below if (FoldArray(array)) { auto n{static_cast(elements_.size())}; if constexpr (std::is_same_v) { @@ -1042,11 +1166,11 @@ public: private: bool FoldArray(const common::CopyableIndirection> &expr) { Expr folded{Fold(context_, common::Clone(expr.value()))}; - if (const auto *c{GetConstantValue(folded)}) { + if (const auto *c{UnwrapConstantValue(folded)}) { // Copy elements in Fortran array element order ConstantSubscripts shape{c->shape()}; int rank{c->Rank()}; - ConstantSubscripts index(shape.size(), 1); + ConstantSubscripts index(GetRank(shape), 1); for (std::size_t n{c->size()}; n-- > 0;) { elements_.emplace_back(c->At(index)); for (int d{0}; d < rank; ++d) { @@ -1109,9 +1233,7 @@ private: template Expr FoldOperation(FoldingContext &context, ArrayConstructor &&array) { - ArrayConstructorFolder folder{context}; - Expr result{folder.FoldArray(std::move(array))}; - return result; + return ArrayConstructorFolder{context}.FoldArray(std::move(array)); } Expr FoldOperation( @@ -1180,7 +1302,7 @@ bool ArrayConstructorIsFlat(const ArrayConstructorValues &values) { template std::optional> AsFlatArrayConstructor(const Expr &expr) { - if (const auto *c{GetConstantValue(expr)}) { + if (const auto *c{UnwrapConstantValue(expr)}) { ArrayConstructor result{expr}; if (c->size() > 0) { ConstantSubscripts at{InitialSubscripts(c->shape())}; @@ -1200,8 +1322,9 @@ std::optional> AsFlatArrayConstructor(const Expr &expr) { } template -std::optional>> AsFlatArrayConstructor( - const Expr> &expr) { +std::enable_if_t>>> +AsFlatArrayConstructor(const Expr> &expr) { return std::visit( [&](const auto &kindExpr) -> std::optional>> { if (auto flattened{AsFlatArrayConstructor(kindExpr)}) { @@ -1222,8 +1345,8 @@ Expr FromArrayConstructor(FoldingContext &context, ArrayConstructor &&values, std::optional &&shape) { Expr result{Fold(context, Expr{std::move(values)})}; if (shape.has_value()) { - if (auto *constant{GetConstantValue(result)}) { - constant->shape() = std::move(*shape); + if (auto *constant{UnwrapConstantValue(result)}) { + return Expr{constant->Reshape(std::move(*shape))}; } } return result; @@ -1249,8 +1372,7 @@ Expr MapOperation(FoldingContext &context, auto &aConst{std::get>(kindExpr.u)}; for (auto &acValue : aConst) { auto &scalar{std::get>(acValue.u)}; - result.Push( - FoldOperation(context, f(Expr{std::move(scalar)}))); + result.Push(Fold(context, f(Expr{std::move(scalar)}))); } }, std::move(values.u)); @@ -1258,7 +1380,7 @@ Expr MapOperation(FoldingContext &context, auto &aConst{std::get>(values.u)}; for (auto &acValue : aConst) { auto &scalar{std::get>(acValue.u)}; - result.Push(FoldOperation(context, f(std::move(scalar)))); + result.Push(Fold(context, f(std::move(scalar)))); } } return FromArrayConstructor( @@ -1283,7 +1405,7 @@ Expr MapOperation(FoldingContext &context, CHECK(rightIter != rightArrConst.end()); auto &leftScalar{std::get>(leftValue.u)}; auto &rightScalar{std::get>(rightIter->u)}; - result.Push(FoldOperation(context, + result.Push(Fold(context, f(std::move(leftScalar), Expr{std::move(rightScalar)}))); ++rightIter; } @@ -1296,8 +1418,8 @@ Expr MapOperation(FoldingContext &context, CHECK(rightIter != rightArrConst.end()); auto &leftScalar{std::get>(leftValue.u)}; auto &rightScalar{std::get>(rightIter->u)}; - result.Push(FoldOperation( - context, f(std::move(leftScalar), std::move(rightScalar)))); + result.Push( + Fold(context, f(std::move(leftScalar), std::move(rightScalar)))); ++rightIter; } } @@ -1315,8 +1437,8 @@ Expr MapOperation(FoldingContext &context, auto &leftArrConst{std::get>(leftValues.u)}; for (auto &leftValue : leftArrConst) { auto &leftScalar{std::get>(leftValue.u)}; - result.Push(FoldOperation( - context, f(std::move(leftScalar), Expr{rightScalar}))); + result.Push( + Fold(context, f(std::move(leftScalar), Expr{rightScalar}))); } return FromArrayConstructor( context, std::move(result), AsConstantExtents(shape)); @@ -1336,7 +1458,7 @@ Expr MapOperation(FoldingContext &context, auto &rightArrConst{std::get>(kindExpr.u)}; for (auto &rightValue : rightArrConst) { auto &rightScalar{std::get>(rightValue.u)}; - result.Push(FoldOperation(context, + result.Push(Fold(context, f(Expr{leftScalar}, Expr{std::move(rightScalar)}))); } @@ -1346,8 +1468,8 @@ Expr MapOperation(FoldingContext &context, auto &rightArrConst{std::get>(rightValues.u)}; for (auto &rightValue : rightArrConst) { auto &rightScalar{std::get>(rightValue.u)}; - result.Push(FoldOperation( - context, f(Expr{leftScalar}, std::move(rightScalar)))); + result.Push( + Fold(context, f(Expr{leftScalar}, std::move(rightScalar)))); } } return FromArrayConstructor( diff --git a/flang/lib/evaluate/fold.h b/flang/lib/evaluate/fold.h index 0ea56bf..72dbe47 100644 --- a/flang/lib/evaluate/fold.h +++ b/flang/lib/evaluate/fold.h @@ -30,7 +30,7 @@ namespace Fortran::evaluate { using namespace Fortran::parser::literals; // Fold() rewrites an expression and returns it. When the rewritten expression -// is a constant, GetConstantValue() and GetScalarConstantValue() below will +// is a constant, UnwrapConstantValue() and GetScalarConstantValue() below will // be able to extract it. // Note the rvalue reference argument: the rewrites are performed in place // for efficiency. @@ -48,29 +48,18 @@ std::optional> Fold( } } -// GetConstantValue() isolates the known constant value of -// an expression, if it has one. The value can be parenthesized. +// UnwrapConstantValue() isolates the known constant value of +// an expression, if it has one. It returns a pointer, which is +// const-qualified when the expression is so. The value can be +// parenthesized. template -const Constant *GetConstantValue(const EXPR &expr) { - if (const auto *c{UnwrapExpr>(expr)}) { - return c; - } else { - if constexpr (!std::is_same_v) { - if (auto *parens{UnwrapExpr>(expr)}) { - return GetConstantValue(parens->left()); - } - } - return nullptr; - } -} - -template Constant *GetConstantValue(EXPR &expr) { +auto UnwrapConstantValue(EXPR &expr) -> common::Constify, EXPR> * { if (auto *c{UnwrapExpr>(expr)}) { return c; } else { if constexpr (!std::is_same_v) { if (auto *parens{UnwrapExpr>(expr)}) { - return GetConstantValue(parens->left()); + return UnwrapConstantValue(parens->left()); } } return nullptr; @@ -81,7 +70,7 @@ template Constant *GetConstantValue(EXPR &expr) { // an expression, if it has one. The value can be parenthesized. template auto GetScalarConstantValue(const EXPR &expr) -> std::optional> { - if (const Constant *constant{GetConstantValue(expr)}) { + if (const Constant *constant{UnwrapConstantValue(expr)}) { return constant->GetScalarValue(); } else { return std::nullopt; diff --git a/flang/lib/evaluate/formatting.cc b/flang/lib/evaluate/formatting.cc index 5eb0c29..ed63691 100644 --- a/flang/lib/evaluate/formatting.cc +++ b/flang/lib/evaluate/formatting.cc @@ -24,7 +24,7 @@ namespace Fortran::evaluate { static void ShapeAsFortran(std::ostream &o, const ConstantSubscripts &shape) { - if (shape.size() > 1) { + if (GetRank(shape) > 1) { o << ",shape="; char ch{'['}; for (auto dim : shape) { @@ -112,7 +112,7 @@ std::ostream &ActualArgument::AsFortran(std::ostream &o) const { if (isAlternateReturn) { o << '*'; } - if (const auto *expr{GetExpr()}) { + if (const auto *expr{UnwrapExpr()}) { return expr->AsFortran(o); } else { return std::get(u_).AsFortran(o); diff --git a/flang/lib/evaluate/intrinsics.cc b/flang/lib/evaluate/intrinsics.cc index bd242c9..320ff33 100644 --- a/flang/lib/evaluate/intrinsics.cc +++ b/flang/lib/evaluate/intrinsics.cc @@ -301,7 +301,7 @@ static const IntrinsicInterface genericIntrinsicFunction[]{ {"cshift", {{"array", SameType, Rank::array}, {"shift", AnyInt, Rank::dimRemoved}, OptionalDIM}, - SameType, Rank::array}, + SameType, Rank::conformable}, {"dble", {{"a", AnyNumeric, Rank::elementalOrBOZ}}, DoublePrecision}, {"dim", {{"x", SameIntOrReal}, {"y", SameIntOrReal}}, SameIntOrReal}, {"dot_product", @@ -333,12 +333,12 @@ static const IntrinsicInterface genericIntrinsicFunction[]{ {"boundary", SameIntrinsic, Rank::dimRemoved, Optionality::optional}, OptionalDIM}, - SameIntrinsic, Rank::array}, + SameIntrinsic, Rank::conformable}, {"eoshift", {{"array", SameDerivedType, Rank::array}, {"shift", AnyInt, Rank::dimRemoved}, {"boundary", SameDerivedType, Rank::dimRemoved}, OptionalDIM}, - SameDerivedType, Rank::array}, + SameDerivedType, Rank::conformable}, {"erf", {{"x", SameReal}}, SameReal}, {"erfc", {{"x", SameReal}}, SameReal}, {"erfc_scaled", {{"x", SameReal}}, SameReal}, @@ -401,13 +401,13 @@ static const IntrinsicInterface genericIntrinsicFunction[]{ {"is_iostat_eor", {{"i", AnyInt}}, DefaultLogical}, {"kind", {{"x", AnyIntrinsic}}, DefaultInt}, {"lbound", - {{"array", Anything, Rank::anyOrAssumedRank}, SubscriptDefaultKIND}, - KINDInt, Rank::vector}, - {"lbound", {{"array", Anything, Rank::anyOrAssumedRank}, {"dim", {IntType, KindCode::dimArg}, Rank::scalar}, SubscriptDefaultKIND}, KINDInt, Rank::scalar}, + {"lbound", + {{"array", Anything, Rank::anyOrAssumedRank}, SubscriptDefaultKIND}, + KINDInt, Rank::vector}, {"leadz", {{"i", AnyInt}}, DefaultInt}, {"len", {{"string", AnyChar}, SubscriptDefaultKIND}, KINDInt}, {"len_trim", {{"string", AnyChar}, SubscriptDefaultKIND}, KINDInt}, @@ -590,13 +590,13 @@ static const IntrinsicInterface genericIntrinsicFunction[]{ {"transpose", {{"matrix", SameType, Rank::matrix}}, SameType, Rank::matrix}, {"trim", {{"string", SameChar, Rank::scalar}}, SameChar, Rank::scalar}, {"ubound", - {{"array", Anything, Rank::anyOrAssumedRank}, SubscriptDefaultKIND}, - KINDInt, Rank::vector}, - {"ubound", {{"array", Anything, Rank::anyOrAssumedRank}, {"dim", {IntType, KindCode::dimArg}, Rank::scalar}, SubscriptDefaultKIND}, KINDInt, Rank::scalar}, + {"ubound", + {{"array", Anything, Rank::anyOrAssumedRank}, SubscriptDefaultKIND}, + KINDInt, Rank::vector}, {"unpack", {{"vector", SameType, Rank::vector}, {"mask", AnyLogical, Rank::array}, {"field", SameType, Rank::conformable}}, @@ -890,7 +890,7 @@ std::optional IntrinsicInterface::Match( std::optional type{arg->GetType()}; if (!type.has_value()) { CHECK(arg->Rank() == 0); - const Expr *expr{arg->GetExpr()}; + const Expr *expr{arg->UnwrapExpr()}; CHECK(expr != nullptr); if (std::holds_alternative(expr->u)) { if (d.typePattern.kindCode == KindCode::typeless || @@ -1111,7 +1111,7 @@ std::optional IntrinsicInterface::Match( CHECK(kindDummyArg != nullptr); CHECK(result.categorySet == CategorySet{*category}); if (kindArg != nullptr) { - if (auto *expr{kindArg->GetExpr()}) { + if (auto *expr{kindArg->UnwrapExpr()}) { CHECK(expr->Rank() == 0); if (auto code{ToInt64(*expr)}) { if (IsValidKindOfIntrinsicType(*category, *code)) { @@ -1215,7 +1215,7 @@ std::optional IntrinsicInterface::Match( for (std::size_t j{0}; j < dummies; ++j) { const IntrinsicDummyArgument &d{dummy[std::min(j, dummyArgPatterns - 1)]}; if (const auto &arg{rearranged[j]}) { - const Expr *expr{arg->GetExpr()}; + const Expr *expr{arg->UnwrapExpr()}; CHECK(expr != nullptr); std::optional typeAndShape; if (auto type{expr->GetType()}) { @@ -1318,7 +1318,7 @@ SpecificCall IntrinsicProcTable::Implementation::HandleNull( context.messages().Say("Unknown argument '%s' to NULL()"_err_en_US, arguments[0]->keyword->ToString()); } else { - if (Expr * mold{arguments[0]->GetExpr()}) { + if (Expr * mold{arguments[0]->UnwrapExpr()}) { if (IsAllocatableOrPointer(*mold)) { characteristics::DummyArguments args; std::optional fResult; @@ -1423,7 +1423,7 @@ std::optional IntrinsicProcTable::Implementation::Probe( if (call.name == "present") { bool ok{false}; if (const auto &arg{specificCall->arguments[0]}) { - if (const auto *expr{arg->GetExpr()}) { + if (const auto *expr{arg->UnwrapExpr()}) { if (const Symbol * symbol{IsWholeSymbolDataRef(*expr)}) { ok = symbol->attrs().test(semantics::Attr::OPTIONAL); } diff --git a/flang/lib/evaluate/shape.cc b/flang/lib/evaluate/shape.cc index 3df551e..f5bd03f 100644 --- a/flang/lib/evaluate/shape.cc +++ b/flang/lib/evaluate/shape.cc @@ -23,13 +23,43 @@ namespace Fortran::evaluate { +bool IsImpliedShape(const Symbol &symbol) { + if (const auto *details{symbol.detailsIf()}) { + if (symbol.attrs().test(semantics::Attr::PARAMETER) && + details->init().has_value()) { + for (const semantics::ShapeSpec &ss : details->shape()) { + if (!ss.ubound().isDeferred()) { + // ss.isDeferred() can't be used because the lower bounds are + // implicitly set to 1 in the symbol table. + return false; + } + } + return !details->shape().empty(); + } + } + return false; +} + +bool IsExplicitShape(const Symbol &symbol) { + if (const auto *details{symbol.detailsIf()}) { + for (const semantics::ShapeSpec &ss : details->shape()) { + if (!ss.isExplicit()) { + return false; + } + } + return true; // even if scalar + } else { + return false; + } +} + Shape AsShape(const Constant &arrayConstant) { CHECK(arrayConstant.Rank() == 1); Shape result; std::size_t dimensions{arrayConstant.size()}; for (std::size_t j{0}; j < dimensions; ++j) { Scalar extent{arrayConstant.values().at(j)}; - result.emplace_back(MaybeExtent{ExtentExpr{extent}}); + result.emplace_back(MaybeExtentExpr{ExtentExpr{extent}}); } return result; } @@ -37,7 +67,7 @@ Shape AsShape(const Constant &arrayConstant) { std::optional AsShape(FoldingContext &context, ExtentExpr &&arrayExpr) { // Flatten any array expression into an array constructor if possible. arrayExpr = Fold(context, std::move(arrayExpr)); - if (const auto *constArray{GetConstantValue(arrayExpr)}) { + if (const auto *constArray{UnwrapConstantValue(arrayExpr)}) { return AsShape(*constArray); } if (auto *constructor{UnwrapExpr>(arrayExpr)}) { @@ -72,13 +102,22 @@ std::optional> AsConstantShape(const Shape &shape) { if (auto shapeArray{AsExtentArrayExpr(shape)}) { FoldingContext noFoldingContext; auto folded{Fold(noFoldingContext, std::move(*shapeArray))}; - if (auto *p{GetConstantValue(folded)}) { + if (auto *p{UnwrapConstantValue(folded)}) { return std::move(*p); } } return std::nullopt; } +Constant AsConstantShape(const ConstantSubscripts &shape) { + using IntType = Scalar; + std::vector result; + for (auto dim : shape) { + result.emplace_back(dim); + } + return {std::move(result), ConstantSubscripts{GetRank(shape)}}; +} + ConstantSubscripts AsConstantExtents(const Constant &shape) { ConstantSubscripts result; for (const auto &extent : shape.values()) { @@ -119,13 +158,13 @@ ExtentExpr CountTrips(const ExtentExpr &lower, const ExtentExpr &upper, common::Clone(lower), common::Clone(upper), common::Clone(stride)); } -MaybeExtent CountTrips( - MaybeExtent &&lower, MaybeExtent &&upper, MaybeExtent &&stride) { +MaybeExtentExpr CountTrips(MaybeExtentExpr &&lower, MaybeExtentExpr &&upper, + MaybeExtentExpr &&stride) { return common::MapOptional( ComputeTripCount, std::move(lower), std::move(upper), std::move(stride)); } -MaybeExtent GetSize(Shape &&shape) { +MaybeExtentExpr GetSize(Shape &&shape) { ExtentExpr extent{1}; for (auto &&dim : std::move(shape)) { if (dim.has_value()) { @@ -146,14 +185,14 @@ bool ContainsAnyImpliedDoIndex(const ExtentExpr &expr) { return Visitor{0}.Traverse(expr); } -MaybeExtent GetShapeHelper::GetLowerBound( - const Symbol &symbol, const Component *component, int dimension) { +MaybeExtentExpr GetLowerBound(FoldingContext &context, const Symbol &symbol, + int dimension, const Component *component) { if (const auto *details{symbol.detailsIf()}) { int j{0}; for (const auto &shapeSpec : details->shape()) { if (j++ == dimension) { if (const auto &bound{shapeSpec.lbound().GetExplicit()}) { - return *bound; + return Fold(context, common::Clone(*bound)); } else if (component != nullptr) { return ExtentExpr{DescriptorInquiry{ *component, DescriptorInquiry::Field::LowerBound, dimension}}; @@ -167,41 +206,27 @@ MaybeExtent GetShapeHelper::GetLowerBound( return std::nullopt; } -static bool IsImpliedShape(const Symbol &symbol) { +MaybeExtentExpr GetExtent(FoldingContext &context, const Symbol &symbol, + int dimension, const Component *component) { + CHECK(dimension >= 0); if (const auto *details{symbol.detailsIf()}) { - if (symbol.attrs().test(semantics::Attr::PARAMETER) && - details->init().has_value()) { - for (const semantics::ShapeSpec &ss : details->shape()) { - if (ss.isExplicit()) { - return false; - } - } - return true; + if (IsImpliedShape(symbol)) { + Shape shape{GetShape(context, symbol).value()}; + return std::move(shape.at(dimension)); } - } - return false; -} - -MaybeExtent GetShapeHelper::GetExtent( - const Symbol &symbol, const Component *component, int dimension) { - if (const auto *details{symbol.detailsIf()}) { int j{0}; for (const auto &shapeSpec : details->shape()) { if (j++ == dimension) { if (shapeSpec.isExplicit()) { if (const auto &ubound{shapeSpec.ubound().GetExplicit()}) { - FoldingContext noFoldingContext; if (const auto &lbound{shapeSpec.lbound().GetExplicit()}) { - return Fold(noFoldingContext, + return Fold(context, common::Clone(ubound.value()) - common::Clone(lbound.value()) + ExtentExpr{1}); } else { - return Fold(noFoldingContext, common::Clone(ubound.value())); + return Fold(context, common::Clone(ubound.value())); } } - } else if (IsImpliedShape(symbol)) { - Shape shape{GetShape(symbol).value()}; - return std::move(shape.at(dimension)); } else if (details->IsAssumedSize() && j == symbol.Rank()) { return std::nullopt; } else if (component != nullptr) { @@ -217,26 +242,26 @@ MaybeExtent GetShapeHelper::GetExtent( return std::nullopt; } -MaybeExtent GetShapeHelper::GetExtent(const Subscript &subscript, - const Symbol &symbol, const Component *component, int dimension) { +MaybeExtentExpr GetExtent(FoldingContext &context, const Subscript &subscript, + const Symbol &symbol, int dimension, const Component *component) { return std::visit( common::visitors{ - [&](const Triplet &triplet) -> MaybeExtent { - MaybeExtent upper{triplet.upper()}; + [&](const Triplet &triplet) -> MaybeExtentExpr { + MaybeExtentExpr upper{triplet.upper()}; if (!upper.has_value()) { - upper = GetExtent(symbol, component, dimension); + upper = GetExtent(context, symbol, dimension, component); } - MaybeExtent lower{triplet.lower()}; + MaybeExtentExpr lower{triplet.lower()}; if (!lower.has_value()) { - lower = GetLowerBound(symbol, component, dimension); + lower = GetLowerBound(context, symbol, dimension, component); } return CountTrips(std::move(lower), std::move(upper), - MaybeExtent{triplet.stride()}); + MaybeExtentExpr{triplet.stride()}); }, - [&](const IndirectSubscriptIntegerExpr &subs) -> MaybeExtent { - if (auto shape{GetShape(subs.value())}) { - if (shape->size() > 0) { - CHECK(shape->size() == 1); // vector-valued subscript + [&](const IndirectSubscriptIntegerExpr &subs) -> MaybeExtentExpr { + if (auto shape{GetShape(context, subs.value())}) { + if (GetRank(*shape) > 0) { + CHECK(GetRank(*shape) == 1); // vector-valued subscript return std::move(shape->at(0)); } } @@ -246,6 +271,16 @@ MaybeExtent GetShapeHelper::GetExtent(const Subscript &subscript, subscript.u); } +MaybeExtentExpr GetUpperBound(FoldingContext &context, MaybeExtentExpr &&lower, + MaybeExtentExpr &&extent) { + if (lower.has_value() && extent.has_value()) { + return Fold( + context, std::move(*extent) - std::move(*lower) + ExtentExpr{1}); + } else { + return std::nullopt; + } +} + std::optional GetShapeHelper::GetShape( const Symbol &symbol, const Component *component) { if (const auto *details{symbol.detailsIf()}) { @@ -255,7 +290,7 @@ std::optional GetShapeHelper::GetShape( Shape result; int n{static_cast(details->shape().size())}; for (int dimension{0}; dimension < n; ++dimension) { - result.emplace_back(GetExtent(symbol, component, dimension++)); + result.emplace_back(GetExtent(context_, symbol, dimension, component)); } return result; } @@ -300,7 +335,7 @@ std::optional GetShapeHelper::GetShape(const ArrayRef &arrayRef) { int dimension{0}; for (const Subscript &ss : arrayRef.subscript()) { if (ss.Rank() > 0) { - shape.emplace_back(GetExtent(ss, symbol, component, dimension)); + shape.emplace_back(GetExtent(context_, ss, symbol, dimension, component)); } ++dimension; } @@ -319,7 +354,7 @@ std::optional GetShapeHelper::GetShape(const CoarrayRef &coarrayRef) { int dimension{0}; for (const Subscript &ss : coarrayRef.subscript()) { if (ss.Rank() > 0) { - shape.emplace_back(GetExtent(ss, symbol, component, dimension)); + shape.emplace_back(GetExtent(context_, ss, symbol, dimension, component)); } ++dimension; } @@ -347,7 +382,7 @@ std::optional GetShapeHelper::GetShape(const ComplexPart &part) { } std::optional GetShapeHelper::GetShape(const ActualArgument &arg) { - if (const auto *expr{arg.GetExpr()}) { + if (const auto *expr{arg.UnwrapExpr()}) { return GetShape(*expr); } else { const Symbol *aType{arg.GetAssumedTypeDummy()}; @@ -379,13 +414,13 @@ std::optional GetShapeHelper::GetShape(const ProcedureRef &call) { std::get_if(&call.proc().u)}) { if (intrinsic->name == "shape" || intrinsic->name == "lbound" || intrinsic->name == "ubound") { - const auto *expr{call.arguments().front().value().GetExpr()}; + const auto *expr{call.arguments().front().value().UnwrapExpr()}; CHECK(expr != nullptr); - return Shape{MaybeExtent{ExtentExpr{expr->Rank()}}}; + return Shape{MaybeExtentExpr{ExtentExpr{expr->Rank()}}}; } else if (intrinsic->name == "reshape") { if (call.arguments().size() >= 2 && call.arguments().at(1).has_value()) { // SHAPE(RESHAPE(array,shape)) -> shape - const auto *shapeExpr{call.arguments().at(1).value().GetExpr()}; + const auto *shapeExpr{call.arguments().at(1).value().UnwrapExpr()}; CHECK(shapeExpr != nullptr); Expr shape{std::get>(shapeExpr->u)}; return AsShape(context_, ConvertToType(std::move(shape))); @@ -425,8 +460,8 @@ std::optional GetShapeHelper::GetShape(const NullPointer &) { bool CheckConformance(parser::ContextualMessages &messages, const Shape &left, const Shape &right, const char *leftDesc, const char *rightDesc) { if (!left.empty() && !right.empty()) { - int n{static_cast(left.size())}; - int rn{static_cast(right.size())}; + int n{GetRank(left)}; + int rn{GetRank(right)}; if (n != rn) { messages.Say("Rank of %s is %d, but %s has rank %d"_err_en_US, leftDesc, n, rightDesc, rn); diff --git a/flang/lib/evaluate/shape.h b/flang/lib/evaluate/shape.h index a7fa80a..9a49f82 100644 --- a/flang/lib/evaluate/shape.h +++ b/flang/lib/evaluate/shape.h @@ -35,27 +35,46 @@ class FoldingContext; using ExtentType = SubscriptInteger; using ExtentExpr = Expr; -using MaybeExtent = std::optional; -using Shape = std::vector; +using MaybeExtentExpr = std::optional; +using Shape = std::vector; + +bool IsImpliedShape(const Symbol &); +bool IsExplicitShape(const Symbol &); // Conversions between various representations of shapes. -Shape AsShape(const Constant &arrayConstant); -std::optional AsShape(FoldingContext &, ExtentExpr &&arrayExpr); +Shape AsShape(const Constant &); +std::optional AsShape(FoldingContext &, ExtentExpr &&); + std::optional AsExtentArrayExpr(const Shape &); + std::optional> AsConstantShape(const Shape &); +Constant AsConstantShape(const ConstantSubscripts &); + ConstantSubscripts AsConstantExtents(const Constant &); std::optional AsConstantExtents(const Shape &); +inline int GetRank(const Shape &s) { return static_cast(s.size()); } + +// The dimension here is zero-based, unlike DIM= arguments to many intrinsics. +MaybeExtentExpr GetLowerBound(FoldingContext &, const Symbol &, int dimension, + const Component * = nullptr); +MaybeExtentExpr GetExtent(FoldingContext &, const Symbol &, int dimension, + const Component * = nullptr); +MaybeExtentExpr GetExtent(FoldingContext &, const Subscript &, const Symbol &, + int dimension, const Component * = nullptr); +MaybeExtentExpr GetUpperBound( + FoldingContext &, MaybeExtentExpr &&lower, MaybeExtentExpr &&extent); + // Compute an element count for a triplet or trip count for a DO. ExtentExpr CountTrips( ExtentExpr &&lower, ExtentExpr &&upper, ExtentExpr &&stride); ExtentExpr CountTrips( const ExtentExpr &lower, const ExtentExpr &upper, const ExtentExpr &stride); -MaybeExtent CountTrips( - MaybeExtent &&lower, MaybeExtent &&upper, MaybeExtent &&stride); +MaybeExtentExpr CountTrips( + MaybeExtentExpr &&lower, MaybeExtentExpr &&upper, MaybeExtentExpr &&stride); // Computes SIZE() == PRODUCT(shape) -MaybeExtent GetSize(Shape &&); +MaybeExtentExpr GetSize(Shape &&); // Utility predicate: does an expression reference any implied DO index? bool ContainsAnyImpliedDoIndex(const ExtentExpr &); @@ -127,7 +146,7 @@ public: template std::optional GetShape(const ArrayConstructor &aconst) { - return Shape{GetExtent(aconst)}; + return Shape{GetArrayConstructorExtent(aconst)}; } template @@ -151,10 +170,11 @@ public: private: template - MaybeExtent GetExtent(const ArrayConstructorValue &value) { + MaybeExtentExpr GetArrayConstructorValueExtent( + const ArrayConstructorValue &value) { return std::visit( common::visitors{ - [&](const Expr &x) -> MaybeExtent { + [&](const Expr &x) -> MaybeExtentExpr { if (std::optional xShape{GetShape(x)}) { // Array values in array constructors get linearized. return GetSize(std::move(*xShape)); @@ -162,13 +182,13 @@ private: return std::nullopt; } }, - [&](const ImpliedDo &ido) -> MaybeExtent { + [&](const ImpliedDo &ido) -> MaybeExtentExpr { // Don't be heroic and try to figure out triangular implied DO // nests. if (!ContainsAnyImpliedDoIndex(ido.lower()) && !ContainsAnyImpliedDoIndex(ido.upper()) && !ContainsAnyImpliedDoIndex(ido.stride())) { - if (auto nValues{GetExtent(ido.values())}) { + if (auto nValues{GetArrayConstructorExtent(ido.values())}) { return std::move(*nValues) * CountTrips(ido.lower(), ido.upper(), ido.stride()); } @@ -180,10 +200,11 @@ private: } template - MaybeExtent GetExtent(const ArrayConstructorValues &values) { + MaybeExtentExpr GetArrayConstructorExtent( + const ArrayConstructorValues &values) { ExtentExpr result{0}; for (const auto &value : values) { - if (MaybeExtent n{GetExtent(value)}) { + if (MaybeExtentExpr n{GetArrayConstructorValueExtent(value)}) { result = std::move(result) + std::move(*n); } else { return std::nullopt; @@ -192,12 +213,6 @@ private: return result; } - // The dimension here is zero-based, unlike DIM= intrinsic arguments. - MaybeExtent GetLowerBound(const Symbol &, const Component *, int dimension); - MaybeExtent GetExtent(const Symbol &, const Component *, int dimension); - MaybeExtent GetExtent( - const Subscript &, const Symbol &, const Component *, int dimension); - FoldingContext &context_; }; diff --git a/flang/lib/evaluate/tools.cc b/flang/lib/evaluate/tools.cc index f88f306..bdedebc 100644 --- a/flang/lib/evaluate/tools.cc +++ b/flang/lib/evaluate/tools.cc @@ -614,7 +614,7 @@ bool IsAssumedRank(const semantics::Symbol &symbol) { } bool IsAssumedRank(const ActualArgument &arg) { - if (const auto *expr{arg.GetExpr()}) { + if (const auto *expr{arg.UnwrapExpr()}) { return IsAssumedRank(*expr); } else { const semantics::Symbol *assumedTypeDummy{arg.GetAssumedTypeDummy()}; diff --git a/flang/lib/evaluate/tools.h b/flang/lib/evaluate/tools.h index cf10aa2..9fcace3 100644 --- a/flang/lib/evaluate/tools.h +++ b/flang/lib/evaluate/tools.h @@ -144,22 +144,26 @@ template constexpr bool IsNumericCategoryExpr() { } // Specializing extractor. If an Expr wraps some type of object, perhaps -// in several layers, return a pointer to it; otherwise null. +// in several layers, return a pointer to it; otherwise null. Also works +// with ActualArgument. template auto UnwrapExpr(B &x) -> common::Constify * { using Ty = std::decay_t; if constexpr (std::is_same_v) { return &x; - } else if constexpr (common::HasMember) { - return nullptr; - } else if constexpr (std::is_same_v>>) { - return common::Unwrap(x.u); - } else if constexpr (std::is_same_v> || - std::is_same_v::category>>>) { + } else if constexpr (std::is_same_v) { + if (auto *expr{x.UnwrapExpr()}) { + return UnwrapExpr(*expr); + } + } else if constexpr (std::is_same_v>) { return std::visit([](auto &x) { return UnwrapExpr(x); }, x.u); - } else { - return nullptr; + } else if constexpr (!common::HasMember) { + if constexpr (std::is_same_v>> || + std::is_same_v::category>>>) { + return std::visit([](auto &x) { return UnwrapExpr(x); }, x.u); + } } + return nullptr; } template diff --git a/flang/lib/evaluate/type.cc b/flang/lib/evaluate/type.cc index 4578385..db6c6f9 100644 --- a/flang/lib/evaluate/type.cc +++ b/flang/lib/evaluate/type.cc @@ -16,6 +16,7 @@ #include "expression.h" #include "fold.h" #include "../common/idioms.h" +#include "../common/template.h" #include "../semantics/scope.h" #include "../semantics/symbol.h" #include "../semantics/tools.h" @@ -219,4 +220,79 @@ bool SomeKind::operator==( const SomeKind &that) const { return PointeeComparison(derivedTypeSpec_, that.derivedTypeSpec_); } + +static constexpr double LogBaseTenOfTwo{0.301029995664}; + +class SelectedIntKindVisitor { +public: + explicit SelectedIntKindVisitor(std::int64_t p) : precision_{p} {} + using Result = std::optional; + using Types = IntegerTypes; + template Result Test() const { + if ((Scalar::bits - 1) * LogBaseTenOfTwo > precision_) { + return T::kind; + } else { + return std::nullopt; + } + } + +private: + std::int64_t precision_; +}; + +int SelectedIntKind(std::int64_t precision) { + if (auto kind{common::SearchTypes(SelectedIntKindVisitor{precision})}) { + return *kind; + } else { + return -1; + } +} + +class SelectedRealKindVisitor { +public: + explicit SelectedRealKindVisitor(std::int64_t p, std::int64_t r) + : precision_{p}, range_{r} {} + using Result = std::optional; + using Types = RealTypes; + template Result Test() const { + if ((Scalar::precision - 1) * LogBaseTenOfTwo > precision_ && + (Scalar::exponentBias - 1) * LogBaseTenOfTwo > range_) { + return {T::kind}; + } else { + return std::nullopt; + } + } + +private: + std::int64_t precision_, range_; +}; + +int SelectedRealKind( + std::int64_t precision, std::int64_t range, std::int64_t radix) { + if (radix != 2) { + return -5; + } + if (auto kind{ + common::SearchTypes(SelectedRealKindVisitor{precision, range})}) { + return *kind; + } + // No kind has both sufficient precision and sufficient range. + // The negative return value encodes whether any kinds exist that + // could satisfy either constraint independently. + bool pOK{common::SearchTypes(SelectedRealKindVisitor{precision, 0})}; + bool rOK{common::SearchTypes(SelectedRealKindVisitor{0, range})}; + if (pOK) { + if (rOK) { + return -4; + } else { + return -2; + } + } else { + if (rOK) { + return -1; + } else { + return -3; + } + } +} } diff --git a/flang/lib/evaluate/type.h b/flang/lib/evaluate/type.h index c46a7c8..77bac30 100644 --- a/flang/lib/evaluate/type.h +++ b/flang/lib/evaluate/type.h @@ -60,6 +60,7 @@ using LargestReal = Type; // A predicate that is true when a kind value is a kind that could possibly // be supported for an intrinsic type category on some target instruction // set architecture. +// TODO: specialize for the actual target architecture static constexpr bool IsValidKindOfIntrinsicType( TypeCategory category, std::int64_t kind) { switch (category) { @@ -410,6 +411,10 @@ template struct TypeOfHelper { template using TypeOf = typename TypeOfHelper::type; +int SelectedIntKind(std::int64_t precision = 0); +int SelectedRealKind( + std::int64_t precision = 0, std::int64_t range = 0, std::int64_t radix = 2); + // For generating "[extern] template class", &c. boilerplate #define EXPAND_FOR_EACH_INTEGER_KIND(M, P, S) \ M(P, S, 1) M(P, S, 2) M(P, S, 4) M(P, S, 8) M(P, S, 16) diff --git a/flang/lib/semantics/expression.cc b/flang/lib/semantics/expression.cc index c5dd39a..d90e455 100644 --- a/flang/lib/semantics/expression.cc +++ b/flang/lib/semantics/expression.cc @@ -30,7 +30,7 @@ #include #include -// #define DUMP_ON_FAILURE 1 +#define DUMP_ON_FAILURE 1 // pmk // #define CRASH_ON_FAILURE #if DUMP_ON_FAILURE #include "../parser/dump-parse-tree.h" @@ -170,6 +170,7 @@ MaybeExpr ExpressionAnalyzer::CompleteSubscripts(ArrayRef &&ref) { if (subscripts != symbolRank) { Say("Reference to rank-%d object '%s' has %d subscripts"_err_en_US, symbolRank, symbol.name(), subscripts); + return std::nullopt; } else if (subscripts == 0) { // nothing to check } else if (Component * component{std::get_if(&ref.base())}) { @@ -183,6 +184,7 @@ MaybeExpr ExpressionAnalyzer::CompleteSubscripts(ArrayRef &&ref) { Say("Subscripts of component '%s' of rank-%d derived type " "array have rank %d but must all be scalar"_err_en_US, symbol.name(), baseRank, subscriptRank); + return std::nullopt; } } } else if (const auto *details{ @@ -193,6 +195,7 @@ MaybeExpr ExpressionAnalyzer::CompleteSubscripts(ArrayRef &&ref) { Say("Assumed-size array '%s' must have explicit final " "subscript upper bound value"_err_en_US, symbol.name()); + return std::nullopt; } } } @@ -612,8 +615,6 @@ MaybeExpr ExpressionAnalyzer::Analyze(const parser::Name &n) { // A bare reference to a derived type parameter (within a parameterized // derived type definition) return AsMaybeExpr(MakeBareTypeParamInquiry(&ultimate)); - } else if (MaybeExpr result{Designate(DataRef{ultimate})}) { - return result; } else { return Designate(DataRef{*n.symbol}); } diff --git a/flang/lib/semantics/mod-file.cc b/flang/lib/semantics/mod-file.cc index 50980e6..dbe9d9d 100644 --- a/flang/lib/semantics/mod-file.cc +++ b/flang/lib/semantics/mod-file.cc @@ -30,14 +30,12 @@ namespace Fortran::semantics { using namespace parser::literals; -// The extension used for module files. -static constexpr auto extension{".mod"}; // The initial characters of a file that identify it as a .mod file. static constexpr auto magic{"!mod$ v1 sum:"}; static const SourceName *GetSubmoduleParent(const parser::Program &); -static std::string ModFilePath( - const std::string &, const SourceName &, const std::string &); +static std::string ModFilePath(const std::string &dir, const SourceName &, + const std::string &ancestor, const std::string &suffix); static std::vector CollectSymbols(const Scope &); static void PutEntity(std::ostream &, const Symbol &); static void PutObjectEntity(std::ostream &, const Symbol &); @@ -126,8 +124,8 @@ void ModFileWriter::WriteOne(const Scope &scope) { void ModFileWriter::Write(const Symbol &symbol) { auto *ancestor{symbol.get().ancestor()}; auto ancestorName{ancestor ? ancestor->name().ToString() : ""s}; - auto path{ - ModFilePath(context_.moduleDirectory(), symbol.name(), ancestorName)}; + auto path{ModFilePath(context_.moduleDirectory(), symbol.name(), ancestorName, + context_.moduleFileSuffix())}; PutSymbols(*symbol.scope()); if (!WriteFile(path, GetAsString(symbol))) { context_.Say(symbol.name(), "Error writing %s: %s"_err_en_US, path, @@ -723,7 +721,8 @@ std::optional ModFileReader::FindModFile( const SourceName &name, const std::string &ancestor) { parser::Messages attachments; for (auto &dir : context_.searchDirectories()) { - std::string path{ModFilePath(dir, name, ancestor)}; + std::string path{ + ModFilePath(dir, name, ancestor, context_.moduleFileSuffix())}; std::ifstream ifstream{path}; if (!ifstream.good()) { attachments.Say(name, "%s: %s"_en_US, path, std::strerror(errno)); @@ -764,7 +763,7 @@ static const SourceName *GetSubmoduleParent(const parser::Program &program) { // Construct the path to a module file. ancestorName not empty means submodule. static std::string ModFilePath(const std::string &dir, const SourceName &name, - const std::string &ancestorName) { + const std::string &ancestorName, const std::string &suffix) { std::stringstream path; if (dir != "."s) { path << dir << '/'; @@ -772,7 +771,7 @@ static std::string ModFilePath(const std::string &dir, const SourceName &name, if (!ancestorName.empty()) { PutLower(path, ancestorName) << '-'; } - PutLower(path, name.ToString()) << extension; + PutLower(path, name.ToString()) << suffix; return path.str(); } diff --git a/flang/lib/semantics/resolve-names.cc b/flang/lib/semantics/resolve-names.cc index 68ba1cd..fbe2881 100644 --- a/flang/lib/semantics/resolve-names.cc +++ b/flang/lib/semantics/resolve-names.cc @@ -1565,9 +1565,10 @@ void ScopeHandler::PushScope(Scope &scope) { } } void ScopeHandler::PopScope() { + // Entities that are not yet classified as objects or procedures are now + // assumed to be objects. for (auto &pair : currScope()) { - auto &symbol{*pair.second}; - ConvertToObjectEntity(symbol); // if not a proc by now, it is an object + ConvertToObjectEntity(*pair.second); } SetScope(currScope_->parent()); } @@ -4252,10 +4253,13 @@ const parser::Name *DeclarationVisitor::ResolveVariable( // If implicit types are allowed, ensure name is in the symbol table. // Otherwise, report an error if it hasn't been declared. const parser::Name *DeclarationVisitor::ResolveName(const parser::Name &name) { - if (FindSymbol(name)) { + if (Symbol * symbol{FindSymbol(name)}) { if (CheckUseError(name)) { return nullptr; // reported an error } + if (symbol->IsDummy()) { + ApplyImplicitRules(*symbol); + } return &name; } if (isImplicitNoneType()) { @@ -4643,7 +4647,20 @@ void ResolveNamesVisitor::ResolveSpecificationParts(ProgramTree &node) { for (auto &child : node.children()) { ResolveSpecificationParts(child); } - PopScope(); + // Subtlety: PopScope() is not called here because we want to defer + // conversions of uncategorized entities into objects until after + // we have traversed the executable part of the subprogram. + // Function results, however, are converted now so that they can + // be used in executable parts. + if (Symbol * symbol{currScope().symbol()}) { + if (auto *details{symbol->detailsIf()}) { + if (details->isFunction()) { + Symbol &result{const_cast(details->result())}; + ConvertToObjectEntity(result); + } + } + } + SetScope(currScope().parent()); } // Add SubprogramNameDetails symbols for contained subprograms @@ -4686,6 +4703,7 @@ void ResolveNamesVisitor::ResolveExecutionParts(const ProgramTree &node) { if (const auto *exec{node.exec()}) { Walk(*exec); } + PopScope(); // converts unclassified entities into objects for (const auto &child : node.children()) { ResolveExecutionParts(child); } diff --git a/flang/lib/semantics/semantics.h b/flang/lib/semantics/semantics.h index b4ddf47..fe2232e 100644 --- a/flang/lib/semantics/semantics.h +++ b/flang/lib/semantics/semantics.h @@ -53,6 +53,7 @@ public: return searchDirectories_; } const std::string &moduleDirectory() const { return moduleDirectory_; } + const std::string &moduleFileSuffix() const { return moduleFileSuffix_; } bool warnOnNonstandardUsage() const { return warnOnNonstandardUsage_; } bool warningsAreErrors() const { return warningsAreErrors_; } const evaluate::IntrinsicProcTable &intrinsics() const { return intrinsics_; } @@ -72,6 +73,10 @@ public: moduleDirectory_ = x; return *this; } + SemanticsContext &set_moduleFileSuffix(const std::string &x) { + moduleFileSuffix_ = x; + return *this; + } SemanticsContext &set_warnOnNonstandardUsage(bool x) { warnOnNonstandardUsage_ = x; return *this; @@ -113,6 +118,7 @@ private: const parser::CharBlock *location_{nullptr}; std::vector searchDirectories_; std::string moduleDirectory_{"."s}; + std::string moduleFileSuffix_{".mod"}; bool warnOnNonstandardUsage_{false}; bool warningsAreErrors_{false}; const evaluate::IntrinsicProcTable intrinsics_; diff --git a/flang/lib/semantics/symbol.h b/flang/lib/semantics/symbol.h index f9d3206..bd5f61e 100644 --- a/flang/lib/semantics/symbol.h +++ b/flang/lib/semantics/symbol.h @@ -148,6 +148,8 @@ public: MaybeExpr &init() { return init_; } const MaybeExpr &init() const { return init_; } void set_init(MaybeExpr &&expr) { init_ = std::move(expr); } + bool initWasValidated() const { return initWasValidated_; } + void set_initWasValidated(bool yes = true) { initWasValidated_ = yes; } ArraySpec &shape() { return shape_; } const ArraySpec &shape() const { return shape_; } ArraySpec &coshape() { return coshape_; } @@ -179,6 +181,7 @@ public: private: MaybeExpr init_; + bool initWasValidated_{false}; ArraySpec shape_; ArraySpec coshape_; const Symbol *commonBlock_{nullptr}; // common block this object is in diff --git a/flang/lib/semantics/tools.cc b/flang/lib/semantics/tools.cc index ef6d5d7..7f38da2 100644 --- a/flang/lib/semantics/tools.cc +++ b/flang/lib/semantics/tools.cc @@ -175,6 +175,7 @@ bool IsProcedure(const Symbol &symbol) { [](const SubprogramNameDetails &) { return true; }, [](const ProcEntityDetails &) { return true; }, [](const GenericDetails &) { return true; }, + [](const ProcBindingDetails &) { return true; }, [](const UseDetails &x) { return IsProcedure(x.symbol()); }, [](const auto &) { return false; }, }, diff --git a/flang/test/semantics/CMakeLists.txt b/flang/test/semantics/CMakeLists.txt index 108efc6..01bf061 100644 --- a/flang/test/semantics/CMakeLists.txt +++ b/flang/test/semantics/CMakeLists.txt @@ -173,6 +173,7 @@ set(MODFILE_TESTS modfile23.f90 modfile24.f90 modfile25.f90 + modfile26.f90 ) set(LABEL_TESTS diff --git a/flang/test/semantics/modfile26.f90 b/flang/test/semantics/modfile26.f90 new file mode 100644 index 0000000..9bf5196 --- /dev/null +++ b/flang/test/semantics/modfile26.f90 @@ -0,0 +1,62 @@ +! Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +! +! Licensed under the Apache License, Version 2.0 (the "License"); +! you may not use this file except in compliance with the License. +! You may obtain a copy of the License at +! +! http://www.apache.org/licenses/LICENSE-2.0 +! +! Unless required by applicable law or agreed to in writing, software +! distributed under the License is distributed on an "AS IS" BASIS, +! WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +! See the License for the specific language governing permissions and +! limitations under the License. + +! SELECTED_INT_KIND and SELECTED_REAL_KIND + +module m1 + ! INTEGER(KIND=1) handles 0 <= P < 3 + ! INTEGER(KIND=2) handles 3 <= P < 5 + ! INTEGER(KIND=4) handles 5 <= P < 10 + ! INTEGER(KIND=8) handles 10 <= P < 19 + ! INTEGER(KIND=16) handles 19 <= P < 38 + integer, parameter :: intpvals(:) = [0, 2, 3, 4, 5, 9, 10, 18, 19, 38, 39] + integer, parameter :: intpkinds(:) = & + [(selected_int_kind(intpvals(j)),j=1,size(intpvals))] + logical, parameter :: ipcheck = & + all([1, 1, 2, 2, 4, 4, 8, 8, 16, 16, -1] == intpkinds) + ! REAL(KIND=2) handles 0 <= P < 4 (if available) + ! REAL(KIND=4) handles 4 <= P < 7 + ! REAL(KIND=8) handles 7 <= P < 16 + ! REAL(KIND=10) handles 16 <= P < 19 (if available; ifort is KIND=16) + ! REAL(KIND=16) handles 19 <= P < 34 (32 with Power double/double) + integer, parameter :: realpvals(:) = [0, 3, 4, 6, 7, 15, 16, 18, 19, 33, 34] + integer, parameter :: realpkinds(:) = & + [(selected_real_kind(realpvals(j),0),j=1,size(realpvals))] + logical, parameter :: realpcheck = & + all([2, 2, 4, 4, 8, 8, 10, 10, 16, 16, -1] == realpkinds) + ! REAL(KIND=2) handles 0 <= R < 5 (if available) + ! REAL(KIND=3) handles 5 <= R < 38 (if available, same range as KIND=4) + ! REAL(KIND=4) handles 5 <= R < 38 (if no KIND=3) + ! REAL(KIND=8) handles 38 <= R < 308 + ! REAL(KIND=10) handles 308 <= R < 4932 (if available; ifort is KIND=16) + ! REAL(KIND=16) handles 4932 <= R < 9864 (except Power double/double) + integer, parameter :: realrvals(:) = & + [0, 4, 5, 37, 38, 307, 308, 4931, 4932, 9863, 9864] + integer, parameter :: realrkinds(:) = & + [(selected_real_kind(0,realrvals(j)),j=1,size(realrvals))] + logical, parameter :: realrcheck = & + all([2, 2, 3, 3, 8, 8, 10, 10, 16, 16, -2] == realrkinds) +end module m1 +!Expect: m1.mod +!module m1 +!integer(4),parameter::intpvals(1_8:)=[Integer(4)::0_4,2_4,3_4,4_4,5_4,9_4,10_4,18_4,19_4,38_4,39_4] +!integer(4),parameter::intpkinds(1_8:)=[Integer(4)::1_4,1_4,2_4,2_4,4_4,4_4,8_4,8_4,16_4,16_4,-1_4] +!logical(4),parameter::ipcheck=.true._4 +!integer(4),parameter::realpvals(1_8:)=[Integer(4)::0_4,3_4,4_4,6_4,7_4,15_4,16_4,18_4,19_4,33_4,34_4] +!integer(4),parameter::realpkinds(1_8:)=[Integer(4)::2_4,2_4,4_4,4_4,8_4,8_4,10_4,10_4,16_4,16_4,-1_4] +!logical(4),parameter::realpcheck=.true._4 +!integer(4),parameter::realrvals(1_8:)=[Integer(4)::0_4,4_4,5_4,37_4,38_4,307_4,308_4,4931_4,4932_4,9863_4,9864_4] +!integer(4),parameter::realrkinds(1_8:)=[Integer(4)::2_4,2_4,3_4,3_4,8_4,8_4,10_4,10_4,16_4,16_4,-2_4] +!logical(4),parameter::realrcheck=.true._4 +!end diff --git a/flang/tools/f18/f18.cc b/flang/tools/f18/f18.cc index 20713f7..d419bc1 100644 --- a/flang/tools/f18/f18.cc +++ b/flang/tools/f18/f18.cc @@ -86,6 +86,7 @@ struct DriverOptions { std::string outputPath; // -o path std::vector searchDirectories{"."s}; // -I dir std::string moduleDirectory{"."s}; // -module dir + std::string moduleFileSuffix{".mod"}; // -moduleSuffix suff bool forcedForm{false}; // -Mfixed or -Mfree appeared bool warnOnNonstandardUsage{false}; // -Mstandard bool warningsAreErrors{false}; // -Werror @@ -452,6 +453,12 @@ int main(int argc, char *const argv[]) { defaultKinds.set_defaultIntegerKind(8); } else if (arg == "-fno-large-arrays") { defaultKinds.set_subscriptIntegerKind(4); + } else if (arg == "-module") { + driver.moduleDirectory = args.front(); + args.pop_front(); + } else if (arg == "-module-suffix") { + driver.moduleFileSuffix = args.front(); + args.pop_front(); } else if (arg == "-help" || arg == "--help" || arg == "-?") { std::cerr << "f18 options:\n" @@ -465,6 +472,7 @@ int main(int argc, char *const argv[]) { << " -Werror treat warnings as errors\n" << " -ed enable fixed form D lines\n" << " -E prescan & preprocess only\n" + << " -module dir module output directory (default .)\n" << " -fparse-only parse only, no output except messages\n" << " -funparse parse & reformat only, no code " "generation\n" @@ -495,10 +503,6 @@ int main(int argc, char *const argv[]) { args.pop_front(); } else if (arg.substr(0, 2) == "-I") { driver.searchDirectories.push_back(arg.substr(2)); - } else if (arg == "-module") { - driver.moduleDirectory = args.front(); - driver.pgf90Args.push_back(driver.moduleDirectory); - args.pop_front(); } else if (arg == "-Mx,125,4") { // PGI "all Kanji" mode options.encoding = Fortran::parser::Encoding::EUC_JP; } @@ -513,14 +517,14 @@ int main(int argc, char *const argv[]) { Fortran::parser::LanguageFeature::BackslashEscapes)) { driver.pgf90Args.push_back("-Mbackslash"); } - if (options.features.IsEnabled( - Fortran::parser::LanguageFeature::OpenMP)) { + if (options.features.IsEnabled(Fortran::parser::LanguageFeature::OpenMP)) { driver.pgf90Args.push_back("-mp"); } Fortran::semantics::SemanticsContext semanticsContext{ defaultKinds, options.features}; semanticsContext.set_moduleDirectory(driver.moduleDirectory) + .set_moduleFileSuffix(driver.moduleFileSuffix) .set_searchDirectories(driver.searchDirectories) .set_warnOnNonstandardUsage(driver.warnOnNonstandardUsage) .set_warningsAreErrors(driver.warningsAreErrors);