53f72b4a99a012b580eae347853b61ec1d440ef9
[platform/upstream/llvm.git] / flang / lib / Semantics / mod-file.cpp
1 //===-- lib/Semantics/mod-file.cpp ----------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mod-file.h"
10 #include "resolve-names.h"
11 #include "flang/Evaluate/tools.h"
12 #include "flang/Parser/message.h"
13 #include "flang/Parser/parsing.h"
14 #include "flang/Semantics/scope.h"
15 #include "flang/Semantics/semantics.h"
16 #include "flang/Semantics/symbol.h"
17 #include "flang/Semantics/tools.h"
18 #include <algorithm>
19 #include <cerrno>
20 #include <fstream>
21 #include <ostream>
22 #include <set>
23 #include <string_view>
24 #include <sys/file.h>
25 #include <sys/stat.h>
26 #include <sys/types.h>
27 #include <unistd.h>
28 #include <vector>
29
30 namespace Fortran::semantics {
31
32 using namespace parser::literals;
33
34 // The first line of a file that identifies it as a .mod file.
35 // The first three bytes are a Unicode byte order mark that ensures
36 // that the module file is decoded as UTF-8 even if source files
37 // are using another encoding.
38 struct ModHeader {
39   static constexpr const char bom[3 + 1]{"\xef\xbb\xbf"};
40   static constexpr int magicLen{13};
41   static constexpr int sumLen{16};
42   static constexpr const char magic[magicLen + 1]{"!mod$ v1 sum:"};
43   static constexpr char terminator{'\n'};
44   static constexpr int len{magicLen + 1 + sumLen};
45 };
46
47 static std::optional<SourceName> GetSubmoduleParent(const parser::Program &);
48 static SymbolVector CollectSymbols(const Scope &);
49 static void PutEntity(std::ostream &, const Symbol &);
50 static void PutObjectEntity(std::ostream &, const Symbol &);
51 static void PutProcEntity(std::ostream &, const Symbol &);
52 static void PutPassName(std::ostream &, const std::optional<SourceName> &);
53 static void PutTypeParam(std::ostream &, const Symbol &);
54 static void PutEntity(
55     std::ostream &, const Symbol &, std::function<void()>, Attrs);
56 static void PutInit(std::ostream &, const Symbol &, const MaybeExpr &);
57 static void PutInit(std::ostream &, const MaybeIntExpr &);
58 static void PutBound(std::ostream &, const Bound &);
59 static std::ostream &PutAttrs(std::ostream &, Attrs,
60     const MaybeExpr & = std::nullopt, std::string before = ","s,
61     std::string after = ""s);
62 static std::ostream &PutAttr(std::ostream &, Attr);
63 static std::ostream &PutType(std::ostream &, const DeclTypeSpec &);
64 static std::ostream &PutLower(std::ostream &, const std::string &);
65 static int WriteFile(const std::string &, const std::string &);
66 static bool FileContentsMatch(
67     const std::string &, const std::string &, const std::string &);
68 static std::size_t GetFileSize(const std::string &);
69 static std::string CheckSum(const std::string_view &);
70
71 // Collect symbols needed for a subprogram interface
72 class SubprogramSymbolCollector {
73 public:
74   SubprogramSymbolCollector(const Symbol &symbol)
75     : symbol_{symbol}, scope_{DEREF(symbol.scope())} {}
76   const SymbolVector &symbols() const { return need_; }
77   const std::set<SourceName> &imports() const { return imports_; }
78   void Collect();
79
80 private:
81   const Symbol &symbol_;
82   const Scope &scope_;
83   bool isInterface_{false};
84   SymbolVector need_;  // symbols that are needed
85   SymbolSet needSet_;  // symbols already in need_
86   SymbolSet useSet_;  // use-associations that might be needed
87   std::set<SourceName> imports_;  // imports from host that are needed
88
89   void DoSymbol(const Symbol &);
90   void DoSymbol(const SourceName &, const Symbol &);
91   void DoType(const DeclTypeSpec *);
92   void DoBound(const Bound &);
93   void DoParamValue(const ParamValue &);
94   bool NeedImport(const SourceName &, const Symbol &);
95
96   template<typename T> void DoExpr(evaluate::Expr<T> expr) {
97     for (const Symbol &symbol : evaluate::CollectSymbols(expr)) {
98       DoSymbol(symbol);
99     }
100   }
101 };
102
103 bool ModFileWriter::WriteAll() {
104   WriteAll(context_.globalScope());
105   return !context_.AnyFatalError();
106 }
107
108 void ModFileWriter::WriteAll(const Scope &scope) {
109   for (const auto &child : scope.children()) {
110     WriteOne(child);
111   }
112 }
113
114 void ModFileWriter::WriteOne(const Scope &scope) {
115   if (scope.kind() == Scope::Kind::Module) {
116     auto *symbol{scope.symbol()};
117     if (!symbol->test(Symbol::Flag::ModFile)) {
118       Write(*symbol);
119     }
120     WriteAll(scope);  // write out submodules
121   }
122 }
123
124 // Construct the name of a module file. Non-empty ancestorName means submodule.
125 static std::string ModFileName(const SourceName &name,
126     const std::string &ancestorName, const std::string &suffix) {
127   std::string result{name.ToString() + suffix};
128   return ancestorName.empty() ? result : ancestorName + '-' + result;
129 }
130
131 // Write the module file for symbol, which must be a module or submodule.
132 void ModFileWriter::Write(const Symbol &symbol) {
133   auto *ancestor{symbol.get<ModuleDetails>().ancestor()};
134   auto ancestorName{ancestor ? ancestor->GetName().value().ToString() : ""s};
135   auto path{context_.moduleDirectory() + '/' +
136       ModFileName(symbol.name(), ancestorName, context_.moduleFileSuffix())};
137   PutSymbols(DEREF(symbol.scope()));
138   if (int error{WriteFile(path, GetAsString(symbol))}) {
139     context_.Say(symbol.name(), "Error writing %s: %s"_err_en_US, path,
140         std::strerror(error));
141   }
142 }
143
144 // Return the entire body of the module file
145 // and clear saved uses, decls, and contains.
146 std::string ModFileWriter::GetAsString(const Symbol &symbol) {
147   std::stringstream all;
148   auto &details{symbol.get<ModuleDetails>()};
149   if (!details.isSubmodule()) {
150     all << "module " << symbol.name();
151   } else {
152     auto *parent{details.parent()->symbol()};
153     auto *ancestor{details.ancestor()->symbol()};
154     all << "submodule(" << ancestor->name();
155     if (parent != ancestor) {
156       all << ':' << parent->name();
157     }
158     all << ") " << symbol.name();
159   }
160   all << '\n' << uses_.str();
161   uses_.str(""s);
162   all << useExtraAttrs_.str();
163   useExtraAttrs_.str(""s);
164   all << decls_.str();
165   decls_.str(""s);
166   auto str{contains_.str()};
167   contains_.str(""s);
168   if (!str.empty()) {
169     all << "contains\n" << str;
170   }
171   all << "end\n";
172   return all.str();
173 }
174
175 // Put out the visible symbols from scope.
176 void ModFileWriter::PutSymbols(const Scope &scope) {
177   std::stringstream typeBindings;  // stuff after CONTAINS in derived type
178   for (const Symbol &symbol : CollectSymbols(scope)) {
179     PutSymbol(typeBindings, symbol);
180   }
181   if (auto str{typeBindings.str()}; !str.empty()) {
182     CHECK(scope.IsDerivedType());
183     decls_ << "contains\n" << str;
184   }
185 }
186
187 // Emit a symbol to decls_, except for bindings in a derived type (type-bound
188 // procedures, type-bound generics, final procedures) which go to typeBindings.
189 void ModFileWriter::PutSymbol(
190     std::stringstream &typeBindings, const Symbol &symbol) {
191   std::visit(
192       common::visitors{
193           [&](const ModuleDetails &) { /* should be current module */ },
194           [&](const DerivedTypeDetails &) { PutDerivedType(symbol); },
195           [&](const SubprogramDetails &) { PutSubprogram(symbol); },
196           [&](const GenericDetails &x) {
197             if (symbol.owner().IsDerivedType()) {
198               // generic binding
199               for (const Symbol &proc : x.specificProcs()) {
200                 typeBindings << "generic::" << symbol.name() << "=>"
201                              << proc.name() << '\n';
202               }
203             } else {
204               PutGeneric(symbol);
205               if (x.specific()) {
206                 PutSymbol(typeBindings, *x.specific());
207               }
208               if (x.derivedType()) {
209                 PutSymbol(typeBindings, *x.derivedType());
210               }
211             }
212           },
213           [&](const UseDetails &) { PutUse(symbol); },
214           [](const UseErrorDetails &) {},
215           [&](const ProcBindingDetails &x) {
216             bool deferred{symbol.attrs().test(Attr::DEFERRED)};
217             typeBindings << "procedure";
218             if (deferred) {
219               typeBindings << '(' << x.symbol().name() << ')';
220             }
221             PutPassName(typeBindings, x.passName());
222             auto attrs{symbol.attrs()};
223             if (x.passName()) {
224               attrs.reset(Attr::PASS);
225             }
226             PutAttrs(typeBindings, attrs);
227             typeBindings << "::" << symbol.name();
228             if (!deferred && x.symbol().name() != symbol.name()) {
229               typeBindings << "=>" << x.symbol().name();
230             }
231             typeBindings << '\n';
232           },
233           [&](const NamelistDetails &x) {
234             decls_ << "namelist/" << symbol.name();
235             char sep{'/'};
236             for (const Symbol &object : x.objects()) {
237               decls_ << sep << object.name();
238               sep = ',';
239             }
240             decls_ << '\n';
241           },
242           [&](const CommonBlockDetails &x) {
243             decls_ << "common/" << symbol.name();
244             char sep = '/';
245             for (const Symbol &object : x.objects()) {
246               decls_ << sep << object.name();
247               sep = ',';
248             }
249             decls_ << '\n';
250             if (symbol.attrs().test(Attr::BIND_C)) {
251               PutAttrs(decls_, symbol.attrs(), x.bindName(), ""s);
252               decls_ << "::/" << symbol.name() << "/\n";
253             }
254           },
255           [&](const FinalProcDetails &) {
256             typeBindings << "final::" << symbol.name() << '\n';
257           },
258           [](const HostAssocDetails &) {},
259           [](const MiscDetails &) {},
260           [&](const auto &) { PutEntity(decls_, symbol); },
261       },
262       symbol.details());
263 }
264
265 void ModFileWriter::PutDerivedType(const Symbol &typeSymbol) {
266   auto &details{typeSymbol.get<DerivedTypeDetails>()};
267   PutAttrs(decls_ << "type", typeSymbol.attrs());
268   if (const DerivedTypeSpec * extends{typeSymbol.GetParentTypeSpec()}) {
269     decls_ << ",extends(" << extends->name() << ')';
270   }
271   decls_ << "::" << typeSymbol.name();
272   auto &typeScope{*typeSymbol.scope()};
273   if (!details.paramNames().empty()) {
274     char sep{'('};
275     for (const auto &name : details.paramNames()) {
276       decls_ << sep << name;
277       sep = ',';
278     }
279     decls_ << ')';
280   }
281   decls_ << '\n';
282   if (details.sequence()) {
283     decls_ << "sequence\n";
284   }
285   PutSymbols(typeScope);
286   decls_ << "end type\n";
287 }
288
289 // Attributes that may be in a subprogram prefix
290 static const Attrs subprogramPrefixAttrs{Attr::ELEMENTAL, Attr::IMPURE,
291     Attr::MODULE, Attr::NON_RECURSIVE, Attr::PURE, Attr::RECURSIVE};
292
293 void ModFileWriter::PutSubprogram(const Symbol &symbol) {
294   auto attrs{symbol.attrs()};
295   auto &details{symbol.get<SubprogramDetails>()};
296   Attrs bindAttrs{};
297   if (attrs.test(Attr::BIND_C)) {
298     // bind(c) is a suffix, not prefix
299     bindAttrs.set(Attr::BIND_C, true);
300     attrs.set(Attr::BIND_C, false);
301   }
302   Attrs prefixAttrs{subprogramPrefixAttrs & attrs};
303   // emit any non-prefix attributes in an attribute statement
304   attrs &= ~subprogramPrefixAttrs;
305   std::stringstream ss;
306   PutAttrs(ss, attrs);
307   if (!ss.str().empty()) {
308     decls_ << ss.str().substr(1) << "::" << symbol.name() << '\n';
309   }
310   bool isInterface{details.isInterface()};
311   std::ostream &os{isInterface ? decls_ : contains_};
312   if (isInterface) {
313     os << "interface\n";
314   }
315   PutAttrs(os, prefixAttrs, std::nullopt, ""s, " "s);
316   os << (details.isFunction() ? "function " : "subroutine ");
317   os << symbol.name() << '(';
318   int n = 0;
319   for (const auto &dummy : details.dummyArgs()) {
320     if (n++ > 0) {
321       os << ',';
322     }
323     os << dummy->name();
324   }
325   os << ')';
326   PutAttrs(os, bindAttrs, details.bindName(), " "s, ""s);
327   if (details.isFunction()) {
328     const Symbol &result{details.result()};
329     if (result.name() != symbol.name()) {
330       os << " result(" << result.name() << ')';
331     }
332   }
333   os << '\n';
334
335   // walk symbols, collect ones needed
336   ModFileWriter writer{context_};
337   std::stringstream typeBindings;
338   SubprogramSymbolCollector collector{symbol};
339   collector.Collect();
340   for (const Symbol &need : collector.symbols()) {
341     writer.PutSymbol(typeBindings, need);
342   }
343   CHECK(typeBindings.str().empty());
344   os << writer.uses_.str();
345   for (const SourceName &import : collector.imports()) {
346     decls_ << "import::" << import << "\n";
347   }
348   os << writer.decls_.str();
349   os << "end\n";
350   if (isInterface) {
351     os << "end interface\n";
352   }
353 }
354
355 static bool IsIntrinsicOp(const Symbol &symbol) {
356   if (const auto *details{symbol.GetUltimate().detailsIf<GenericDetails>()}) {
357     return details->kind().IsIntrinsicOperator();
358   } else {
359     return false;
360   }
361 }
362
363 static std::ostream &PutGenericName(std::ostream &os, const Symbol &symbol) {
364   if (IsGenericDefinedOp(symbol)) {
365     return os << "operator(" << symbol.name() << ')';
366   } else {
367     return os << symbol.name();
368   }
369 }
370
371 void ModFileWriter::PutGeneric(const Symbol &symbol) {
372   auto &details{symbol.get<GenericDetails>()};
373   PutGenericName(decls_ << "interface ", symbol) << '\n';
374   for (const Symbol &specific : details.specificProcs()) {
375     decls_ << "procedure::" << specific.name() << '\n';
376   }
377   decls_ << "end interface\n";
378   if (symbol.attrs().test(Attr::PRIVATE)) {
379     PutGenericName(decls_ << "private::", symbol) << '\n';
380   }
381 }
382
383 void ModFileWriter::PutUse(const Symbol &symbol) {
384   auto &details{symbol.get<UseDetails>()};
385   auto &use{details.symbol()};
386   uses_ << "use " << details.module().name();
387   PutGenericName(uses_ << ",only:", symbol);
388   // Can have intrinsic op with different local-name and use-name
389   // (e.g. `operator(<)` and `operator(.lt.)`) but rename is not allowed
390   if (!IsIntrinsicOp(symbol) && use.name() != symbol.name()) {
391     PutGenericName(uses_ << "=>", use);
392   }
393   uses_ << '\n';
394   PutUseExtraAttr(Attr::VOLATILE, symbol, use);
395   PutUseExtraAttr(Attr::ASYNCHRONOUS, symbol, use);
396 }
397
398 // We have "USE local => use" in this module. If attr was added locally
399 // (i.e. on local but not on use), also write it out in the mod file.
400 void ModFileWriter::PutUseExtraAttr(
401     Attr attr, const Symbol &local, const Symbol &use) {
402   if (local.attrs().test(attr) && !use.attrs().test(attr)) {
403     PutAttr(useExtraAttrs_, attr) << "::";
404     useExtraAttrs_ << local.name() << '\n';
405   }
406 }
407
408 // Collect the symbols of this scope sorted by their original order, not name.
409 // Namelists are an exception: they are sorted after other symbols.
410 SymbolVector CollectSymbols(const Scope &scope) {
411   SymbolSet symbols;  // to prevent duplicates
412   SymbolVector sorted;
413   SymbolVector namelist;
414   SymbolVector common;
415   sorted.reserve(scope.size() + scope.commonBlocks().size());
416   for (const auto &pair : scope) {
417     const Symbol &symbol{*pair.second};
418     if (!symbol.test(Symbol::Flag::ParentComp)) {
419       if (symbols.insert(symbol).second) {
420         if (symbol.has<NamelistDetails>()) {
421           namelist.push_back(symbol);
422         } else {
423           sorted.push_back(symbol);
424         }
425       }
426     }
427   }
428   for (const auto &pair : scope.commonBlocks()) {
429     const Symbol &symbol{*pair.second};
430     if (symbols.insert(symbol).second) {
431       common.push_back(symbol);
432     }
433   }
434   // sort normal symbols, then namelists, then common blocks:
435   auto cursor{sorted.begin()};
436   std::sort(cursor, sorted.end());
437   cursor = sorted.insert(sorted.end(), namelist.begin(), namelist.end());
438   std::sort(cursor, sorted.end());
439   cursor = sorted.insert(sorted.end(), common.begin(), common.end());
440   std::sort(cursor, sorted.end());
441   return sorted;
442 }
443
444 void PutEntity(std::ostream &os, const Symbol &symbol) {
445   std::visit(
446       common::visitors{
447           [&](const ObjectEntityDetails &) { PutObjectEntity(os, symbol); },
448           [&](const ProcEntityDetails &) { PutProcEntity(os, symbol); },
449           [&](const TypeParamDetails &) { PutTypeParam(os, symbol); },
450           [&](const auto &) {
451             common::die("PutEntity: unexpected details: %s",
452                 DetailsToString(symbol.details()).c_str());
453           },
454       },
455       symbol.details());
456 }
457
458 void PutShapeSpec(std::ostream &os, const ShapeSpec &x) {
459   if (x.lbound().isAssumed()) {
460     CHECK(x.ubound().isAssumed());
461     os << "..";
462   } else {
463     if (!x.lbound().isDeferred()) {
464       PutBound(os, x.lbound());
465     }
466     os << ':';
467     if (!x.ubound().isDeferred()) {
468       PutBound(os, x.ubound());
469     }
470   }
471 }
472 void PutShape(std::ostream &os, const ArraySpec &shape, char open, char close) {
473   if (!shape.empty()) {
474     os << open;
475     bool first{true};
476     for (const auto &shapeSpec : shape) {
477       if (first) {
478         first = false;
479       } else {
480         os << ',';
481       }
482       PutShapeSpec(os, shapeSpec);
483     }
484     os << close;
485   }
486 }
487
488 void PutObjectEntity(std::ostream &os, const Symbol &symbol) {
489   auto &details{symbol.get<ObjectEntityDetails>()};
490   PutEntity(os, symbol, [&]() { PutType(os, DEREF(symbol.GetType())); },
491       symbol.attrs());
492   PutShape(os, details.shape(), '(', ')');
493   PutShape(os, details.coshape(), '[', ']');
494   PutInit(os, symbol, details.init());
495   os << '\n';
496 }
497
498 void PutProcEntity(std::ostream &os, const Symbol &symbol) {
499   if (symbol.attrs().test(Attr::INTRINSIC)) {
500     os << "intrinsic::" << symbol.name() << '\n';
501     return;
502   }
503   const auto &details{symbol.get<ProcEntityDetails>()};
504   const ProcInterface &interface{details.interface()};
505   Attrs attrs{symbol.attrs()};
506   if (details.passName()) {
507     attrs.reset(Attr::PASS);
508   }
509   PutEntity(os, symbol,
510       [&]() {
511         os << "procedure(";
512         if (interface.symbol()) {
513           os << interface.symbol()->name();
514         } else if (interface.type()) {
515           PutType(os, *interface.type());
516         }
517         os << ')';
518         PutPassName(os, details.passName());
519       },
520       attrs);
521   os << '\n';
522 }
523
524 void PutPassName(std::ostream &os, const std::optional<SourceName> &passName) {
525   if (passName) {
526     os << ",pass(" << *passName << ')';
527   }
528 }
529
530 void PutTypeParam(std::ostream &os, const Symbol &symbol) {
531   auto &details{symbol.get<TypeParamDetails>()};
532   PutEntity(os, symbol,
533       [&]() {
534         PutType(os, DEREF(symbol.GetType()));
535         PutLower(os << ',', common::EnumToString(details.attr()));
536       },
537       symbol.attrs());
538   PutInit(os, details.init());
539   os << '\n';
540 }
541
542 void PutInit(std::ostream &os, const Symbol &symbol, const MaybeExpr &init) {
543   if (init) {
544     if (symbol.attrs().test(Attr::PARAMETER) ||
545         symbol.owner().IsDerivedType()) {
546       os << (symbol.attrs().test(Attr::POINTER) ? "=>" : "=");
547       init->AsFortran(os);
548     }
549   }
550 }
551
552 void PutInit(std::ostream &os, const MaybeIntExpr &init) {
553   if (init) {
554     init->AsFortran(os << '=');
555   }
556 }
557
558 void PutBound(std::ostream &os, const Bound &x) {
559   if (x.isAssumed()) {
560     os << '*';
561   } else if (x.isDeferred()) {
562     os << ':';
563   } else {
564     x.GetExplicit()->AsFortran(os);
565   }
566 }
567
568 // Write an entity (object or procedure) declaration.
569 // writeType is called to write out the type.
570 void PutEntity(std::ostream &os, const Symbol &symbol,
571     std::function<void()> writeType, Attrs attrs) {
572   writeType();
573   MaybeExpr bindName;
574   std::visit(
575       common::visitors{
576           [&](const SubprogramDetails &x) { bindName = x.bindName(); },
577           [&](const ObjectEntityDetails &x) { bindName = x.bindName(); },
578           [&](const ProcEntityDetails &x) { bindName = x.bindName(); },
579           [&](const auto &) {},
580       },
581       symbol.details());
582   PutAttrs(os, attrs, bindName);
583   os << "::" << symbol.name();
584 }
585
586 // Put out each attribute to os, surrounded by `before` and `after` and
587 // mapped to lower case.
588 std::ostream &PutAttrs(std::ostream &os, Attrs attrs, const MaybeExpr &bindName,
589     std::string before, std::string after) {
590   attrs.set(Attr::PUBLIC, false);  // no need to write PUBLIC
591   attrs.set(Attr::EXTERNAL, false);  // no need to write EXTERNAL
592   if (bindName) {
593     bindName->AsFortran(os << before << "bind(c, name=") << ')' << after;
594     attrs.set(Attr::BIND_C, false);
595   }
596   for (std::size_t i{0}; i < Attr_enumSize; ++i) {
597     Attr attr{static_cast<Attr>(i)};
598     if (attrs.test(attr)) {
599       PutAttr(os << before, attr) << after;
600     }
601   }
602   return os;
603 }
604
605 std::ostream &PutAttr(std::ostream &os, Attr attr) {
606   return PutLower(os, AttrToString(attr));
607 }
608
609 std::ostream &PutType(std::ostream &os, const DeclTypeSpec &type) {
610   return PutLower(os, type.AsFortran());
611 }
612
613 std::ostream &PutLower(std::ostream &os, const std::string &str) {
614   for (char c : str) {
615     os << parser::ToLowerCaseLetter(c);
616   }
617   return os;
618 }
619
620 struct Temp {
621   Temp() = delete;
622   ~Temp() {
623     close(fd);
624     unlink(path.c_str());
625   }
626   int fd;
627   std::string path;
628 };
629
630 // Create a temp file in the same directory and with the same suffix as path.
631 // Return an open file descriptor and its path.
632 static Temp MkTemp(const std::string &path) {
633   auto length{path.length()};
634   auto dot{path.find_last_of("./")};
635   std::string suffix{dot < length && path[dot] == '.' ? path.substr(dot) : ""};
636   CHECK(length > suffix.length() &&
637       path.substr(length - suffix.length()) == suffix);
638   auto tempPath{path.substr(0, length - suffix.length()) + "XXXXXX" + suffix};
639   int fd{mkstemps(&tempPath[0], suffix.length())};
640   auto mask{umask(0777)};
641   umask(mask);
642   chmod(tempPath.c_str(), 0666 & ~mask);  // temp is created with mode 0600
643   return Temp{fd, tempPath};
644 }
645
646 // Write the module file at path, prepending header. If an error occurs,
647 // return errno, otherwise 0.
648 static int WriteFile(const std::string &path, const std::string &contents) {
649   auto header{std::string{ModHeader::bom} + ModHeader::magic +
650       CheckSum(contents) + ModHeader::terminator};
651   if (FileContentsMatch(path, header, contents)) {
652     return 0;
653   }
654   Temp temp{MkTemp(path)};
655   if (temp.fd < 0) {
656     return errno;
657   }
658   if (write(temp.fd, header.c_str(), header.size()) !=
659           static_cast<ssize_t>(header.size()) ||
660       write(temp.fd, contents.c_str(), contents.size()) !=
661           static_cast<ssize_t>(contents.size())) {
662     return errno;
663   }
664   if (std::rename(temp.path.c_str(), path.c_str()) == -1) {
665     return errno;
666   }
667   return 0;
668 }
669
670 // Return true if the stream matches what we would write for the mod file.
671 static bool FileContentsMatch(const std::string &path,
672     const std::string &header, const std::string &contents) {
673   std::size_t hsize{header.size()};
674   std::size_t csize{contents.size()};
675   if (GetFileSize(path) != hsize + csize) {
676     return false;
677   }
678   int fd{open(path.c_str(), O_RDONLY)};
679   if (fd < 0) {
680     return false;
681   }
682   constexpr std::size_t bufSize{4096};
683   std::string buffer(bufSize, '\0');
684   if (read(fd, &buffer[0], hsize) != static_cast<ssize_t>(hsize) ||
685       std::memcmp(&buffer[0], &header[0], hsize) != 0) {
686     close(fd);
687     return false;  // header doesn't match
688   }
689   for (auto remaining{csize};;) {
690     auto bytes{std::min(bufSize, remaining)};
691     auto got{read(fd, &buffer[0], bytes)};
692     if (got != static_cast<ssize_t>(bytes) ||
693         std::memcmp(&buffer[0], &contents[csize - remaining], bytes) != 0) {
694       close(fd);
695       return false;
696     }
697     if (bytes == 0 && remaining == 0) {
698       close(fd);
699       return true;
700     }
701     remaining -= bytes;
702   }
703 }
704
705 // Compute a simple hash of the contents of a module file and
706 // return it as a string of hex digits.
707 // This uses the Fowler-Noll-Vo hash function.
708 static std::string CheckSum(const std::string_view &contents) {
709   std::uint64_t hash{0xcbf29ce484222325ull};
710   for (char c : contents) {
711     hash ^= c & 0xff;
712     hash *= 0x100000001b3;
713   }
714   static const char *digits = "0123456789abcdef";
715   std::string result(ModHeader::sumLen, '0');
716   for (size_t i{ModHeader::sumLen}; hash != 0; hash >>= 4) {
717     result[--i] = digits[hash & 0xf];
718   }
719   return result;
720 }
721
722 static bool VerifyHeader(const char *content, std::size_t len) {
723   std::string_view sv{content, len};
724   if (sv.substr(0, ModHeader::magicLen) != ModHeader::magic) {
725     return false;
726   }
727   std::string_view expectSum{sv.substr(ModHeader::magicLen, ModHeader::sumLen)};
728   std::string actualSum{CheckSum(sv.substr(ModHeader::len))};
729   return expectSum == actualSum;
730 }
731
732 static std::size_t GetFileSize(const std::string &path) {
733   struct stat statbuf;
734   if (stat(path.c_str(), &statbuf) == 0) {
735     return static_cast<std::size_t>(statbuf.st_size);
736   } else {
737     return 0;
738   }
739 }
740
741 Scope *ModFileReader::Read(const SourceName &name, Scope *ancestor) {
742   std::string ancestorName;  // empty for module
743   if (ancestor) {
744     if (auto *scope{ancestor->FindSubmodule(name)}) {
745       return scope;
746     }
747     ancestorName = ancestor->GetName().value().ToString();
748   } else {
749     auto it{context_.globalScope().find(name)};
750     if (it != context_.globalScope().end()) {
751       return it->second->scope();
752     }
753   }
754   parser::Parsing parsing{context_.allSources()};
755   parser::Options options;
756   options.isModuleFile = true;
757   options.features.Enable(common::LanguageFeature::BackslashEscapes);
758   options.searchDirectories = context_.searchDirectories();
759   auto path{ModFileName(name, ancestorName, context_.moduleFileSuffix())};
760   const auto *sourceFile{parsing.Prescan(path, options)};
761   if (parsing.messages().AnyFatalError()) {
762     for (auto &msg : parsing.messages().messages()) {
763       std::string str{msg.ToString()};
764       Say(name, ancestorName, parser::MessageFixedText{str.c_str(), str.size()},
765           path);
766     }
767     return nullptr;
768   }
769   CHECK(sourceFile);
770   if (!VerifyHeader(sourceFile->content(), sourceFile->bytes())) {
771     Say(name, ancestorName, "File has invalid checksum: %s"_en_US,
772         sourceFile->path());
773     return nullptr;
774   }
775
776   parsing.Parse(nullptr);
777   auto &parseTree{parsing.parseTree()};
778   if (!parsing.messages().empty() || !parsing.consumedWholeFile() ||
779       !parseTree) {
780     Say(name, ancestorName, "Module file is corrupt: %s"_err_en_US,
781         sourceFile->path());
782     return nullptr;
783   }
784   Scope *parentScope;  // the scope this module/submodule goes into
785   if (!ancestor) {
786     parentScope = &context_.globalScope();
787   } else if (std::optional<SourceName> parent{GetSubmoduleParent(*parseTree)}) {
788     parentScope = Read(*parent, ancestor);
789   } else {
790     parentScope = ancestor;
791   }
792   ResolveNames(context_, *parseTree);
793   const auto &it{parentScope->find(name)};
794   if (it == parentScope->end()) {
795     return nullptr;
796   }
797   auto &modSymbol{*it->second};
798   modSymbol.set(Symbol::Flag::ModFile);
799   modSymbol.scope()->set_chars(parsing.cooked());
800   return modSymbol.scope();
801 }
802
803 parser::Message &ModFileReader::Say(const SourceName &name,
804     const std::string &ancestor, parser::MessageFixedText &&msg,
805     const std::string &arg) {
806   return context_
807       .Say(name,
808           ancestor.empty()
809               ? "Error reading module file for module '%s'"_err_en_US
810               : "Error reading module file for submodule '%s' of module '%s'"_err_en_US,
811           name, ancestor)
812       .Attach(name, std::move(msg), arg);
813 }
814
815 // program was read from a .mod file for a submodule; return the name of the
816 // submodule's parent submodule, nullptr if none.
817 static std::optional<SourceName> GetSubmoduleParent(
818     const parser::Program &program) {
819   CHECK(program.v.size() == 1);
820   auto &unit{program.v.front()};
821   auto &submod{std::get<common::Indirection<parser::Submodule>>(unit.u)};
822   auto &stmt{
823       std::get<parser::Statement<parser::SubmoduleStmt>>(submod.value().t)};
824   auto &parentId{std::get<parser::ParentIdentifier>(stmt.statement.t)};
825   if (auto &parent{std::get<std::optional<parser::Name>>(parentId.t)}) {
826     return parent->source;
827   } else {
828     return std::nullopt;
829   }
830 }
831
832 void SubprogramSymbolCollector::Collect() {
833   const auto &details{symbol_.get<SubprogramDetails>()};
834   isInterface_ = details.isInterface();
835   for (const Symbol *dummyArg : details.dummyArgs()) {
836     DoSymbol(DEREF(dummyArg));
837   }
838   if (details.isFunction()) {
839     DoSymbol(details.result());
840   }
841   for (const auto &pair : scope_) {
842     const Symbol &symbol{*pair.second};
843     if (const auto *useDetails{symbol.detailsIf<UseDetails>()}) {
844       if (useSet_.count(useDetails->symbol()) > 0) {
845         need_.push_back(symbol);
846       }
847     }
848   }
849 }
850
851 void SubprogramSymbolCollector::DoSymbol(const Symbol &symbol) {
852   DoSymbol(symbol.name(), symbol);
853 }
854
855 // Do symbols this one depends on; then add to need_
856 void SubprogramSymbolCollector::DoSymbol(
857     const SourceName &name, const Symbol &symbol) {
858   const auto &scope{symbol.owner()};
859   if (scope != scope_ && !scope.IsDerivedType()) {
860     if (scope != scope_.parent()) {
861       useSet_.insert(symbol);
862     }
863     if (NeedImport(name, symbol)) {
864       imports_.insert(name);
865     }
866     return;
867   }
868   if (!needSet_.insert(symbol).second) {
869     return;  // already done
870   }
871   std::visit(
872       common::visitors{
873           [this](const ObjectEntityDetails &details) {
874             for (const ShapeSpec &spec : details.shape()) {
875               DoBound(spec.lbound());
876               DoBound(spec.ubound());
877             }
878             for (const ShapeSpec &spec : details.coshape()) {
879               DoBound(spec.lbound());
880               DoBound(spec.ubound());
881             }
882             if (const Symbol * commonBlock{details.commonBlock()}) {
883               DoSymbol(*commonBlock);
884             }
885           },
886           [this](const CommonBlockDetails &details) {
887             for (const Symbol &object : details.objects()) {
888               DoSymbol(object);
889             }
890           },
891           [](const auto &) {},
892       },
893       symbol.details());
894   if (!symbol.has<UseDetails>()) {
895     DoType(symbol.GetType());
896   }
897   if (!scope.IsDerivedType()) {
898     need_.push_back(symbol);
899   }
900 }
901
902 void SubprogramSymbolCollector::DoType(const DeclTypeSpec *type) {
903   if (!type) {
904     return;
905   }
906   switch (type->category()) {
907   case DeclTypeSpec::Numeric:
908   case DeclTypeSpec::Logical: break;  // nothing to do
909   case DeclTypeSpec::Character:
910     DoParamValue(type->characterTypeSpec().length());
911     break;
912   default:
913     if (const DerivedTypeSpec * derived{type->AsDerived()}) {
914       const auto &typeSymbol{derived->typeSymbol()};
915       if (const DerivedTypeSpec * extends{typeSymbol.GetParentTypeSpec()}) {
916         DoSymbol(extends->name(), extends->typeSymbol());
917       }
918       for (const auto &pair : derived->parameters()) {
919         DoParamValue(pair.second);
920       }
921       for (const auto &pair : *typeSymbol.scope()) {
922         const Symbol &comp{*pair.second};
923         DoSymbol(comp);
924       }
925       DoSymbol(derived->name(), derived->typeSymbol());
926     }
927   }
928 }
929
930 void SubprogramSymbolCollector::DoBound(const Bound &bound) {
931   if (const MaybeSubscriptIntExpr & expr{bound.GetExplicit()}) {
932     DoExpr(*expr);
933   }
934 }
935 void SubprogramSymbolCollector::DoParamValue(const ParamValue &paramValue) {
936   if (const auto &expr{paramValue.GetExplicit()}) {
937     DoExpr(*expr);
938   }
939 }
940
941 // Do we need a IMPORT of this symbol into an interface block?
942 bool SubprogramSymbolCollector::NeedImport(
943     const SourceName &name, const Symbol &symbol) {
944   if (!isInterface_) {
945     return false;
946   } else if (symbol.owner() != scope_.parent()) {
947     // detect import from parent of use-associated symbol
948     // can be null in the case of a use-associated derived type's parent type
949     const auto *found{scope_.FindSymbol(name)};
950     CHECK(found || symbol.has<DerivedTypeDetails>());
951     return found && found->has<UseDetails>() && found->owner() != scope_;
952   } else {
953     return true;
954   }
955 }
956
957 }