[clang][ASTImporter] Improve import of InjectedClassNameType.
authorBalázs Kéri <1.int32@gmail.com>
Wed, 1 Mar 2023 07:42:32 +0000 (08:42 +0100)
committerBalázs Kéri <1.int32@gmail.com>
Wed, 1 Mar 2023 08:26:10 +0000 (09:26 +0100)
During AST import multiple different InjectedClassNameType objects
could be created for a single class template. This can cause problems
and failed assertions when these types are compared and found to be
not the same (because the instance is different and there is no
canonical type).
The import of this type does not use the factory method in ASTContext,
probably because the preconditions are not fulfilled at that state.
The fix tries to make the code in ASTImporter work more like the code
in ASTContext::getInjectedClassNameType. If a type is stored at the
Decl or previous Decl object, it is reused instead of creating a new
one. This avoids crash at least a part of the cases.

Reviewed By: gamesh411, donat.nagy, vabridgers

Differential Revision: https://reviews.llvm.org/D140562

clang/lib/AST/ASTImporter.cpp
clang/unittests/AST/ASTImporterTest.cpp

index 3ca2b67..99f72d0 100644 (file)
@@ -1464,20 +1464,11 @@ ExpectedType ASTNodeImporter::VisitInjectedClassNameType(
   if (!ToDeclOrErr)
     return ToDeclOrErr.takeError();
 
-  ExpectedType ToInjTypeOrErr = import(T->getInjectedSpecializationType());
-  if (!ToInjTypeOrErr)
-    return ToInjTypeOrErr.takeError();
-
-  // FIXME: ASTContext::getInjectedClassNameType is not suitable for AST reading
-  // See comments in InjectedClassNameType definition for details
-  // return Importer.getToContext().getInjectedClassNameType(D, InjType);
-  enum {
-    TypeAlignmentInBits = 4,
-    TypeAlignment = 1 << TypeAlignmentInBits
-  };
-
-  return QualType(new (Importer.getToContext(), TypeAlignment)
-                  InjectedClassNameType(*ToDeclOrErr, *ToInjTypeOrErr), 0);
+  // The InjectedClassNameType is created in VisitRecordDecl when the
+  // T->getDecl() is imported. Here we can return the existing type.
+  const Type *Ty = (*ToDeclOrErr)->getTypeForDecl();
+  assert(Ty && isa<InjectedClassNameType>(Ty));
+  return QualType(Ty, 0);
 }
 
 ExpectedType ASTNodeImporter::VisitRecordType(const RecordType *T) {
@@ -2952,8 +2943,6 @@ ExpectedDecl ASTNodeImporter::VisitRecordDecl(RecordDecl *D) {
         // InjectedClassNameType (see Sema::CheckClassTemplate). Update the
         // previously set type to the correct value here (ToDescribed is not
         // available at record create).
-        // FIXME: The previous type is cleared but not removed from
-        // ASTContext's internal storage.
         CXXRecordDecl *Injected = nullptr;
         for (NamedDecl *Found : D2CXX->noload_lookup(Name)) {
           auto *Record = dyn_cast<CXXRecordDecl>(Found);
@@ -2963,20 +2952,34 @@ ExpectedDecl ASTNodeImporter::VisitRecordDecl(RecordDecl *D) {
           }
         }
         // Create an injected type for the whole redecl chain.
+        // The chain may contain an already existing injected type at the start,
+        // if yes this should be reused. We must ensure that only one type
+        // object exists for the injected type (including the injected record
+        // declaration), ASTContext does not check it.
         SmallVector<Decl *, 2> Redecls =
             getCanonicalForwardRedeclChain(D2CXX);
+        const Type *FrontTy =
+            cast<CXXRecordDecl>(Redecls.front())->getTypeForDecl();
+        QualType InjSpec;
+        if (auto *InjTy = FrontTy->getAs<InjectedClassNameType>())
+          InjSpec = InjTy->getInjectedSpecializationType();
+        else
+          InjSpec = ToDescribed->getInjectedClassNameSpecialization();
         for (auto *R : Redecls) {
           auto *RI = cast<CXXRecordDecl>(R);
-          RI->setTypeForDecl(nullptr);
-          // Below we create a new injected type and assign that to the
-          // canonical decl, subsequent declarations in the chain will reuse
-          // that type.
-          Importer.getToContext().getInjectedClassNameType(
-              RI, ToDescribed->getInjectedClassNameSpecialization());
+          if (R != Redecls.front() ||
+              !isa<InjectedClassNameType>(RI->getTypeForDecl()))
+            RI->setTypeForDecl(nullptr);
+          // This function tries to get the injected type from getTypeForDecl,
+          // then from the previous declaration if possible. If not, it creates
+          // a new type.
+          Importer.getToContext().getInjectedClassNameType(RI, InjSpec);
         }
-        // Set the new type for the previous injected decl too.
+        // Set the new type for the injected decl too.
         if (Injected) {
           Injected->setTypeForDecl(nullptr);
+          // This function will copy the injected type from D2CXX into Injected.
+          // The injected decl does not have a previous decl to copy from.
           Importer.getToContext().getTypeDeclType(Injected, D2CXX);
         }
       }
index 78de31e..87b9fe5 100644 (file)
@@ -8137,6 +8137,247 @@ TEST_P(ASTImporterOptionSpecificTestBase, isNewDecl) {
   EXPECT_FALSE(SharedStatePtr->isNewDecl(ToBar));
 }
 
+struct ImportInjectedClassNameType : public ASTImporterOptionSpecificTestBase {
+protected:
+  const CXXRecordDecl *findInjected(const CXXRecordDecl *Parent) {
+    for (Decl *Found : Parent->decls()) {
+      const auto *Record = dyn_cast<CXXRecordDecl>(Found);
+      if (Record && Record->isInjectedClassName())
+        return Record;
+    }
+    return nullptr;
+  }
+
+  void checkInjType(const CXXRecordDecl *D) {
+    // The whole redecl chain should have the same InjectedClassNameType
+    // instance. The injected record declaration is a separate chain, this
+    // should contain the same type too.
+    const Type *Ty = nullptr;
+    for (const Decl *ReD : D->redecls()) {
+      const auto *ReRD = cast<CXXRecordDecl>(ReD);
+      EXPECT_TRUE(ReRD->getTypeForDecl());
+      EXPECT_TRUE(!Ty || Ty == ReRD->getTypeForDecl());
+      Ty = ReRD->getTypeForDecl();
+    }
+    ASSERT_TRUE(Ty);
+    const auto *InjTy = Ty->castAs<InjectedClassNameType>();
+    EXPECT_TRUE(InjTy);
+    if (CXXRecordDecl *Def = D->getDefinition()) {
+      const CXXRecordDecl *InjRD = findInjected(Def);
+      EXPECT_TRUE(InjRD);
+      EXPECT_EQ(InjRD->getTypeForDecl(), InjTy);
+    }
+  }
+
+  void testImport(Decl *ToTU, Decl *FromTU, Decl *FromD) {
+    checkInjType(cast<CXXRecordDecl>(FromD));
+    Decl *ToD = Import(FromD, Lang_CXX11);
+    if (auto *ToRD = dyn_cast<CXXRecordDecl>(ToD))
+      checkInjType(ToRD);
+  }
+
+  const char *ToCodeA =
+      R"(
+      template <class T>
+      struct A;
+      )";
+  const char *ToCodeADef =
+      R"(
+      template <class T>
+      struct A {
+        typedef A T1;
+      };
+      )";
+  const char *ToCodeC =
+      R"(
+      template <class T>
+      struct C;
+      )";
+  const char *ToCodeCDef =
+      R"(
+      template <class T>
+      struct A {
+        typedef A T1;
+      };
+
+      template <class T1, class T2>
+      struct B {};
+
+      template<class T>
+      struct C {
+        typedef typename A<T>::T1 T1;
+        typedef B<T1, T> T2;
+        typedef B<T1, C> T3;
+      };
+      )";
+  const char *FromCode =
+      R"(
+      template <class T>
+      struct A;
+      template <class T>
+      struct A {
+        typedef A T1;
+      };
+      template <class T>
+      struct A;
+
+      template <class T1, class T2>
+      struct B {};
+
+      template <class T>
+      struct C;
+      template <class T>
+      struct C {
+        typedef typename A<T>::T1 T1;
+        typedef B<T1, T> T2;
+        typedef B<T1, C> T3;
+      };
+      template <class T>
+      struct C;
+
+      template <class T>
+      struct D {
+        void f(typename C<T>::T3 *);
+      };
+      )";
+};
+
+TEST_P(ImportInjectedClassNameType, ImportADef) {
+  Decl *ToTU = getToTuDecl(ToCodeA, Lang_CXX11);
+  Decl *FromTU = getTuDecl(FromCode, Lang_CXX11);
+  auto *FromA = FirstDeclMatcher<CXXRecordDecl>().match(
+      FromTU, cxxRecordDecl(hasName("A"), isDefinition()));
+  testImport(ToTU, FromTU, FromA);
+}
+
+TEST_P(ImportInjectedClassNameType, ImportAFirst) {
+  Decl *ToTU = getToTuDecl(ToCodeA, Lang_CXX11);
+  Decl *FromTU = getTuDecl(FromCode, Lang_CXX11);
+  auto *FromA = FirstDeclMatcher<CXXRecordDecl>().match(
+      FromTU, cxxRecordDecl(hasName("A")));
+  testImport(ToTU, FromTU, FromA);
+}
+
+TEST_P(ImportInjectedClassNameType, ImportALast) {
+  Decl *ToTU = getToTuDecl(ToCodeA, Lang_CXX11);
+  Decl *FromTU = getTuDecl(FromCode, Lang_CXX11);
+  auto *FromA = LastDeclMatcher<CXXRecordDecl>().match(
+      FromTU, cxxRecordDecl(hasName("A")));
+  testImport(ToTU, FromTU, FromA);
+}
+
+TEST_P(ImportInjectedClassNameType, ImportADefToDef) {
+  Decl *ToTU = getToTuDecl(ToCodeADef, Lang_CXX11);
+  Decl *FromTU = getTuDecl(FromCode, Lang_CXX11);
+  auto *FromA = FirstDeclMatcher<CXXRecordDecl>().match(
+      FromTU, cxxRecordDecl(hasName("A"), isDefinition()));
+  testImport(ToTU, FromTU, FromA);
+}
+
+TEST_P(ImportInjectedClassNameType, ImportAFirstToDef) {
+  Decl *ToTU = getToTuDecl(ToCodeADef, Lang_CXX11);
+  Decl *FromTU = getTuDecl(FromCode, Lang_CXX11);
+  auto *FromA = FirstDeclMatcher<CXXRecordDecl>().match(
+      FromTU, cxxRecordDecl(hasName("A")));
+  testImport(ToTU, FromTU, FromA);
+}
+
+TEST_P(ImportInjectedClassNameType, ImportALastToDef) {
+  Decl *ToTU = getToTuDecl(ToCodeADef, Lang_CXX11);
+  Decl *FromTU = getTuDecl(FromCode, Lang_CXX11);
+  auto *FromA = LastDeclMatcher<CXXRecordDecl>().match(
+      FromTU, cxxRecordDecl(hasName("A")));
+  testImport(ToTU, FromTU, FromA);
+}
+
+TEST_P(ImportInjectedClassNameType, ImportCDef) {
+  Decl *ToTU = getToTuDecl(ToCodeC, Lang_CXX11);
+  Decl *FromTU = getTuDecl(FromCode, Lang_CXX11);
+  auto *FromC = FirstDeclMatcher<CXXRecordDecl>().match(
+      FromTU, cxxRecordDecl(hasName("C"), isDefinition()));
+  testImport(ToTU, FromTU, FromC);
+}
+
+TEST_P(ImportInjectedClassNameType, ImportCLast) {
+  Decl *ToTU = getToTuDecl(ToCodeC, Lang_CXX11);
+  Decl *FromTU = getTuDecl(FromCode, Lang_CXX11);
+  auto *FromC = LastDeclMatcher<CXXRecordDecl>().match(
+      FromTU, cxxRecordDecl(hasName("C")));
+  testImport(ToTU, FromTU, FromC);
+}
+
+TEST_P(ImportInjectedClassNameType, ImportCDefToDef) {
+  Decl *ToTU = getToTuDecl(ToCodeCDef, Lang_CXX11);
+  Decl *FromTU = getTuDecl(FromCode, Lang_CXX11);
+  auto *FromC = FirstDeclMatcher<CXXRecordDecl>().match(
+      FromTU, cxxRecordDecl(hasName("C"), isDefinition()));
+  testImport(ToTU, FromTU, FromC);
+}
+
+TEST_P(ImportInjectedClassNameType, ImportCLastToDef) {
+  Decl *ToTU = getToTuDecl(ToCodeCDef, Lang_CXX11);
+  Decl *FromTU = getTuDecl(FromCode, Lang_CXX11);
+  auto *FromC = LastDeclMatcher<CXXRecordDecl>().match(
+      FromTU, cxxRecordDecl(hasName("C")));
+  testImport(ToTU, FromTU, FromC);
+}
+
+TEST_P(ImportInjectedClassNameType, ImportD) {
+  Decl *ToTU = getToTuDecl("", Lang_CXX11);
+  Decl *FromTU = getTuDecl(FromCode, Lang_CXX11);
+  auto *FromD = FirstDeclMatcher<CXXRecordDecl>().match(
+      FromTU, cxxRecordDecl(hasName("D"), isDefinition()));
+  testImport(ToTU, FromTU, FromD);
+}
+
+TEST_P(ImportInjectedClassNameType, ImportDToDef) {
+  Decl *ToTU = getToTuDecl(ToCodeCDef, Lang_CXX11);
+  Decl *FromTU = getTuDecl(FromCode, Lang_CXX11);
+  auto *FromD = FirstDeclMatcher<CXXRecordDecl>().match(
+      FromTU, cxxRecordDecl(hasName("D"), isDefinition()));
+  testImport(ToTU, FromTU, FromD);
+}
+
+TEST_P(ImportInjectedClassNameType, ImportTypedefType) {
+  Decl *ToTU = getToTuDecl(
+      R"(
+      template <class T>
+      struct A {
+        typedef A A1;
+        void f(A1 *);
+      };
+      )",
+      Lang_CXX11);
+  Decl *FromTU = getTuDecl(
+      R"(
+      template <class T>
+      struct A {
+        typedef A A1;
+        void f(A1 *);
+      };
+      template<class T>
+      void A<T>::f(A::A1 *) {}
+      )",
+      Lang_CXX11);
+
+  auto *FromF = FirstDeclMatcher<FunctionDecl>().match(
+      FromTU, functionDecl(hasName("f"), isDefinition()));
+  auto *ToF = Import(FromF, Lang_CXX11);
+  EXPECT_TRUE(ToF);
+  ASTContext &ToCtx = ToF->getDeclContext()->getParentASTContext();
+
+  auto *ToA1 =
+      FirstDeclMatcher<TypedefDecl>().match(ToTU, typedefDecl(hasName("A1")));
+  QualType ToInjTypedef = ToA1->getUnderlyingType().getCanonicalType();
+  QualType ToInjParmVar =
+      ToF->parameters()[0]->getType().getDesugaredType(ToCtx);
+  ToInjParmVar =
+      ToInjParmVar->getAs<PointerType>()->getPointeeType().getCanonicalType();
+  EXPECT_TRUE(isa<InjectedClassNameType>(ToInjTypedef));
+  EXPECT_TRUE(isa<InjectedClassNameType>(ToInjParmVar));
+  EXPECT_TRUE(ToCtx.hasSameType(ToInjTypedef, ToInjParmVar));
+}
+
 INSTANTIATE_TEST_SUITE_P(ParameterizedTests, ASTImporterLookupTableTest,
                          DefaultTestValuesForRunOptions);
 
@@ -8214,5 +8455,8 @@ INSTANTIATE_TEST_SUITE_P(ParameterizedTests, ImportWithExternalSource,
 INSTANTIATE_TEST_SUITE_P(ParameterizedTests, ImportAttributes,
                          DefaultTestValuesForRunOptions);
 
+INSTANTIATE_TEST_SUITE_P(ParameterizedTests, ImportInjectedClassNameType,
+                         DefaultTestValuesForRunOptions);
+
 } // end namespace ast_matchers
 } // end namespace clang