1#include "duckdb/function/scalar/string_functions.hpp"
2
3#include "duckdb/common/exception.hpp"
4#include "duckdb/common/vector_operations/vector_operations.hpp"
5#include "duckdb/planner/expression/bound_function_expression.hpp"
6
7namespace duckdb {
8
9template <class UNSIGNED, int NEEDLE_SIZE>
10static idx_t ContainsUnaligned(const unsigned char *haystack, idx_t haystack_size, const unsigned char *needle,
11 idx_t base_offset) {
12 if (NEEDLE_SIZE > haystack_size) {
13 // needle is bigger than haystack: haystack cannot contain needle
14 return DConstants::INVALID_INDEX;
15 }
16 // contains for a small unaligned needle (3/5/6/7 bytes)
17 // we perform unsigned integer comparisons to check for equality of the entire needle in a single comparison
18 // this implementation is inspired by the memmem implementation of freebsd
19
20 // first we set up the needle and the first NEEDLE_SIZE characters of the haystack as UNSIGNED integers
21 UNSIGNED needle_entry = 0;
22 UNSIGNED haystack_entry = 0;
23 const UNSIGNED start = (sizeof(UNSIGNED) * 8) - 8;
24 const UNSIGNED shift = (sizeof(UNSIGNED) - NEEDLE_SIZE) * 8;
25 for (int i = 0; i < NEEDLE_SIZE; i++) {
26 needle_entry |= UNSIGNED(needle[i]) << UNSIGNED(start - i * 8);
27 haystack_entry |= UNSIGNED(haystack[i]) << UNSIGNED(start - i * 8);
28 }
29 // now we perform the actual search
30 for (idx_t offset = NEEDLE_SIZE; offset < haystack_size; offset++) {
31 // for this position we first compare the haystack with the needle
32 if (haystack_entry == needle_entry) {
33 return base_offset + offset - NEEDLE_SIZE;
34 }
35 // now we adjust the haystack entry by
36 // (1) removing the left-most character (shift by 8)
37 // (2) adding the next character (bitwise or, with potential shift)
38 // this shift is only necessary if the needle size is not aligned with the unsigned integer size
39 // (e.g. needle size 3, unsigned integer size 4, we need to shift by 1)
40 haystack_entry = (haystack_entry << 8) | ((UNSIGNED(haystack[offset])) << shift);
41 }
42 if (haystack_entry == needle_entry) {
43 return base_offset + haystack_size - NEEDLE_SIZE;
44 }
45 return DConstants::INVALID_INDEX;
46}
47
48template <class UNSIGNED>
49static idx_t ContainsAligned(const unsigned char *haystack, idx_t haystack_size, const unsigned char *needle,
50 idx_t base_offset) {
51 if (sizeof(UNSIGNED) > haystack_size) {
52 // needle is bigger than haystack: haystack cannot contain needle
53 return DConstants::INVALID_INDEX;
54 }
55 // contains for a small needle aligned with unsigned integer (2/4/8)
56 // similar to ContainsUnaligned, but simpler because we only need to do a reinterpret cast
57 auto needle_entry = Load<UNSIGNED>(needle);
58 for (idx_t offset = 0; offset <= haystack_size - sizeof(UNSIGNED); offset++) {
59 // for this position we first compare the haystack with the needle
60 auto haystack_entry = Load<UNSIGNED>(haystack + offset);
61 if (needle_entry == haystack_entry) {
62 return base_offset + offset;
63 }
64 }
65 return DConstants::INVALID_INDEX;
66}
67
68idx_t ContainsGeneric(const unsigned char *haystack, idx_t haystack_size, const unsigned char *needle,
69 idx_t needle_size, idx_t base_offset) {
70 if (needle_size > haystack_size) {
71 // needle is bigger than haystack: haystack cannot contain needle
72 return DConstants::INVALID_INDEX;
73 }
74 // this implementation is inspired by Raphael Javaux's faststrstr (https://github.com/RaphaelJ/fast_strstr)
75 // generic contains; note that we can't use strstr because we don't have null-terminated strings anymore
76 // we keep track of a shifting window sum of all characters with window size equal to needle_size
77 // this shifting sum is used to avoid calling into memcmp;
78 // we only need to call into memcmp when the window sum is equal to the needle sum
79 // when that happens, the characters are potentially the same and we call into memcmp to check if they are
80 uint32_t sums_diff = 0;
81 for (idx_t i = 0; i < needle_size; i++) {
82 sums_diff += haystack[i];
83 sums_diff -= needle[i];
84 }
85 idx_t offset = 0;
86 while (true) {
87 if (sums_diff == 0 && haystack[offset] == needle[0]) {
88 if (memcmp(s1: haystack + offset, s2: needle, n: needle_size) == 0) {
89 return base_offset + offset;
90 }
91 }
92 if (offset >= haystack_size - needle_size) {
93 return DConstants::INVALID_INDEX;
94 }
95 sums_diff -= haystack[offset];
96 sums_diff += haystack[offset + needle_size];
97 offset++;
98 }
99}
100
101idx_t ContainsFun::Find(const unsigned char *haystack, idx_t haystack_size, const unsigned char *needle,
102 idx_t needle_size) {
103 D_ASSERT(needle_size > 0);
104 // start off by performing a memchr to find the first character of the
105 auto location = memchr(s: haystack, c: needle[0], n: haystack_size);
106 if (location == nullptr) {
107 return DConstants::INVALID_INDEX;
108 }
109 idx_t base_offset = const_uchar_ptr_cast(src: location) - haystack;
110 haystack_size -= base_offset;
111 haystack = const_uchar_ptr_cast(src: location);
112 // switch algorithm depending on needle size
113 switch (needle_size) {
114 case 1:
115 return base_offset;
116 case 2:
117 return ContainsAligned<uint16_t>(haystack, haystack_size, needle, base_offset);
118 case 3:
119 return ContainsUnaligned<uint32_t, 3>(haystack, haystack_size, needle, base_offset);
120 case 4:
121 return ContainsAligned<uint32_t>(haystack, haystack_size, needle, base_offset);
122 case 5:
123 return ContainsUnaligned<uint64_t, 5>(haystack, haystack_size, needle, base_offset);
124 case 6:
125 return ContainsUnaligned<uint64_t, 6>(haystack, haystack_size, needle, base_offset);
126 case 7:
127 return ContainsUnaligned<uint64_t, 7>(haystack, haystack_size, needle, base_offset);
128 case 8:
129 return ContainsAligned<uint64_t>(haystack, haystack_size, needle, base_offset);
130 default:
131 return ContainsGeneric(haystack, haystack_size, needle, needle_size, base_offset);
132 }
133}
134
135idx_t ContainsFun::Find(const string_t &haystack_s, const string_t &needle_s) {
136 auto haystack = const_uchar_ptr_cast(src: haystack_s.GetData());
137 auto haystack_size = haystack_s.GetSize();
138 auto needle = const_uchar_ptr_cast(src: needle_s.GetData());
139 auto needle_size = needle_s.GetSize();
140 if (needle_size == 0) {
141 // empty needle: always true
142 return 0;
143 }
144 return ContainsFun::Find(haystack, haystack_size, needle, needle_size);
145}
146
147struct ContainsOperator {
148 template <class TA, class TB, class TR>
149 static inline TR Operation(TA left, TB right) {
150 return ContainsFun::Find(left, right) != DConstants::INVALID_INDEX;
151 }
152};
153
154ScalarFunction ContainsFun::GetFunction() {
155 return ScalarFunction("contains", // name of the function
156 {LogicalType::VARCHAR, LogicalType::VARCHAR}, // argument list
157 LogicalType::BOOLEAN, // return type
158 ScalarFunction::BinaryFunction<string_t, string_t, bool, ContainsOperator>);
159}
160
161void ContainsFun::RegisterFunction(BuiltinFunctions &set) {
162 set.AddFunction(function: GetFunction());
163}
164
165} // namespace duckdb
166