1//===-- IncludeFixer.cpp - Include inserter based on sema callbacks -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "IncludeFixer.h"
10#include "clang/Format/Format.h"
11#include "clang/Frontend/CompilerInstance.h"
12#include "clang/Lex/HeaderSearch.h"
13#include "clang/Lex/Preprocessor.h"
14#include "clang/Parse/ParseAST.h"
15#include "clang/Sema/Sema.h"
16#include "llvm/Support/Debug.h"
17#include "llvm/Support/raw_ostream.h"
18
19#define DEBUG_TYPE "clang-include-fixer"
20
21using namespace clang;
22
23namespace clang {
24namespace include_fixer {
25namespace {
26/// Manages the parse, gathers include suggestions.
27class Action : public clang::ASTFrontendAction {
28public:
29 explicit Action(SymbolIndexManager &SymbolIndexMgr, bool MinimizeIncludePaths)
30 : SemaSource(new IncludeFixerSemaSource(SymbolIndexMgr,
31 MinimizeIncludePaths,
32 /*GenerateDiagnostics=*/false)) {}
33
34 std::unique_ptr<clang::ASTConsumer>
35 CreateASTConsumer(clang::CompilerInstance &Compiler,
36 StringRef InFile) override {
37 SemaSource->setFilePath(InFile);
38 return std::make_unique<clang::ASTConsumer>();
39 }
40
41 void ExecuteAction() override {
42 clang::CompilerInstance *Compiler = &getCompilerInstance();
43 assert(!Compiler->hasSema() && "CI already has Sema");
44
45 // Set up our hooks into sema and parse the AST.
46 if (hasCodeCompletionSupport() &&
47 !Compiler->getFrontendOpts().CodeCompletionAt.FileName.empty())
48 Compiler->createCodeCompletionConsumer();
49
50 clang::CodeCompleteConsumer *CompletionConsumer = nullptr;
51 if (Compiler->hasCodeCompletionConsumer())
52 CompletionConsumer = &Compiler->getCodeCompletionConsumer();
53
54 Compiler->createSema(TUKind: getTranslationUnitKind(), CompletionConsumer);
55 SemaSource->setCompilerInstance(Compiler);
56 Compiler->getSema().addExternalSource(E: SemaSource.get());
57
58 clang::ParseAST(S&: Compiler->getSema(), PrintStats: Compiler->getFrontendOpts().ShowStats,
59 SkipFunctionBodies: Compiler->getFrontendOpts().SkipFunctionBodies);
60 }
61
62 IncludeFixerContext
63 getIncludeFixerContext(const clang::SourceManager &SourceManager,
64 clang::HeaderSearch &HeaderSearch) const {
65 return SemaSource->getIncludeFixerContext(SourceManager, HeaderSearch,
66 MatchedSymbols: SemaSource->getMatchedSymbols());
67 }
68
69private:
70 IntrusiveRefCntPtr<IncludeFixerSemaSource> SemaSource;
71};
72
73} // namespace
74
75IncludeFixerActionFactory::IncludeFixerActionFactory(
76 SymbolIndexManager &SymbolIndexMgr,
77 std::vector<IncludeFixerContext> &Contexts, StringRef StyleName,
78 bool MinimizeIncludePaths)
79 : SymbolIndexMgr(SymbolIndexMgr), Contexts(Contexts),
80 MinimizeIncludePaths(MinimizeIncludePaths) {}
81
82IncludeFixerActionFactory::~IncludeFixerActionFactory() = default;
83
84bool IncludeFixerActionFactory::runInvocation(
85 std::shared_ptr<clang::CompilerInvocation> Invocation,
86 clang::FileManager *Files,
87 std::shared_ptr<clang::PCHContainerOperations> PCHContainerOps,
88 clang::DiagnosticConsumer *Diagnostics) {
89 assert(Invocation->getFrontendOpts().Inputs.size() == 1);
90
91 // Set up Clang.
92 clang::CompilerInstance Compiler(PCHContainerOps);
93 Compiler.setInvocation(std::move(Invocation));
94 Compiler.setFileManager(Files);
95
96 // Create the compiler's actual diagnostics engine. We want to drop all
97 // diagnostics here.
98 Compiler.createDiagnostics(VFS&: Files->getVirtualFileSystem(),
99 Client: new clang::IgnoringDiagConsumer,
100 /*ShouldOwnClient=*/true);
101 Compiler.createSourceManager(FileMgr&: *Files);
102
103 // We abort on fatal errors so don't let a large number of errors become
104 // fatal. A missing #include can cause thousands of errors.
105 Compiler.getDiagnostics().setErrorLimit(0);
106
107 // Run the parser, gather missing includes.
108 auto ScopedToolAction =
109 std::make_unique<Action>(args&: SymbolIndexMgr, args&: MinimizeIncludePaths);
110 Compiler.ExecuteAction(Act&: *ScopedToolAction);
111
112 Contexts.push_back(x: ScopedToolAction->getIncludeFixerContext(
113 SourceManager: Compiler.getSourceManager(),
114 HeaderSearch&: Compiler.getPreprocessor().getHeaderSearchInfo()));
115
116 // Technically this should only return true if we're sure that we have a
117 // parseable file. We don't know that though. Only inform users of fatal
118 // errors.
119 return !Compiler.getDiagnostics().hasFatalErrorOccurred();
120}
121
122static bool addDiagnosticsForContext(TypoCorrection &Correction,
123 const IncludeFixerContext &Context,
124 StringRef Code, SourceLocation StartOfFile,
125 ASTContext &Ctx) {
126 auto Reps = createIncludeFixerReplacements(
127 Code, Context, Style: format::getLLVMStyle(), /*AddQualifiers=*/false);
128 if (!Reps || Reps->size() != 1)
129 return false;
130
131 unsigned DiagID = Ctx.getDiagnostics().getCustomDiagID(
132 L: DiagnosticsEngine::Note, FormatString: "Add '#include %0' to provide the missing "
133 "declaration [clang-include-fixer]");
134
135 // FIXME: Currently we only generate a diagnostic for the first header. Give
136 // the user choices.
137 const tooling::Replacement &Placed = *Reps->begin();
138
139 auto Begin = StartOfFile.getLocWithOffset(Offset: Placed.getOffset());
140 auto End = Begin.getLocWithOffset(Offset: std::max(a: 0, b: (int)Placed.getLength() - 1));
141 PartialDiagnostic PD(DiagID, Ctx.getDiagAllocator());
142 PD << Context.getHeaderInfos().front().Header
143 << FixItHint::CreateReplacement(RemoveRange: CharSourceRange::getCharRange(B: Begin, E: End),
144 Code: Placed.getReplacementText());
145 Correction.addExtraDiagnostic(PD: std::move(PD));
146 return true;
147}
148
149/// Callback for incomplete types. If we encounter a forward declaration we
150/// have the fully qualified name ready. Just query that.
151bool IncludeFixerSemaSource::MaybeDiagnoseMissingCompleteType(
152 clang::SourceLocation Loc, clang::QualType T) {
153 // Ignore spurious callbacks from SFINAE contexts.
154 if (CI->getSema().isSFINAEContext())
155 return false;
156
157 clang::ASTContext &context = CI->getASTContext();
158 std::string QueryString = QualType(T->getUnqualifiedDesugaredType(), 0)
159 .getAsString(Policy: context.getPrintingPolicy());
160 LLVM_DEBUG(llvm::dbgs() << "Query missing complete type '" << QueryString
161 << "'");
162 // Pass an empty range here since we don't add qualifier in this case.
163 std::vector<find_all_symbols::SymbolInfo> MatchedSymbols =
164 query(Query: QueryString, ScopedQualifiers: "", Range: tooling::Range());
165
166 if (!MatchedSymbols.empty() && GenerateDiagnostics) {
167 TypoCorrection Correction;
168 FileID FID = CI->getSourceManager().getFileID(SpellingLoc: Loc);
169 StringRef Code = CI->getSourceManager().getBufferData(FID);
170 SourceLocation StartOfFile =
171 CI->getSourceManager().getLocForStartOfFile(FID);
172 addDiagnosticsForContext(
173 Correction,
174 Context: getIncludeFixerContext(SourceManager: CI->getSourceManager(),
175 HeaderSearch&: CI->getPreprocessor().getHeaderSearchInfo(),
176 MatchedSymbols),
177 Code, StartOfFile, Ctx&: CI->getASTContext());
178 for (const PartialDiagnostic &PD : Correction.getExtraDiagnostics())
179 CI->getSema().Diag(Loc, PD);
180 }
181 return true;
182}
183
184/// Callback for unknown identifiers. Try to piece together as much
185/// qualification as we can get and do a query.
186clang::TypoCorrection IncludeFixerSemaSource::CorrectTypo(
187 const DeclarationNameInfo &Typo, int LookupKind, Scope *S, CXXScopeSpec *SS,
188 CorrectionCandidateCallback &CCC, DeclContext *MemberContext,
189 bool EnteringContext, const ObjCObjectPointerType *OPT) {
190 // Ignore spurious callbacks from SFINAE contexts.
191 if (CI->getSema().isSFINAEContext())
192 return clang::TypoCorrection();
193
194 // We currently ignore the unidentified symbol which is not from the
195 // main file.
196 //
197 // However, this is not always true due to templates in a non-self contained
198 // header, consider the case:
199 //
200 // // header.h
201 // template <typename T>
202 // class Foo {
203 // T t;
204 // };
205 //
206 // // test.cc
207 // // We need to add <bar.h> in test.cc instead of header.h.
208 // class Bar;
209 // Foo<Bar> foo;
210 //
211 // FIXME: Add the missing header to the header file where the symbol comes
212 // from.
213 if (!CI->getSourceManager().isWrittenInMainFile(Loc: Typo.getLoc()))
214 return clang::TypoCorrection();
215
216 std::string TypoScopeString;
217 if (S) {
218 // FIXME: Currently we only use namespace contexts. Use other context
219 // types for query.
220 for (const auto *Context = S->getEntity(); Context;
221 Context = Context->getParent()) {
222 if (const auto *ND = dyn_cast<NamespaceDecl>(Val: Context)) {
223 if (!ND->getName().empty())
224 TypoScopeString = ND->getNameAsString() + "::" + TypoScopeString;
225 }
226 }
227 }
228
229 auto ExtendNestedNameSpecifier = [this](CharSourceRange Range) {
230 StringRef Source =
231 Lexer::getSourceText(Range, SM: CI->getSourceManager(), LangOpts: CI->getLangOpts());
232
233 // Skip forward until we find a character that's neither identifier nor
234 // colon. This is a bit of a hack around the fact that we will only get a
235 // single callback for a long nested name if a part of the beginning is
236 // unknown. For example:
237 //
238 // llvm::sys::path::parent_path(...)
239 // ^~~~ ^~~
240 // known
241 // ^~~~
242 // unknown, last callback
243 // ^~~~~~~~~~~
244 // no callback
245 //
246 // With the extension we get the full nested name specifier including
247 // parent_path.
248 // FIXME: Don't rely on source text.
249 const char *End = Source.end();
250 while (isAsciiIdentifierContinue(c: *End) || *End == ':')
251 ++End;
252
253 return std::string(Source.begin(), End);
254 };
255
256 /// If we have a scope specification, use that to get more precise results.
257 std::string QueryString;
258 tooling::Range SymbolRange;
259 const auto &SM = CI->getSourceManager();
260 auto CreateToolingRange = [&QueryString, &SM](SourceLocation BeginLoc) {
261 return tooling::Range(SM.getDecomposedLoc(Loc: BeginLoc).second,
262 QueryString.size());
263 };
264 if (SS && SS->getRange().isValid()) {
265 auto Range = CharSourceRange::getTokenRange(B: SS->getRange().getBegin(),
266 E: Typo.getLoc());
267
268 QueryString = ExtendNestedNameSpecifier(Range);
269 SymbolRange = CreateToolingRange(Range.getBegin());
270 } else if (Typo.getName().isIdentifier() && !Typo.getLoc().isMacroID()) {
271 auto Range =
272 CharSourceRange::getTokenRange(B: Typo.getBeginLoc(), E: Typo.getEndLoc());
273
274 QueryString = ExtendNestedNameSpecifier(Range);
275 SymbolRange = CreateToolingRange(Range.getBegin());
276 } else {
277 QueryString = Typo.getAsString();
278 SymbolRange = CreateToolingRange(Typo.getLoc());
279 }
280
281 LLVM_DEBUG(llvm::dbgs() << "TypoScopeQualifiers: " << TypoScopeString
282 << "\n");
283 std::vector<find_all_symbols::SymbolInfo> MatchedSymbols =
284 query(Query: QueryString, ScopedQualifiers: TypoScopeString, Range: SymbolRange);
285
286 if (!MatchedSymbols.empty() && GenerateDiagnostics) {
287 TypoCorrection Correction(Typo.getName());
288 Correction.setCorrectionRange(SS, TypoName: Typo);
289 FileID FID = SM.getFileID(SpellingLoc: Typo.getLoc());
290 StringRef Code = SM.getBufferData(FID);
291 SourceLocation StartOfFile = SM.getLocForStartOfFile(FID);
292 if (addDiagnosticsForContext(
293 Correction, Context: getIncludeFixerContext(
294 SourceManager: SM, HeaderSearch&: CI->getPreprocessor().getHeaderSearchInfo(),
295 MatchedSymbols),
296 Code, StartOfFile, Ctx&: CI->getASTContext()))
297 return Correction;
298 }
299 return TypoCorrection();
300}
301
302/// Get the minimal include for a given path.
303std::string IncludeFixerSemaSource::minimizeInclude(
304 StringRef Include, const clang::SourceManager &SourceManager,
305 clang::HeaderSearch &HeaderSearch) const {
306 if (!MinimizeIncludePaths)
307 return std::string(Include);
308
309 // Get the FileEntry for the include.
310 StringRef StrippedInclude = Include.trim(Chars: "\"<>");
311 auto Entry =
312 SourceManager.getFileManager().getOptionalFileRef(Filename: StrippedInclude);
313
314 // If the file doesn't exist return the path from the database.
315 // FIXME: This should never happen.
316 if (!Entry)
317 return std::string(Include);
318
319 bool IsAngled = false;
320 std::string Suggestion =
321 HeaderSearch.suggestPathToFileForDiagnostics(File: *Entry, MainFile: "", IsAngled: &IsAngled);
322
323 return IsAngled ? '<' + Suggestion + '>' : '"' + Suggestion + '"';
324}
325
326/// Get the include fixer context for the queried symbol.
327IncludeFixerContext IncludeFixerSemaSource::getIncludeFixerContext(
328 const clang::SourceManager &SourceManager,
329 clang::HeaderSearch &HeaderSearch,
330 ArrayRef<find_all_symbols::SymbolInfo> MatchedSymbols) const {
331 std::vector<find_all_symbols::SymbolInfo> SymbolCandidates;
332 for (const auto &Symbol : MatchedSymbols) {
333 std::string FilePath = Symbol.getFilePath().str();
334 std::string MinimizedFilePath = minimizeInclude(
335 Include: ((FilePath[0] == '"' || FilePath[0] == '<') ? FilePath
336 : "\"" + FilePath + "\""),
337 SourceManager, HeaderSearch);
338 SymbolCandidates.emplace_back(args: Symbol.getName(), args: Symbol.getSymbolKind(),
339 args&: MinimizedFilePath, args: Symbol.getContexts());
340 }
341 return IncludeFixerContext(FilePath, QuerySymbolInfos, SymbolCandidates);
342}
343
344std::vector<find_all_symbols::SymbolInfo>
345IncludeFixerSemaSource::query(StringRef Query, StringRef ScopedQualifiers,
346 tooling::Range Range) {
347 assert(!Query.empty() && "Empty query!");
348
349 // Save all instances of an unidentified symbol.
350 //
351 // We use conservative behavior for detecting the same unidentified symbol
352 // here. The symbols which have the same ScopedQualifier and RawIdentifier
353 // are considered equal. So that clang-include-fixer avoids false positives,
354 // and always adds missing qualifiers to correct symbols.
355 if (!GenerateDiagnostics && !QuerySymbolInfos.empty()) {
356 if (ScopedQualifiers == QuerySymbolInfos.front().ScopedQualifiers &&
357 Query == QuerySymbolInfos.front().RawIdentifier) {
358 QuerySymbolInfos.push_back(
359 x: {.RawIdentifier: Query.str(), .ScopedQualifiers: std::string(ScopedQualifiers), .Range: Range});
360 }
361 return {};
362 }
363
364 LLVM_DEBUG(llvm::dbgs() << "Looking up '" << Query << "' at ");
365 LLVM_DEBUG(CI->getSourceManager()
366 .getLocForStartOfFile(CI->getSourceManager().getMainFileID())
367 .getLocWithOffset(Range.getOffset())
368 .print(llvm::dbgs(), CI->getSourceManager()));
369 LLVM_DEBUG(llvm::dbgs() << " ...");
370 llvm::StringRef FileName = CI->getSourceManager().getFilename(
371 SpellingLoc: CI->getSourceManager().getLocForStartOfFile(
372 FID: CI->getSourceManager().getMainFileID()));
373
374 QuerySymbolInfos.push_back(
375 x: {.RawIdentifier: Query.str(), .ScopedQualifiers: std::string(ScopedQualifiers), .Range: Range});
376
377 // Query the symbol based on C++ name Lookup rules.
378 // Firstly, lookup the identifier with scoped namespace contexts;
379 // If that fails, falls back to look up the identifier directly.
380 //
381 // For example:
382 //
383 // namespace a {
384 // b::foo f;
385 // }
386 //
387 // 1. lookup a::b::foo.
388 // 2. lookup b::foo.
389 std::string QueryString = ScopedQualifiers.str() + Query.str();
390 // It's unsafe to do nested search for the identifier with scoped namespace
391 // context, it might treat the identifier as a nested class of the scoped
392 // namespace.
393 std::vector<find_all_symbols::SymbolInfo> MatchedSymbols =
394 SymbolIndexMgr.search(Identifier: QueryString, /*IsNestedSearch=*/false, FileName);
395 if (MatchedSymbols.empty())
396 MatchedSymbols =
397 SymbolIndexMgr.search(Identifier: Query, /*IsNestedSearch=*/true, FileName);
398 LLVM_DEBUG(llvm::dbgs() << "Having found " << MatchedSymbols.size()
399 << " symbols\n");
400 // We store a copy of MatchedSymbols in a place where it's globally reachable.
401 // This is used by the standalone version of the tool.
402 this->MatchedSymbols = MatchedSymbols;
403 return MatchedSymbols;
404}
405
406llvm::Expected<tooling::Replacements> createIncludeFixerReplacements(
407 StringRef Code, const IncludeFixerContext &Context,
408 const clang::format::FormatStyle &Style, bool AddQualifiers) {
409 if (Context.getHeaderInfos().empty())
410 return tooling::Replacements();
411 StringRef FilePath = Context.getFilePath();
412 std::string IncludeName =
413 "#include " + Context.getHeaderInfos().front().Header + "\n";
414 // Create replacements for the new header.
415 clang::tooling::Replacements Insertions;
416 auto Err =
417 Insertions.add(R: tooling::Replacement(FilePath, UINT_MAX, 0, IncludeName));
418 if (Err)
419 return std::move(Err);
420
421 auto CleanReplaces = cleanupAroundReplacements(Code, Replaces: Insertions, Style);
422 if (!CleanReplaces)
423 return CleanReplaces;
424
425 auto Replaces = std::move(*CleanReplaces);
426 if (AddQualifiers) {
427 for (const auto &Info : Context.getQuerySymbolInfos()) {
428 // Ignore the empty range.
429 if (Info.Range.getLength() > 0) {
430 auto R = tooling::Replacement(
431 {FilePath, Info.Range.getOffset(), Info.Range.getLength(),
432 Context.getHeaderInfos().front().QualifiedName});
433 auto Err = Replaces.add(R);
434 if (Err) {
435 llvm::consumeError(Err: std::move(Err));
436 R = tooling::Replacement(
437 R.getFilePath(), Replaces.getShiftedCodePosition(Position: R.getOffset()),
438 R.getLength(), R.getReplacementText());
439 Replaces = Replaces.merge(Replaces: tooling::Replacements(R));
440 }
441 }
442 }
443 }
444 return formatReplacements(Code, Replaces, Style);
445}
446
447} // namespace include_fixer
448} // namespace clang
449