1//===-- IncludeFixer.cpp - Include inserter based on sema callbacks -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "IncludeFixer.h"
10#include "clang/Format/Format.h"
11#include "clang/Frontend/CompilerInstance.h"
12#include "clang/Lex/HeaderSearch.h"
13#include "clang/Lex/Preprocessor.h"
14#include "clang/Parse/ParseAST.h"
15#include "clang/Sema/Sema.h"
16#include "llvm/Support/Debug.h"
17#include "llvm/Support/raw_ostream.h"
18
19#define DEBUG_TYPE "clang-include-fixer"
20
21using namespace clang;
22
23namespace clang {
24namespace include_fixer {
25namespace {
26/// Manages the parse, gathers include suggestions.
27class Action : public clang::ASTFrontendAction {
28public:
29 explicit Action(SymbolIndexManager &SymbolIndexMgr, bool MinimizeIncludePaths)
30 : SemaSource(new IncludeFixerSemaSource(SymbolIndexMgr,
31 MinimizeIncludePaths,
32 /*GenerateDiagnostics=*/false)) {}
33
34 std::unique_ptr<clang::ASTConsumer>
35 CreateASTConsumer(clang::CompilerInstance &Compiler,
36 StringRef InFile) override {
37 SemaSource->setFilePath(InFile);
38 return std::make_unique<clang::ASTConsumer>();
39 }
40
41 void ExecuteAction() override {
42 clang::CompilerInstance *Compiler = &getCompilerInstance();
43 assert(!Compiler->hasSema() && "CI already has Sema");
44
45 // Set up our hooks into sema and parse the AST.
46 if (hasCodeCompletionSupport() &&
47 !Compiler->getFrontendOpts().CodeCompletionAt.FileName.empty())
48 Compiler->createCodeCompletionConsumer();
49
50 clang::CodeCompleteConsumer *CompletionConsumer = nullptr;
51 if (Compiler->hasCodeCompletionConsumer())
52 CompletionConsumer = &Compiler->getCodeCompletionConsumer();
53
54 Compiler->createSema(getTranslationUnitKind(), CompletionConsumer);
55 SemaSource->setCompilerInstance(Compiler);
56 Compiler->getSema().addExternalSource(SemaSource.get());
57
58 clang::ParseAST(Compiler->getSema(), Compiler->getFrontendOpts().ShowStats,
59 Compiler->getFrontendOpts().SkipFunctionBodies);
60 }
61
62 IncludeFixerContext
63 getIncludeFixerContext(const clang::SourceManager &SourceManager,
64 clang::HeaderSearch &HeaderSearch) const {
65 return SemaSource->getIncludeFixerContext(SourceManager, HeaderSearch,
66 SemaSource->getMatchedSymbols());
67 }
68
69private:
70 IntrusiveRefCntPtr<IncludeFixerSemaSource> SemaSource;
71};
72
73} // namespace
74
75IncludeFixerActionFactory::IncludeFixerActionFactory(
76 SymbolIndexManager &SymbolIndexMgr,
77 std::vector<IncludeFixerContext> &Contexts, StringRef StyleName,
78 bool MinimizeIncludePaths)
79 : SymbolIndexMgr(SymbolIndexMgr), Contexts(Contexts),
80 MinimizeIncludePaths(MinimizeIncludePaths) {}
81
82IncludeFixerActionFactory::~IncludeFixerActionFactory() = default;
83
84bool IncludeFixerActionFactory::runInvocation(
85 std::shared_ptr<clang::CompilerInvocation> Invocation,
86 clang::FileManager *Files,
87 std::shared_ptr<clang::PCHContainerOperations> PCHContainerOps,
88 clang::DiagnosticConsumer *Diagnostics) {
89 assert(Invocation->getFrontendOpts().Inputs.size() == 1);
90
91 // Set up Clang.
92 clang::CompilerInstance Compiler(PCHContainerOps);
93 Compiler.setInvocation(std::move(Invocation));
94 Compiler.setFileManager(Files);
95
96 // Create the compiler's actual diagnostics engine. We want to drop all
97 // diagnostics here.
98 Compiler.createDiagnostics(new clang::IgnoringDiagConsumer,
99 /*ShouldOwnClient=*/true);
100 Compiler.createSourceManager(*Files);
101
102 // We abort on fatal errors so don't let a large number of errors become
103 // fatal. A missing #include can cause thousands of errors.
104 Compiler.getDiagnostics().setErrorLimit(0);
105
106 // Run the parser, gather missing includes.
107 auto ScopedToolAction =
108 std::make_unique<Action>(SymbolIndexMgr, MinimizeIncludePaths);
109 Compiler.ExecuteAction(*ScopedToolAction);
110
111 Contexts.push_back(ScopedToolAction->getIncludeFixerContext(
112 Compiler.getSourceManager(),
113 Compiler.getPreprocessor().getHeaderSearchInfo()));
114
115 // Technically this should only return true if we're sure that we have a
116 // parseable file. We don't know that though. Only inform users of fatal
117 // errors.
118 return !Compiler.getDiagnostics().hasFatalErrorOccurred();
119}
120
121static bool addDiagnosticsForContext(TypoCorrection &Correction,
122 const IncludeFixerContext &Context,
123 StringRef Code, SourceLocation StartOfFile,
124 ASTContext &Ctx) {
125 auto Reps = createIncludeFixerReplacements(
126 Code, Context, format::getLLVMStyle(), /*AddQualifiers=*/false);
127 if (!Reps || Reps->size() != 1)
128 return false;
129
130 unsigned DiagID = Ctx.getDiagnostics().getCustomDiagID(
131 DiagnosticsEngine::Note, "Add '#include %0' to provide the missing "
132 "declaration [clang-include-fixer]");
133
134 // FIXME: Currently we only generate a diagnostic for the first header. Give
135 // the user choices.
136 const tooling::Replacement &Placed = *Reps->begin();
137
138 auto Begin = StartOfFile.getLocWithOffset(Placed.getOffset());
139 auto End = Begin.getLocWithOffset(std::max(0, (int)Placed.getLength() - 1));
140 PartialDiagnostic PD(DiagID, Ctx.getDiagAllocator());
141 PD << Context.getHeaderInfos().front().Header
142 << FixItHint::CreateReplacement(CharSourceRange::getCharRange(Begin, End),
143 Placed.getReplacementText());
144 Correction.addExtraDiagnostic(std::move(PD));
145 return true;
146}
147
148/// Callback for incomplete types. If we encounter a forward declaration we
149/// have the fully qualified name ready. Just query that.
150bool IncludeFixerSemaSource::MaybeDiagnoseMissingCompleteType(
151 clang::SourceLocation Loc, clang::QualType T) {
152 // Ignore spurious callbacks from SFINAE contexts.
153 if (CI->getSema().isSFINAEContext())
154 return false;
155
156 clang::ASTContext &context = CI->getASTContext();
157 std::string QueryString = QualType(T->getUnqualifiedDesugaredType(), 0)
158 .getAsString(context.getPrintingPolicy());
159 LLVM_DEBUG(llvm::dbgs() << "Query missing complete type '" << QueryString
160 << "'");
161 // Pass an empty range here since we don't add qualifier in this case.
162 std::vector<find_all_symbols::SymbolInfo> MatchedSymbols =
163 query(QueryString, "", tooling::Range());
164
165 if (!MatchedSymbols.empty() && GenerateDiagnostics) {
166 TypoCorrection Correction;
167 FileID FID = CI->getSourceManager().getFileID(Loc);
168 StringRef Code = CI->getSourceManager().getBufferData(FID);
169 SourceLocation StartOfFile =
170 CI->getSourceManager().getLocForStartOfFile(FID);
171 addDiagnosticsForContext(
172 Correction,
173 getIncludeFixerContext(CI->getSourceManager(),
174 CI->getPreprocessor().getHeaderSearchInfo(),
175 MatchedSymbols),
176 Code, StartOfFile, CI->getASTContext());
177 for (const PartialDiagnostic &PD : Correction.getExtraDiagnostics())
178 CI->getSema().Diag(Loc, PD);
179 }
180 return true;
181}
182
183/// Callback for unknown identifiers. Try to piece together as much
184/// qualification as we can get and do a query.
185clang::TypoCorrection IncludeFixerSemaSource::CorrectTypo(
186 const DeclarationNameInfo &Typo, int LookupKind, Scope *S, CXXScopeSpec *SS,
187 CorrectionCandidateCallback &CCC, DeclContext *MemberContext,
188 bool EnteringContext, const ObjCObjectPointerType *OPT) {
189 // Ignore spurious callbacks from SFINAE contexts.
190 if (CI->getSema().isSFINAEContext())
191 return clang::TypoCorrection();
192
193 // We currently ignore the unidentified symbol which is not from the
194 // main file.
195 //
196 // However, this is not always true due to templates in a non-self contained
197 // header, consider the case:
198 //
199 // // header.h
200 // template <typename T>
201 // class Foo {
202 // T t;
203 // };
204 //
205 // // test.cc
206 // // We need to add <bar.h> in test.cc instead of header.h.
207 // class Bar;
208 // Foo<Bar> foo;
209 //
210 // FIXME: Add the missing header to the header file where the symbol comes
211 // from.
212 if (!CI->getSourceManager().isWrittenInMainFile(Typo.getLoc()))
213 return clang::TypoCorrection();
214
215 std::string TypoScopeString;
216 if (S) {
217 // FIXME: Currently we only use namespace contexts. Use other context
218 // types for query.
219 for (const auto *Context = S->getEntity(); Context;
220 Context = Context->getParent()) {
221 if (const auto *ND = dyn_cast<NamespaceDecl>(Context)) {
222 if (!ND->getName().empty())
223 TypoScopeString = ND->getNameAsString() + "::" + TypoScopeString;
224 }
225 }
226 }
227
228 auto ExtendNestedNameSpecifier = [this](CharSourceRange Range) {
229 StringRef Source =
230 Lexer::getSourceText(Range, CI->getSourceManager(), CI->getLangOpts());
231
232 // Skip forward until we find a character that's neither identifier nor
233 // colon. This is a bit of a hack around the fact that we will only get a
234 // single callback for a long nested name if a part of the beginning is
235 // unknown. For example:
236 //
237 // llvm::sys::path::parent_path(...)
238 // ^~~~ ^~~
239 // known
240 // ^~~~
241 // unknown, last callback
242 // ^~~~~~~~~~~
243 // no callback
244 //
245 // With the extension we get the full nested name specifier including
246 // parent_path.
247 // FIXME: Don't rely on source text.
248 const char *End = Source.end();
249 while (isAsciiIdentifierContinue(*End) || *End == ':')
250 ++End;
251
252 return std::string(Source.begin(), End);
253 };
254
255 /// If we have a scope specification, use that to get more precise results.
256 std::string QueryString;
257 tooling::Range SymbolRange;
258 const auto &SM = CI->getSourceManager();
259 auto CreateToolingRange = [&QueryString, &SM](SourceLocation BeginLoc) {
260 return tooling::Range(SM.getDecomposedLoc(BeginLoc).second,
261 QueryString.size());
262 };
263 if (SS && SS->getRange().isValid()) {
264 auto Range = CharSourceRange::getTokenRange(SS->getRange().getBegin(),
265 Typo.getLoc());
266
267 QueryString = ExtendNestedNameSpecifier(Range);
268 SymbolRange = CreateToolingRange(Range.getBegin());
269 } else if (Typo.getName().isIdentifier() && !Typo.getLoc().isMacroID()) {
270 auto Range =
271 CharSourceRange::getTokenRange(Typo.getBeginLoc(), Typo.getEndLoc());
272
273 QueryString = ExtendNestedNameSpecifier(Range);
274 SymbolRange = CreateToolingRange(Range.getBegin());
275 } else {
276 QueryString = Typo.getAsString();
277 SymbolRange = CreateToolingRange(Typo.getLoc());
278 }
279
280 LLVM_DEBUG(llvm::dbgs() << "TypoScopeQualifiers: " << TypoScopeString
281 << "\n");
282 std::vector<find_all_symbols::SymbolInfo> MatchedSymbols =
283 query(QueryString, TypoScopeString, SymbolRange);
284
285 if (!MatchedSymbols.empty() && GenerateDiagnostics) {
286 TypoCorrection Correction(Typo.getName());
287 Correction.setCorrectionRange(SS, Typo);
288 FileID FID = SM.getFileID(Typo.getLoc());
289 StringRef Code = SM.getBufferData(FID);
290 SourceLocation StartOfFile = SM.getLocForStartOfFile(FID);
291 if (addDiagnosticsForContext(
292 Correction, getIncludeFixerContext(
293 SM, CI->getPreprocessor().getHeaderSearchInfo(),
294 MatchedSymbols),
295 Code, StartOfFile, CI->getASTContext()))
296 return Correction;
297 }
298 return TypoCorrection();
299}
300
301/// Get the minimal include for a given path.
302std::string IncludeFixerSemaSource::minimizeInclude(
303 StringRef Include, const clang::SourceManager &SourceManager,
304 clang::HeaderSearch &HeaderSearch) const {
305 if (!MinimizeIncludePaths)
306 return std::string(Include);
307
308 // Get the FileEntry for the include.
309 StringRef StrippedInclude = Include.trim("\"<>");
310 auto Entry = SourceManager.getFileManager().getFile(StrippedInclude);
311
312 // If the file doesn't exist return the path from the database.
313 // FIXME: This should never happen.
314 if (!Entry)
315 return std::string(Include);
316
317 bool IsSystem = false;
318 std::string Suggestion =
319 HeaderSearch.suggestPathToFileForDiagnostics(*Entry, "", &IsSystem);
320
321 return IsSystem ? '<' + Suggestion + '>' : '"' + Suggestion + '"';
322}
323
324/// Get the include fixer context for the queried symbol.
325IncludeFixerContext IncludeFixerSemaSource::getIncludeFixerContext(
326 const clang::SourceManager &SourceManager,
327 clang::HeaderSearch &HeaderSearch,
328 ArrayRef<find_all_symbols::SymbolInfo> MatchedSymbols) const {
329 std::vector<find_all_symbols::SymbolInfo> SymbolCandidates;
330 for (const auto &Symbol : MatchedSymbols) {
331 std::string FilePath = Symbol.getFilePath().str();
332 std::string MinimizedFilePath = minimizeInclude(
333 ((FilePath[0] == '"' || FilePath[0] == '<') ? FilePath
334 : "\"" + FilePath + "\""),
335 SourceManager, HeaderSearch);
336 SymbolCandidates.emplace_back(Symbol.getName(), Symbol.getSymbolKind(),
337 MinimizedFilePath, Symbol.getContexts());
338 }
339 return IncludeFixerContext(FilePath, QuerySymbolInfos, SymbolCandidates);
340}
341
342std::vector<find_all_symbols::SymbolInfo>
343IncludeFixerSemaSource::query(StringRef Query, StringRef ScopedQualifiers,
344 tooling::Range Range) {
345 assert(!Query.empty() && "Empty query!");
346
347 // Save all instances of an unidentified symbol.
348 //
349 // We use conservative behavior for detecting the same unidentified symbol
350 // here. The symbols which have the same ScopedQualifier and RawIdentifier
351 // are considered equal. So that clang-include-fixer avoids false positives,
352 // and always adds missing qualifiers to correct symbols.
353 if (!GenerateDiagnostics && !QuerySymbolInfos.empty()) {
354 if (ScopedQualifiers == QuerySymbolInfos.front().ScopedQualifiers &&
355 Query == QuerySymbolInfos.front().RawIdentifier) {
356 QuerySymbolInfos.push_back(
357 {Query.str(), std::string(ScopedQualifiers), Range});
358 }
359 return {};
360 }
361
362 LLVM_DEBUG(llvm::dbgs() << "Looking up '" << Query << "' at ");
363 LLVM_DEBUG(CI->getSourceManager()
364 .getLocForStartOfFile(CI->getSourceManager().getMainFileID())
365 .getLocWithOffset(Range.getOffset())
366 .print(llvm::dbgs(), CI->getSourceManager()));
367 LLVM_DEBUG(llvm::dbgs() << " ...");
368 llvm::StringRef FileName = CI->getSourceManager().getFilename(
369 CI->getSourceManager().getLocForStartOfFile(
370 CI->getSourceManager().getMainFileID()));
371
372 QuerySymbolInfos.push_back(
373 {Query.str(), std::string(ScopedQualifiers), Range});
374
375 // Query the symbol based on C++ name Lookup rules.
376 // Firstly, lookup the identifier with scoped namespace contexts;
377 // If that fails, falls back to look up the identifier directly.
378 //
379 // For example:
380 //
381 // namespace a {
382 // b::foo f;
383 // }
384 //
385 // 1. lookup a::b::foo.
386 // 2. lookup b::foo.
387 std::string QueryString = ScopedQualifiers.str() + Query.str();
388 // It's unsafe to do nested search for the identifier with scoped namespace
389 // context, it might treat the identifier as a nested class of the scoped
390 // namespace.
391 std::vector<find_all_symbols::SymbolInfo> MatchedSymbols =
392 SymbolIndexMgr.search(QueryString, /*IsNestedSearch=*/false, FileName);
393 if (MatchedSymbols.empty())
394 MatchedSymbols =
395 SymbolIndexMgr.search(Query, /*IsNestedSearch=*/true, FileName);
396 LLVM_DEBUG(llvm::dbgs() << "Having found " << MatchedSymbols.size()
397 << " symbols\n");
398 // We store a copy of MatchedSymbols in a place where it's globally reachable.
399 // This is used by the standalone version of the tool.
400 this->MatchedSymbols = MatchedSymbols;
401 return MatchedSymbols;
402}
403
404llvm::Expected<tooling::Replacements> createIncludeFixerReplacements(
405 StringRef Code, const IncludeFixerContext &Context,
406 const clang::format::FormatStyle &Style, bool AddQualifiers) {
407 if (Context.getHeaderInfos().empty())
408 return tooling::Replacements();
409 StringRef FilePath = Context.getFilePath();
410 std::string IncludeName =
411 "#include " + Context.getHeaderInfos().front().Header + "\n";
412 // Create replacements for the new header.
413 clang::tooling::Replacements Insertions;
414 auto Err =
415 Insertions.add(tooling::Replacement(FilePath, UINT_MAX, 0, IncludeName));
416 if (Err)
417 return std::move(Err);
418
419 auto CleanReplaces = cleanupAroundReplacements(Code, Insertions, Style);
420 if (!CleanReplaces)
421 return CleanReplaces;
422
423 auto Replaces = std::move(*CleanReplaces);
424 if (AddQualifiers) {
425 for (const auto &Info : Context.getQuerySymbolInfos()) {
426 // Ignore the empty range.
427 if (Info.Range.getLength() > 0) {
428 auto R = tooling::Replacement(
429 {FilePath, Info.Range.getOffset(), Info.Range.getLength(),
430 Context.getHeaderInfos().front().QualifiedName});
431 auto Err = Replaces.add(R);
432 if (Err) {
433 llvm::consumeError(std::move(Err));
434 R = tooling::Replacement(
435 R.getFilePath(), Replaces.getShiftedCodePosition(R.getOffset()),
436 R.getLength(), R.getReplacementText());
437 Replaces = Replaces.merge(tooling::Replacements(R));
438 }
439 }
440 }
441 }
442 return formatReplacements(Code, Replaces, Style);
443}
444
445} // namespace include_fixer
446} // namespace clang
447