[flang] Add a way to check and dereference a pointer
authorTim Keith <tkeith@nvidia.com>
Mon, 29 Jul 2019 16:12:52 +0000 (09:12 -0700)
committerTim Keith <tkeith@nvidia.com>
Mon, 29 Jul 2019 16:12:52 +0000 (09:12 -0700)
It is common to get a pointer, check it is not null, and dereference it.
Sometimes that requires a named temporary just to be able to do the check.

The macro `DEREF(p)` provides this capability: it asserts that `p` is not null
and returns `*p`. This is analagous to `.value()` on an `std::optional`.

We might want to add a way to disable `CHECK` and the check in `DEREF` together.

This change also includes some examples of making use of `DEREF`.

Original-commit: flang-compiler/f18@d7aa90e55ac80c7f2460ab7b0cb6d1ef0c068938
Reviewed-on: https://github.com/flang-compiler/f18/pull/608

flang/lib/common/idioms.h
flang/lib/semantics/assignment.cc
flang/lib/semantics/check-allocate.cc
flang/lib/semantics/expression.cc
flang/lib/semantics/mod-file.cc
flang/lib/semantics/resolve-names.cc
flang/lib/semantics/scope.cc
flang/lib/semantics/scope.h
flang/lib/semantics/unparse-with-symbols.cc

index 8ad2ac3..619bbd9 100644 (file)
@@ -130,6 +130,16 @@ template<typename A> struct ListItemCount {
         static_cast<int>(e), #__VA_ARGS__); \
   }
 
+// Check that a pointer is non-null and dereference it
+#define DEREF(p) Fortran::common::Deref(p, __FILE__, __LINE__)
+
+template<typename T> T &Deref(T *p, const char *file, int line) {
+  if (p == nullptr) {
+    Fortran::common::die("nullptr dereference at %s(%d)", file, line);
+  }
+  return *p;
+}
+
 // Given a const reference to a value, return a copy of the value.
 template<typename A> A Clone(const A &x) { return x; }
 
