Clang: fix AST representation of expanded template arguments.
authorMatheus Izvekov <mizvekov@gmail.com>
Sat, 18 Jun 2022 02:21:59 +0000 (04:21 +0200)
committerMatheus Izvekov <mizvekov@gmail.com>
Tue, 9 Aug 2022 12:26:16 +0000 (14:26 +0200)
Extend clang's SubstTemplateTypeParm to represent the pack substitution index.

Fixes PR56099.

Signed-off-by: Matheus Izvekov <mizvekov@gmail.com>
Differential Revision: https://reviews.llvm.org/D128113

15 files changed:
clang/include/clang/AST/ASTContext.h
clang/include/clang/AST/JSONNodeDumper.h
clang/include/clang/AST/TextNodeDumper.h
clang/include/clang/AST/Type.h
clang/include/clang/AST/TypeProperties.td
clang/lib/AST/ASTContext.cpp
clang/lib/AST/ASTImporter.cpp
clang/lib/AST/ASTStructuralEquivalence.cpp
clang/lib/AST/JSONNodeDumper.cpp
clang/lib/AST/TextNodeDumper.cpp
clang/lib/AST/Type.cpp
clang/lib/Sema/SemaTemplateInstantiate.cpp
clang/lib/Sema/TreeTransform.h
clang/test/AST/ast-dump-template-decls.cpp
clang/unittests/AST/ASTImporterTest.cpp

index d27b978..a01a6aa 100644 (file)
@@ -1614,7 +1614,8 @@ public:
                                    QualType Wrapped);
 
   QualType getSubstTemplateTypeParmType(const TemplateTypeParmType *Replaced,
-                                        QualType Replacement) const;
+                                        QualType Replacement,
+                                        Optional<unsigned> PackIndex) const;
   QualType getSubstTemplateTypeParmPackType(
                                           const TemplateTypeParmType *Replaced,
                                             const TemplateArgument &ArgPack);
index a5575d7..3597903 100644 (file)
@@ -220,6 +220,7 @@ public:
   void VisitUnaryTransformType(const UnaryTransformType *UTT);
   void VisitTagType(const TagType *TT);
   void VisitTemplateTypeParmType(const TemplateTypeParmType *TTPT);
+  void VisitSubstTemplateTypeParmType(const SubstTemplateTypeParmType *STTPT);
   void VisitAutoType(const AutoType *AT);
   void VisitTemplateSpecializationType(const TemplateSpecializationType *TST);
   void VisitInjectedClassNameType(const InjectedClassNameType *ICNT);
index 2e4bcdd..e6853b1 100644 (file)
@@ -317,6 +317,7 @@ public:
   void VisitUnaryTransformType(const UnaryTransformType *T);
   void VisitTagType(const TagType *T);
   void VisitTemplateTypeParmType(const TemplateTypeParmType *T);
+  void VisitSubstTemplateTypeParmType(const SubstTemplateTypeParmType *T);
   void VisitAutoType(const AutoType *T);
   void VisitDeducedTemplateSpecializationType(
       const DeducedTemplateSpecializationType *T);
index 1a8cf27..ad9835e 100644 (file)
@@ -1790,6 +1790,18 @@ protected:
     unsigned NumArgs;
   };
 
