// TODO pmk: get rank
} else if (name == "shape") {
// TODO pmk: call GetShape on argument, massage result
+ } else if (name == "size") {
+ // TODO pmk: call GetShape on argument, extract or compute result
}
// TODO:
// ceiling, count, cshift, dot_product, eoshift,
// index, ishftc, lbound, len_trim, matmul, max, maxloc, maxval, merge, min,
// minloc, minval, mod, modulo, nint, not, pack, product, reduce, reshape,
// scan, selected_char_kind, selected_int_kind, selected_real_kind,
- // sign, size, spread, sum, transfer, transpose, ubound, unpack, verify
+ // sign, spread, sum, transfer, transpose, ubound, unpack, verify
}
return Expr<T>{std::move(funcRef)};
}
#include "shape.h"
#include "fold.h"
#include "tools.h"
+#include "traversal.h"
#include "../common/idioms.h"
+#include "../common/template.h"
#include "../semantics/symbol.h"
namespace Fortran::evaluate {
-static Extent GetLowerBound(const semantics::Symbol &symbol,
+Shape AsGeneralShape(const Constant<ExtentType> &constShape) {
+ CHECK(constShape.Rank() == 1);
+ Shape result;
+ std::size_t dimensions{constShape.size()};
+ for (std::size_t j{0}; j < dimensions; ++j) {
+ Scalar<ExtentType> extent{constShape.values().at(j)};
+ result.emplace_back(MaybeExtent{ExtentExpr{extent}});
+ }
+ return result;
+}
+
+std::optional<Constant<ExtentType>> AsConstantShape(const Shape &shape) {
+ std::vector<Scalar<ExtentType>> extents;
+ for (const auto &dim : shape) {
+ if (dim.has_value()) {
+ if (const auto cdim{UnwrapExpr<Constant<ExtentType>>(*dim)}) {
+ extents.emplace_back(**cdim);
+ continue;
+ }
+ }
+ return std::nullopt;
+ }
+ std::vector<std::int64_t> rshape{static_cast<std::int64_t>(shape.size())};
+ return Constant<ExtentType>{std::move(extents), std::move(rshape)};
+}
+
+static ExtentExpr ComputeTripCount(
+ ExtentExpr &&lower, ExtentExpr &&upper, ExtentExpr &&stride) {
+ ExtentExpr strideCopy{common::Clone(stride)};
+ ExtentExpr span{
+ (std::move(upper) - std::move(lower) + std::move(strideCopy)) /
+ std::move(stride)};
+ ExtentExpr extent{
+ Extremum<ExtentType>{std::move(span), ExtentExpr{0}, Ordering::Greater}};
+ FoldingContext noFoldingContext;
+ return Fold(noFoldingContext, std::move(extent));
+}
+
+ExtentExpr CountTrips(
+ ExtentExpr &&lower, ExtentExpr &&upper, ExtentExpr &&stride) {
+ return ComputeTripCount(
+ std::move(lower), std::move(upper), std::move(stride));
+}
+
+ExtentExpr CountTrips(const ExtentExpr &lower, const ExtentExpr &upper,
+ const ExtentExpr &stride) {
+ return ComputeTripCount(
+ common::Clone(lower), common::Clone(upper), common::Clone(stride));
+}
+
+MaybeExtent CountTrips(
+ MaybeExtent &&lower, MaybeExtent &&upper, MaybeExtent &&stride) {
+ return common::MapOptional(
+ ComputeTripCount, std::move(lower), std::move(upper), std::move(stride));
+}
+
+MaybeExtent GetSize(Shape &&shape) {
+ ExtentExpr extent{1};
+ for (auto &&dim : std::move(shape)) {
+ if (dim.has_value()) {
+ extent = std::move(extent) * std::move(*dim);
+ } else {
+ return std::nullopt;
+ }
+ }
+ return extent;
+}
+
+static MaybeExtent GetLowerBound(const semantics::Symbol &symbol,
const Component *component, int dimension) {
if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
int j{0};
if (const auto &bound{shapeSpec.lbound().GetExplicit()}) {
return *bound;
} else if (component != nullptr) {
- return Expr<SubscriptInteger>{DescriptorInquiry{
+ return ExtentExpr{DescriptorInquiry{
*component, DescriptorInquiry::Field::LowerBound, dimension}};
} else {
- return Expr<SubscriptInteger>{DescriptorInquiry{
+ return ExtentExpr{DescriptorInquiry{
symbol, DescriptorInquiry::Field::LowerBound, dimension}};
}
}
return std::nullopt;
}
-static Extent GetExtent(const semantics::Symbol &symbol,
+static MaybeExtent GetExtent(const semantics::Symbol &symbol,
const Component *component, int dimension) {
if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
int j{0};
FoldingContext noFoldingContext;
return Fold(noFoldingContext,
common::Clone(ubound.value()) - common::Clone(lbound.value()) +
- Expr<SubscriptInteger>{1});
+ ExtentExpr{1});
}
}
if (component != nullptr) {
- return Expr<SubscriptInteger>{DescriptorInquiry{
+ return ExtentExpr{DescriptorInquiry{
*component, DescriptorInquiry::Field::Extent, dimension}};
} else {
- return Expr<SubscriptInteger>{DescriptorInquiry{
+ return ExtentExpr{DescriptorInquiry{
&symbol, DescriptorInquiry::Field::Extent, dimension}};
}
}
return std::nullopt;
}
-static Extent GetExtent(const Subscript &subscript, const Symbol &symbol,
+static MaybeExtent GetExtent(const Subscript &subscript, const Symbol &symbol,
const Component *component, int dimension) {
return std::visit(
common::visitors{
- [&](const Triplet &triplet) -> Extent {
- Extent upper{triplet.upper()};
+ [&](const Triplet &triplet) -> MaybeExtent {
+ MaybeExtent upper{triplet.upper()};
if (!upper.has_value()) {
upper = GetExtent(symbol, component, dimension);
}
- if (upper.has_value()) {
- Extent lower{triplet.lower()};
- if (!lower.has_value()) {
- lower = GetLowerBound(symbol, component, dimension);
- }
- if (lower.has_value()) {
- auto span{
- (std::move(*upper) - std::move(*lower) + triplet.stride()) /
- triplet.stride()};
- Expr<SubscriptInteger> extent{
- Extremum<SubscriptInteger>{std::move(span),
- Expr<SubscriptInteger>{0}, Ordering::Greater}};
- FoldingContext noFoldingContext;
- return Fold(noFoldingContext, std::move(extent));
- }
+ MaybeExtent lower{triplet.lower()};
+ if (!lower.has_value()) {
+ lower = GetLowerBound(symbol, component, dimension);
}
- return std::nullopt;
+ return CountTrips(std::move(lower), std::move(upper),
+ MaybeExtent{triplet.stride()});
},
- [](const IndirectSubscriptIntegerExpr &subs) -> Extent {
+ [](const IndirectSubscriptIntegerExpr &subs) -> MaybeExtent {
if (auto shape{GetShape(subs.value())}) {
if (shape->size() > 0) {
CHECK(shape->size() == 1); // vector-valued subscript
subscript.u);
}
+bool ContainsAnyImpliedDoIndex(const ExtentExpr &expr) {
+ struct MyVisitor : public virtual VisitorBase<bool> {
+ explicit MyVisitor(int) { result() = false; }
+ void Handle(const ImpliedDoIndex &) { Return(true); }
+ };
+ return Visitor<bool, MyVisitor>{0}.Traverse(expr);
+}
+
std::optional<Shape> GetShape(
const semantics::Symbol &symbol, const Component *component) {
if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
std::get_if<SpecificIntrinsic>(&call.proc().u)}) {
if (intrinsic->name == "shape" || intrinsic->name == "lbound" ||
intrinsic->name == "ubound") {
- return Shape{Extent{Expr<SubscriptInteger>{
- call.arguments().front().value().value().Rank()}}};
+ return Shape{MaybeExtent{
+ ExtentExpr{call.arguments().front().value().value().Rank()}}};
}
// TODO: shapes of other non-elemental intrinsic results
// esp. reshape, where shape is value of second argument
#define FORTRAN_EVALUATE_SHAPE_H_
#include "expression.h"
+#include "tools.h"
#include "type.h"
#include "../common/indirection.h"
#include <optional>
namespace Fortran::evaluate {
-using Extent = std::optional<Expr<SubscriptInteger>>;
-using Shape = std::vector<Extent>;
+using ExtentType = SubscriptInteger;
+using ExtentExpr = Expr<ExtentType>;
+using MaybeExtent = std::optional<ExtentExpr>;
+using Shape = std::vector<MaybeExtent>;
+
+// Convert a constant shape to the expression form, and vice versa.
+Shape AsGeneralShape(const Constant<ExtentType> &);
+std::optional<Constant<ExtentType>> AsConstantShape(const Shape &);
+
+// Compute a trip count for a triplet or implied DO.
+ExtentExpr CountTrips(
+ ExtentExpr &&lower, ExtentExpr &&upper, ExtentExpr &&stride);
+ExtentExpr CountTrips(
+ const ExtentExpr &lower, const ExtentExpr &upper, const ExtentExpr &stride);
+MaybeExtent CountTrips(
+ MaybeExtent &&lower, MaybeExtent &&upper, MaybeExtent &&stride);
+
+// Computes SIZE() == PRODUCT(shape)
+MaybeExtent GetSize(Shape &&);
template<typename A> std::optional<Shape> GetShape(const A &) {
- return std::nullopt; // default case
+ return std::nullopt; // default case TODO pmk remove
}
// Forward declarations
std::optional<Shape> GetShape(const BOZLiteralConstant &);
std::optional<Shape> GetShape(const NullPointer &);
+template<typename T> std::optional<Shape> GetShape(const Constant<T> &c) {
+ Constant<ExtentType> shape{c.SHAPE()};
+ return AsGeneralShape(shape);
+}
+
template<typename T>
std::optional<Shape> GetShape(const Designator<T> &designator) {
return GetShape(designator.u);
template<int KIND>
std::optional<Shape> GetShape(const TypeParamInquiry<KIND> &) {
- return Shape{}; // always scalar
+ return Shape{}; // always scalar, even when applied to an array
+}
+
+// Utility predicate: does an expression reference any implied DO index?
+bool ContainsAnyImpliedDoIndex(const ExtentExpr &);
+
+template<typename T> MaybeExtent GetExtent(const ArrayConstructorValues<T> &);
+
+template<typename T>
+MaybeExtent GetExtent(const ArrayConstructorValue<T> &value) {
+ return std::visit(
+ common::visitors{
+ [](const common::CopyableIndirection<Expr<T>> &x) -> MaybeExtent {
+ if (std::optional<Shape> xShape{GetShape(x)}) {
+ // Array values in array constructors get linearized.
+ return GetSize(std::move(*xShape));
+ }
+ return std::nullopt;
+ },
+ [](const ImpliedDo<T> &ido) -> MaybeExtent {
+ // Don't be heroic and try to figure out triangular implied DO
+ // nests.
+ if (!ContainsAnyImpliedDoIndex(ido.lower()) &&
+ !ContainsAnyImpliedDoIndex(ido.upper()) &&
+ !ContainsAnyImpliedDoIndex(ido.stride())) {
+ if (auto nValues{GetExtent(ido.values())}) {
+ return std::move(*nValues) *
+ CountTrips(ido.lower(), ido.upper(), ido.stride());
+ }
+ }
+ return std::nullopt;
+ },
+ },
+ value.u);
+}
+
+template<typename T>
+MaybeExtent GetExtent(const ArrayConstructorValues<T> &values) {
+ ExtentExpr result{0};
+ for (const auto &value : values.values()) {
+ if (MaybeExtent n{GetExtent(value)}) {
+ result = std::move(result) + std::move(*n);
+ } else {
+ return std::nullopt;
+ }
+ }
+ return result;
}
template<typename T>
-std::optional<Shape> GetShape(const ArrayConstructorValues<T> &aconst) {
- return std::nullopt; // TODO pmk much more here!!
+std::optional<Shape> GetShape(const ArrayConstructor<T> &aconst) {
+ return Shape{GetExtent(aconst)};
}
template<typename... A>