#include "scope.h"
#include "semantics.h"
#include "symbol.h"
+#include "../evaluate/traversal.h"
#include "../parser/parsing.h"
#include <algorithm>
#include <cerrno>
static std::string GetHeader(const std::string &);
static std::size_t GetFileSize(const std::string &);
+// Collect symbols needed for a subprogram interface
+class SubprogramSymbolCollector {
+public:
+ using SymbolSet = std::set<const Symbol *>;
+
+ SubprogramSymbolCollector(const Symbol &symbol)
+ : symbol_{symbol}, scope_{*symbol.scope()} {}
+ SymbolList symbols() const { return need_; }
+ SymbolSet imports() const { return imports_; }
+ void Collect();
+
+private:
+ const Symbol &symbol_;
+ const Scope &scope_;
+ bool isInterface_{false};
+ SymbolList need_; // symbols that are needed
+ SymbolSet needSet_; // symbols already in need_
+ SymbolSet useSet_; // use-associations that might be needed
+ SymbolSet imports_; // imports from host that are needed
+
+ void DoSymbol(const Symbol &);
+ void DoType(const DeclTypeSpec *);
+ void DoBound(const Bound &);
+ void DoParamValue(const ParamValue &);
+
+ using SymbolVector = std::vector<const Symbol *>;
+ struct SymbolVisitor : public virtual evaluate::VisitorBase<SymbolVector> {
+ explicit SymbolVisitor(int) {}
+ void Handle(const Symbol *symbol) { result().push_back(symbol); }
+ };
+
+ template<typename T> void DoExpr(evaluate::Expr<T> expr) {
+ evaluate::Visitor<SymbolVector, SymbolVisitor> visitor{0};
+ for (const Symbol *symbol : visitor.Traverse(expr)) {
+ CHECK(symbol && "bad symbol from Traverse");
+ DoSymbol(*symbol);
+ }
+ }
+
+};
+
bool ModFileWriter::WriteAll() {
WriteAll(context_.globalScope());
return !context_.AnyFatalError();
PutSymbol(typeBindings, symbol);
}
if (auto str{typeBindings.str()}; !str.empty()) {
+ CHECK(scope.kind() == Scope::Kind::DerivedType);
decls_ << "contains\n" << str;
}
}
PutLower(os, symbol) << '(';
int n = 0;
for (const auto &dummy : details.dummyArgs()) {
- if (n++ > 0) os << ',';
+ if (n++ > 0) {
+ os << ',';
+ }
PutLower(os, *dummy);
}
os << ')';
if (result.name() != symbol.name()) {
PutLower(os << " result(", result) << ')';
}
- os << '\n';
- PutEntity(os, details.result());
- } else {
- os << '\n';
}
- for (const auto &dummy : details.dummyArgs()) {
- PutEntity(os, *dummy);
+ os << '\n';
+
+ // walk symbols, collect ones needed
+ ModFileWriter writer{context_};
+ std::stringstream typeBindings;
+ SubprogramSymbolCollector collector{symbol};
+ collector.Collect();
+ for (const Symbol *need : collector.symbols()) {
+ writer.PutSymbol(typeBindings, need);
}
+ CHECK(typeBindings.str().empty());
+ os << writer.uses_.str();
+ for (const Symbol *import : collector.imports()) {
+ decls_ << "import::" << import->name().ToString() << "\n";
+ }
+ os << writer.decls_.str();
os << "end\n";
if (isInterface) {
os << "end interface\n";
std::vector<const Symbol *> CollectSymbols(const Scope &scope) {
std::set<const Symbol *> symbols; // to prevent duplicates
std::vector<const Symbol *> sorted;
- sorted.reserve(scope.size());
+ std::vector<const Symbol *> namelist;
+ std::vector<const Symbol *> common;
+ sorted.reserve(scope.size() + scope.commonBlocks().size());
for (const auto &pair : scope) {
auto *symbol{pair.second};
if (!symbol->test(Symbol::Flag::ParentComp)) {
if (symbols.insert(symbol).second) {
- sorted.push_back(symbol);
+ if (symbol->has<NamelistDetails>()) {
+ namelist.push_back(symbol);
+ } else {
+ sorted.push_back(symbol);
+ }
}
}
}
for (const auto &pair : scope.commonBlocks()) {
- auto *symbol{pair.second};
+ const Symbol *symbol{pair.second};
+ SourceName name{pair.first};
if (symbols.insert(symbol).second) {
- sorted.push_back(symbol);
+ common.push_back(symbol);
}
}
- std::sort(sorted.begin(), sorted.end(), [](const Symbol *x, const Symbol *y) {
- bool xIsNml{x->has<NamelistDetails>()};
- bool yIsNml{y->has<NamelistDetails>()};
- if (xIsNml != yIsNml) {
- return xIsNml < yIsNml;
- } else {
- return x->name().begin() < y->name().begin();
- }
- });
+ // sort normal symbols, then namelists, then common blocks:
+ auto compareByOrder = [](const Symbol *x, const Symbol *y) {
+ CHECK(x != nullptr);
+ return x->name().begin() < y->name().begin();
+ };
+ auto cursor{sorted.begin()};
+ std::sort(cursor, sorted.end(), compareByOrder);
+ cursor = sorted.insert(sorted.end(), namelist.begin(), namelist.end());
+ std::sort(cursor, sorted.end(), compareByOrder);
+ cursor = sorted.insert(sorted.end(), common.begin(), common.end());
+ std::sort(cursor, sorted.end(), compareByOrder);
return sorted;
}
PutLower(path, name.ToString()) << extension;
return path.str();
}
+
+void SubprogramSymbolCollector::Collect() {
+ const auto &details{symbol_.get<SubprogramDetails>()};
+ isInterface_ = details.isInterface();
+ if (details.isFunction()) {
+ DoSymbol(details.result());
+ }
+ for (const Symbol *dummyArg : details.dummyArgs()) {
+ CHECK(dummyArg);
+ DoSymbol(*dummyArg);
+ }
+ for (const auto &pair : scope_) {
+ const Symbol *symbol{pair.second};
+ if (const auto *useDetails{symbol->detailsIf<UseDetails>()}) {
+ if (useSet_.count(&useDetails->symbol()) > 0) {
+ need_.push_back(symbol);
+ }
+ }
+ }
+}
+
+// Do symbols this one depends on; then add to need_
+void SubprogramSymbolCollector::DoSymbol(const Symbol &symbol) {
+ const auto &scope{symbol.owner()};
+ if (scope != scope_ && scope.kind() != Scope::Kind::DerivedType) {
+ if (scope != scope_.parent()) {
+ useSet_.insert(&symbol);
+ } else if (isInterface_) {
+ imports_.insert(&symbol);
+ }
+ return;
+ }
+ if (!needSet_.insert(&symbol).second) {
+ return; // already done
+ }
+ std::visit(
+ common::visitors{
+ [this](const ObjectEntityDetails &details) {
+ for (const ShapeSpec &spec : details.shape()) {
+ DoBound(spec.lbound());
+ DoBound(spec.ubound());
+ }
+ if (const Symbol * commonBlock{details.commonBlock()}) {
+ DoSymbol(*commonBlock);
+ }
+ },
+ [this](const CommonBlockDetails &details) {
+ for (const Symbol *object : details.objects()) {
+ DoSymbol(*object);
+ }
+ },
+ [](const auto &) {},
+ },
+ symbol.details());
+ if (!symbol.has<UseDetails>()) {
+ DoType(symbol.GetType());
+ }
+ if (scope.kind() != Scope::Kind::DerivedType) {
+ need_.push_back(&symbol);
+ }
+}
+
+void SubprogramSymbolCollector::DoType(const DeclTypeSpec *type) {
+ if (!type) {
+ return;
+ }
+ switch (type->category()) {
+ case DeclTypeSpec::Numeric:
+ case DeclTypeSpec::Logical: break; // nothing to do
+ case DeclTypeSpec::Character:
+ DoParamValue(type->characterTypeSpec().length());
+ break;
+ default:
+ if (const DerivedTypeSpec * derived{type->AsDerived()}) {
+ const auto &typeSymbol{derived->typeSymbol()};
+ if (const DerivedTypeSpec * extends{typeSymbol.GetParentTypeSpec()}) {
+ DoSymbol(extends->typeSymbol());
+ }
+ for (const auto pair : derived->parameters()) {
+ DoParamValue(pair.second);
+ }
+ for (const auto pair : *typeSymbol.scope()) {
+ const auto &comp{*pair.second};
+ DoSymbol(comp);
+ }
+ DoSymbol(typeSymbol);
+ }
+ }
+}
+
+void SubprogramSymbolCollector::DoBound(const Bound &bound) {
+ if (const MaybeSubscriptIntExpr & expr{bound.GetExplicit()}) {
+ DoExpr(*expr);
+ }
+}
+void SubprogramSymbolCollector::DoParamValue(const ParamValue ¶mValue) {
+ if (const auto &expr{paramValue.GetExplicit()}) {
+ DoExpr(*expr);
+ }
+}
+
}