1//
2// immer: immutable data structures for C++
3// Copyright (C) 2016, 2017, 2018 Juan Pedro Bolivar Puente
4//
5// This software is distributed under the Boost Software License, Version 1.0.
6// See accompanying file LICENSE or copy at http://boost.org/LICENSE_1_0.txt
7//
8
9#pragma once
10
11#include <immer/config.hpp>
12#include <immer/detail/hamts/node.hpp>
13
14#include <algorithm>
15
16namespace immer {
17namespace detail {
18namespace hamts {
19
20template <typename T,
21 typename Hash,
22 typename Equal,
23 typename MemoryPolicy,
24 bits_t B>
25struct champ
26{
27 static constexpr auto bits = B;
28
29 using node_t = node<T, Hash, Equal, MemoryPolicy, B>;
30 using bitmap_t = typename get_bitmap_type<B>::type;
31
32 static_assert(branches<B> <= sizeof(bitmap_t) * 8, "");
33
34 node_t* root;
35 size_t size;
36
37 static const champ& empty()
38 {
39 static const champ empty_ {
40 node_t::make_inner_n(0),
41 0,
42 };
43 return empty_;
44 }
45
46 champ(node_t* r, size_t sz)
47 : root{r}, size{sz}
48 {
49 }
50
51 champ(const champ& other)
52 : champ{other.root, other.size}
53 {
54 inc();
55 }
56
57 champ(champ&& other)
58 : champ{empty()}
59 {
60 swap(*this, other);
61 }
62
63 champ& operator=(const champ& other)
64 {
65 auto next = other;
66 swap(*this, next);
67 return *this;
68 }
69
70 champ& operator=(champ&& other)
71 {
72 swap(*this, other);
73 return *this;
74 }
75
76 friend void swap(champ& x, champ& y)
77 {
78 using std::swap;
79 swap(x.root, y.root);
80 swap(x.size, y.size);
81 }
82
83 ~champ()
84 {
85 dec();
86 }
87
88 void inc() const
89 {
90 root->inc();
91 }
92
93 void dec() const
94 {
95 if (root->dec())
96 node_t::delete_deep(root, 0);
97 }
98
99 template <typename Fn>
100 void for_each_chunk(Fn&& fn) const
101 {
102 for_each_chunk_traversal(root, 0, fn);
103 }
104
105 template <typename Fn>
106 void for_each_chunk_traversal(node_t* node, count_t depth, Fn&& fn) const
107 {
108 if (depth < max_depth<B>) {
109 auto datamap = node->datamap();
110 if (datamap)
111 fn(node->values(), node->values() + popcount(datamap));
112 auto nodemap = node->nodemap();
113 if (nodemap) {
114 auto fst = node->children();
115 auto lst = fst + popcount(nodemap);
116 for (; fst != lst; ++fst)
117 for_each_chunk_traversal(*fst, depth + 1, fn);
118 }
119 } else {
120 fn(node->collisions(), node->collisions() + node->collision_count());
121 }
122 }
123
124 template <typename Project, typename Default, typename K>
125 decltype(auto) get(const K& k) const
126 {
127 auto node = root;
128 auto hash = Hash{}(k);
129 for (auto i = count_t{}; i < max_depth<B>; ++i) {
130 auto bit = bitmap_t{1u} << (hash & mask<B>);
131 if (node->nodemap() & bit) {
132 auto offset = popcount(node->nodemap() & (bit - 1));
133 node = node->children() [offset];
134 hash = hash >> B;
135 } else if (node->datamap() & bit) {
136 auto offset = popcount(node->datamap() & (bit - 1));
137 auto val = node->values() + offset;
138 if (Equal{}(*val, k))
139 return Project{}(*val);
140 else
141 return Default{}();
142 } else {
143 return Default{}();
144 }
145 }
146 auto fst = node->collisions();
147 auto lst = fst + node->collision_count();
148 for (; fst != lst; ++fst)
149 if (Equal{}(*fst, k))
150 return Project{}(*fst);
151 return Default{}();
152 }
153
154 std::pair<node_t*, bool>
155 do_add(node_t* node, T v, hash_t hash, shift_t shift) const
156 {
157 if (shift == max_shift<B>) {
158 auto fst = node->collisions();
159 auto lst = fst + node->collision_count();
160 for (; fst != lst; ++fst)
161 if (Equal{}(*fst, v))
162 return {
163 node_t::copy_collision_replace(node, fst, std::move(v)),
164 false
165 };
166 return {
167 node_t::copy_collision_insert(node, std::move(v)),
168 true
169 };
170 } else {
171 auto idx = (hash & (mask<B> << shift)) >> shift;
172 auto bit = bitmap_t{1u} << idx;
173 if (node->nodemap() & bit) {
174 auto offset = popcount(node->nodemap() & (bit - 1));
175 auto result = do_add(node->children() [offset],
176 std::move(v), hash,
177 shift + B);
178 try {
179 result.first = node_t::copy_inner_replace(
180 node, offset, result.first);
181 return result;
182 } catch (...) {
183 node_t::delete_deep_shift(result.first, shift + B);
184 throw;
185 }
186 } else if (node->datamap() & bit) {
187 auto offset = popcount(node->datamap() & (bit - 1));
188 auto val = node->values() + offset;
189 if (Equal{}(*val, v))
190 return {
191 node_t::copy_inner_replace_value(
192 node, offset, std::move(v)),
193 false
194 };
195 else {
196 auto child = node_t::make_merged(shift + B,
197 std::move(v), hash,
198 *val, Hash{}(*val));
199 try {
200 return {
201 node_t::copy_inner_replace_merged(
202 node, bit, offset, child),
203 true
204 };
205 } catch (...) {
206 node_t::delete_deep_shift(child, shift + B);
207 throw;
208 }
209 }
210 } else {
211 return {
212 node_t::copy_inner_insert_value(node, bit, std::move(v)),
213 true
214 };
215 }
216 }
217 }
218
219 champ add(T v) const
220 {
221 auto hash = Hash{}(v);
222 auto res = do_add(root, std::move(v), hash, 0);
223 auto new_size = size + (res.second ? 1 : 0);
224 return { res.first, new_size };
225 }
226
227 template <typename Project, typename Default, typename Combine,
228 typename K, typename Fn>
229 std::pair<node_t*, bool>
230 do_update(node_t* node, K&& k, Fn&& fn,
231 hash_t hash, shift_t shift) const
232 {
233 if (shift == max_shift<B>) {
234 auto fst = node->collisions();
235 auto lst = fst + node->collision_count();
236 for (; fst != lst; ++fst)
237 if (Equal{}(*fst, k))
238 return {
239 node_t::copy_collision_replace(
240 node, fst, Combine{}(std::forward<K>(k),
241 std::forward<Fn>(fn)(
242 Project{}(*fst)))),
243 false
244 };
245 return {
246 node_t::copy_collision_insert(
247 node, Combine{}(std::forward<K>(k),
248 std::forward<Fn>(fn)(
249 Default{}()))),
250 true
251 };
252 } else {
253 auto idx = (hash & (mask<B> << shift)) >> shift;
254 auto bit = bitmap_t{1u} << idx;
255 if (node->nodemap() & bit) {
256 auto offset = popcount(node->nodemap() & (bit - 1));
257 auto result = do_update<Project, Default, Combine>(
258 node->children() [offset], k, std::forward<Fn>(fn),
259 hash, shift + B);
260 try {
261 result.first = node_t::copy_inner_replace(
262 node, offset, result.first);
263 return result;
264 } catch (...) {
265 node_t::delete_deep_shift(result.first, shift + B);
266 throw;
267 }
268 } else if (node->datamap() & bit) {
269 auto offset = popcount(node->datamap() & (bit - 1));
270 auto val = node->values() + offset;
271 if (Equal{}(*val, k))
272 return {
273 node_t::copy_inner_replace_value(
274 node, offset, Combine{}(std::forward<K>(k),
275 std::forward<Fn>(fn)(
276 Project{}(*val)))),
277 false
278 };
279 else {
280 auto child = node_t::make_merged(
281 shift + B, Combine{}(std::forward<K>(k),
282 std::forward<Fn>(fn)(
283 Default{}())),
284 hash, *val, Hash{}(*val));
285 try {
286 return {
287 node_t::copy_inner_replace_merged(
288 node, bit, offset, child),
289 true
290 };
291 } catch (...) {
292 node_t::delete_deep_shift(child, shift + B);
293 throw;
294 }
295 }
296 } else {
297 return {
298 node_t::copy_inner_insert_value(
299 node, bit, Combine{}(std::forward<K>(k),
300 std::forward<Fn>(fn)(
301 Default{}()))),
302 true
303 };
304 }
305 }
306 }
307
308 template <typename Project, typename Default, typename Combine,
309 typename K, typename Fn>
310 champ update(const K& k, Fn&& fn) const
311 {
312 auto hash = Hash{}(k);
313 auto res = do_update<Project, Default, Combine>(
314 root, k, std::forward<Fn>(fn), hash, 0);
315 auto new_size = size + (res.second ? 1 : 0);
316 return { res.first, new_size };
317 }
318
319 // basically:
320 // variant<monostate_t, T*, node_t*>
321 // boo bad we are not using... C++17 :'(
322 struct sub_result
323 {
324 enum kind_t
325 {
326 nothing,
327 singleton,
328 tree
329 };
330
331 union data_t
332 {
333 T* singleton;
334 node_t* tree;
335 };
336
337 kind_t kind;
338 data_t data;
339
340 sub_result() : kind{nothing} {};
341 sub_result(T* x) : kind{singleton} { data.singleton = x; };
342 sub_result(node_t* x) : kind{tree} { data.tree = x; };
343 };
344
345 template <typename K>
346 sub_result do_sub(node_t* node, const K& k, hash_t hash, shift_t shift) const
347 {
348 if (shift == max_shift<B>) {
349 auto fst = node->collisions();
350 auto lst = fst + node->collision_count();
351 for (auto cur = fst; cur != lst; ++cur)
352 if (Equal{}(*cur, k))
353 return node->collision_count() > 2
354 ? node_t::copy_collision_remove(node, cur)
355 : sub_result{fst + (cur == fst)};
356 return {};
357 } else {
358 auto idx = (hash & (mask<B> << shift)) >> shift;
359 auto bit = bitmap_t{1u} << idx;
360 if (node->nodemap() & bit) {
361 auto offset = popcount(node->nodemap() & (bit - 1));
362 auto result = do_sub(node->children() [offset],
363 k, hash, shift + B);
364 switch (result.kind) {
365 case sub_result::nothing:
366 return {};
367 case sub_result::singleton:
368 return node->datamap() == 0 &&
369 popcount(node->nodemap()) == 1 &&
370 shift > 0
371 ? result
372 : node_t::copy_inner_replace_inline(
373 node, bit, offset, *result.data.singleton);
374 case sub_result::tree:
375 try {
376 return node_t::copy_inner_replace(node, offset,
377 result.data.tree);
378 } catch (...) {
379 node_t::delete_deep_shift(result.data.tree, shift + B);
380 throw;
381 }
382 }
383 } else if (node->datamap() & bit) {
384 auto offset = popcount(node->datamap() & (bit - 1));
385 auto val = node->values() + offset;
386 if (Equal{}(*val, k)) {
387 auto nv = popcount(node->datamap());
388 if (node->nodemap() || nv > 2)
389 return node_t::copy_inner_remove_value(node, bit, offset);
390 else if (nv == 2) {
391 return shift > 0
392 ? sub_result{node->values() + !offset}
393 : node_t::make_inner_n(0,
394 node->datamap() & ~bit,
395 node->values()[!offset]);
396 } else {
397 assert(shift == 0);
398 return empty().root->inc();
399 }
400 }
401 }
402 return {};
403 }
404 }
405
406 template <typename K>
407 champ sub(const K& k) const
408 {
409 auto hash = Hash{}(k);
410 auto res = do_sub(root, k, hash, 0);
411 switch (res.kind) {
412 case sub_result::nothing:
413 return *this;
414 case sub_result::tree:
415 return {
416 res.data.tree,
417 size - 1
418 };
419 default:
420 IMMER_UNREACHABLE;
421 }
422 }
423
424 template <typename Eq=Equal>
425 bool equals(const champ& other) const
426 {
427 return size == other.size && equals_tree<Eq>(root, other.root, 0);
428 }
429
430 template <typename Eq>
431 static bool equals_tree(const node_t* a, const node_t* b, count_t depth)
432 {
433 if (a == b)
434 return true;
435 else if (depth == max_depth<B>) {
436 auto nv = a->collision_count();
437 return nv == b->collision_count() &&
438 equals_collisions<Eq>(a->collisions(), b->collisions(), nv);
439 } else {
440 if (a->nodemap() != b->nodemap() ||
441 a->datamap() != b->datamap())
442 return false;
443 auto n = popcount(a->nodemap());
444 for (auto i = count_t{}; i < n; ++i)
445 if (!equals_tree<Eq>(a->children()[i], b->children()[i], depth + 1))
446 return false;
447 auto nv = popcount(a->datamap());
448 return !nv || equals_values<Eq>(a->values(), b->values(), nv);
449 }
450 }
451
452 template <typename Eq>
453 static bool equals_values(const T* a, const T* b, count_t n)
454 {
455 return std::equal(a, a + n, b, Eq{});
456 }
457
458 template <typename Eq>
459 static bool equals_collisions(const T* a, const T* b, count_t n)
460 {
461 auto ae = a + n;
462 auto be = b + n;
463 for (; a != ae; ++a) {
464 for (auto fst = b; fst != be; ++fst)
465 if (Eq{}(*a, *fst))
466 goto good;
467 return false;
468 good: continue;
469 }
470 return true;
471 }
472};
473
474} // namespace hamts
475} // namespace detail
476} // namespace immer
477