index 24b0b67..f3a38b7 100644 (file)
@@ -455,8 +455,7 @@ void AssignmentContext::Analyze(
 
 void AssignmentContext::Analyze(
     const parser::WhereConstruct::Elsewhere &elsewhere) {
-  CHECK(where_ != nullptr);
-  MaskExpr copyCumulative{where_->cumulativeMaskExpr};
+  MaskExpr copyCumulative{DEREF(where_).cumulativeMaskExpr};
   where_->thisMaskExpr = evaluate::LogicalNegation(std::move(copyCumulative));
   for (const auto &x :
       std::get<std::list<parser::WhereBodyConstruct>>(elsewhere.t)) {
@@ -465,8 +464,7 @@ void AssignmentContext::Analyze(
 }
 
 void AssignmentContext::Analyze(const parser::ConcurrentHeader &header) {
-  CHECK(forall_ != nullptr);
-  forall_->integerKind = GetIntegerKind(
+  DEREF(forall_).integerKind = GetIntegerKind(
       std::get<std::optional<parser::IntegerTypeSpec>>(header.t));
   for (const auto &control :
       std::get<std::list<parser::ConcurrentControl>>(header.t)) {
index 6b14b1d..8ea2c33 100644 (file)
@@ -192,8 +192,7 @@ static std::optional<AllocateCheckerInfo> CheckAllocateOptions(
   }
 
   if (info.gotSrc || info.gotMold) {
-    CHECK(parserSourceExpr);
-    if (const auto *expr{GetExpr(*parserSourceExpr)}) {
+    if (const auto *expr{GetExpr(DEREF(parserSourceExpr))}) {
       info.sourceExprType = expr->GetType();
       if (!info.sourceExprType.has_value()) {
         CHECK(context.AnyFatalError());
@@ -390,13 +389,10 @@ static bool HaveCompatibleKindParameters(
     return true;
   }
   if (const IntrinsicTypeSpec * intrinsicType1{type1.AsIntrinsic()}) {
-    const IntrinsicTypeSpec *intrinsicType2{type2.AsIntrinsic()};
-    CHECK(intrinsicType2);  // Violation of type compatibility hypothesis.
-    return intrinsicType1->kind() == intrinsicType2->kind();
+    return intrinsicType1->kind() == DEREF(type2.AsIntrinsic()).kind();
   } else if (const DerivedTypeSpec * derivedType1{type1.AsDerived()}) {
-    const DerivedTypeSpec *derivedType2{type2.AsDerived()};
-    CHECK(derivedType2);  // Violation of type compatibility hypothesis.
-    return HaveCompatibleKindParameters(*derivedType1, *derivedType2);
+    return HaveCompatibleKindParameters(
+        *derivedType1, DEREF(type2.AsDerived()));
   } else {
     common::die("unexpected type1 category");
   }
index 32b4be1..c537fe0 100644 (file)
@@ -721,8 +721,7 @@ MaybeExpr ExpressionAnalyzer::Analyze(
           [&](auto &&ckExpr) -> MaybeExpr {
             using Result = ResultType<decltype(ckExpr)>;
             auto *cp{std::get_if<Constant<Result>>(&ckExpr.u)};
-            CHECK(cp != nullptr);  // the parent was parsed as a constant string
-            CHECK(cp->size() == 1);
+            CHECK(DEREF(cp).size() == 1);
             StaticDataObject::Pointer staticData{StaticDataObject::Create()};
             staticData->set_alignment(Result::kind)
                 .set_itemBytes(Result::kind)
@@ -2088,9 +2087,8 @@ MaybeExpr ExpressionAnalyzer::MakeFunctionRef(
             ProcedureRef{std::move(proc), std::move(arguments)}};
       } else {
         // Not a procedure pointer, so type and shape are known.
-        const auto *typeAndShape{result.GetTypeAndShape()};
-        CHECK(typeAndShape != nullptr);
-        return TypedWrapper<FunctionRef, ProcedureRef>(typeAndShape->type(),
+        return TypedWrapper<FunctionRef, ProcedureRef>(
+            DEREF(result.GetTypeAndShape()).type(),
             ProcedureRef{std::move(proc), std::move(arguments)});
       }
     }
index a4dd501..7a11028 100644 (file)
@@ -96,8 +96,7 @@ private:
   template<typename T> void DoExpr(evaluate::Expr<T> expr) {
     evaluate::Visitor<SymbolVisitor> visitor{0};
     for (const Symbol *symbol : visitor.Traverse(expr)) {
-      CHECK(symbol && "bad symbol from Traverse");
-      DoSymbol(*symbol);
+      DoSymbol(DEREF(symbol));
     }
   }
 };
@@ -403,8 +402,7 @@ std::vector<const Symbol *> CollectSymbols(const Scope &scope) {
   }
   // sort normal symbols, then namelists, then common blocks:
   auto compareByOrder = [](const Symbol *x, const Symbol *y) {
-    CHECK(x != nullptr);
-    return x->name().begin() < y->name().begin();
+    return DEREF(x).name().begin() < DEREF(y).name().begin();
   };
   auto cursor{sorted.begin()};
   std::sort(cursor, sorted.end(), compareByOrder);
@@ -461,11 +459,7 @@ void PutShape(std::ostream &os, const ArraySpec &shape, char open, char close) {
 
 void PutObjectEntity(std::ostream &os, const Symbol &symbol) {
   auto &details{symbol.get<ObjectEntityDetails>()};
-  PutEntity(os, symbol, [&]() {
-    auto *type{symbol.GetType()};
-    CHECK(type);
-    PutLower(os, *type);
-  });
+  PutEntity(os, symbol, [&]() { PutLower(os, DEREF(symbol.GetType())); });
   PutShape(os, details.shape(), '(', ')');
   PutShape(os, details.coshape(), '[', ']');
   PutInit(os, details.init());
@@ -500,9 +494,7 @@ void PutPassName(std::ostream &os, const SourceName *passName) {
 void PutTypeParam(std::ostream &os, const Symbol &symbol) {
   auto &details{symbol.get<TypeParamDetails>()};
   PutEntity(os, symbol, [&]() {
-    auto *type{symbol.GetType()};
-    CHECK(type);
-    PutLower(os, *type);
+    PutLower(os, DEREF(symbol.GetType()));
     PutLower(os << ',', common::EnumToString(details.attr()));
   });
   PutInit(os, details.init());
@@ -795,8 +787,7 @@ void SubprogramSymbolCollector::Collect() {
     DoSymbol(details.result());
   }
   for (const Symbol *dummyArg : details.dummyArgs()) {
-    CHECK(dummyArg);
-    DoSymbol(*dummyArg);
+    DoSymbol(DEREF(dummyArg));
   }
   for (const auto &pair : scope_) {
     const Symbol *symbol{pair.second};
index 4fbf954..e67777f 100644 (file)
@@ -1552,8 +1552,7 @@ void ScopeHandler::SayLocalMustBeVariable(
 void ScopeHandler::SayDerivedType(
     const SourceName &name, MessageFixedText &&msg, const Scope &type) {
   const Symbol *typeSymbol{type.GetSymbol()};
-  CHECK(typeSymbol != nullptr);
-  Say(name, std::move(msg), name, typeSymbol->name())
+  Say(name, std::move(msg), name, DEREF(typeSymbol).name())
       .Attach(typeSymbol->name(), "Declaration of derived type '%s'"_en_US,
           typeSymbol->name());
 }
index 9767f3c..ca0c426 100644 (file)
@@ -174,8 +174,7 @@ DeclTypeSpec &Scope::MakeDerivedType(
 void Scope::set_chars(parser::CookedSource &cooked) {
   CHECK(kind_ == Kind::Module);
   CHECK(parent_.IsGlobal() || parent_.IsModuleFile());
-  CHECK(symbol_ != nullptr);
-  CHECK(symbol_->test(Symbol::Flag::ModFile));
+  CHECK(DEREF(symbol_).test(Symbol::Flag::ModFile));
   // TODO: Preserve the CookedSource rather than acquiring its string.
   chars_ = cooked.AcquireData();
 }
index 281a5f5..9a95f65 100644 (file)
@@ -89,11 +89,7 @@ public:
   const Symbol *GetSymbol() const;
   const Scope *GetDerivedTypeParent() const;
 
-  const SourceName &name() const {
-    const Symbol *sym{GetSymbol()};
-    CHECK(sym != nullptr);
-    return sym->name();
-  }
+  const SourceName &name() const { return DEREF(GetSymbol()).name(); }
 
   /// Make a scope nested in this one
   Scope &MakeScope(Kind kind, Symbol *symbol = nullptr);
index a3155a1..5508c01 100644 (file)
@@ -1,4 +1,4 @@
-// Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
+// Copyright (c) 2018-2019, NVIDIA CORPORATION.  All rights reserved.
 //
 // Licensed under the Apache License, Version 2.0 (the "License");
 // you may not use this file except in compliance with the License.
@@ -75,8 +75,7 @@ void SymbolDumpVisitor::Indent(std::ostream &out, int indent) const {
 void SymbolDumpVisitor::Post(const parser::Name &name) {
   if (const auto *symbol{name.symbol}) {
     if (!symbol->has<MiscDetails>()) {
-      CHECK(currStmt_);
-      symbols_.emplace(currStmt_->begin(), symbol);
+      symbols_.emplace(DEREF(currStmt_).begin(), symbol);
     }
   }
 }