1/*
2 * Copyright (c) 2015-2017, Intel Corporation
3 *
4 * Redistribution and use in source and binary forms, with or without
5 * modification, are permitted provided that the following conditions are met:
6 *
7 * * Redistributions of source code must retain the above copyright notice,
8 * this list of conditions and the following disclaimer.
9 * * Redistributions in binary form must reproduce the above copyright
10 * notice, this list of conditions and the following disclaimer in the
11 * documentation and/or other materials provided with the distribution.
12 * * Neither the name of Intel Corporation nor the names of its contributors
13 * may be used to endorse or promote products derived from this software
14 * without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
20 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
26 * POSSIBILITY OF SUCH DAMAGE.
27 */
28
29/**
30 * \file
31 * \brief Small-write engine build code.
32 */
33
34#include "smallwrite/smallwrite_build.h"
35
36#include "grey.h"
37#include "ue2common.h"
38#include "compiler/compiler.h"
39#include "nfa/dfa_min.h"
40#include "nfa/mcclellancompile.h"
41#include "nfa/mcclellancompile_util.h"
42#include "nfa/nfa_internal.h"
43#include "nfa/rdfa_merge.h"
44#include "nfa/shengcompile.h"
45#include "nfagraph/ng.h"
46#include "nfagraph/ng_depth.h"
47#include "nfagraph/ng_holder.h"
48#include "nfagraph/ng_mcclellan.h"
49#include "nfagraph/ng_reports.h"
50#include "nfagraph/ng_prune.h"
51#include "nfagraph/ng_util.h"
52#include "smallwrite/smallwrite_internal.h"
53#include "util/alloc.h"
54#include "util/bytecode_ptr.h"
55#include "util/charreach.h"
56#include "util/compare.h"
57#include "util/compile_context.h"
58#include "util/container.h"
59#include "util/make_unique.h"
60#include "util/ue2_graph.h"
61#include "util/ue2string.h"
62#include "util/verify_types.h"
63
64#include <map>
65#include <set>
66#include <vector>
67#include <utility>
68
69#include <boost/graph/breadth_first_search.hpp>
70
71using namespace std;
72
73namespace ue2 {
74
75#define DFA_MERGE_MAX_STATES 8000
76#define MAX_TRIE_VERTICES 8000
77
78struct LitTrieVertexProps {
79 LitTrieVertexProps() = default;
80 explicit LitTrieVertexProps(u8 c_in) : c(c_in) {}
81 size_t index; // managed by ue2_graph
82 u8 c = 0; //!< character reached on this vertex
83 flat_set<ReportID> reports; //!< managed reports fired on this vertex
84};
85
86struct LitTrieEdgeProps {
87 size_t index; // managed by ue2_graph
88};
89
90/**
91 * \brief BGL graph used to store a trie of literals (for later AC construction
92 * into a DFA).
93 */
94struct LitTrie
95 : public ue2_graph<LitTrie, LitTrieVertexProps, LitTrieEdgeProps> {
96
97 LitTrie() : root(add_vertex(*this)) {}
98
99 const vertex_descriptor root; //!< Root vertex for the trie.
100};
101
102static
103bool is_empty(const LitTrie &trie) {
104 return num_vertices(trie) <= 1;
105}
106
107static
108std::set<ReportID> all_reports(const LitTrie &trie) {
109 std::set<ReportID> reports;
110 for (auto v : vertices_range(trie)) {
111 insert(&reports, trie[v].reports);
112 }
113 return reports;
114}
115
116using LitTrieVertex = LitTrie::vertex_descriptor;
117using LitTrieEdge = LitTrie::edge_descriptor;
118
119namespace { // unnamed
120
121// Concrete impl class
122class SmallWriteBuildImpl : public SmallWriteBuild {
123public:
124 SmallWriteBuildImpl(size_t num_patterns, const ReportManager &rm,
125 const CompileContext &cc);
126
127 // Construct a runtime implementation.
128 bytecode_ptr<SmallWriteEngine> build(u32 roseQuality) override;
129
130 void add(const NGHolder &g, const ExpressionInfo &expr) override;
131 void add(const ue2_literal &literal, ReportID r) override;
132
133 set<ReportID> all_reports() const override;
134
135 const ReportManager &rm;
136 const CompileContext &cc;
137
138 vector<unique_ptr<raw_dfa>> dfas;
139 LitTrie lit_trie;
140 LitTrie lit_trie_nocase;
141 size_t num_literals = 0;
142 bool poisoned;
143};
144
145} // namespace
146
147SmallWriteBuild::~SmallWriteBuild() = default;
148
149SmallWriteBuildImpl::SmallWriteBuildImpl(size_t num_patterns,
150 const ReportManager &rm_in,
151 const CompileContext &cc_in)
152 : rm(rm_in), cc(cc_in),
153 /* small write is block mode only */
154 poisoned(!cc.grey.allowSmallWrite
155 || cc.streaming
156 || num_patterns > cc.grey.smallWriteMaxPatterns) {
157}
158
159/**
160 * \brief Remove any reports from the given vertex that cannot match within
161 * max_depth due to their constraints.
162 */
163static
164bool pruneOverlongReports(NFAVertex v, NGHolder &g, const depth &max_depth,
165 const ReportManager &rm) {
166 assert(!g[v].reports.empty());
167
168 vector<ReportID> bad_reports;
169
170 for (ReportID id : g[v].reports) {
171 const auto &report = rm.getReport(id);
172 if (report.minOffset > max_depth) {
173 bad_reports.push_back(id);
174 }
175 }
176
177 for (ReportID id : bad_reports) {
178 g[v].reports.erase(id);
179 }
180
181 if (g[v].reports.empty()) {
182 DEBUG_PRINTF("none of vertex %zu's reports can match, cut accepts\n",
183 g[v].index);
184 remove_edge(v, g.accept, g);
185 remove_edge(v, g.acceptEod, g);
186 }
187
188 return !bad_reports.empty();
189}
190
191/**
192 * \brief Prune vertices and reports from the graph that cannot match within
193 * max_depth.
194 */
195static
196bool pruneOverlong(NGHolder &g, const depth &max_depth,
197 const ReportManager &rm) {
198 bool modified = false;
199 auto depths = calcBidiDepths(g);
200
201 for (auto v : vertices_range(g)) {
202 if (is_special(v, g)) {
203 continue;
204 }
205 const auto &d = depths.at(g[v].index);
206 depth min_match_offset = min(d.fromStart.min, d.fromStartDotStar.min)
207 + min(d.toAccept.min, d.toAcceptEod.min);
208 if (min_match_offset > max_depth) {
209 clear_vertex(v, g);
210 modified = true;
211 continue;
212 }
213
214 if (is_match_vertex(v, g)) {
215 modified |= pruneOverlongReports(v, g, max_depth, rm);
216 }
217 }
218
219 if (modified) {
220 pruneUseless(g);
221 DEBUG_PRINTF("pruned graph down to %zu vertices\n", num_vertices(g));
222 }
223
224 return modified;
225}
226
227/**
228 * \brief Attempt to merge the set of DFAs given down into a single raw_dfa.
229 * Returns false on failure.
230 */
231static
232bool mergeDfas(vector<unique_ptr<raw_dfa>> &dfas, const ReportManager &rm,
233 const CompileContext &cc) {
234 assert(!dfas.empty());
235
236 if (dfas.size() == 1) {
237 return true;
238 }
239
240 DEBUG_PRINTF("attempting to merge %zu DFAs\n", dfas.size());
241
242 vector<const raw_dfa *> dfa_ptrs;
243 dfa_ptrs.reserve(dfas.size());
244 for (auto &d : dfas) {
245 dfa_ptrs.push_back(d.get());
246 }
247
248 auto merged = mergeAllDfas(dfa_ptrs, DFA_MERGE_MAX_STATES, &rm, cc.grey);
249 if (!merged) {
250 DEBUG_PRINTF("merge failed\n");
251 return false;
252 }
253
254 DEBUG_PRINTF("merge succeeded, result has %zu states\n",
255 merged->states.size());
256 dfas.clear();
257 dfas.push_back(std::move(merged));
258 return true;
259}
260
261void SmallWriteBuildImpl::add(const NGHolder &g, const ExpressionInfo &expr) {
262 // If the graph is poisoned (i.e. we can't build a SmallWrite version),
263 // we don't even try.
264 if (poisoned) {
265 return;
266 }
267
268 if (expr.som) {
269 DEBUG_PRINTF("no SOM support in small-write engine\n");
270 poisoned = true;
271 return;
272 }
273
274 if (isVacuous(g)) {
275 DEBUG_PRINTF("no vacuous graph support in small-write engine\n");
276 poisoned = true;
277 return;
278 }
279
280 if (any_of_in(::ue2::all_reports(g), [&](ReportID id) {
281 return rm.getReport(id).minLength > 0;
282 })) {
283 DEBUG_PRINTF("no min_length extparam support in small-write engine\n");
284 poisoned = true;
285 return;
286 }
287
288 DEBUG_PRINTF("g=%p\n", &g);
289
290 // make a copy of the graph so that we can modify it for our purposes
291 unique_ptr<NGHolder> h = cloneHolder(g);
292
293 pruneOverlong(*h, depth(cc.grey.smallWriteLargestBuffer), rm);
294
295 reduceGraph(*h, SOM_NONE, expr.utf8, cc);
296
297 if (can_never_match(*h)) {
298 DEBUG_PRINTF("graph can never match in small block\n");
299 return;
300 }
301
302 // Now we can actually build the McClellan DFA
303 assert(h->kind == NFA_OUTFIX);
304 auto r = buildMcClellan(*h, &rm, cc.grey);
305
306 // If we couldn't build a McClellan DFA for this portion, we won't be able
307 // build a smwr which represents the pattern set
308 if (!r) {
309 DEBUG_PRINTF("failed to determinise\n");
310 poisoned = true;
311 return;
312 }
313
314 if (clear_deeper_reports(*r, cc.grey.smallWriteLargestBuffer)) {
315 minimize_hopcroft(*r, cc.grey);
316 }
317
318 dfas.push_back(std::move(r));
319
320 if (dfas.size() >= cc.grey.smallWriteMergeBatchSize) {
321 if (!mergeDfas(dfas, rm, cc)) {
322 dfas.clear();
323 poisoned = true;
324 return;
325 }
326 }
327}
328
329static
330bool add_to_trie(const ue2_literal &literal, ReportID report, LitTrie &trie) {
331 auto u = trie.root;
332 for (const auto &c : literal) {
333 auto next = LitTrie::null_vertex();
334 for (auto v : adjacent_vertices_range(u, trie)) {
335 if (trie[v].c == (u8)c.c) {
336 next = v;
337 break;
338 }
339 }
340 if (!next) {
341 next = add_vertex(LitTrieVertexProps((u8)c.c), trie);
342 add_edge(u, next, trie);
343 }
344 u = next;
345 }
346
347 trie[u].reports.insert(report);
348
349 DEBUG_PRINTF("added '%s' (report %u) to trie, now %zu vertices\n",
350 escapeString(literal).c_str(), report, num_vertices(trie));
351 return num_vertices(trie) <= MAX_TRIE_VERTICES;
352}
353
354void SmallWriteBuildImpl::add(const ue2_literal &literal, ReportID r) {
355 // If the graph is poisoned (i.e. we can't build a SmallWrite version),
356 // we don't even try.
357 if (poisoned) {
358 DEBUG_PRINTF("poisoned\n");
359 return;
360 }
361
362 if (literal.length() > cc.grey.smallWriteLargestBuffer) {
363 DEBUG_PRINTF("exceeded length limit\n");
364 return; /* too long */
365 }
366
367 if (++num_literals > cc.grey.smallWriteMaxLiterals) {
368 DEBUG_PRINTF("exceeded literal limit\n");
369 poisoned = true;
370 return;
371 }
372
373 auto &trie = literal.any_nocase() ? lit_trie_nocase : lit_trie;
374 if (!add_to_trie(literal, r, trie)) {
375 DEBUG_PRINTF("trie add failed\n");
376 poisoned = true;
377 }
378}
379
380namespace {
381
382/**
383 * \brief BFS visitor for Aho-Corasick automaton construction.
384 *
385 * This is doing two things:
386 *
387 * - Computing the failure edges (also called fall or supply edges) for each
388 * vertex, giving the longest suffix of the path to that point that is also
389 * a prefix in the trie reached on the same character. The BFS traversal
390 * makes it possible to build these from earlier failure paths.
391 *
392 * - Computing the output function for each vertex, which is done by
393 * propagating the reports from failure paths as well. This ensures that
394 * substrings of the current path also report correctly.
395 */
396struct ACVisitor : public boost::default_bfs_visitor {
397 ACVisitor(LitTrie &trie_in,
398 unordered_map<LitTrieVertex, LitTrieVertex> &failure_map_in,
399 vector<LitTrieVertex> &ordering_in)
400 : mutable_trie(trie_in), failure_map(failure_map_in),
401 ordering(ordering_in) {}
402
403 LitTrieVertex find_failure_target(LitTrieVertex u, LitTrieVertex v,
404 const LitTrie &trie) {
405 assert(u == trie.root || contains(failure_map, u));
406 assert(!contains(failure_map, v));
407
408 const auto &c = trie[v].c;
409
410 while (u != trie.root) {
411 auto f = failure_map.at(u);
412 for (auto w : adjacent_vertices_range(f, trie)) {
413 if (trie[w].c == c) {
414 return w;
415 }
416 }
417 u = f;
418 }
419
420 DEBUG_PRINTF("no failure edge\n");
421 return LitTrie::null_vertex();
422 }
423
424 void tree_edge(LitTrieEdge e, const LitTrie &trie) {
425 auto u = source(e, trie);
426 auto v = target(e, trie);
427 DEBUG_PRINTF("bfs (%zu, %zu) on '%c'\n", trie[u].index, trie[v].index,
428 trie[v].c);
429 ordering.push_back(v);
430
431 auto f = find_failure_target(u, v, trie);
432
433 if (f) {
434 DEBUG_PRINTF("final failure vertex %zu\n", trie[f].index);
435 failure_map.emplace(v, f);
436
437 // Propagate reports from failure path to ensure we correctly
438 // report substrings.
439 insert(&mutable_trie[v].reports, mutable_trie[f].reports);
440 } else {
441 DEBUG_PRINTF("final failure vertex root\n");
442 failure_map.emplace(v, trie.root);
443 }
444 }
445
446private:
447 LitTrie &mutable_trie; //!< For setting reports property.
448 unordered_map<LitTrieVertex, LitTrieVertex> &failure_map;
449 vector<LitTrieVertex> &ordering; //!< BFS ordering for vertices.
450};
451}
452
453static UNUSED
454bool isSaneTrie(const LitTrie &trie) {
455 CharReach seen;
456 for (auto u : vertices_range(trie)) {
457 seen.clear();
458 for (auto v : adjacent_vertices_range(u, trie)) {
459 if (seen.test(trie[v].c)) {
460 return false;
461 }
462 seen.set(trie[v].c);
463 }
464 }
465 return true;
466}
467
468/**
469 * \brief Turn the given literal trie into an AC automaton by adding additional
470 * edges and reports.
471 */
472static
473void buildAutomaton(LitTrie &trie,
474 unordered_map<LitTrieVertex, LitTrieVertex> &failure_map,
475 vector<LitTrieVertex> &ordering) {
476 assert(isSaneTrie(trie));
477
478 // Find our failure transitions and reports.
479 failure_map.reserve(num_vertices(trie));
480 ordering.reserve(num_vertices(trie));
481 ACVisitor ac_vis(trie, failure_map, ordering);
482 boost::breadth_first_search(trie, trie.root, visitor(ac_vis));
483
484 // Compute missing edges from failure map.
485 for (auto v : ordering) {
486 DEBUG_PRINTF("vertex %zu\n", trie[v].index);
487 CharReach seen;
488 for (auto w : adjacent_vertices_range(v, trie)) {
489 DEBUG_PRINTF("edge to %zu with reach 0x%02x\n", trie[w].index,
490 trie[w].c);
491 assert(!seen.test(trie[w].c));
492 seen.set(trie[w].c);
493 }
494 auto parent = failure_map.at(v);
495 for (auto w : adjacent_vertices_range(parent, trie)) {
496 if (!seen.test(trie[w].c)) {
497 add_edge(v, w, trie);
498 }
499 }
500 }
501}
502
503static
504vector<u32> findDistFromRoot(const LitTrie &trie) {
505 vector<u32> dist(num_vertices(trie), UINT32_MAX);
506 dist[trie[trie.root].index] = 0;
507
508 // BFS to find dist from root.
509 breadth_first_search(
510 trie, trie.root,
511 visitor(make_bfs_visitor(record_distances(
512 make_iterator_property_map(dist.begin(),
513 get(&LitTrieVertexProps::index, trie)),
514 boost::on_tree_edge()))));
515
516 return dist;
517}
518
519static
520vector<u32> findDistToAccept(const LitTrie &trie) {
521 vector<u32> dist(num_vertices(trie), UINT32_MAX);
522
523 // Start with all reporting vertices.
524 deque<LitTrieVertex> q;
525 for (auto v : vertices_range(trie)) {
526 if (!trie[v].reports.empty()) {
527 q.push_back(v);
528 dist[trie[v].index] = 0;
529 }
530 }
531
532 // Custom BFS, since we have a pile of sources.
533 while (!q.empty()) {
534 auto v = q.front();
535 q.pop_front();
536 u32 d = dist[trie[v].index];
537
538 for (auto u : inv_adjacent_vertices_range(v, trie)) {
539 auto &u_dist = dist[trie[u].index];
540 if (u_dist == UINT32_MAX) {
541 q.push_back(u);
542 u_dist = d + 1;
543 }
544 }
545 }
546
547 return dist;
548}
549
550/**
551 * \brief Prune all vertices from the trie that do not lie on a path from root
552 * to accept of length <= max_depth.
553 */
554static
555void pruneTrie(LitTrie &trie, u32 max_depth) {
556 DEBUG_PRINTF("pruning trie to %u\n", max_depth);
557
558 auto dist_from_root = findDistFromRoot(trie);
559 auto dist_to_accept = findDistToAccept(trie);
560
561 vector<LitTrieVertex> dead;
562 for (auto v : vertices_range(trie)) {
563 if (v == trie.root) {
564 continue;
565 }
566 auto v_index = trie[v].index;
567 DEBUG_PRINTF("vertex %zu: from_start=%u, to_accept=%u\n", trie[v].index,
568 dist_from_root[v_index], dist_to_accept[v_index]);
569 assert(dist_from_root[v_index] != UINT32_MAX);
570 assert(dist_to_accept[v_index] != UINT32_MAX);
571 u32 min_path_len = dist_from_root[v_index] + dist_to_accept[v_index];
572 if (min_path_len > max_depth) {
573 DEBUG_PRINTF("pruning vertex %zu (min path len %u)\n",
574 trie[v].index, min_path_len);
575 clear_vertex(v, trie);
576 dead.push_back(v);
577 }
578 }
579
580 if (dead.empty()) {
581 return;
582 }
583
584 for (auto v : dead) {
585 remove_vertex(v, trie);
586 }
587
588 DEBUG_PRINTF("%zu vertices remain\n", num_vertices(trie));
589
590 renumber_edges(trie);
591 renumber_vertices(trie);
592}
593
594static
595vector<CharReach> getAlphabet(const LitTrie &trie, bool nocase) {
596 vector<CharReach> esets = {CharReach::dot()};
597 for (auto v : vertices_range(trie)) {
598 if (v == trie.root) {
599 continue;
600 }
601
602 CharReach cr;
603 if (nocase) {
604 cr.set(mytoupper(trie[v].c));
605 cr.set(mytolower(trie[v].c));
606 } else {
607 cr.set(trie[v].c);
608 }
609
610 for (size_t i = 0; i < esets.size(); i++) {
611 if (esets[i].count() == 1) {
612 continue;
613 }
614
615 CharReach t = cr & esets[i];
616 if (t.any() && t != esets[i]) {
617 esets[i] &= ~t;
618 esets.push_back(t);
619 }
620 }
621 }
622
623 // For deterministic compiles.
624 sort(esets.begin(), esets.end());
625 return esets;
626}
627
628static
629u16 buildAlphabet(const LitTrie &trie, bool nocase,
630 array<u16, ALPHABET_SIZE> &alpha,
631 array<u16, ALPHABET_SIZE> &unalpha) {
632 const auto &esets = getAlphabet(trie, nocase);
633
634 u16 i = 0;
635 for (const auto &cr : esets) {
636 u16 leader = cr.find_first();
637 for (size_t s = cr.find_first(); s != cr.npos; s = cr.find_next(s)) {
638 alpha[s] = i;
639 }
640 unalpha[i] = leader;
641 i++;
642 }
643
644 for (u16 j = N_CHARS; j < ALPHABET_SIZE; j++, i++) {
645 alpha[j] = i;
646 unalpha[i] = j;
647 }
648
649 DEBUG_PRINTF("alphabet size %u\n", i);
650 return i;
651}
652
653/**
654 * \brief Calculate state mapping, from vertex in trie to state index in BFS
655 * ordering.
656 */
657static
658unordered_map<LitTrieVertex, u32>
659makeStateMap(const LitTrie &trie, const vector<LitTrieVertex> &ordering) {
660 unordered_map<LitTrieVertex, u32> state_ids;
661 state_ids.reserve(num_vertices(trie));
662 u32 idx = DEAD_STATE + 1;
663 state_ids.emplace(trie.root, idx++);
664 for (auto v : ordering) {
665 state_ids.emplace(v, idx++);
666 }
667 assert(state_ids.size() == num_vertices(trie));
668 return state_ids;
669}
670
671/** \brief Construct a raw_dfa from a literal trie. */
672static
673unique_ptr<raw_dfa> buildDfa(LitTrie &trie, bool nocase) {
674 DEBUG_PRINTF("trie has %zu states\n", num_vertices(trie));
675
676 vector<LitTrieVertex> ordering;
677 unordered_map<LitTrieVertex, LitTrieVertex> failure_map;
678 buildAutomaton(trie, failure_map, ordering);
679
680 // Construct DFA states in BFS order.
681 const auto state_ids = makeStateMap(trie, ordering);
682
683 auto rdfa = make_unique<raw_dfa>(NFA_OUTFIX);
684
685 // Calculate alphabet.
686 array<u16, ALPHABET_SIZE> unalpha;
687 auto &alpha = rdfa->alpha_remap;
688 rdfa->alpha_size = buildAlphabet(trie, nocase, alpha, unalpha);
689
690 // Construct states and transitions.
691 const u16 root_state = state_ids.at(trie.root);
692 assert(root_state == DEAD_STATE + 1);
693 rdfa->start_anchored = root_state;
694 rdfa->start_floating = root_state;
695 rdfa->states.resize(num_vertices(trie) + 1, dstate(rdfa->alpha_size));
696
697 // Dead state.
698 fill(rdfa->states[DEAD_STATE].next.begin(),
699 rdfa->states[DEAD_STATE].next.end(), DEAD_STATE);
700
701 for (auto u : vertices_range(trie)) {
702 auto u_state = state_ids.at(u);
703 DEBUG_PRINTF("state %u\n", u_state);
704 assert(u_state < rdfa->states.size());
705 auto &ds = rdfa->states[u_state];
706 ds.reports = trie[u].reports;
707 if (!ds.reports.empty()) {
708 DEBUG_PRINTF("reports: %s\n", as_string_list(ds.reports).c_str());
709 }
710
711 // Set daddy state from failure map.
712 if (u == trie.root) {
713 ds.daddy = DEAD_STATE;
714 } else {
715 assert(contains(failure_map, u));
716 ds.daddy = state_ids.at(failure_map.at(u));
717 }
718
719 // By default, transition back to the root.
720 fill(ds.next.begin(), ds.next.end(), root_state);
721 // TOP should be a self-loop.
722 ds.next[alpha[TOP]] = u_state;
723
724 // Add in the real transitions.
725 for (auto v : adjacent_vertices_range(u, trie)) {
726 if (v == trie.root) {
727 continue;
728 }
729 auto v_state = state_ids.at(v);
730 u16 sym = alpha[trie[v].c];
731 DEBUG_PRINTF("edge to %u on 0x%02x (sym %u)\n", v_state,
732 trie[v].c, sym);
733 assert(sym < ds.next.size());
734 assert(ds.next[sym] == root_state);
735 ds.next[sym] = v_state;
736 }
737 }
738
739 return rdfa;
740}
741
742#define MAX_GOOD_ACCEL_DEPTH 4
743
744static
745bool is_slow(const raw_dfa &rdfa, const set<dstate_id_t> &accel,
746 u32 roseQuality) {
747 /* we consider a dfa as slow if there is no way to quickly get into an accel
748 * state/dead state. In these cases, it is more likely that we will be
749 * running at our unaccelerated dfa speeds so the small write engine is only
750 * competitive over a small region where start up costs are dominant. */
751
752 if (roseQuality) {
753 return true;
754 }
755
756 set<dstate_id_t> visited;
757 set<dstate_id_t> next;
758 set<dstate_id_t> curr;
759 curr.insert(rdfa.start_anchored);
760
761 u32 ialpha_size = rdfa.getImplAlphaSize();
762
763 for (u32 i = 0; i < MAX_GOOD_ACCEL_DEPTH; i++) {
764 next.clear();
765 for (dstate_id_t s : curr) {
766 if (contains(visited, s)) {
767 continue;
768 }
769 visited.insert(s);
770 if (s == DEAD_STATE || contains(accel, s)) {
771 return false;
772 }
773
774 for (size_t j = 0; j < ialpha_size; j++) {
775 next.insert(rdfa.states[s].next[j]);
776 }
777 }
778 curr.swap(next);
779 }
780
781 return true;
782}
783
784static
785bytecode_ptr<NFA> getDfa(raw_dfa &rdfa, const CompileContext &cc,
786 const ReportManager &rm, bool has_non_literals,
787 set<dstate_id_t> &accel_states) {
788 // If we determinised only literals, then we only need to consider the init
789 // states for acceleration.
790 bool only_accel_init = !has_non_literals;
791 bool trust_daddy_states = !has_non_literals;
792
793 bytecode_ptr<NFA> dfa = nullptr;
794 if (cc.grey.allowSmallWriteSheng) {
795 dfa = shengCompile(rdfa, cc, rm, only_accel_init, &accel_states);
796 }
797 if (!dfa) {
798 dfa = mcclellanCompile(rdfa, cc, rm, only_accel_init,
799 trust_daddy_states, &accel_states);
800 }
801 return dfa;
802}
803
804static
805bytecode_ptr<NFA> prepEngine(raw_dfa &rdfa, u32 roseQuality,
806 const CompileContext &cc, const ReportManager &rm,
807 bool has_non_literals, u32 *start_offset,
808 u32 *small_region) {
809 *start_offset = remove_leading_dots(rdfa);
810
811 // Unleash the McClellan!
812 set<dstate_id_t> accel_states;
813
814 auto nfa = getDfa(rdfa, cc, rm, has_non_literals, accel_states);
815 if (!nfa) {
816 DEBUG_PRINTF("DFA compile failed for smallwrite NFA\n");
817 return nullptr;
818 }
819
820 if (is_slow(rdfa, accel_states, roseQuality)) {
821 DEBUG_PRINTF("is slow\n");
822 *small_region = cc.grey.smallWriteLargestBufferBad;
823 if (*small_region <= *start_offset) {
824 return nullptr;
825 }
826 if (clear_deeper_reports(rdfa, *small_region - *start_offset)) {
827 minimize_hopcroft(rdfa, cc.grey);
828 if (rdfa.start_anchored == DEAD_STATE) {
829 DEBUG_PRINTF("all patterns pruned out\n");
830 return nullptr;
831 }
832
833 nfa = getDfa(rdfa, cc, rm, has_non_literals, accel_states);
834 if (!nfa) {
835 DEBUG_PRINTF("DFA compile failed for smallwrite NFA\n");
836 assert(0); /* able to build orig dfa but not the trimmed? */
837 return nullptr;
838 }
839 }
840 } else {
841 *small_region = cc.grey.smallWriteLargestBuffer;
842 }
843
844 assert(isDfaType(nfa->type));
845 if (nfa->length > cc.grey.limitSmallWriteOutfixSize
846 || nfa->length > cc.grey.limitDFASize) {
847 DEBUG_PRINTF("smallwrite outfix size too large\n");
848 return nullptr; /* this is just a soft failure - don't build smwr */
849 }
850
851 nfa->queueIndex = 0; /* dummy, small write API does not use queue */
852 return nfa;
853}
854
855// SmallWriteBuild factory
856unique_ptr<SmallWriteBuild> makeSmallWriteBuilder(size_t num_patterns,
857 const ReportManager &rm,
858 const CompileContext &cc) {
859 return ue2::make_unique<SmallWriteBuildImpl>(num_patterns, rm, cc);
860}
861
862bytecode_ptr<SmallWriteEngine> SmallWriteBuildImpl::build(u32 roseQuality) {
863 const bool has_literals = !is_empty(lit_trie) || !is_empty(lit_trie_nocase);
864 const bool has_non_literals = !dfas.empty();
865 if (dfas.empty() && !has_literals) {
866 DEBUG_PRINTF("no smallwrite engine\n");
867 poisoned = true;
868 return nullptr;
869 }
870
871 if (poisoned) {
872 DEBUG_PRINTF("some pattern could not be made into a smallwrite dfa\n");
873 return nullptr;
874 }
875
876 // We happen to know that if the rose is high quality, we're going to limit
877 // depth further.
878 if (roseQuality) {
879 u32 max_depth = cc.grey.smallWriteLargestBufferBad;
880 if (!is_empty(lit_trie)) {
881 pruneTrie(lit_trie, max_depth);
882 }
883 if (!is_empty(lit_trie_nocase)) {
884 pruneTrie(lit_trie_nocase, max_depth);
885 }
886 }
887
888 if (!is_empty(lit_trie)) {
889 dfas.push_back(buildDfa(lit_trie, false));
890 DEBUG_PRINTF("caseful literal dfa with %zu states\n",
891 dfas.back()->states.size());
892 }
893 if (!is_empty(lit_trie_nocase)) {
894 dfas.push_back(buildDfa(lit_trie_nocase, true));
895 DEBUG_PRINTF("nocase literal dfa with %zu states\n",
896 dfas.back()->states.size());
897 }
898
899 if (dfas.empty()) {
900 DEBUG_PRINTF("no dfa, pruned everything away\n");
901 return nullptr;
902 }
903
904 if (!mergeDfas(dfas, rm, cc)) {
905 dfas.clear();
906 return nullptr;
907 }
908
909 assert(dfas.size() == 1);
910 auto rdfa = std::move(dfas.front());
911 dfas.clear();
912
913 DEBUG_PRINTF("building rdfa %p\n", rdfa.get());
914
915 u32 start_offset;
916 u32 small_region;
917 auto nfa = prepEngine(*rdfa, roseQuality, cc, rm, has_non_literals,
918 &start_offset, &small_region);
919 if (!nfa) {
920 DEBUG_PRINTF("some smallwrite outfix could not be prepped\n");
921 /* just skip the smallwrite optimization */
922 poisoned = true;
923 return nullptr;
924 }
925
926 u32 size = sizeof(SmallWriteEngine) + nfa->length;
927 auto smwr = make_zeroed_bytecode_ptr<SmallWriteEngine>(size);
928
929 smwr->size = size;
930 smwr->start_offset = start_offset;
931 smwr->largestBuffer = small_region;
932
933 /* copy in nfa after the smwr */
934 assert(ISALIGNED_CL(smwr.get() + 1));
935 memcpy(smwr.get() + 1, nfa.get(), nfa->length);
936
937 DEBUG_PRINTF("smallwrite done %p\n", smwr.get());
938 return smwr;
939}
940
941set<ReportID> SmallWriteBuildImpl::all_reports() const {
942 set<ReportID> reports;
943 if (poisoned) {
944 return reports;
945 }
946
947 for (const auto &rdfa : dfas) {
948 insert(&reports, ::ue2::all_reports(*rdfa));
949 }
950
951 insert(&reports, ::ue2::all_reports(lit_trie));
952 insert(&reports, ::ue2::all_reports(lit_trie_nocase));
953
954 return reports;
955}
956
957} // namespace ue2
958