+  class SubstTemplateTypeParmTypeBitfields {
+    friend class SubstTemplateTypeParmType;
+
+    unsigned : NumTypeBits;
+
+    /// Represents the index within a pack if this represents a substitution
+    /// from a pack expansion.
+    /// Positive non-zero number represents the index + 1.
+    /// Zero means this is not substituted from an expansion.
+    unsigned PackIndex;
+  };
+
   class SubstTemplateTypeParmPackTypeBitfields {
     friend class SubstTemplateTypeParmPackType;
 
@@ -1872,6 +1884,7 @@ protected:
     ElaboratedTypeBitfields ElaboratedTypeBits;
     VectorTypeBitfields VectorTypeBits;
     SubstTemplateTypeParmPackTypeBitfields SubstTemplateTypeParmPackTypeBits;
+    SubstTemplateTypeParmTypeBitfields SubstTemplateTypeParmTypeBits;
     TemplateSpecializationTypeBitfields TemplateSpecializationTypeBits;
     DependentTemplateSpecializationTypeBitfields
       DependentTemplateSpecializationTypeBits;
@@ -4974,9 +4987,12 @@ class SubstTemplateTypeParmType : public Type, public llvm::FoldingSetNode {
   // The original type parameter.
   const TemplateTypeParmType *Replaced;
 
-  SubstTemplateTypeParmType(const TemplateTypeParmType *Param, QualType Canon)
+  SubstTemplateTypeParmType(const TemplateTypeParmType *Param, QualType Canon,
+                            Optional<unsigned> PackIndex)
       : Type(SubstTemplateTypeParm, Canon, Canon->getDependence()),
-        Replaced(Param) {}
+        Replaced(Param) {
+    SubstTemplateTypeParmTypeBits.PackIndex = PackIndex ? *PackIndex + 1 : 0;
+  }
 
 public:
   /// Gets the template parameter that was substituted for.
@@ -4990,18 +5006,25 @@ public:
     return getCanonicalTypeInternal();
   }
 
+  Optional<unsigned> getPackIndex() const {
+    if (SubstTemplateTypeParmTypeBits.PackIndex == 0)
+      return None;
+    return SubstTemplateTypeParmTypeBits.PackIndex - 1;
+  }
+
   bool isSugared() const { return true; }
   QualType desugar() const { return getReplacementType(); }
 
   void Profile(llvm::FoldingSetNodeID &ID) {
-    Profile(ID, getReplacedParameter(), getReplacementType());
+    Profile(ID, getReplacedParameter(), getReplacementType(), getPackIndex());
   }
 
   static void Profile(llvm::FoldingSetNodeID &ID,
                       const TemplateTypeParmType *Replaced,
-                      QualType Replacement) {
+                      QualType Replacement, Optional<unsigned> PackIndex) {
     ID.AddPointer(Replaced);
     ID.AddPointer(Replacement.getAsOpaquePtr());
+    ID.AddInteger(PackIndex ? *PackIndex - 1 : 0);
   }
 
   static bool classof(const Type *T) {
index c46b0bd..aa8bb79 100644 (file)
@@ -734,12 +734,15 @@ let Class = SubstTemplateTypeParmType in {
   def : Property<"replacementType", QualType> {
     let Read = [{ node->getReplacementType() }];
   }
+  def : Property<"PackIndex", Optional<UInt32>> {
+    let Read = [{ node->getPackIndex() }];
+  }
 
   def : Creator<[{
     // The call to getCanonicalType here existed in ASTReader.cpp, too.
     return ctx.getSubstTemplateTypeParmType(
         cast<TemplateTypeParmType>(replacedParameter),
-        ctx.getCanonicalType(replacementType));
+        ctx.getCanonicalType(replacementType), PackIndex);
   }]>;
 }
 
index 7999971..e4933fb 100644 (file)
@@ -4744,19 +4744,20 @@ QualType ASTContext::getBTFTagAttributedType(const BTFTypeTagAttr *BTFAttr,
 /// Retrieve a substitution-result type.
 QualType
 ASTContext::getSubstTemplateTypeParmType(const TemplateTypeParmType *Parm,
-                                         QualType Replacement) const {
+                                         QualType Replacement,
+                                         Optional<unsigned> PackIndex) const {
   assert(Replacement.isCanonical()
          && "replacement types must always be canonical");
 
   llvm::FoldingSetNodeID ID;
-  SubstTemplateTypeParmType::Profile(ID, Parm, Replacement);
+  SubstTemplateTypeParmType::Profile(ID, Parm, Replacement, PackIndex);
   void *InsertPos = nullptr;
   SubstTemplateTypeParmType *SubstParm
     = SubstTemplateTypeParmTypes.FindNodeOrInsertPos(ID, InsertPos);
 
   if (!SubstParm) {
     SubstParm = new (*this, TypeAlignment)
-      SubstTemplateTypeParmType(Parm, Replacement);
+        SubstTemplateTypeParmType(Parm, Replacement, PackIndex);
     Types.push_back(SubstParm);
     SubstTemplateTypeParmTypes.InsertNode(SubstParm, InsertPos);
   }
index c368a61..827185e 100644 (file)
@@ -1530,7 +1530,8 @@ ExpectedType ASTNodeImporter::VisitSubstTemplateTypeParmType(
     return ToReplacementTypeOrErr.takeError();
 
   return Importer.getToContext().getSubstTemplateTypeParmType(
-      *ReplacedOrErr, ToReplacementTypeOrErr->getCanonicalType());
+      *ReplacedOrErr, ToReplacementTypeOrErr->getCanonicalType(),
+      T->getPackIndex());
 }
 
 ExpectedType ASTNodeImporter::VisitSubstTemplateTypeParmPackType(
index 8c7a2d5..7e6055d 100644 (file)
@@ -1062,6 +1062,8 @@ static bool IsStructurallyEquivalent(StructuralEquivalenceContext &Context,
     if (!IsStructurallyEquivalent(Context, Subst1->getReplacementType(),
                                   Subst2->getReplacementType()))
       return false;
+    if (Subst1->getPackIndex() != Subst2->getPackIndex())
+      return false;
     break;
   }
 
index 87e4255..74ef4c3 100644 (file)
@@ -680,6 +680,12 @@ void JSONNodeDumper::VisitTemplateTypeParmType(
   JOS.attribute("decl", createBareDeclRef(TTPT->getDecl()));
 }
 
+void JSONNodeDumper::VisitSubstTemplateTypeParmType(
+    const SubstTemplateTypeParmType *STTPT) {
+  if (auto PackIndex = STTPT->getPackIndex())
+    JOS.attribute("pack_index", *PackIndex);
+}
+
 void JSONNodeDumper::VisitAutoType(const AutoType *AT) {
   JOS.attribute("undeduced", !AT->isDeduced());
   switch (AT->getKeyword()) {
index 22643d4..c3da5a4 100644 (file)
@@ -1568,6 +1568,12 @@ void TextNodeDumper::VisitTemplateTypeParmType(const TemplateTypeParmType *T) {
   dumpDeclRef(T->getDecl());
 }
 
+void TextNodeDumper::VisitSubstTemplateTypeParmType(
+    const SubstTemplateTypeParmType *T) {
+  if (auto PackIndex = T->getPackIndex())
+    OS << " pack_index " << *PackIndex;
+}
+
 void TextNodeDumper::VisitAutoType(const AutoType *T) {
   if (T->isDecltypeAuto())
     OS << " decltype(auto)";
index 136af19..8534e57 100644 (file)
@@ -1167,7 +1167,7 @@ public:
       return QualType(T, 0);
 
     return Ctx.getSubstTemplateTypeParmType(T->getReplacedParameter(),
-                                            replacementType);
+                                            replacementType, T->getPackIndex());
   }
 
   // FIXME: Non-trivial to implement, but important for C++
index a530d08..c4e5fb5 100644 (file)
@@ -1813,6 +1813,7 @@ TemplateInstantiator::TransformTemplateTypeParmType(TypeLocBuilder &TLB,
       return NewT;
     }
 
+    Optional<unsigned> PackIndex;
     if (T->isParameterPack()) {
       assert(Arg.getKind() == TemplateArgument::Pack &&
              "Missing argument pack");
@@ -1830,6 +1831,7 @@ TemplateInstantiator::TransformTemplateTypeParmType(TypeLocBuilder &TLB,
       }
 
       Arg = getPackSubstitutedTemplateArgument(getSema(), Arg);
+      PackIndex = getSema().ArgumentPackSubstitutionIndex;
     }
 
     assert(Arg.getKind() == TemplateArgument::Type &&
@@ -1838,8 +1840,8 @@ TemplateInstantiator::TransformTemplateTypeParmType(TypeLocBuilder &TLB,
     QualType Replacement = Arg.getAsType();
 
     // TODO: only do this uniquing once, at the start of instantiation.
-    QualType Result
-      = getSema().Context.getSubstTemplateTypeParmType(T, Replacement);
+    QualType Result = getSema().Context.getSubstTemplateTypeParmType(
+        T, Replacement, PackIndex);
     SubstTemplateTypeParmTypeLoc NewTL
       = TLB.push<SubstTemplateTypeParmTypeLoc>(Result);
     NewTL.setNameLoc(TL.getNameLoc());
@@ -1877,11 +1879,10 @@ TemplateInstantiator::TransformSubstTemplateTypeParmPackType(
 
   TemplateArgument Arg = TL.getTypePtr()->getArgumentPack();
   Arg = getPackSubstitutedTemplateArgument(getSema(), Arg);
-  QualType Result = Arg.getAsType();
 
-  Result = getSema().Context.getSubstTemplateTypeParmType(
-                                      TL.getTypePtr()->getReplacedParameter(),
-                                                          Result);
+  QualType Result = getSema().Context.getSubstTemplateTypeParmType(
+      TL.getTypePtr()->getReplacedParameter(), Arg.getAsType(),
+      getSema().ArgumentPackSubstitutionIndex);
   SubstTemplateTypeParmTypeLoc NewTL
     = TLB.push<SubstTemplateTypeParmTypeLoc>(Result);
   NewTL.setNameLoc(TL.getNameLoc());
index 14f6526..07e71a2 100644 (file)
@@ -4854,7 +4854,8 @@ QualType TreeTransform<Derived>::RebuildQualifiedType(QualType T,
         Replacement = SemaRef.Context.getQualifiedType(
             Replacement.getUnqualifiedType(), Qs);
         T = SemaRef.Context.getSubstTemplateTypeParmType(
-            SubstTypeParam->getReplacedParameter(), Replacement);
+            SubstTypeParam->getReplacedParameter(), Replacement,
+            SubstTypeParam->getPackIndex());
       } else if ((AutoTy = dyn_cast<AutoType>(T)) && AutoTy->isDeduced()) {
         // 'auto' types behave the same way as template parameters.
         QualType Deduced = AutoTy->getDeducedType();
@@ -6410,9 +6411,8 @@ QualType TreeTransform<Derived>::TransformSubstTemplateTypeParmType(
 
   // Always canonicalize the replacement type.
   Replacement = SemaRef.Context.getCanonicalType(Replacement);
-  QualType Result
-    = SemaRef.Context.getSubstTemplateTypeParmType(T->getReplacedParameter(),
-                                                   Replacement);
+  QualType Result = SemaRef.Context.getSubstTemplateTypeParmType(
+      T->getReplacedParameter(), Replacement, T->getPackIndex());
 
   // Propagate type-source information.
   SubstTemplateTypeParmTypeLoc NewTL
index 13050ee..bc0f221 100644 (file)
@@ -136,13 +136,13 @@ template <typename... Cs> struct foo {
 };
 using t1 = foo<int, short>::bind<char, float>;
 // CHECK:      TemplateSpecializationType 0x{{[^ ]*}} 'Y<char, float, int, short>' sugar Y
-// CHECK:      SubstTemplateTypeParmType 0x{{[^ ]*}} 'char' sugar
+// CHECK:      SubstTemplateTypeParmType 0x{{[^ ]*}} 'char' sugar pack_index 0
 // CHECK-NEXT: TemplateTypeParmType 0x{{[^ ]*}} 'Bs' dependent contains_unexpanded_pack depth 0 index 0 pack
-// CHECK:      SubstTemplateTypeParmType 0x{{[^ ]*}} 'float' sugar
+// CHECK:      SubstTemplateTypeParmType 0x{{[^ ]*}} 'float' sugar pack_index 1
 // CHECK-NEXT: TemplateTypeParmType 0x{{[^ ]*}} 'Bs' dependent contains_unexpanded_pack depth 0 index 0 pack
-// CHECK:      SubstTemplateTypeParmType 0x{{[^ ]*}} 'int' sugar
+// CHECK:      SubstTemplateTypeParmType 0x{{[^ ]*}} 'int' sugar pack_index 2
 // CHECK-NEXT: TemplateTypeParmType 0x{{[^ ]*}} 'Bs' dependent contains_unexpanded_pack depth 0 index 0 pack
-// CHECK:      SubstTemplateTypeParmType 0x{{[^ ]*}} 'short' sugar
+// CHECK:      SubstTemplateTypeParmType 0x{{[^ ]*}} 'short' sugar pack_index 3
 // CHECK-NEXT: TemplateTypeParmType 0x{{[^ ]*}} 'Bs' dependent contains_unexpanded_pack depth 0 index 0 pack
 
 template <typename... T> struct D {
@@ -152,13 +152,13 @@ using t2 = D<float, char>::B<int, short>;
 // CHECK:      TemplateSpecializationType 0x{{[^ ]*}} 'B<int, short>' sugar alias B
 // CHECK:      FunctionProtoType 0x{{[^ ]*}} 'int (int (*)(float, int), int (*)(char, short))' cdecl
 // CHECK:      FunctionProtoType 0x{{[^ ]*}} 'int (float, int)' cdecl
-// CHECK:      SubstTemplateTypeParmType 0x{{[^ ]*}} 'float' sugar
+// CHECK:      SubstTemplateTypeParmType 0x{{[^ ]*}} 'float' sugar pack_index 0
 // CHECK-NEXT: TemplateTypeParmType 0x{{[^ ]*}} 'T' dependent contains_unexpanded_pack depth 0 index 0 pack
-// CHECK:      SubstTemplateTypeParmType 0x{{[^ ]*}} 'int' sugar
+// CHECK:      SubstTemplateTypeParmType 0x{{[^ ]*}} 'int' sugar pack_index 0
 // CHECK-NEXT: TemplateTypeParmType 0x{{[^ ]*}} 'U' dependent contains_unexpanded_pack depth 0 index 0 pack
 // CHECK:      FunctionProtoType 0x{{[^ ]*}} 'int (char, short)' cdecl
-// CHECK:      SubstTemplateTypeParmType 0x{{[^ ]*}} 'char' sugar
+// CHECK:      SubstTemplateTypeParmType 0x{{[^ ]*}} 'char' sugar pack_index 1
 // CHECK-NEXT: TemplateTypeParmType 0x{{[^ ]*}} 'T' dependent contains_unexpanded_pack depth 0 index 0 pack
-// CHECK:      SubstTemplateTypeParmType 0x{{[^ ]*}} 'short' sugar
+// CHECK:      SubstTemplateTypeParmType 0x{{[^ ]*}} 'short' sugar pack_index 1
 // CHECK-NEXT: TemplateTypeParmType 0x{{[^ ]*}} 'U' dependent contains_unexpanded_pack depth 0 index 0 pack
 } // namespace PR56099
index d9b7883..d553eec 100644 (file)
@@ -4793,6 +4793,44 @@ TEST_P(ASTImporterOptionSpecificTestBase,
       ToD2->getDeclContext(), ToD2->getTemplateParameters()->getParam(0)));
 }
 
+TEST_P(ASTImporterOptionSpecificTestBase, ImportSubstTemplateTypeParmType) {
+  constexpr auto Code = R"(
+    template <class A1, class... A2> struct A {
+      using B = A1(A2...);
+    };
+    template struct A<void, char, float, int, short>;
+    )";
+  Decl *FromTU = getTuDecl(Code, Lang_CXX11, "input.cpp");
+  auto *FromClass = FirstDeclMatcher<ClassTemplateSpecializationDecl>().match(
+      FromTU, classTemplateSpecializationDecl());
+
+  auto testType = [&](ASTContext &Ctx, const char *Name,
+                      llvm::Optional<unsigned> PackIndex) {
+    const auto *Subst = selectFirst<SubstTemplateTypeParmType>(
+        "sttp", match(substTemplateTypeParmType(
+                          hasReplacementType(hasCanonicalType(asString(Name))))
+                          .bind("sttp"),
+                      Ctx));
+    const char *ExpectedTemplateParamName = PackIndex ? "A2" : "A1";
+    ASSERT_TRUE(Subst);
+    ASSERT_EQ(Subst->getReplacedParameter()->getIdentifier()->getName(),
+              ExpectedTemplateParamName);
+    ASSERT_EQ(Subst->getPackIndex(), PackIndex);
+  };
+  auto tests = [&](ASTContext &Ctx) {
+    testType(Ctx, "void", None);
+    testType(Ctx, "char", 0);
+    testType(Ctx, "float", 1);
+    testType(Ctx, "int", 2);
+    testType(Ctx, "short", 3);
+  };
+
+  tests(FromTU->getASTContext());
+
+  ClassTemplateSpecializationDecl *ToClass = Import(FromClass, Lang_CXX11);
+  tests(ToClass->getASTContext());
+}
+
 const AstTypeMatcher<SubstTemplateTypeParmPackType>
     substTemplateTypeParmPackType;