template<typename... Ts>
constexpr bool AreTypesDistinct{AreTypesDistinctHelper<Ts...>::value()};
+template<typename A, typename... Ts> struct AreSameTypeHelper {
+ using type = A;
+ static constexpr bool value() {
+ if constexpr (sizeof...(Ts) == 0) {
+ return true;
+ } else {
+ using Rest = AreSameTypeHelper<Ts...>;
+ return std::is_same_v<type, typename Rest::type> && Rest::value();
+ }
+ }
+};
+
+template<typename... Ts>
+constexpr bool AreSameType{AreSameTypeHelper<Ts...>::value()};
+
template<typename> struct TupleToVariantHelper;
template<typename... Ts> struct TupleToVariantHelper<std::tuple<Ts...>> {
static_assert(AreTypesDistinct<Ts...> ||
return FoldElementalIntrinsic<T, T, T, T>(
context, std::move(funcRef), &Scalar<T>::MERGE_BITS);
} else if (name == "rank") {
- // TODO pmk: get rank
+ // TODO assumed-rank dummy argument
+ return Expr<T>{args[0].value().Rank()};
} else if (name == "shape") {
- // TODO pmk: call GetShape on argument, massage result
+ if (auto shape{GetShape(args[0].value())}) {
+ if (auto shapeExpr{AsShapeArrayExpr(*shape)}) {
+ return Fold(context, ConvertToType<T>(std::move(*shapeExpr)));
+ }
+ }
} else if (name == "size") {
- // TODO pmk: call GetShape on argument, extract or compute result
+ if (auto shape{GetShape(args[0].value())}) {
+ // TODO pmk: extract or compute result
+ }
}
// TODO:
// ceiling, count, cshift, dot_product, eoshift,
class IsConstantExprVisitor : public virtual VisitorBase<bool> {
public:
+ using Result = bool;
explicit IsConstantExprVisitor(int) { result() = true; }
template<int KIND> void Handle(const TypeParamInquiry<KIND> &inq) {
};
bool IsConstantExpr(const Expr<SomeType> &expr) {
- return Visitor<bool, IsConstantExprVisitor>{0}.Traverse(expr);
+ return Visitor<IsConstantExprVisitor>{0}.Traverse(expr);
}
std::optional<std::int64_t> ToInt64(const Expr<SomeInteger> &expr) {
{"product",
{{"array", SameNumeric, Rank::array}, OptionalDIM, OptionalMASK},
SameNumeric, Rank::dimReduced},
- // TODO pmk: "rank"
+ {"rank", {{"a", Anything, Rank::anyOrAssumedRank}}, DefaultInt},
{"real", {{"a", AnyNumeric, Rank::elementalOrBOZ}, DefaultingKIND},
KINDReal},
{"reduce",
CHECK(!shapeArgSize.has_value());
if (rank == 1) {
if (auto shape{GetShape(*arg)}) {
- CHECK(shape->size() == 1);
- if (auto value{ToInt64(shape->at(0))}) {
- shapeArgSize = *value;
- argOk = *value >= 0;
+ if (auto constShape{AsConstantShape(*shape)}) {
+ shapeArgSize = (**constShape).ToInt64();
+ CHECK(shapeArgSize >= 0);
+ argOk = true;
}
}
}
return result;
}
-std::optional<Constant<ExtentType>> AsConstantShape(const Shape &shape) {
- std::vector<Scalar<ExtentType>> extents;
+std::optional<ExtentExpr> AsShapeArrayExpr(const Shape &shape) {
+ ArrayConstructorValues<ExtentType> values;
for (const auto &dim : shape) {
if (dim.has_value()) {
- if (const auto cdim{UnwrapExpr<Constant<ExtentType>>(*dim)}) {
- extents.emplace_back(**cdim);
- continue;
- }
+ values.Push(common::Clone(*dim));
+ } else {
+ return std::nullopt;
+ }
+ }
+ return ExtentExpr{ArrayConstructor<ExtentType>{std::move(values)}};
+}
+
+std::optional<Constant<ExtentType>> AsConstantShape(const Shape &shape) {
+ if (auto shapeArray{AsShapeArrayExpr(shape)}) {
+ FoldingContext noFoldingContext;
+ auto folded{Fold(noFoldingContext, std::move(*shapeArray))};
+ if (auto *p{UnwrapExpr<Constant<ExtentType>>(folded)}) {
+ return std::move(*p);
}
- return std::nullopt;
}
- std::vector<std::int64_t> rshape{static_cast<std::int64_t>(shape.size())};
- return Constant<ExtentType>{std::move(extents), std::move(rshape)};
+ return std::nullopt;
}
static ExtentExpr ComputeTripCount(
return extent;
}
-static MaybeExtent GetLowerBound(const semantics::Symbol &symbol,
- const Component *component, int dimension) {
+static MaybeExtent GetLowerBound(
+ const Symbol &symbol, const Component *component, int dimension) {
if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
int j{0};
for (const auto &shapeSpec : details->shape()) {
return std::nullopt;
}
-static MaybeExtent GetExtent(const semantics::Symbol &symbol,
- const Component *component, int dimension) {
+static MaybeExtent GetExtent(
+ const Symbol &symbol, const Component *component, int dimension) {
if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
int j{0};
for (const auto &shapeSpec : details->shape()) {
bool ContainsAnyImpliedDoIndex(const ExtentExpr &expr) {
struct MyVisitor : public virtual VisitorBase<bool> {
+ using Result = bool;
explicit MyVisitor(int) { result() = false; }
void Handle(const ImpliedDoIndex &) { Return(true); }
};
- return Visitor<bool, MyVisitor>{0}.Traverse(expr);
+ return Visitor<MyVisitor>{0}.Traverse(expr);
}
std::optional<Shape> GetShape(
- const semantics::Symbol &symbol, const Component *component) {
+ const Symbol &symbol, const Component *component) {
if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
Shape result;
int n = details->shape().size();
}
}
+std::optional<Shape> GetShape(const Symbol *symbol) {
+ if (symbol != nullptr) {
+ return GetShape(*symbol);
+ } else {
+ return std::nullopt;
+ }
+}
+
std::optional<Shape> GetShape(const BaseObject &object) {
if (const Symbol * symbol{object.symbol()}) {
return GetShape(*symbol);
}
std::optional<Shape> GetShape(const DataRef &dataRef) {
- return std::visit([](const auto &x) { return GetShape(x); }, dataRef.u);
+ return GetShape(dataRef.u);
}
std::optional<Shape> GetShape(const Substring &substring) {
return std::nullopt;
}
+std::optional<Shape> GetShape(const Relational<SomeType> &relation) {
+ return GetShape(relation.u);
+}
+
std::optional<Shape> GetShape(const StructureConstructor &) {
return Shape{}; // always scalar
}
+std::optional<Shape> GetShape(const ImpliedDoIndex &) {
+ return Shape{}; // always scalar
+}
+
+std::optional<Shape> GetShape(const DescriptorInquiry &) {
+ return Shape{}; // always scalar
+}
+
std::optional<Shape> GetShape(const BOZLiteralConstant &) {
return Shape{}; // always scalar
}
// Convert a constant shape to the expression form, and vice versa.
Shape AsGeneralShape(const Constant<ExtentType> &);
+std::optional<ExtentExpr> AsShapeArrayExpr(const Shape &); // array constructor
std::optional<Constant<ExtentType>> AsConstantShape(const Shape &);
-// Compute a trip count for a triplet or implied DO.
+// Compute an element count for a triplet or trip count for a DO.
ExtentExpr CountTrips(
ExtentExpr &&lower, ExtentExpr &&upper, ExtentExpr &&stride);
ExtentExpr CountTrips(
// Computes SIZE() == PRODUCT(shape)
MaybeExtent GetSize(Shape &&);
-template<typename A> std::optional<Shape> GetShape(const A &) {
- return std::nullopt; // default case TODO pmk remove
-}
-
// Forward declarations
template<typename... A>
std::optional<Shape> GetShape(const std::variant<A...> &);
return GetShape(expr.u);
}
-std::optional<Shape> GetShape(
- const semantics::Symbol &, const Component * = nullptr);
+std::optional<Shape> GetShape(const Symbol &, const Component * = nullptr);
+std::optional<Shape> GetShape(const Symbol *);
std::optional<Shape> GetShape(const BaseObject &);
std::optional<Shape> GetShape(const Component &);
std::optional<Shape> GetShape(const ArrayRef &);
std::optional<Shape> GetShape(const ComplexPart &);
std::optional<Shape> GetShape(const ActualArgument &);
std::optional<Shape> GetShape(const ProcedureRef &);
+std::optional<Shape> GetShape(const ImpliedDoIndex &);
+std::optional<Shape> GetShape(const Relational<SomeType> &);
std::optional<Shape> GetShape(const StructureConstructor &);
+std::optional<Shape> GetShape(const DescriptorInquiry &);
std::optional<Shape> GetShape(const BOZLiteralConstant &);
std::optional<Shape> GetShape(const NullPointer &);
template<typename D, typename R, typename... O>
std::optional<Shape> GetShape(const Operation<D, R, O...> &operation) {
- if constexpr (operation.operands > 1) {
+ if constexpr (sizeof...(O) > 1) {
if (operation.right().Rank() > 0) {
return GetShape(operation.right());
}
// To use for non-mutating visitation, define one or more client visitation
// classes of the form:
// class MyVisitor : public virtual VisitorBase<RESULT> {
+// using Result = RESULT;
// explicit MyVisitor(ARGTYPE); // single-argument constructor
// void Handle(const T1 &); // callback for type T1 objects
// void Pre(const T2 &); // callback before visiting T2
// void Post(const T2 &); // callback after visiting T2
// ...
// };
-// RESULT should have some default-constructible type.
-// Then instantiate and construct a Visitor and its embedded MyVisitor via:
-// Visitor<RESULT, MyVisitor, ...> v{value}; // value is ARGTYPE &&
+// RESULT should have some default-constructible type, and it must be
+// the same type in all of the visitors that you combine in the next step.
+//
+// Then instantiate and construct a Visitor and its embedded visitors via:
+// Visitor<MyVisitor, ...> v{value...}; // value is/are ARGTYPE &&
// and call:
// RESULT result{v.Traverse(topLevelExpr)};
// Within the callback routines (Handle, Pre, Post), one may call
// argument types are rvalues and the non-void result types match
// the arguments:
// class MyRewriter : public virtual RewriterBase<RESULT> {
+// using Result = RESULT;
// explicit MyRewriter(ARGTYPE); // single-argument constructor
// T1 Handle(T1 &&); // rewriting callback for type T1 objects
// void Pre(T2 &); // in-place mutating callback before visiting T2
Result result_;
};
-template<typename RESULT, typename... A>
-class Visitor : public virtual VisitorBase<RESULT>, public A... {
+template<typename A, typename... B> struct VisitorResultTypeHelper {
+ using type = typename A::Result;
+ static_assert(common::AreSameType<type, typename B::Result...>);
+};
+template<typename... A>
+using VisitorResultType = typename VisitorResultTypeHelper<A...>::type;
+
+template<typename... A>
+class Visitor : public virtual VisitorBase<VisitorResultType<A...>>,
+ public A... {
public:
- using Result = RESULT;
+ using Result = VisitorResultType<A...>;
using Base = VisitorBase<Result>;
using Base::Handle, Base::Pre, Base::Post;
using A::Handle..., A::Pre..., A::Post...;
// Use the new expression traversal framework if possible, for testing.
if (expression.typedExpr) {
struct CollectSymbols : public virtual evaluate::VisitorBase<CS> {
+ using Result = CS;
explicit CollectSymbols(int) {}
void Handle(const Symbol *symbol) { result().push_back(symbol); }
};
- return evaluate::Visitor<CS, CollectSymbols>{0}.Traverse(
- *expression.typedExpr);
+ return evaluate::Visitor<CollectSymbols>{0}.Traverse(*expression.typedExpr);
} else {
GatherSymbols gatherSymbols;
parser::Walk(expression, gatherSymbols);
using SymbolVector = std::vector<const Symbol *>;
struct SymbolVisitor : public virtual evaluate::VisitorBase<SymbolVector> {
+ using Result = SymbolVector;
explicit SymbolVisitor(int) {}
void Handle(const Symbol *symbol) { result().push_back(symbol); }
};
template<typename T> void DoExpr(evaluate::Expr<T> expr) {
- evaluate::Visitor<SymbolVector, SymbolVisitor> visitor{0};
+ evaluate::Visitor<SymbolVisitor> visitor{0};
for (const Symbol *symbol : visitor.Traverse(expr)) {
CHECK(symbol && "bad symbol from Traverse");
DoSymbol(*symbol);