1// SPDX-License-Identifier: Apache-2.0
2// ----------------------------------------------------------------------------
3// Copyright 2011-2023 Arm Limited
4//
5// Licensed under the Apache License, Version 2.0 (the "License"); you may not
6// use this file except in compliance with the License. You may obtain a copy
7// of the License at:
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14// License for the specific language governing permissions and limitations
15// under the License.
16// ----------------------------------------------------------------------------
17
18/**
19 * @brief Functions to generate block size descriptor and decimation tables.
20 */
21
22#include "astcenc_internal.h"
23
24/**
25 * @brief Decode the properties of an encoded 2D block mode.
26 *
27 * @param block_mode The encoded block mode.
28 * @param[out] x_weights The number of weights in the X dimension.
29 * @param[out] y_weights The number of weights in the Y dimension.
30 * @param[out] is_dual_plane True if this block mode has two weight planes.
31 * @param[out] quant_mode The quantization level for the weights.
32 * @param[out] weight_bits The storage bit count for the weights.
33 *
34 * @return Returns true if a valid mode, false otherwise.
35 */
36static bool decode_block_mode_2d(
37 unsigned int block_mode,
38 unsigned int& x_weights,
39 unsigned int& y_weights,
40 bool& is_dual_plane,
41 unsigned int& quant_mode,
42 unsigned int& weight_bits
43) {
44 unsigned int base_quant_mode = (block_mode >> 4) & 1;
45 unsigned int H = (block_mode >> 9) & 1;
46 unsigned int D = (block_mode >> 10) & 1;
47 unsigned int A = (block_mode >> 5) & 0x3;
48
49 x_weights = 0;
50 y_weights = 0;
51
52 if ((block_mode & 3) != 0)
53 {
54 base_quant_mode |= (block_mode & 3) << 1;
55 unsigned int B = (block_mode >> 7) & 3;
56 switch ((block_mode >> 2) & 3)
57 {
58 case 0:
59 x_weights = B + 4;
60 y_weights = A + 2;
61 break;
62 case 1:
63 x_weights = B + 8;
64 y_weights = A + 2;
65 break;
66 case 2:
67 x_weights = A + 2;
68 y_weights = B + 8;
69 break;
70 case 3:
71 B &= 1;
72 if (block_mode & 0x100)
73 {
74 x_weights = B + 2;
75 y_weights = A + 2;
76 }
77 else
78 {
79 x_weights = A + 2;
80 y_weights = B + 6;
81 }
82 break;
83 }
84 }
85 else
86 {
87 base_quant_mode |= ((block_mode >> 2) & 3) << 1;
88 if (((block_mode >> 2) & 3) == 0)
89 {
90 return false;
91 }
92
93 unsigned int B = (block_mode >> 9) & 3;
94 switch ((block_mode >> 7) & 3)
95 {
96 case 0:
97 x_weights = 12;
98 y_weights = A + 2;
99 break;
100 case 1:
101 x_weights = A + 2;
102 y_weights = 12;
103 break;
104 case 2:
105 x_weights = A + 6;
106 y_weights = B + 6;
107 D = 0;
108 H = 0;
109 break;
110 case 3:
111 switch ((block_mode >> 5) & 3)
112 {
113 case 0:
114 x_weights = 6;
115 y_weights = 10;
116 break;
117 case 1:
118 x_weights = 10;
119 y_weights = 6;
120 break;
121 case 2:
122 case 3:
123 return false;
124 }
125 break;
126 }
127 }
128
129 unsigned int weight_count = x_weights * y_weights * (D + 1);
130 quant_mode = (base_quant_mode - 2) + 6 * H;
131 is_dual_plane = D != 0;
132
133 weight_bits = get_ise_sequence_bitcount(weight_count, static_cast<quant_method>(quant_mode));
134 return (weight_count <= BLOCK_MAX_WEIGHTS &&
135 weight_bits >= BLOCK_MIN_WEIGHT_BITS &&
136 weight_bits <= BLOCK_MAX_WEIGHT_BITS);
137}
138
139/**
140 * @brief Decode the properties of an encoded 3D block mode.
141 *
142 * @param block_mode The encoded block mode.
143 * @param[out] x_weights The number of weights in the X dimension.
144 * @param[out] y_weights The number of weights in the Y dimension.
145 * @param[out] z_weights The number of weights in the Z dimension.
146 * @param[out] is_dual_plane True if this block mode has two weight planes.
147 * @param[out] quant_mode The quantization level for the weights.
148 * @param[out] weight_bits The storage bit count for the weights.
149 *
150 * @return Returns true if a valid mode, false otherwise.
151 */
152static bool decode_block_mode_3d(
153 unsigned int block_mode,
154 unsigned int& x_weights,
155 unsigned int& y_weights,
156 unsigned int& z_weights,
157 bool& is_dual_plane,
158 unsigned int& quant_mode,
159 unsigned int& weight_bits
160) {
161 unsigned int base_quant_mode = (block_mode >> 4) & 1;
162 unsigned int H = (block_mode >> 9) & 1;
163 unsigned int D = (block_mode >> 10) & 1;
164 unsigned int A = (block_mode >> 5) & 0x3;
165
166 x_weights = 0;
167 y_weights = 0;
168 z_weights = 0;
169
170 if ((block_mode & 3) != 0)
171 {
172 base_quant_mode |= (block_mode & 3) << 1;
173 unsigned int B = (block_mode >> 7) & 3;
174 unsigned int C = (block_mode >> 2) & 0x3;
175 x_weights = A + 2;
176 y_weights = B + 2;
177 z_weights = C + 2;
178 }
179 else
180 {
181 base_quant_mode |= ((block_mode >> 2) & 3) << 1;
182 if (((block_mode >> 2) & 3) == 0)
183 {
184 return false;
185 }
186
187 int B = (block_mode >> 9) & 3;
188 if (((block_mode >> 7) & 3) != 3)
189 {
190 D = 0;
191 H = 0;
192 }
193 switch ((block_mode >> 7) & 3)
194 {
195 case 0:
196 x_weights = 6;
197 y_weights = B + 2;
198 z_weights = A + 2;
199 break;
200 case 1:
201 x_weights = A + 2;
202 y_weights = 6;
203 z_weights = B + 2;
204 break;
205 case 2:
206 x_weights = A + 2;
207 y_weights = B + 2;
208 z_weights = 6;
209 break;
210 case 3:
211 x_weights = 2;
212 y_weights = 2;
213 z_weights = 2;
214 switch ((block_mode >> 5) & 3)
215 {
216 case 0:
217 x_weights = 6;
218 break;
219 case 1:
220 y_weights = 6;
221 break;
222 case 2:
223 z_weights = 6;
224 break;
225 case 3:
226 return false;
227 }
228 break;
229 }
230 }
231
232 unsigned int weight_count = x_weights * y_weights * z_weights * (D + 1);
233 quant_mode = (base_quant_mode - 2) + 6 * H;
234 is_dual_plane = D != 0;
235
236 weight_bits = get_ise_sequence_bitcount(weight_count, static_cast<quant_method>(quant_mode));
237 return (weight_count <= BLOCK_MAX_WEIGHTS &&
238 weight_bits >= BLOCK_MIN_WEIGHT_BITS &&
239 weight_bits <= BLOCK_MAX_WEIGHT_BITS);
240}
241
242/**
243 * @brief Create a 2D decimation entry for a block-size and weight-decimation pair.
244 *
245 * @param x_texels The number of texels in the X dimension.
246 * @param y_texels The number of texels in the Y dimension.
247 * @param x_weights The number of weights in the X dimension.
248 * @param y_weights The number of weights in the Y dimension.
249 * @param[out] di The decimation info structure to populate.
250 * @param[out] wb The decimation table init scratch working buffers.
251 */
252static void init_decimation_info_2d(
253 unsigned int x_texels,
254 unsigned int y_texels,
255 unsigned int x_weights,
256 unsigned int y_weights,
257 decimation_info& di,
258 dt_init_working_buffers& wb
259) {
260 unsigned int texels_per_block = x_texels * y_texels;
261 unsigned int weights_per_block = x_weights * y_weights;
262
263 uint8_t max_texel_count_of_weight = 0;
264
265 promise(weights_per_block > 0);
266 promise(texels_per_block > 0);
267 promise(x_texels > 0);
268 promise(y_texels > 0);
269
270 for (unsigned int i = 0; i < weights_per_block; i++)
271 {
272 wb.texel_count_of_weight[i] = 0;
273 }
274
275 for (unsigned int i = 0; i < texels_per_block; i++)
276 {
277 wb.weight_count_of_texel[i] = 0;
278 }
279
280 for (unsigned int y = 0; y < y_texels; y++)
281 {
282 for (unsigned int x = 0; x < x_texels; x++)
283 {
284 unsigned int texel = y * x_texels + x;
285
286 unsigned int x_weight = (((1024 + x_texels / 2) / (x_texels - 1)) * x * (x_weights - 1) + 32) >> 6;
287 unsigned int y_weight = (((1024 + y_texels / 2) / (y_texels - 1)) * y * (y_weights - 1) + 32) >> 6;
288
289 unsigned int x_weight_frac = x_weight & 0xF;
290 unsigned int y_weight_frac = y_weight & 0xF;
291 unsigned int x_weight_int = x_weight >> 4;
292 unsigned int y_weight_int = y_weight >> 4;
293
294 unsigned int qweight[4];
295 qweight[0] = x_weight_int + y_weight_int * x_weights;
296 qweight[1] = qweight[0] + 1;
297 qweight[2] = qweight[0] + x_weights;
298 qweight[3] = qweight[2] + 1;
299
300 // Truncated-precision bilinear interpolation
301 unsigned int prod = x_weight_frac * y_weight_frac;
302
303 unsigned int weight[4];
304 weight[3] = (prod + 8) >> 4;
305 weight[1] = x_weight_frac - weight[3];
306 weight[2] = y_weight_frac - weight[3];
307 weight[0] = 16 - x_weight_frac - y_weight_frac + weight[3];
308
309 for (unsigned int i = 0; i < 4; i++)
310 {
311 if (weight[i] != 0)
312 {
313 wb.grid_weights_of_texel[texel][wb.weight_count_of_texel[texel]] = static_cast<uint8_t>(qweight[i]);
314 wb.weights_of_texel[texel][wb.weight_count_of_texel[texel]] = static_cast<uint8_t>(weight[i]);
315 wb.weight_count_of_texel[texel]++;
316 wb.texels_of_weight[qweight[i]][wb.texel_count_of_weight[qweight[i]]] = static_cast<uint8_t>(texel);
317 wb.texel_weights_of_weight[qweight[i]][wb.texel_count_of_weight[qweight[i]]] = static_cast<uint8_t>(weight[i]);
318 wb.texel_count_of_weight[qweight[i]]++;
319 max_texel_count_of_weight = astc::max(max_texel_count_of_weight, wb.texel_count_of_weight[qweight[i]]);
320 }
321 }
322 }
323 }
324
325 uint8_t max_texel_weight_count = 0;
326 for (unsigned int i = 0; i < texels_per_block; i++)
327 {
328 di.texel_weight_count[i] = wb.weight_count_of_texel[i];
329 max_texel_weight_count = astc::max(max_texel_weight_count, di.texel_weight_count[i]);
330
331 for (unsigned int j = 0; j < wb.weight_count_of_texel[i]; j++)
332 {
333 di.texel_weight_contribs_int_tr[j][i] = wb.weights_of_texel[i][j];
334 di.texel_weight_contribs_float_tr[j][i] = static_cast<float>(wb.weights_of_texel[i][j]) * (1.0f / WEIGHTS_TEXEL_SUM);
335 di.texel_weights_tr[j][i] = wb.grid_weights_of_texel[i][j];
336 }
337
338 // Init all 4 entries so we can rely on zeros for vectorization
339 for (unsigned int j = wb.weight_count_of_texel[i]; j < 4; j++)
340 {
341 di.texel_weight_contribs_int_tr[j][i] = 0;
342 di.texel_weight_contribs_float_tr[j][i] = 0.0f;
343 di.texel_weights_tr[j][i] = 0;
344 }
345 }
346
347 di.max_texel_weight_count = max_texel_weight_count;
348
349 for (unsigned int i = 0; i < weights_per_block; i++)
350 {
351 unsigned int texel_count_wt = wb.texel_count_of_weight[i];
352 di.weight_texel_count[i] = static_cast<uint8_t>(texel_count_wt);
353
354 for (unsigned int j = 0; j < texel_count_wt; j++)
355 {
356 uint8_t texel = wb.texels_of_weight[i][j];
357
358 // Create transposed versions of these for better vectorization
359 di.weight_texels_tr[j][i] = texel;
360 di.weights_texel_contribs_tr[j][i] = static_cast<float>(wb.texel_weights_of_weight[i][j]);
361
362 // Store the per-texel contribution of this weight for each texel it contributes to
363 di.texel_contrib_for_weight[j][i] = 0.0f;
364 for (unsigned int k = 0; k < 4; k++)
365 {
366 uint8_t dttw = di.texel_weights_tr[k][texel];
367 float dttwf = di.texel_weight_contribs_float_tr[k][texel];
368 if (dttw == i && dttwf != 0.0f)
369 {
370 di.texel_contrib_for_weight[j][i] = di.texel_weight_contribs_float_tr[k][texel];
371 break;
372 }
373 }
374 }
375
376 // Initialize array tail so we can over-fetch with SIMD later to avoid loop tails
377 // Match last texel in active lane in SIMD group, for better gathers
378 uint8_t last_texel = di.weight_texels_tr[texel_count_wt - 1][i];
379 for (unsigned int j = texel_count_wt; j < max_texel_count_of_weight; j++)
380 {
381 di.weight_texels_tr[j][i] = last_texel;
382 di.weights_texel_contribs_tr[j][i] = 0.0f;
383 }
384 }
385
386 // Initialize array tail so we can over-fetch with SIMD later to avoid loop tails
387 unsigned int texels_per_block_simd = round_up_to_simd_multiple_vla(texels_per_block);
388 for (unsigned int i = texels_per_block; i < texels_per_block_simd; i++)
389 {
390 di.texel_weight_count[i] = 0;
391
392 for (unsigned int j = 0; j < 4; j++)
393 {
394 di.texel_weight_contribs_float_tr[j][i] = 0;
395 di.texel_weights_tr[j][i] = 0;
396 di.texel_weight_contribs_int_tr[j][i] = 0;
397 }
398 }
399
400 // Initialize array tail so we can over-fetch with SIMD later to avoid loop tails
401 // Match last texel in active lane in SIMD group, for better gathers
402 unsigned int last_texel_count_wt = wb.texel_count_of_weight[weights_per_block - 1];
403 uint8_t last_texel = di.weight_texels_tr[last_texel_count_wt - 1][weights_per_block - 1];
404
405 unsigned int weights_per_block_simd = round_up_to_simd_multiple_vla(weights_per_block);
406 for (unsigned int i = weights_per_block; i < weights_per_block_simd; i++)
407 {
408 di.weight_texel_count[i] = 0;
409
410 for (unsigned int j = 0; j < max_texel_count_of_weight; j++)
411 {
412 di.weight_texels_tr[j][i] = last_texel;
413 di.weights_texel_contribs_tr[j][i] = 0.0f;
414 }
415 }
416
417 di.texel_count = static_cast<uint8_t>(texels_per_block);
418 di.weight_count = static_cast<uint8_t>(weights_per_block);
419 di.weight_x = static_cast<uint8_t>(x_weights);
420 di.weight_y = static_cast<uint8_t>(y_weights);
421 di.weight_z = 1;
422}
423
424/**
425 * @brief Create a 3D decimation entry for a block-size and weight-decimation pair.
426 *
427 * @param x_texels The number of texels in the X dimension.
428 * @param y_texels The number of texels in the Y dimension.
429 * @param z_texels The number of texels in the Z dimension.
430 * @param x_weights The number of weights in the X dimension.
431 * @param y_weights The number of weights in the Y dimension.
432 * @param z_weights The number of weights in the Z dimension.
433 * @param[out] di The decimation info structure to populate.
434 @param[out] wb The decimation table init scratch working buffers.
435 */
436static void init_decimation_info_3d(
437 unsigned int x_texels,
438 unsigned int y_texels,
439 unsigned int z_texels,
440 unsigned int x_weights,
441 unsigned int y_weights,
442 unsigned int z_weights,
443 decimation_info& di,
444 dt_init_working_buffers& wb
445) {
446 unsigned int texels_per_block = x_texels * y_texels * z_texels;
447 unsigned int weights_per_block = x_weights * y_weights * z_weights;
448
449 uint8_t max_texel_count_of_weight = 0;
450
451 promise(weights_per_block > 0);
452 promise(texels_per_block > 0);
453
454 for (unsigned int i = 0; i < weights_per_block; i++)
455 {
456 wb.texel_count_of_weight[i] = 0;
457 }
458
459 for (unsigned int i = 0; i < texels_per_block; i++)
460 {
461 wb.weight_count_of_texel[i] = 0;
462 }
463
464 for (unsigned int z = 0; z < z_texels; z++)
465 {
466 for (unsigned int y = 0; y < y_texels; y++)
467 {
468 for (unsigned int x = 0; x < x_texels; x++)
469 {
470 int texel = (z * y_texels + y) * x_texels + x;
471
472 int x_weight = (((1024 + x_texels / 2) / (x_texels - 1)) * x * (x_weights - 1) + 32) >> 6;
473 int y_weight = (((1024 + y_texels / 2) / (y_texels - 1)) * y * (y_weights - 1) + 32) >> 6;
474 int z_weight = (((1024 + z_texels / 2) / (z_texels - 1)) * z * (z_weights - 1) + 32) >> 6;
475
476 int x_weight_frac = x_weight & 0xF;
477 int y_weight_frac = y_weight & 0xF;
478 int z_weight_frac = z_weight & 0xF;
479 int x_weight_int = x_weight >> 4;
480 int y_weight_int = y_weight >> 4;
481 int z_weight_int = z_weight >> 4;
482 int qweight[4];
483 int weight[4];
484 qweight[0] = (z_weight_int * y_weights + y_weight_int) * x_weights + x_weight_int;
485 qweight[3] = ((z_weight_int + 1) * y_weights + (y_weight_int + 1)) * x_weights + (x_weight_int + 1);
486
487 // simplex interpolation
488 int fs = x_weight_frac;
489 int ft = y_weight_frac;
490 int fp = z_weight_frac;
491
492 int cas = ((fs > ft) << 2) + ((ft > fp) << 1) + ((fs > fp));
493 int N = x_weights;
494 int NM = x_weights * y_weights;
495
496 int s1, s2, w0, w1, w2, w3;
497 switch (cas)
498 {
499 case 7:
500 s1 = 1;
501 s2 = N;
502 w0 = 16 - fs;
503 w1 = fs - ft;
504 w2 = ft - fp;
505 w3 = fp;
506 break;
507 case 3:
508 s1 = N;
509 s2 = 1;
510 w0 = 16 - ft;
511 w1 = ft - fs;
512 w2 = fs - fp;
513 w3 = fp;
514 break;
515 case 5:
516 s1 = 1;
517 s2 = NM;
518 w0 = 16 - fs;
519 w1 = fs - fp;
520 w2 = fp - ft;
521 w3 = ft;
522 break;
523 case 4:
524 s1 = NM;
525 s2 = 1;
526 w0 = 16 - fp;
527 w1 = fp - fs;
528 w2 = fs - ft;
529 w3 = ft;
530 break;
531 case 2:
532 s1 = N;
533 s2 = NM;
534 w0 = 16 - ft;
535 w1 = ft - fp;
536 w2 = fp - fs;
537 w3 = fs;
538 break;
539 case 0:
540 s1 = NM;
541 s2 = N;
542 w0 = 16 - fp;
543 w1 = fp - ft;
544 w2 = ft - fs;
545 w3 = fs;
546 break;
547 default:
548 s1 = NM;
549 s2 = N;
550 w0 = 16 - fp;
551 w1 = fp - ft;
552 w2 = ft - fs;
553 w3 = fs;
554 break;
555 }
556
557 qweight[1] = qweight[0] + s1;
558 qweight[2] = qweight[1] + s2;
559 weight[0] = w0;
560 weight[1] = w1;
561 weight[2] = w2;
562 weight[3] = w3;
563
564 for (unsigned int i = 0; i < 4; i++)
565 {
566 if (weight[i] != 0)
567 {
568 wb.grid_weights_of_texel[texel][wb.weight_count_of_texel[texel]] = static_cast<uint8_t>(qweight[i]);
569 wb.weights_of_texel[texel][wb.weight_count_of_texel[texel]] = static_cast<uint8_t>(weight[i]);
570 wb.weight_count_of_texel[texel]++;
571 wb.texels_of_weight[qweight[i]][wb.texel_count_of_weight[qweight[i]]] = static_cast<uint8_t>(texel);
572 wb.texel_weights_of_weight[qweight[i]][wb.texel_count_of_weight[qweight[i]]] = static_cast<uint8_t>(weight[i]);
573 wb.texel_count_of_weight[qweight[i]]++;
574 max_texel_count_of_weight = astc::max(max_texel_count_of_weight, wb.texel_count_of_weight[qweight[i]]);
575 }
576 }
577 }
578 }
579 }
580
581 uint8_t max_texel_weight_count = 0;
582 for (unsigned int i = 0; i < texels_per_block; i++)
583 {
584 di.texel_weight_count[i] = wb.weight_count_of_texel[i];
585 max_texel_weight_count = astc::max(max_texel_weight_count, di.texel_weight_count[i]);
586
587 // Init all 4 entries so we can rely on zeros for vectorization
588 for (unsigned int j = 0; j < 4; j++)
589 {
590 di.texel_weight_contribs_int_tr[j][i] = 0;
591 di.texel_weight_contribs_float_tr[j][i] = 0.0f;
592 di.texel_weights_tr[j][i] = 0;
593 }
594
595 for (unsigned int j = 0; j < wb.weight_count_of_texel[i]; j++)
596 {
597 di.texel_weight_contribs_int_tr[j][i] = wb.weights_of_texel[i][j];
598 di.texel_weight_contribs_float_tr[j][i] = static_cast<float>(wb.weights_of_texel[i][j]) * (1.0f / WEIGHTS_TEXEL_SUM);
599 di.texel_weights_tr[j][i] = wb.grid_weights_of_texel[i][j];
600 }
601 }
602
603 di.max_texel_weight_count = max_texel_weight_count;
604
605 for (unsigned int i = 0; i < weights_per_block; i++)
606 {
607 unsigned int texel_count_wt = wb.texel_count_of_weight[i];
608 di.weight_texel_count[i] = static_cast<uint8_t>(texel_count_wt);
609
610 for (unsigned int j = 0; j < texel_count_wt; j++)
611 {
612 unsigned int texel = wb.texels_of_weight[i][j];
613
614 // Create transposed versions of these for better vectorization
615 di.weight_texels_tr[j][i] = static_cast<uint8_t>(texel);
616 di.weights_texel_contribs_tr[j][i] = static_cast<float>(wb.texel_weights_of_weight[i][j]);
617
618 // Store the per-texel contribution of this weight for each texel it contributes to
619 di.texel_contrib_for_weight[j][i] = 0.0f;
620 for (unsigned int k = 0; k < 4; k++)
621 {
622 uint8_t dttw = di.texel_weights_tr[k][texel];
623 float dttwf = di.texel_weight_contribs_float_tr[k][texel];
624 if (dttw == i && dttwf != 0.0f)
625 {
626 di.texel_contrib_for_weight[j][i] = di.texel_weight_contribs_float_tr[k][texel];
627 break;
628 }
629 }
630 }
631
632 // Initialize array tail so we can over-fetch with SIMD later to avoid loop tails
633 // Match last texel in active lane in SIMD group, for better gathers
634 uint8_t last_texel = di.weight_texels_tr[texel_count_wt - 1][i];
635 for (unsigned int j = texel_count_wt; j < max_texel_count_of_weight; j++)
636 {
637 di.weight_texels_tr[j][i] = last_texel;
638 di.weights_texel_contribs_tr[j][i] = 0.0f;
639 }
640 }
641
642 // Initialize array tail so we can over-fetch with SIMD later to avoid loop tails
643 unsigned int texels_per_block_simd = round_up_to_simd_multiple_vla(texels_per_block);
644 for (unsigned int i = texels_per_block; i < texels_per_block_simd; i++)
645 {
646 di.texel_weight_count[i] = 0;
647
648 for (unsigned int j = 0; j < 4; j++)
649 {
650 di.texel_weight_contribs_float_tr[j][i] = 0;
651 di.texel_weights_tr[j][i] = 0;
652 di.texel_weight_contribs_int_tr[j][i] = 0;
653 }
654 }
655
656 // Initialize array tail so we can over-fetch with SIMD later to avoid loop tails
657 // Match last texel in active lane in SIMD group, for better gathers
658 int last_texel_count_wt = wb.texel_count_of_weight[weights_per_block - 1];
659 uint8_t last_texel = di.weight_texels_tr[last_texel_count_wt - 1][weights_per_block - 1];
660
661 unsigned int weights_per_block_simd = round_up_to_simd_multiple_vla(weights_per_block);
662 for (unsigned int i = weights_per_block; i < weights_per_block_simd; i++)
663 {
664 di.weight_texel_count[i] = 0;
665
666 for (int j = 0; j < max_texel_count_of_weight; j++)
667 {
668 di.weight_texels_tr[j][i] = last_texel;
669 di.weights_texel_contribs_tr[j][i] = 0.0f;
670 }
671 }
672
673 di.texel_count = static_cast<uint8_t>(texels_per_block);
674 di.weight_count = static_cast<uint8_t>(weights_per_block);
675 di.weight_x = static_cast<uint8_t>(x_weights);
676 di.weight_y = static_cast<uint8_t>(y_weights);
677 di.weight_z = static_cast<uint8_t>(z_weights);
678}
679
680/**
681 * @brief Assign the texels to use for kmeans clustering.
682 *
683 * The max limit is @c BLOCK_MAX_KMEANS_TEXELS; above this a random selection is used.
684 * The @c bsd.texel_count is an input and must be populated beforehand.
685 *
686 * @param[in,out] bsd The block size descriptor to populate.
687 */
688static void assign_kmeans_texels(
689 block_size_descriptor& bsd
690) {
691 // Use all texels for kmeans on a small block
692 if (bsd.texel_count <= BLOCK_MAX_KMEANS_TEXELS)
693 {
694 for (uint8_t i = 0; i < bsd.texel_count; i++)
695 {
696 bsd.kmeans_texels[i] = i;
697 }
698
699 return;
700 }
701
702 // Select a random subset of BLOCK_MAX_KMEANS_TEXELS for kmeans on a large block
703 uint64_t rng_state[2];
704 astc::rand_init(rng_state);
705
706 // Initialize array used for tracking used indices
707 bool seen[BLOCK_MAX_TEXELS];
708 for (uint8_t i = 0; i < bsd.texel_count; i++)
709 {
710 seen[i] = false;
711 }
712
713 // Assign 64 random indices, retrying if we see repeats
714 unsigned int arr_elements_set = 0;
715 while (arr_elements_set < BLOCK_MAX_KMEANS_TEXELS)
716 {
717 uint8_t texel = static_cast<uint8_t>(astc::rand(rng_state));
718 texel = texel % bsd.texel_count;
719 if (!seen[texel])
720 {
721 bsd.kmeans_texels[arr_elements_set++] = texel;
722 seen[texel] = true;
723 }
724 }
725}
726
727/**
728 * @brief Allocate a single 2D decimation table entry.
729 *
730 * @param x_texels The number of texels in the X dimension.
731 * @param y_texels The number of texels in the Y dimension.
732 * @param x_weights The number of weights in the X dimension.
733 * @param y_weights The number of weights in the Y dimension.
734 * @param bsd The block size descriptor we are populating.
735 * @param wb The decimation table init scratch working buffers.
736 * @param index The packed array index to populate.
737 */
738static void construct_dt_entry_2d(
739 unsigned int x_texels,
740 unsigned int y_texels,
741 unsigned int x_weights,
742 unsigned int y_weights,
743 block_size_descriptor& bsd,
744 dt_init_working_buffers& wb,
745 unsigned int index
746) {
747 unsigned int weight_count = x_weights * y_weights;
748 assert(weight_count <= BLOCK_MAX_WEIGHTS);
749
750 bool try_2planes = (2 * weight_count) <= BLOCK_MAX_WEIGHTS;
751
752 decimation_info& di = bsd.decimation_tables[index];
753 init_decimation_info_2d(x_texels, y_texels, x_weights, y_weights, di, wb);
754
755 int maxprec_1plane = -1;
756 int maxprec_2planes = -1;
757 for (int i = 0; i < 12; i++)
758 {
759 unsigned int bits_1plane = get_ise_sequence_bitcount(weight_count, static_cast<quant_method>(i));
760 if (bits_1plane >= BLOCK_MIN_WEIGHT_BITS && bits_1plane <= BLOCK_MAX_WEIGHT_BITS)
761 {
762 maxprec_1plane = i;
763 }
764
765 if (try_2planes)
766 {
767 unsigned int bits_2planes = get_ise_sequence_bitcount(2 * weight_count, static_cast<quant_method>(i));
768 if (bits_2planes >= BLOCK_MIN_WEIGHT_BITS && bits_2planes <= BLOCK_MAX_WEIGHT_BITS)
769 {
770 maxprec_2planes = i;
771 }
772 }
773 }
774
775 // At least one of the two should be valid ...
776 assert(maxprec_1plane >= 0 || maxprec_2planes >= 0);
777 bsd.decimation_modes[index].maxprec_1plane = static_cast<int8_t>(maxprec_1plane);
778 bsd.decimation_modes[index].maxprec_2planes = static_cast<int8_t>(maxprec_2planes);
779 bsd.decimation_modes[index].refprec_1plane = 0;
780 bsd.decimation_modes[index].refprec_2planes = 0;
781}
782
783/**
784 * @brief Allocate block modes and decimation tables for a single 2D block size.
785 *
786 * @param x_texels The number of texels in the X dimension.
787 * @param y_texels The number of texels in the Y dimension.
788 * @param can_omit_modes Can we discard modes that astcenc won't use, even if legal?
789 * @param mode_cutoff Percentile cutoff in range [0,1]. Low values more likely to be used.
790 * @param[out] bsd The block size descriptor to populate.
791 */
792static void construct_block_size_descriptor_2d(
793 unsigned int x_texels,
794 unsigned int y_texels,
795 bool can_omit_modes,
796 float mode_cutoff,
797 block_size_descriptor& bsd
798) {
799 // Store a remap table for storing packed decimation modes.
800 // Indexing uses [Y * 16 + X] and max size for each axis is 12.
801 static const unsigned int MAX_DMI = 12 * 16 + 12;
802 int decimation_mode_index[MAX_DMI];
803
804 dt_init_working_buffers* wb = new dt_init_working_buffers;
805
806 bsd.xdim = static_cast<uint8_t>(x_texels);
807 bsd.ydim = static_cast<uint8_t>(y_texels);
808 bsd.zdim = 1;
809 bsd.texel_count = static_cast<uint8_t>(x_texels * y_texels);
810
811 for (unsigned int i = 0; i < MAX_DMI; i++)
812 {
813 decimation_mode_index[i] = -1;
814 }
815
816 // Gather all the decimation grids that can be used with the current block
817#if !defined(ASTCENC_DECOMPRESS_ONLY)
818 const float *percentiles = get_2d_percentile_table(x_texels, y_texels);
819 float always_cutoff = 0.0f;
820#else
821 // Unused in decompress-only builds
822 (void)can_omit_modes;
823 (void)mode_cutoff;
824#endif
825
826 // Construct the list of block formats referencing the decimation tables
827 unsigned int packed_bm_idx = 0;
828 unsigned int packed_dm_idx = 0;
829
830 // Trackers
831 unsigned int bm_counts[4] { 0 };
832 unsigned int dm_counts[4] { 0 };
833
834 // Clear the list to a known-bad value
835 for (unsigned int i = 0; i < WEIGHTS_MAX_BLOCK_MODES; i++)
836 {
837 bsd.block_mode_packed_index[i] = BLOCK_BAD_BLOCK_MODE;
838 }
839
840 // Iterate four times to build a usefully ordered list:
841 // - Pass 0 - keep selected single plane "always" block modes
842 // - Pass 1 - keep selected single plane "non-always" block modes
843 // - Pass 2 - keep select dual plane block modes
844 // - Pass 3 - keep everything else that's legal
845 unsigned int limit = can_omit_modes ? 3 : 4;
846 for (unsigned int j = 0; j < limit; j ++)
847 {
848 for (unsigned int i = 0; i < WEIGHTS_MAX_BLOCK_MODES; i++)
849 {
850 // Skip modes we've already included in a previous pass
851 if (bsd.block_mode_packed_index[i] != BLOCK_BAD_BLOCK_MODE)
852 {
853 continue;
854 }
855
856 // Decode parameters
857 unsigned int x_weights;
858 unsigned int y_weights;
859 bool is_dual_plane;
860 unsigned int quant_mode;
861 unsigned int weight_bits;
862 bool valid = decode_block_mode_2d(i, x_weights, y_weights, is_dual_plane, quant_mode, weight_bits);
863
864 // Always skip invalid encodings for the current block size
865 if (!valid || (x_weights > x_texels) || (y_weights > y_texels))
866 {
867 continue;
868 }
869
870 // Selectively skip dual plane encodings
871 if (((j <= 1) && is_dual_plane) || (j == 2 && !is_dual_plane))
872 {
873 continue;
874 }
875
876 // Always skip encodings we can't physically encode based on
877 // generic encoding bit availability
878 if (is_dual_plane)
879 {
880 // This is the only check we need as only support 1 partition
881 if ((109 - weight_bits) <= 0)
882 {
883 continue;
884 }
885 }
886 else
887 {
888 // This is conservative - fewer bits may be available for > 1 partition
889 if ((111 - weight_bits) <= 0)
890 {
891 continue;
892 }
893 }
894
895 // Selectively skip encodings based on percentile
896 bool percentile_hit = false;
897 #if !defined(ASTCENC_DECOMPRESS_ONLY)
898 if (j == 0)
899 {
900 percentile_hit = percentiles[i] <= always_cutoff;
901 }
902 else
903 {
904 percentile_hit = percentiles[i] <= mode_cutoff;
905 }
906 #endif
907
908 if (j != 3 && !percentile_hit)
909 {
910 continue;
911 }
912
913 // Allocate and initialize the decimation table entry if we've not used it yet
914 int decimation_mode = decimation_mode_index[y_weights * 16 + x_weights];
915 if (decimation_mode < 0)
916 {
917 construct_dt_entry_2d(x_texels, y_texels, x_weights, y_weights, bsd, *wb, packed_dm_idx);
918 decimation_mode_index[y_weights * 16 + x_weights] = packed_dm_idx;
919 decimation_mode = packed_dm_idx;
920
921 dm_counts[j]++;
922 packed_dm_idx++;
923 }
924
925 auto& bm = bsd.block_modes[packed_bm_idx];
926
927 bm.decimation_mode = static_cast<uint8_t>(decimation_mode);
928 bm.quant_mode = static_cast<uint8_t>(quant_mode);
929 bm.is_dual_plane = static_cast<uint8_t>(is_dual_plane);
930 bm.weight_bits = static_cast<uint8_t>(weight_bits);
931 bm.mode_index = static_cast<uint16_t>(i);
932
933 auto& dm = bsd.decimation_modes[decimation_mode];
934
935 if (is_dual_plane)
936 {
937 dm.set_ref_2plane(bm.get_weight_quant_mode());
938 }
939 else
940 {
941 dm.set_ref_1plane(bm.get_weight_quant_mode());
942 }
943
944 bsd.block_mode_packed_index[i] = static_cast<uint16_t>(packed_bm_idx);
945
946 packed_bm_idx++;
947 bm_counts[j]++;
948 }
949 }
950
951 bsd.block_mode_count_1plane_always = bm_counts[0];
952 bsd.block_mode_count_1plane_selected = bm_counts[0] + bm_counts[1];
953 bsd.block_mode_count_1plane_2plane_selected = bm_counts[0] + bm_counts[1] + bm_counts[2];
954 bsd.block_mode_count_all = bm_counts[0] + bm_counts[1] + bm_counts[2] + bm_counts[3];
955
956 bsd.decimation_mode_count_always = dm_counts[0];
957 bsd.decimation_mode_count_selected = dm_counts[0] + dm_counts[1] + dm_counts[2];
958 bsd.decimation_mode_count_all = dm_counts[0] + dm_counts[1] + dm_counts[2] + dm_counts[3];
959
960#if !defined(ASTCENC_DECOMPRESS_ONLY)
961 assert(bsd.block_mode_count_1plane_always > 0);
962 assert(bsd.decimation_mode_count_always > 0);
963
964 delete[] percentiles;
965#endif
966
967 // Ensure the end of the array contains valid data (should never get read)
968 for (unsigned int i = bsd.decimation_mode_count_all; i < WEIGHTS_MAX_DECIMATION_MODES; i++)
969 {
970 bsd.decimation_modes[i].maxprec_1plane = -1;
971 bsd.decimation_modes[i].maxprec_2planes = -1;
972 bsd.decimation_modes[i].refprec_1plane = 0;
973 bsd.decimation_modes[i].refprec_2planes = 0;
974 }
975
976 // Determine the texels to use for kmeans clustering.
977 assign_kmeans_texels(bsd);
978
979 delete wb;
980}
981
982/**
983 * @brief Allocate block modes and decimation tables for a single 3D block size.
984 *
985 * TODO: This function doesn't include all of the heuristics that we use for 2D block sizes such as
986 * the percentile mode cutoffs. If 3D becomes more widely used we should look at this.
987 *
988 * @param x_texels The number of texels in the X dimension.
989 * @param y_texels The number of texels in the Y dimension.
990 * @param z_texels The number of texels in the Z dimension.
991 * @param[out] bsd The block size descriptor to populate.
992 */
993static void construct_block_size_descriptor_3d(
994 unsigned int x_texels,
995 unsigned int y_texels,
996 unsigned int z_texels,
997 block_size_descriptor& bsd
998) {
999 // Store a remap table for storing packed decimation modes.
1000 // Indexing uses [Z * 64 + Y * 8 + X] and max size for each axis is 6.
1001 static constexpr unsigned int MAX_DMI = 6 * 64 + 6 * 8 + 6;
1002 int decimation_mode_index[MAX_DMI];
1003 unsigned int decimation_mode_count = 0;
1004
1005 dt_init_working_buffers* wb = new dt_init_working_buffers;
1006
1007 bsd.xdim = static_cast<uint8_t>(x_texels);
1008 bsd.ydim = static_cast<uint8_t>(y_texels);
1009 bsd.zdim = static_cast<uint8_t>(z_texels);
1010 bsd.texel_count = static_cast<uint8_t>(x_texels * y_texels * z_texels);
1011
1012 for (unsigned int i = 0; i < MAX_DMI; i++)
1013 {
1014 decimation_mode_index[i] = -1;
1015 }
1016
1017 // gather all the infill-modes that can be used with the current block size
1018 for (unsigned int x_weights = 2; x_weights <= x_texels; x_weights++)
1019 {
1020 for (unsigned int y_weights = 2; y_weights <= y_texels; y_weights++)
1021 {
1022 for (unsigned int z_weights = 2; z_weights <= z_texels; z_weights++)
1023 {
1024 unsigned int weight_count = x_weights * y_weights * z_weights;
1025 if (weight_count > BLOCK_MAX_WEIGHTS)
1026 {
1027 continue;
1028 }
1029
1030 decimation_info& di = bsd.decimation_tables[decimation_mode_count];
1031 decimation_mode_index[z_weights * 64 + y_weights * 8 + x_weights] = decimation_mode_count;
1032 init_decimation_info_3d(x_texels, y_texels, z_texels, x_weights, y_weights, z_weights, di, *wb);
1033
1034 int maxprec_1plane = -1;
1035 int maxprec_2planes = -1;
1036 for (unsigned int i = 0; i < 12; i++)
1037 {
1038 unsigned int bits_1plane = get_ise_sequence_bitcount(weight_count, static_cast<quant_method>(i));
1039 if (bits_1plane >= BLOCK_MIN_WEIGHT_BITS && bits_1plane <= BLOCK_MAX_WEIGHT_BITS)
1040 {
1041 maxprec_1plane = i;
1042 }
1043
1044 unsigned int bits_2planes = get_ise_sequence_bitcount(2 * weight_count, static_cast<quant_method>(i));
1045 if (bits_2planes >= BLOCK_MIN_WEIGHT_BITS && bits_2planes <= BLOCK_MAX_WEIGHT_BITS)
1046 {
1047 maxprec_2planes = i;
1048 }
1049 }
1050
1051 if ((2 * weight_count) > BLOCK_MAX_WEIGHTS)
1052 {
1053 maxprec_2planes = -1;
1054 }
1055
1056 bsd.decimation_modes[decimation_mode_count].maxprec_1plane = static_cast<int8_t>(maxprec_1plane);
1057 bsd.decimation_modes[decimation_mode_count].maxprec_2planes = static_cast<int8_t>(maxprec_2planes);
1058 bsd.decimation_modes[decimation_mode_count].refprec_1plane = maxprec_1plane == -1 ? 0 : 0xFFFF;
1059 bsd.decimation_modes[decimation_mode_count].refprec_2planes = maxprec_2planes == -1 ? 0 : 0xFFFF;
1060 decimation_mode_count++;
1061 }
1062 }
1063 }
1064
1065 // Ensure the end of the array contains valid data (should never get read)
1066 for (unsigned int i = decimation_mode_count; i < WEIGHTS_MAX_DECIMATION_MODES; i++)
1067 {
1068 bsd.decimation_modes[i].maxprec_1plane = -1;
1069 bsd.decimation_modes[i].maxprec_2planes = -1;
1070 bsd.decimation_modes[i].refprec_1plane = 0;
1071 bsd.decimation_modes[i].refprec_2planes = 0;
1072 }
1073
1074 bsd.decimation_mode_count_always = 0; // Skipped for 3D modes
1075 bsd.decimation_mode_count_selected = decimation_mode_count;
1076 bsd.decimation_mode_count_all = decimation_mode_count;
1077
1078 // Construct the list of block formats referencing the decimation tables
1079
1080 // Clear the list to a known-bad value
1081 for (unsigned int i = 0; i < WEIGHTS_MAX_BLOCK_MODES; i++)
1082 {
1083 bsd.block_mode_packed_index[i] = BLOCK_BAD_BLOCK_MODE;
1084 }
1085
1086 unsigned int packed_idx = 0;
1087 unsigned int bm_counts[2] { 0 };
1088
1089 // Iterate two times to build a usefully ordered list:
1090 // - Pass 0 - keep valid single plane block modes
1091 // - Pass 1 - keep valid dual plane block modes
1092 for (unsigned int j = 0; j < 2; j++)
1093 {
1094 for (unsigned int i = 0; i < WEIGHTS_MAX_BLOCK_MODES; i++)
1095 {
1096 // Skip modes we've already included in a previous pass
1097 if (bsd.block_mode_packed_index[i] != BLOCK_BAD_BLOCK_MODE)
1098 {
1099 continue;
1100 }
1101
1102 unsigned int x_weights;
1103 unsigned int y_weights;
1104 unsigned int z_weights;
1105 bool is_dual_plane;
1106 unsigned int quant_mode;
1107 unsigned int weight_bits;
1108
1109 bool valid = decode_block_mode_3d(i, x_weights, y_weights, z_weights, is_dual_plane, quant_mode, weight_bits);
1110 // Skip invalid encodings
1111 if (!valid || x_weights > x_texels || y_weights > y_texels || z_weights > z_texels)
1112 {
1113 continue;
1114 }
1115
1116 // Skip encodings in the wrong iteration
1117 if ((j == 0 && is_dual_plane) || (j == 1 && !is_dual_plane))
1118 {
1119 continue;
1120 }
1121
1122 // Always skip encodings we can't physically encode based on bit availability
1123 if (is_dual_plane)
1124 {
1125 // This is the only check we need as only support 1 partition
1126 if ((109 - weight_bits) <= 0)
1127 {
1128 continue;
1129 }
1130 }
1131 else
1132 {
1133 // This is conservative - fewer bits may be available for > 1 partition
1134 if ((111 - weight_bits) <= 0)
1135 {
1136 continue;
1137 }
1138 }
1139
1140 int decimation_mode = decimation_mode_index[z_weights * 64 + y_weights * 8 + x_weights];
1141 bsd.block_modes[packed_idx].decimation_mode = static_cast<uint8_t>(decimation_mode);
1142 bsd.block_modes[packed_idx].quant_mode = static_cast<uint8_t>(quant_mode);
1143 bsd.block_modes[packed_idx].weight_bits = static_cast<uint8_t>(weight_bits);
1144 bsd.block_modes[packed_idx].is_dual_plane = static_cast<uint8_t>(is_dual_plane);
1145 bsd.block_modes[packed_idx].mode_index = static_cast<uint16_t>(i);
1146
1147 bsd.block_mode_packed_index[i] = static_cast<uint16_t>(packed_idx);
1148 bm_counts[j]++;
1149 packed_idx++;
1150 }
1151 }
1152
1153 bsd.block_mode_count_1plane_always = 0; // Skipped for 3D modes
1154 bsd.block_mode_count_1plane_selected = bm_counts[0];
1155 bsd.block_mode_count_1plane_2plane_selected = bm_counts[0] + bm_counts[1];
1156 bsd.block_mode_count_all = bm_counts[0] + bm_counts[1];
1157
1158 // Determine the texels to use for kmeans clustering.
1159 assign_kmeans_texels(bsd);
1160
1161 delete wb;
1162}
1163
1164/* See header for documentation. */
1165void init_block_size_descriptor(
1166 unsigned int x_texels,
1167 unsigned int y_texels,
1168 unsigned int z_texels,
1169 bool can_omit_modes,
1170 unsigned int partition_count_cutoff,
1171 float mode_cutoff,
1172 block_size_descriptor& bsd
1173) {
1174 if (z_texels > 1)
1175 {
1176 construct_block_size_descriptor_3d(x_texels, y_texels, z_texels, bsd);
1177 }
1178 else
1179 {
1180 construct_block_size_descriptor_2d(x_texels, y_texels, can_omit_modes, mode_cutoff, bsd);
1181 }
1182
1183 init_partition_tables(bsd, can_omit_modes, partition_count_cutoff);
1184}
1185