class Attr;
#define ATTR(A) class A##Attr;
#include "clang/Basic/AttrList.inc"
+class ObjCProtocolLoc;
} // end namespace clang
NKI_Attr,
#define ATTR(A) NKI_##A##Attr,
#include "clang/Basic/AttrList.inc"
+ NKI_ObjCProtocolLoc,
NKI_NumberOfKinds
};
KIND_TO_KIND_ID(Type)
KIND_TO_KIND_ID(OMPClause)
KIND_TO_KIND_ID(Attr)
+KIND_TO_KIND_ID(ObjCProtocolLoc)
KIND_TO_KIND_ID(CXXBaseSpecifier)
#define DECL(DERIVED, BASE) KIND_TO_KIND_ID(DERIVED##Decl)
#include "clang/AST/DeclNodes.inc"
/// have storage or unique pointers and thus need to be stored by value.
llvm::AlignedCharArrayUnion<const void *, TemplateArgument,
TemplateArgumentLoc, NestedNameSpecifierLoc,
- QualType, TypeLoc>
+ QualType, TypeLoc, ObjCProtocolLoc>
Storage;
};
struct DynTypedNode::BaseConverter<CXXBaseSpecifier, void>
: public PtrConverter<CXXBaseSpecifier> {};
+template <>
+struct DynTypedNode::BaseConverter<ObjCProtocolLoc, void>
+ : public ValueConverter<ObjCProtocolLoc> {};
+
// The only operation we allow on unsupported types is \c get.
// This allows to conveniently use \c DynTypedNode when having an arbitrary
// AST node that is not supported, but prevents misuse - a user cannot create
/// \returns false if the visitation was terminated early, true otherwise.
bool TraverseConceptReference(const ConceptReference &C);
+ /// Recursively visit an Objective-C protocol reference with location
+ /// information.
+ ///
+ /// \returns false if the visitation was terminated early, true otherwise.
+ bool TraverseObjCProtocolLoc(ObjCProtocolLoc ProtocolLoc);
+
// ---- Methods on Attrs ----
// Visit an attribute.
DEF_TRAVERSE_TYPELOC(PackExpansionType,
{ TRY_TO(TraverseTypeLoc(TL.getPatternLoc())); })
-DEF_TRAVERSE_TYPELOC(ObjCTypeParamType, {})
+DEF_TRAVERSE_TYPELOC(ObjCTypeParamType, {
+ for (unsigned I = 0, N = TL.getNumProtocols(); I != N; ++I) {
+ ObjCProtocolLoc ProtocolLoc(TL.getProtocol(I), TL.getProtocolLoc(I));
+ TRY_TO(TraverseObjCProtocolLoc(ProtocolLoc));
+ }
+})
DEF_TRAVERSE_TYPELOC(ObjCInterfaceType, {})
TRY_TO(TraverseTypeLoc(TL.getBaseLoc()));
for (unsigned i = 0, n = TL.getNumTypeArgs(); i != n; ++i)
TRY_TO(TraverseTypeLoc(TL.getTypeArgTInfo(i)->getTypeLoc()));
+ for (unsigned I = 0, N = TL.getNumProtocols(); I != N; ++I) {
+ ObjCProtocolLoc ProtocolLoc(TL.getProtocol(I), TL.getProtocolLoc(I));
+ TRY_TO(TraverseObjCProtocolLoc(ProtocolLoc));
+ }
})
DEF_TRAVERSE_TYPELOC(ObjCObjectPointerType,
DEF_TRAVERSE_DECL(ObjCCompatibleAliasDecl, {// FIXME: implement
})
-DEF_TRAVERSE_DECL(ObjCCategoryDecl, {// FIXME: implement
+DEF_TRAVERSE_DECL(ObjCCategoryDecl, {
if (ObjCTypeParamList *typeParamList = D->getTypeParamList()) {
for (auto typeParam : *typeParamList) {
TRY_TO(TraverseObjCTypeParamDecl(typeParam));
}
}
+ for (auto It : llvm::zip(D->protocols(), D->protocol_locs())) {
+ ObjCProtocolLoc ProtocolLoc(std::get<0>(It), std::get<1>(It));
+ TRY_TO(TraverseObjCProtocolLoc(ProtocolLoc));
+ }
})
DEF_TRAVERSE_DECL(ObjCCategoryImplDecl, {// FIXME: implement
DEF_TRAVERSE_DECL(ObjCImplementationDecl, {// FIXME: implement
})
-DEF_TRAVERSE_DECL(ObjCInterfaceDecl, {// FIXME: implement
+DEF_TRAVERSE_DECL(ObjCInterfaceDecl, {
if (ObjCTypeParamList *typeParamList = D->getTypeParamListAsWritten()) {
for (auto typeParam : *typeParamList) {
TRY_TO(TraverseObjCTypeParamDecl(typeParam));
if (TypeSourceInfo *superTInfo = D->getSuperClassTInfo()) {
TRY_TO(TraverseTypeLoc(superTInfo->getTypeLoc()));
}
+ if (D->isThisDeclarationADefinition()) {
+ for (auto It : llvm::zip(D->protocols(), D->protocol_locs())) {
+ ObjCProtocolLoc ProtocolLoc(std::get<0>(It), std::get<1>(It));
+ TRY_TO(TraverseObjCProtocolLoc(ProtocolLoc));
+ }
+ }
})
-DEF_TRAVERSE_DECL(ObjCProtocolDecl, {// FIXME: implement
- })
+DEF_TRAVERSE_DECL(ObjCProtocolDecl, {
+ if (D->isThisDeclarationADefinition()) {
+ for (auto It : llvm::zip(D->protocols(), D->protocol_locs())) {
+ ObjCProtocolLoc ProtocolLoc(std::get<0>(It), std::get<1>(It));
+ TRY_TO(TraverseObjCProtocolLoc(ProtocolLoc));
+ }
+ }
+})
DEF_TRAVERSE_DECL(ObjCMethodDecl, {
if (D->getReturnTypeSourceInfo()) {
return true;
}
+template <typename Derived>
+bool RecursiveASTVisitor<Derived>::TraverseObjCProtocolLoc(
+ ObjCProtocolLoc ProtocolLoc) {
+ return true;
+}
+
// If shouldVisitImplicitCode() returns false, this method traverses only the
// syntactic form of InitListExpr.
// If shouldVisitImplicitCode() return true, this method is called once for
: public InheritingConcreteTypeLoc<TypeSpecTypeLoc, DependentBitIntTypeLoc,
DependentBitIntType> {};
+class ObjCProtocolLoc {
+ ObjCProtocolDecl *Protocol = nullptr;
+ SourceLocation Loc = SourceLocation();
+
+public:
+ ObjCProtocolLoc(ObjCProtocolDecl *protocol, SourceLocation loc)
+ : Protocol(protocol), Loc(loc) {}
+ ObjCProtocolDecl *getProtocol() const { return Protocol; }
+ SourceLocation getLocation() const { return Loc; }
+
+ /// The source range is just the protocol name.
+ SourceRange getSourceRange() const LLVM_READONLY {
+ return SourceRange(Loc, Loc);
+ }
+};
+
} // namespace clang
#endif // LLVM_CLANG_AST_TYPELOC_H
#include "clang/AST/ASTContext.h"
#include "clang/AST/Attr.h"
#include "clang/AST/DeclCXX.h"
+#include "clang/AST/DeclObjC.h"
#include "clang/AST/NestedNameSpecifier.h"
#include "clang/AST/OpenMPClause.h"
#include "clang/AST/TypeLoc.h"
{NKI_None, "Attr"},
#define ATTR(A) {NKI_Attr, #A "Attr"},
#include "clang/Basic/AttrList.inc"
+ {NKI_None, "ObjCProtocolLoc"},
};
bool ASTNodeKind::isBaseOf(ASTNodeKind Other, unsigned *Distance) const {
QualType(T, 0).print(OS, PP);
else if (const Attr *A = get<Attr>())
A->printPretty(OS, PP);
+ else if (const ObjCProtocolLoc *P = get<ObjCProtocolLoc>())
+ P->getProtocol()->print(OS, PP);
else
OS << "Unable to print values of type " << NodeKind.asStringRef() << "\n";
}
return CBS->getSourceRange();
if (const auto *A = get<Attr>())
return A->getRange();
+ if (const ObjCProtocolLoc *P = get<ObjCProtocolLoc>())
+ return P->getSourceRange();
return SourceRange();
}
DynTypedNode createDynTypedNode(const NestedNameSpecifierLoc &Node) {
return DynTypedNode::create(Node);
}
+template <> DynTypedNode createDynTypedNode(const ObjCProtocolLoc &Node) {
+ return DynTypedNode::create(Node);
+}
/// @}
/// A \c RecursiveASTVisitor that builds a map from nodes to their
}
}
+ template <typename T> static bool isNull(T Node) { return !Node; }
+ static bool isNull(ObjCProtocolLoc Node) { return false; }
+
template <typename T, typename MapNodeTy, typename BaseTraverseFn,
typename MapTy>
bool TraverseNode(T Node, MapNodeTy MapNode, BaseTraverseFn BaseTraverse,
MapTy *Parents) {
- if (!Node)
+ if (isNull(Node))
return true;
addParent(MapNode, Parents);
ParentStack.push_back(createDynTypedNode(Node));
AttrNode, AttrNode, [&] { return VisitorBase::TraverseAttr(AttrNode); },
&Map.PointerParents);
}
+ bool TraverseObjCProtocolLoc(ObjCProtocolLoc ProtocolLocNode) {
+ return TraverseNode(
+ ProtocolLocNode, DynTypedNode::create(ProtocolLocNode),
+ [&] { return VisitorBase::TraverseObjCProtocolLoc(ProtocolLocNode); },
+ &Map.OtherParents);
+ }
// Using generic TraverseNode for Stmt would prevent data-recursion.
bool dataTraverseStmtPre(Stmt *StmtNode) {
EndTraverseEnum,
StartTraverseTypedefType,
EndTraverseTypedefType,
+ StartTraverseObjCInterface,
+ EndTraverseObjCInterface,
+ StartTraverseObjCProtocol,
+ EndTraverseObjCProtocol,
+ StartTraverseObjCProtocolLoc,
+ EndTraverseObjCProtocolLoc,
};
class CollectInterestingEvents
return Ret;
}
+ bool TraverseObjCInterfaceDecl(ObjCInterfaceDecl *ID) {
+ Events.push_back(VisitEvent::StartTraverseObjCInterface);
+ bool Ret = RecursiveASTVisitor::TraverseObjCInterfaceDecl(ID);
+ Events.push_back(VisitEvent::EndTraverseObjCInterface);
+
+ return Ret;
+ }
+
+ bool TraverseObjCProtocolDecl(ObjCProtocolDecl *PD) {
+ Events.push_back(VisitEvent::StartTraverseObjCProtocol);
+ bool Ret = RecursiveASTVisitor::TraverseObjCProtocolDecl(PD);
+ Events.push_back(VisitEvent::EndTraverseObjCProtocol);
+
+ return Ret;
+ }
+
+ bool TraverseObjCProtocolLoc(ObjCProtocolLoc ProtocolLoc) {
+ Events.push_back(VisitEvent::StartTraverseObjCProtocolLoc);
+ bool Ret = RecursiveASTVisitor::TraverseObjCProtocolLoc(ProtocolLoc);
+ Events.push_back(VisitEvent::EndTraverseObjCProtocolLoc);
+
+ return Ret;
+ }
+
std::vector<VisitEvent> takeEvents() && { return std::move(Events); }
private:
std::vector<VisitEvent> Events;
};
-std::vector<VisitEvent> collectEvents(llvm::StringRef Code) {
+std::vector<VisitEvent> collectEvents(llvm::StringRef Code,
+ const Twine &FileName = "input.cc") {
CollectInterestingEvents Visitor;
clang::tooling::runToolOnCode(
std::make_unique<ProcessASTAction>(
[&](clang::ASTContext &Ctx) { Visitor.TraverseAST(Ctx); }),
- Code);
+ Code, FileName);
return std::move(Visitor).takeEvents();
}
} // namespace
VisitEvent::EndTraverseTypedefType,
VisitEvent::EndTraverseEnum));
}
+
+TEST(RecursiveASTVisitorTest, InterfaceDeclWithProtocols) {
+ // Check interface and its protocols are visited.
+ llvm::StringRef Code = R"cpp(
+ @protocol Foo
+ @end
+ @protocol Bar
+ @end
+
+ @interface SomeObject <Foo, Bar>
+ @end
+ )cpp";
+
+ EXPECT_THAT(collectEvents(Code, "input.m"),
+ ElementsAre(VisitEvent::StartTraverseObjCProtocol,
+ VisitEvent::EndTraverseObjCProtocol,
+ VisitEvent::StartTraverseObjCProtocol,
+ VisitEvent::EndTraverseObjCProtocol,
+ VisitEvent::StartTraverseObjCInterface,
+ VisitEvent::StartTraverseObjCProtocolLoc,
+ VisitEvent::EndTraverseObjCProtocolLoc,
+ VisitEvent::StartTraverseObjCProtocolLoc,
+ VisitEvent::EndTraverseObjCProtocolLoc,
+ VisitEvent::EndTraverseObjCInterface));
+}