#include "clang/Basic/SourceLocation.h"
#include "clang/Basic/SourceManager.h"
#include "clang/Tooling/Core/Replacement.h"
-#include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/STLExtras.h"
#include <cassert>
#include <string>
}
private:
+ class ExpectedCase {
+ public:
+ ExpectedCase(const EnumConstantDecl *Decl) : Data(Decl, false) {}
+ bool isCovered() const { return Data.getInt(); }
+ void setCovered(bool Val = true) { Data.setInt(Val); }
+ const EnumConstantDecl *getEnumConstant() const {
+ return Data.getPointer();
+ }
+
+ private:
+ llvm::PointerIntPair<const EnumConstantDecl *, 1, bool> Data;
+ };
+
const DeclContext *DeclCtx = nullptr;
const SwitchStmt *Switch = nullptr;
const CompoundStmt *Body = nullptr;
const EnumType *EnumT = nullptr;
const EnumDecl *EnumD = nullptr;
+ // Maps the Enum values to the EnumConstantDecl and a bool signifying if its
+ // covered in the switch.
+ llvm::MapVector<llvm::APSInt, ExpectedCase> ExpectedCases;
};
REGISTER_TWEAK(PopulateSwitch)
if (!EnumD)
return false;
- // We trigger if there are fewer cases than enum values (and no case covers
- // multiple values). This guarantees we'll have at least one case to insert.
- // We don't yet determine what the cases are, as that means evaluating
- // expressions.
- auto I = EnumD->enumerator_begin();
- auto E = EnumD->enumerator_end();
+ // We trigger if there are any values in the enum that aren't covered by the
+ // switch.
+
+ ASTContext &Ctx = Sel.AST->getASTContext();
+
+ unsigned EnumIntWidth = Ctx.getIntWidth(QualType(EnumT, 0));
+ bool EnumIsSigned = EnumT->isSignedIntegerOrEnumerationType();
- for (const SwitchCase *CaseList = Switch->getSwitchCaseList();
- CaseList && I != E; CaseList = CaseList->getNextSwitchCase(), I++) {
+ auto Normalize = [&](llvm::APSInt Val) {
+ Val = Val.extOrTrunc(EnumIntWidth);
+ Val.setIsSigned(EnumIsSigned);
+ return Val;
+ };
+
+ for (auto *EnumConstant : EnumD->enumerators()) {
+ ExpectedCases.insert(
+ std::make_pair(Normalize(EnumConstant->getInitVal()), EnumConstant));
+ }
+
+ for (const SwitchCase *CaseList = Switch->getSwitchCaseList(); CaseList;
+ CaseList = CaseList->getNextSwitchCase()) {
// Default likely intends to cover cases we'd insert.
if (isa<DefaultStmt>(CaseList))
return false;
const CaseStmt *CS = cast<CaseStmt>(CaseList);
- // Case statement covers multiple values, so just counting doesn't work.
+
+ // GNU range cases are rare, we don't support them.
if (CS->caseStmtIsGNURange())
return false;
const ConstantExpr *CE = dyn_cast<ConstantExpr>(CS->getLHS());
if (!CE || CE->isValueDependent())
return false;
+
+ // Unsure if this case could ever come up, but prevents an unreachable
+ // executing in getResultAsAPSInt.
+ if (CE->getResultStorageKind() == ConstantExpr::RSK_None)
+ return false;
+ auto Iter = ExpectedCases.find(Normalize(CE->getResultAsAPSInt()));
+ if (Iter != ExpectedCases.end())
+ Iter->second.setCovered();
}
- // Only suggest tweak if we have more enumerators than cases.
- return I != E;
+ return !llvm::all_of(ExpectedCases,
+ [](auto &Pair) { return Pair.second.isCovered(); });
}
Expected<Tweak::Effect> PopulateSwitch::apply(const Selection &Sel) {
ASTContext &Ctx = Sel.AST->getASTContext();
- // Get the enum's integer width and signedness, for adjusting case literals.
- unsigned EnumIntWidth = Ctx.getIntWidth(QualType(EnumT, 0));
- bool EnumIsSigned = EnumT->isSignedIntegerOrEnumerationType();
-
- llvm::SmallSet<llvm::APSInt, 32> ExistingEnumerators;
- for (const SwitchCase *CaseList = Switch->getSwitchCaseList(); CaseList;
- CaseList = CaseList->getNextSwitchCase()) {
- const CaseStmt *CS = cast<CaseStmt>(CaseList);
- assert(!CS->caseStmtIsGNURange());
- const ConstantExpr *CE = cast<ConstantExpr>(CS->getLHS());
- assert(!CE->isValueDependent());
- llvm::APSInt Val = CE->getResultAsAPSInt();
- Val = Val.extOrTrunc(EnumIntWidth);
- Val.setIsSigned(EnumIsSigned);
- ExistingEnumerators.insert(Val);
- }
-
SourceLocation Loc = Body->getRBracLoc();
ASTContext &DeclASTCtx = DeclCtx->getParentASTContext();
- std::string Text;
- for (EnumConstantDecl *Enumerator : EnumD->enumerators()) {
- if (ExistingEnumerators.contains(Enumerator->getInitVal()))
+ llvm::SmallString<256> Text;
+ for (auto &EnumConstant : ExpectedCases) {
+ // Skip any enum constants already covered
+ if (EnumConstant.second.isCovered())
continue;
- Text += "case ";
- Text += getQualification(DeclASTCtx, DeclCtx, Loc, EnumD);
- if (EnumD->isScoped()) {
- Text += EnumD->getName();
- Text += "::";
- }
- Text += Enumerator->getName();
- Text += ":";
+ Text.append({"case ", getQualification(DeclASTCtx, DeclCtx, Loc, EnumD)});
+ if (EnumD->isScoped())
+ Text.append({EnumD->getName(), "::"});
+ Text.append({EnumConstant.second.getEnumConstant()->getName(), ":"});
}
assert(!Text.empty() && "No enumerators to insert!");