isAlternateReturn == that.isAlternateReturn && value() == that.value();
}
-std::optional<int> ActualArgument::VectorSize() const {
- if (Rank() != 1) {
- return std::nullopt;
- }
- // TODO: get shape vector of value, return its length
- return std::nullopt;
-}
-
bool SpecificIntrinsic::operator==(const SpecificIntrinsic &that) const {
return name == that.name && type == that.type && rank == that.rank &&
attrs == that.attrs;
int Rank() const;
bool operator==(const ActualArgument &) const;
std::ostream &AsFortran(std::ostream &) const;
- std::optional<int> VectorSize() const;
std::optional<parser::CharBlock> keyword;
bool isAlternateReturn{false}; // when true, "value" is a label number
class FoldingContext {
public:
+ FoldingContext() = default;
explicit FoldingContext(const parser::ContextualMessages &m,
Rounding round = defaultRounding, bool flush = false)
: messages_{m}, rounding_{round}, flushSubnormalsToZero_{flush} {}
#include "host.h"
#include "int-power.h"
#include "intrinsics-library-templates.h"
+#include "shape.h"
#include "tools.h"
#include "traversal.h"
#include "type.h"
}
return FoldElementalIntrinsic<T, T, T, T>(
context, std::move(funcRef), &Scalar<T>::MERGE_BITS);
+ } else if (name == "rank") {
+ // TODO pmk: get rank
+ } else if (name == "shape") {
+ // TODO pmk: call GetShape on argument, massage result
}
// TODO:
// ceiling, count, cshift, dot_product, eoshift,
// findloc, floor, iachar, iall, iany, iparity, ibits, ichar, image_status,
// index, ishftc, lbound, len_trim, matmul, max, maxloc, maxval, merge, min,
// minloc, minval, mod, modulo, nint, not, pack, product, reduce, reshape,
- // scan, selected_char_kind, selected_int_kind, selected_real_kind, shape,
+ // scan, selected_char_kind, selected_int_kind, selected_real_kind,
// sign, size, spread, sum, transfer, transpose, ubound, unpack, verify
}
return Expr<T>{std::move(funcRef)};
#include "intrinsics.h"
#include "expression.h"
#include "fold.h"
+#include "shape.h"
#include "tools.h"
#include "type.h"
#include "../common/Fortran.h"
{"product",
{{"array", SameNumeric, Rank::array}, OptionalDIM, OptionalMASK},
SameNumeric, Rank::dimReduced},
+ // TODO pmk: "rank"
{"real", {{"a", AnyNumeric, Rank::elementalOrBOZ}, DefaultingKIND},
KINDReal},
{"reduce",
// COSHAPE
// TODO: Object characteristic inquiry functions
// ALLOCATED, ASSOCIATED, EXTENDS_TYPE_OF, IS_CONTIGUOUS,
-// PRESENT, RANK, SAME_TYPE, STORAGE_SIZE
+// PRESENT, SAME_TYPE, STORAGE_SIZE
// TODO: Type inquiry intrinsic functions - these return constants
// BIT_SIZE, DIGITS, EPSILON, HUGE, KIND, MAXEXPONENT, MINEXPONENT,
// NEW_LINE, PRECISION, RADIX, RANGE, TINY
// Check the ranks of the arguments against the intrinsic's interface.
const ActualArgument *arrayArg{nullptr};
const ActualArgument *knownArg{nullptr};
- const ActualArgument *shapeArg{nullptr};
+ std::optional<int> shapeArgSize;
int elementalRank{0};
for (std::size_t j{0}; j < dummies; ++j) {
const IntrinsicDummyArgument &d{dummy[std::min(j, dummyArgPatterns - 1)]};
case Rank::scalar: argOk = rank == 0; break;
case Rank::vector: argOk = rank == 1; break;
case Rank::shape:
- CHECK(shapeArg == nullptr);
- shapeArg = arg;
- argOk = rank == 1 && arg->VectorSize().has_value();
+ CHECK(!shapeArgSize.has_value());
+ if (rank == 1) {
+ if (auto shape{GetShape(*arg)}) {
+ CHECK(shape->size() == 1);
+ if (auto value{ToInt64(shape->at(0))}) {
+ shapeArgSize = *value;
+ argOk = *value >= 0;
+ }
+ }
+ }
+ if (!argOk) {
+ messages.Say(
+ "'shape=' argument must be a vector of known size"_err_en_US);
+ return std::nullopt;
+ }
break;
case Rank::matrix: argOk = rank == 2; break;
case Rank::array:
resultRank = knownArg->Rank() + 1;
break;
case Rank::shaped:
- CHECK(shapeArg != nullptr);
- resultRank = shapeArg->VectorSize().value();
+ CHECK(shapeArgSize.has_value());
+ resultRank = *shapeArgSize;
break;
case Rank::elementalOrBOZ:
case Rank::shape:
#include "../semantics/symbol.h"
namespace Fortran::evaluate {
-std::optional<Shape> GetShape(
- const semantics::Symbol &symbol, const Component *component) {
+
+static Extent GetLowerBound(const semantics::Symbol &symbol,
+ const Component *component, int dimension) {
if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
- Shape result;
- int dimension{1};
+ int j{0};
for (const auto &shapeSpec : details->shape()) {
- if (shapeSpec.isExplicit()) {
- result.emplace_back(
- common::Clone(shapeSpec.ubound().GetExplicit().value()) -
- common::Clone(shapeSpec.lbound().GetExplicit().value()) +
- Expr<SubscriptInteger>{1});
- } else if (component != nullptr) {
- result.emplace_back(Expr<SubscriptInteger>{DescriptorInquiry{
- *component, DescriptorInquiry::Field::Extent, dimension}});
- } else {
- result.emplace_back(Expr<SubscriptInteger>{DescriptorInquiry{
- symbol, DescriptorInquiry::Field::Extent, dimension}});
+ if (j++ == dimension) {
+ if (const auto &bound{shapeSpec.lbound().GetExplicit()}) {
+ return *bound;
+ } else if (component != nullptr) {
+ return Expr<SubscriptInteger>{DescriptorInquiry{
+ *component, DescriptorInquiry::Field::LowerBound, dimension}};
+ } else {
+ return Expr<SubscriptInteger>{DescriptorInquiry{
+ symbol, DescriptorInquiry::Field::LowerBound, dimension}};
+ }
}
- ++dimension;
}
- return result;
- } else {
- return std::nullopt;
}
+ return std::nullopt;
}
-std::optional<Shape> GetShape(const Component &component) {
- const Symbol &symbol{component.GetLastSymbol()};
- if (symbol.Rank() > 0) {
- return GetShape(symbol, &component);
- } else {
- return GetShape(component.base());
+static Extent GetExtent(const semantics::Symbol &symbol,
+ const Component *component, int dimension) {
+ if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
+ int j{0};
+ for (const auto &shapeSpec : details->shape()) {
+ if (j++ == dimension) {
+ if (const auto &lbound{shapeSpec.lbound().GetExplicit()}) {
+ if (const auto &ubound{shapeSpec.ubound().GetExplicit()}) {
+ FoldingContext noFoldingContext;
+ return Fold(noFoldingContext,
+ common::Clone(ubound.value()) - common::Clone(lbound.value()) +
+ Expr<SubscriptInteger>{1});
+ }
+ }
+ if (component != nullptr) {
+ return Expr<SubscriptInteger>{DescriptorInquiry{
+ *component, DescriptorInquiry::Field::Extent, dimension}};
+ } else {
+ return Expr<SubscriptInteger>{DescriptorInquiry{
+ &symbol, DescriptorInquiry::Field::Extent, dimension}};
+ }
+ }
+ }
}
+ return std::nullopt;
}
-static Extent GetExtent(const Subscript &subscript) {
+
+static Extent GetExtent(const Subscript &subscript, const Symbol &symbol,
+ const Component *component, int dimension) {
return std::visit(
common::visitors{
- [](const Triplet &triplet) -> Extent {
- if (auto lower{triplet.lower()}) {
- if (auto lowerValue{ToInt64(*lower)}) {
- if (auto upper{triplet.upper()}) {
- if (auto upperValue{ToInt64(*upper)}) {
- if (auto strideValue{ToInt64(triplet.stride())}) {
- if (*strideValue != 0) {
- std::int64_t extent{
- (*upperValue - *lowerValue + *strideValue) /
- *strideValue};
- return Expr<SubscriptInteger>{extent > 0 ? extent : 0};
- }
- }
- }
- }
+ [&](const Triplet &triplet) -> Extent {
+ Extent upper{triplet.upper()};
+ if (!upper.has_value()) {
+ upper = GetExtent(symbol, component, dimension);
+ }
+ if (upper.has_value()) {
+ Extent lower{triplet.lower()};
+ if (!lower.has_value()) {
+ lower = GetLowerBound(symbol, component, dimension);
+ }
+ if (lower.has_value()) {
+ auto span{
+ (std::move(*upper) - std::move(*lower) + triplet.stride()) /
+ triplet.stride()};
+ Expr<SubscriptInteger> extent{
+ Extremum<SubscriptInteger>{std::move(span),
+ Expr<SubscriptInteger>{0}, Ordering::Greater}};
+ FoldingContext noFoldingContext;
+ return Fold(noFoldingContext, std::move(extent));
}
}
return std::nullopt;
},
[](const IndirectSubscriptIntegerExpr &subs) -> Extent {
if (auto shape{GetShape(subs.value())}) {
- if (shape->size() == 1) {
+ if (shape->size() > 0) {
+ CHECK(shape->size() == 1); // vector-valued subscript
return std::move(shape->at(0));
}
}
},
subscript.u);
}
+
+std::optional<Shape> GetShape(
+ const semantics::Symbol &symbol, const Component *component) {
+ if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
+ Shape result;
+ int n = details->shape().size();
+ for (int dimension{0}; dimension < n; ++dimension) {
+ result.emplace_back(GetExtent(symbol, component, dimension++));
+ }
+ return result;
+ } else {
+ return std::nullopt;
+ }
+}
+
+std::optional<Shape> GetShape(const BaseObject &object) {
+ if (const Symbol * symbol{object.symbol()}) {
+ return GetShape(*symbol);
+ } else {
+ return Shape{};
+ }
+}
+
+std::optional<Shape> GetShape(const Component &component) {
+ const Symbol &symbol{component.GetLastSymbol()};
+ if (symbol.Rank() > 0) {
+ return GetShape(symbol, &component);
+ } else {
+ return GetShape(component.base());
+ }
+}
+
std::optional<Shape> GetShape(const ArrayRef &arrayRef) {
Shape shape;
+ const Symbol &symbol{arrayRef.GetLastSymbol()};
+ const Component *component{std::get_if<Component>(&arrayRef.base())};
+ int dimension{0};
for (const Subscript &ss : arrayRef.subscript()) {
if (ss.Rank() > 0) {
- shape.emplace_back(GetExtent(ss));
+ shape.emplace_back(GetExtent(ss, symbol, component, dimension));
}
+ ++dimension;
}
if (shape.empty()) {
return GetShape(arrayRef.base());
return shape;
}
}
+
std::optional<Shape> GetShape(const CoarrayRef &coarrayRef) {
Shape shape;
+ SymbolOrComponent base{coarrayRef.GetBaseSymbolOrComponent()};
+ const Symbol &symbol{coarrayRef.GetLastSymbol()};
+ const Component *component{std::get_if<Component>(&base)};
+ int dimension{0};
for (const Subscript &ss : coarrayRef.subscript()) {
if (ss.Rank() > 0) {
- shape.emplace_back(GetExtent(ss));
+ shape.emplace_back(GetExtent(ss, symbol, component, dimension));
}
+ ++dimension;
}
if (shape.empty()) {
return GetShape(coarrayRef.GetLastSymbol());
return shape;
}
}
+
std::optional<Shape> GetShape(const DataRef &dataRef) {
return std::visit([](const auto &x) { return GetShape(x); }, dataRef.u);
}
+
std::optional<Shape> GetShape(const Substring &substring) {
if (const auto *dataRef{substring.GetParentIf<DataRef>()}) {
return GetShape(*dataRef);
return std::nullopt;
}
}
+
std::optional<Shape> GetShape(const ComplexPart &part) {
return GetShape(part.complex());
}
+
+std::optional<Shape> GetShape(const ActualArgument &arg) {
+ return GetShape(arg.value());
+}
+
+std::optional<Shape> GetShape(const ProcedureRef &call) {
+ if (call.Rank() == 0) {
+ return Shape{};
+ } else if (call.IsElemental()) {
+ for (const auto &arg : call.arguments()) {
+ if (arg.has_value() && arg->Rank() > 0) {
+ return GetShape(*arg);
+ }
+ }
+ } else if (const Symbol * symbol{call.proc().GetSymbol()}) {
+ return GetShape(*symbol);
+ } else if (const auto *intrinsic{
+ std::get_if<SpecificIntrinsic>(&call.proc().u)}) {
+ if (intrinsic->name == "shape" || intrinsic->name == "lbound" ||
+ intrinsic->name == "ubound") {
+ return Shape{Extent{Expr<SubscriptInteger>{
+ call.arguments().front().value().value().Rank()}}};
+ }
+ // TODO: shapes of other non-elemental intrinsic results
+ // esp. reshape, where shape is value of second argument
+ }
+ return std::nullopt;
+}
+
+std::optional<Shape> GetShape(const StructureConstructor &) {
+ return Shape{}; // always scalar
+}
+
+std::optional<Shape> GetShape(const BOZLiteralConstant &) {
+ return Shape{}; // always scalar
+}
+
+std::optional<Shape> GetShape(const NullPointer &) {
+ return {}; // not an object
+}
+
}
using Shape = std::vector<Extent>;
template<typename A> std::optional<Shape> GetShape(const A &) {
- return std::nullopt;
+ return std::nullopt; // default case
}
-template<typename T> std::optional<Shape> GetShape(const Expr<T> &);
-
+// Forward declarations
+template<typename... A>
+std::optional<Shape> GetShape(const std::variant<A...> &);
template<typename A, bool COPY>
-std::optional<Shape> GetShape(const common::Indirection<A, COPY> &p) {
- return GetShape(p.value());
-}
-
-template<typename A> std::optional<Shape> GetShape(const std::optional<A> &x) {
- if (x.has_value()) {
- return GetShape(*x);
- } else {
- return std::nullopt;
- }
-}
+std::optional<Shape> GetShape(const common::Indirection<A, COPY> &);
+template<typename A> std::optional<Shape> GetShape(const std::optional<A> &);
-template<typename... A>
-std::optional<Shape> GetShape(const std::variant<A...> &u) {
- return std::visit([](const auto &x) { return GetShape(x); }, u);
+template<typename T> std::optional<Shape> GetShape(const Expr<T> &expr) {
+ return GetShape(expr.u);
}
std::optional<Shape> GetShape(
const semantics::Symbol &, const Component * = nullptr);
-std::optional<Shape> GetShape(const DataRef &);
-std::optional<Shape> GetShape(const ComplexPart &);
-std::optional<Shape> GetShape(const Substring &);
+std::optional<Shape> GetShape(const BaseObject &);
std::optional<Shape> GetShape(const Component &);
std::optional<Shape> GetShape(const ArrayRef &);
std::optional<Shape> GetShape(const CoarrayRef &);
+std::optional<Shape> GetShape(const DataRef &);
+std::optional<Shape> GetShape(const Substring &);
+std::optional<Shape> GetShape(const ComplexPart &);
+std::optional<Shape> GetShape(const ActualArgument &);
+std::optional<Shape> GetShape(const ProcedureRef &);
+std::optional<Shape> GetShape(const StructureConstructor &);
+std::optional<Shape> GetShape(const BOZLiteralConstant &);
+std::optional<Shape> GetShape(const NullPointer &);
template<typename T>
std::optional<Shape> GetShape(const Designator<T> &designator) {
- return std::visit([](const auto &x) { return GetShape(x); }, designator.u);
+ return GetShape(designator.u);
}
-template<typename T> std::optional<Shape> GetShape(const Expr<T> &expr) {
- return std::visit(
- common::visitors{
- [](const BOZLiteralConstant &) { return Shape{}; },
- [](const NullPointer &) { return std::nullopt; },
- [](const auto &x) { return GetShape(x); },
- },
- expr.u);
+template<typename T>
+std::optional<Shape> GetShape(const Variable<T> &variable) {
+ return GetShape(variable.u);
+}
+
+template<typename D, typename R, typename... O>
+std::optional<Shape> GetShape(const Operation<D, R, O...> &operation) {
+ if constexpr (operation.operands > 1) {
+ if (operation.right().Rank() > 0) {
+ return GetShape(operation.right());
+ }
+ }
+ return GetShape(operation.left());
+}
+
+template<int KIND>
+std::optional<Shape> GetShape(const TypeParamInquiry<KIND> &) {
+ return Shape{}; // always scalar
+}
+
+template<typename T>
+std::optional<Shape> GetShape(const ArrayConstructorValues<T> &aconst) {
+ return std::nullopt; // TODO pmk much more here!!
+}
+
+template<typename... A>
+std::optional<Shape> GetShape(const std::variant<A...> &u) {
+ return std::visit([](const auto &x) { return GetShape(x); }, u);
+}
+
+template<typename A, bool COPY>
+std::optional<Shape> GetShape(const common::Indirection<A, COPY> &p) {
+ return GetShape(p.value());
+}
+
+template<typename A> std::optional<Shape> GetShape(const std::optional<A> &x) {
+ if (x.has_value()) {
+ return GetShape(*x);
+ } else {
+ return std::nullopt;
+ }
}
}
#endif // FORTRAN_EVALUATE_SHAPE_H_
return std::nullopt;
}
-const Expr<SubscriptInteger> &Triplet::stride() const {
- return stride_.value();
-}
+Expr<SubscriptInteger> Triplet::stride() const { return stride_.value(); }
bool Triplet::IsStrideOne() const {
if (auto stride{ToInt64(stride_.value())}) {
}
return std::visit(
common::visitors{
- [=](const Symbol *s) { return s->Rank(); },
+ [=](const Symbol *s) { return 0; },
[=](const Component &c) { return c.Rank(); },
},
base_);
}
int CoarrayRef::Rank() const {
- int rank{0};
- for (const auto &expr : subscript_) {
- rank += expr.Rank();
- }
- if (rank > 0) {
+ if (!subscript_.empty()) {
+ int rank{0};
+ for (const auto &expr : subscript_) {
+ rank += expr.Rank();
+ }
return rank;
} else {
return base_.back()->Rank();
}
}
+SymbolOrComponent CoarrayRef::GetBaseSymbolOrComponent() const {
+ SymbolOrComponent base{base_.front()};
+ int j{0};
+ for (const Symbol *symbol : base_) {
+ if (j == 0) { // X - already captured the symbol above
+ } else if (j == 1) { // X%Y
+ base = Component{DataRef{std::get<const Symbol *>(base)}, *symbol};
+ } else { // X%Y%Z or more
+ base = Component{DataRef{std::move(std::get<Component>(base))}, *symbol};
+ }
+ ++j;
+ }
+ return base;
+}
+
// Equality testing
bool BaseObject::operator==(const BaseObject &that) const {
std::optional<Expr<SubscriptInteger>> &&);
std::optional<Expr<SubscriptInteger>> lower() const;
std::optional<Expr<SubscriptInteger>> upper() const;
- const Expr<SubscriptInteger> &stride() const;
+ Expr<SubscriptInteger> stride() const;
bool operator==(const Triplet &) const;
bool IsStrideOne() const;
std::ostream &AsFortran(std::ostream &) const;
int Rank() const;
const Symbol &GetFirstSymbol() const;
const Symbol &GetLastSymbol() const;
+ SymbolOrComponent GetBaseSymbolOrComponent() const;
Expr<SubscriptInteger> LEN() const;
bool operator==(const CoarrayRef &) const;
std::ostream &AsFortran(std::ostream &) const;
private:
SymbolOrComponent base_{nullptr};
Field field_;
- int dimension_{0};
+ int dimension_{0}; // zero-based
};
#define INSTANTIATE_VARIABLE_TEMPLATES \
class ContextualMessages {
public:
+ ContextualMessages() = default;
ContextualMessages(CharBlock at, Messages *m) : at_{at}, messages_{m} {}
ContextualMessages(const ContextualMessages &that)
: at_{that.at_}, messages_{that.messages_} {}