1// basisu_ssim.cpp
2// Copyright (C) 2019 Binomial LLC. All Rights Reserved.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15#include "basisu_ssim.h"
16
17#ifndef M_PI
18#define M_PI 3.14159265358979323846
19#endif
20
21namespace basisu
22{
23 float gauss(int x, int y, float sigma_sqr)
24 {
25 float pow = expf(-((x * x + y * y) / (2.0f * sigma_sqr)));
26 float g = (1.0f / (sqrtf((float)(2.0f * M_PI * sigma_sqr)))) * pow;
27 return g;
28 }
29
30 // size_x/y should be odd
31 void compute_gaussian_kernel(float *pDst, int size_x, int size_y, float sigma_sqr, uint32_t flags)
32 {
33 assert(size_x & size_y & 1);
34
35 if (!(size_x | size_y))
36 return;
37
38 int mid_x = size_x / 2;
39 int mid_y = size_y / 2;
40
41 double sum = 0;
42 for (int x = 0; x < size_x; x++)
43 {
44 for (int y = 0; y < size_y; y++)
45 {
46 float g;
47 if ((x > mid_x) && (y < mid_y))
48 g = pDst[(size_x - x - 1) + y * size_x];
49 else if ((x < mid_x) && (y > mid_y))
50 g = pDst[x + (size_y - y - 1) * size_x];
51 else if ((x > mid_x) && (y > mid_y))
52 g = pDst[(size_x - x - 1) + (size_y - y - 1) * size_x];
53 else
54 g = gauss(x - mid_x, y - mid_y, sigma_sqr);
55
56 pDst[x + y * size_x] = g;
57 sum += g;
58 }
59 }
60
61 if (flags & cComputeGaussianFlagNormalizeCenterToOne)
62 {
63 sum = pDst[mid_x + mid_y * size_x];
64 }
65
66 if (flags & (cComputeGaussianFlagNormalizeCenterToOne | cComputeGaussianFlagNormalize))
67 {
68 double one_over_sum = 1.0f / sum;
69 for (int i = 0; i < size_x * size_y; i++)
70 pDst[i] = static_cast<float>(pDst[i] * one_over_sum);
71
72 if (flags & cComputeGaussianFlagNormalizeCenterToOne)
73 pDst[mid_x + mid_y * size_x] = 1.0f;
74 }
75
76 if (flags & cComputeGaussianFlagPrint)
77 {
78 printf("{\n");
79 for (int y = 0; y < size_y; y++)
80 {
81 printf(" ");
82 for (int x = 0; x < size_x; x++)
83 {
84 printf("%f, ", pDst[x + y * size_x]);
85 }
86 printf("\n");
87 }
88 printf("}");
89 }
90 }
91
92 void gaussian_filter(imagef &dst, const imagef &orig_img, uint32_t odd_filter_width, float sigma_sqr, bool wrapping, uint32_t width_divisor, uint32_t height_divisor)
93 {
94 assert(odd_filter_width && (odd_filter_width & 1));
95 odd_filter_width |= 1;
96
97 vector2D<float> kernel(odd_filter_width, odd_filter_width);
98 compute_gaussian_kernel(kernel.get_ptr(), odd_filter_width, odd_filter_width, sigma_sqr, cComputeGaussianFlagNormalize);
99
100 const int dst_width = orig_img.get_width() / width_divisor;
101 const int dst_height = orig_img.get_height() / height_divisor;
102
103 const int H = odd_filter_width / 2;
104 const int L = -H;
105
106 dst.crop(dst_width, dst_height);
107
108//#pragma omp parallel for
109 for (int oy = 0; oy < dst_height; oy++)
110 {
111 for (int ox = 0; ox < dst_width; ox++)
112 {
113 vec4F c(0.0f);
114
115 for (int yd = L; yd <= H; yd++)
116 {
117 int y = oy * height_divisor + (height_divisor >> 1) + yd;
118
119 for (int xd = L; xd <= H; xd++)
120 {
121 int x = ox * width_divisor + (width_divisor >> 1) + xd;
122
123 const vec4F &p = orig_img.get_clamped_or_wrapped(x, y, wrapping, wrapping);
124
125 float w = kernel(xd + H, yd + H);
126 c[0] += p[0] * w;
127 c[1] += p[1] * w;
128 c[2] += p[2] * w;
129 c[3] += p[3] * w;
130 }
131 }
132
133 dst(ox, oy).set(c[0], c[1], c[2], c[3]);
134 }
135 }
136 }
137
138 void pow_image(const imagef &src, imagef &dst, const vec4F &power)
139 {
140 dst.resize(src);
141
142//#pragma omp parallel for
143 for (int y = 0; y < (int)dst.get_height(); y++)
144 {
145 for (uint32_t x = 0; x < dst.get_width(); x++)
146 {
147 const vec4F &p = src(x, y);
148
149 if ((power[0] == 2.0f) && (power[1] == 2.0f) && (power[2] == 2.0f) && (power[3] == 2.0f))
150 dst(x, y).set(p[0] * p[0], p[1] * p[1], p[2] * p[2], p[3] * p[3]);
151 else
152 dst(x, y).set(powf(p[0], power[0]), powf(p[1], power[1]), powf(p[2], power[2]), powf(p[3], power[3]));
153 }
154 }
155 }
156
157 void mul_image(const imagef &src, imagef &dst, const vec4F &mul)
158 {
159 dst.resize(src);
160
161//#pragma omp parallel for
162 for (int y = 0; y < (int)dst.get_height(); y++)
163 {
164 for (uint32_t x = 0; x < dst.get_width(); x++)
165 {
166 const vec4F &p = src(x, y);
167 dst(x, y).set(p[0] * mul[0], p[1] * mul[1], p[2] * mul[2], p[3] * mul[3]);
168 }
169 }
170 }
171
172 void scale_image(const imagef &src, imagef &dst, const vec4F &scale, const vec4F &shift)
173 {
174 dst.resize(src);
175
176//#pragma omp parallel for
177 for (int y = 0; y < (int)dst.get_height(); y++)
178 {
179 for (uint32_t x = 0; x < dst.get_width(); x++)
180 {
181 const vec4F &p = src(x, y);
182
183 vec4F d;
184
185 for (uint32_t c = 0; c < 4; c++)
186 d[c] = scale[c] * p[c] + shift[c];
187
188 dst(x, y).set(d[0], d[1], d[2], d[3]);
189 }
190 }
191 }
192
193 void add_weighted_image(const imagef &src1, const vec4F &alpha, const imagef &src2, const vec4F &beta, const vec4F &gamma, imagef &dst)
194 {
195 dst.resize(src1);
196
197//#pragma omp parallel for
198 for (int y = 0; y < (int)dst.get_height(); y++)
199 {
200 for (uint32_t x = 0; x < dst.get_width(); x++)
201 {
202 const vec4F &s1 = src1(x, y);
203 const vec4F &s2 = src2(x, y);
204
205 dst(x, y).set(
206 s1[0] * alpha[0] + s2[0] * beta[0] + gamma[0],
207 s1[1] * alpha[1] + s2[1] * beta[1] + gamma[1],
208 s1[2] * alpha[2] + s2[2] * beta[2] + gamma[2],
209 s1[3] * alpha[3] + s2[3] * beta[3] + gamma[3]);
210 }
211 }
212 }
213
214 void add_image(const imagef &src1, const imagef &src2, imagef &dst)
215 {
216 dst.resize(src1);
217
218//#pragma omp parallel for
219 for (int y = 0; y < (int)dst.get_height(); y++)
220 {
221 for (uint32_t x = 0; x < dst.get_width(); x++)
222 {
223 const vec4F &s1 = src1(x, y);
224 const vec4F &s2 = src2(x, y);
225
226 dst(x, y).set(s1[0] + s2[0], s1[1] + s2[1], s1[2] + s2[2], s1[3] + s2[3]);
227 }
228 }
229 }
230
231 void adds_image(const imagef &src, const vec4F &value, imagef &dst)
232 {
233 dst.resize(src);
234
235//#pragma omp parallel for
236 for (int y = 0; y < (int)dst.get_height(); y++)
237 {
238 for (uint32_t x = 0; x < dst.get_width(); x++)
239 {
240 const vec4F &p = src(x, y);
241
242 dst(x, y).set(p[0] + value[0], p[1] + value[1], p[2] + value[2], p[3] + value[3]);
243 }
244 }
245 }
246
247 void mul_image(const imagef &src1, const imagef &src2, imagef &dst, const vec4F &scale)
248 {
249 dst.resize(src1);
250
251//#pragma omp parallel for
252 for (int y = 0; y < (int)dst.get_height(); y++)
253 {
254 for (uint32_t x = 0; x < dst.get_width(); x++)
255 {
256 const vec4F &s1 = src1(x, y);
257 const vec4F &s2 = src2(x, y);
258
259 vec4F d;
260
261 for (uint32_t c = 0; c < 4; c++)
262 {
263 float v1 = s1[c];
264 float v2 = s2[c];
265 d[c] = v1 * v2 * scale[c];
266 }
267
268 dst(x, y) = d;
269 }
270 }
271 }
272
273 void div_image(const imagef &src1, const imagef &src2, imagef &dst, const vec4F &scale)
274 {
275 dst.resize(src1);
276
277//#pragma omp parallel for
278 for (int y = 0; y < (int)dst.get_height(); y++)
279 {
280 for (uint32_t x = 0; x < dst.get_width(); x++)
281 {
282 const vec4F &s1 = src1(x, y);
283 const vec4F &s2 = src2(x, y);
284
285 vec4F d;
286
287 for (uint32_t c = 0; c < 4; c++)
288 {
289 float v = s2[c];
290 if (v == 0.0f)
291 d[c] = 0.0f;
292 else
293 d[c] = (s1[c] * scale[c]) / v;
294 }
295
296 dst(x, y) = d;
297 }
298 }
299 }
300
301 vec4F avg_image(const imagef &src)
302 {
303 vec4F avg(0.0f);
304
305 for (uint32_t y = 0; y < src.get_height(); y++)
306 {
307 for (uint32_t x = 0; x < src.get_width(); x++)
308 {
309 const vec4F &s = src(x, y);
310
311 avg += vec4F(s[0], s[1], s[2], s[3]);
312 }
313 }
314
315 avg /= static_cast<float>(src.get_total_pixels());
316
317 return avg;
318 }
319
320 // Reference: https://ece.uwaterloo.ca/~z70wang/research/ssim/index.html
321 vec4F compute_ssim(const imagef &a, const imagef &b)
322 {
323 imagef axb, a_sq, b_sq, mu1, mu2, mu1_sq, mu2_sq, mu1_mu2, s1_sq, s2_sq, s12, smap, t1, t2, t3;
324
325 const float C1 = 6.50250f, C2 = 58.52250f;
326
327 pow_image(a, a_sq, vec4F(2));
328 pow_image(b, b_sq, vec4F(2));
329 mul_image(a, b, axb, vec4F(1.0f));
330
331 gaussian_filter(mu1, a, 11, 1.5f * 1.5f);
332 gaussian_filter(mu2, b, 11, 1.5f * 1.5f);
333
334 pow_image(mu1, mu1_sq, vec4F(2));
335 pow_image(mu2, mu2_sq, vec4F(2));
336 mul_image(mu1, mu2, mu1_mu2, vec4F(1.0f));
337
338 gaussian_filter(s1_sq, a_sq, 11, 1.5f * 1.5f);
339 add_weighted_image(s1_sq, vec4F(1), mu1_sq, vec4F(-1), vec4F(0), s1_sq);
340
341 gaussian_filter(s2_sq, b_sq, 11, 1.5f * 1.5f);
342 add_weighted_image(s2_sq, vec4F(1), mu2_sq, vec4F(-1), vec4F(0), s2_sq);
343
344 gaussian_filter(s12, axb, 11, 1.5f * 1.5f);
345 add_weighted_image(s12, vec4F(1), mu1_mu2, vec4F(-1), vec4F(0), s12);
346
347 scale_image(mu1_mu2, t1, vec4F(2), vec4F(0));
348 adds_image(t1, vec4F(C1), t1);
349
350 scale_image(s12, t2, vec4F(2), vec4F(0));
351 adds_image(t2, vec4F(C2), t2);
352
353 mul_image(t1, t2, t3, vec4F(1));
354
355 add_image(mu1_sq, mu2_sq, t1);
356 adds_image(t1, vec4F(C1), t1);
357
358 add_image(s1_sq, s2_sq, t2);
359 adds_image(t2, vec4F(C2), t2);
360
361 mul_image(t1, t2, t1, vec4F(1));
362
363 div_image(t3, t1, smap, vec4F(1));
364
365 return avg_image(smap);
366 }
367
368 vec4F compute_ssim(const image &a, const image &b, bool luma, bool luma_601)
369 {
370 image ta(a), tb(b);
371
372 if ((ta.get_width() != tb.get_width()) || (ta.get_height() != tb.get_height()))
373 {
374 debug_printf("compute_ssim: Cropping input images to equal dimensions\n");
375
376 const uint32_t w = minimum(a.get_width(), b.get_width());
377 const uint32_t h = minimum(a.get_height(), b.get_height());
378 ta.crop(w, h);
379 tb.crop(w, h);
380 }
381
382 if (!ta.get_width() || !ta.get_height())
383 {
384 assert(0);
385 return vec4F(0);
386 }
387
388 if (luma)
389 {
390 for (uint32_t y = 0; y < ta.get_height(); y++)
391 {
392 for (uint32_t x = 0; x < ta.get_width(); x++)
393 {
394 ta(x, y).set(ta(x, y).get_luma(luma_601), ta(x, y).a);
395 tb(x, y).set(tb(x, y).get_luma(luma_601), tb(x, y).a);
396 }
397 }
398 }
399
400 imagef fta, ftb;
401
402 fta.set(ta);
403 ftb.set(tb);
404
405 return compute_ssim(fta, ftb);
406 }
407
408} // namespace basisu
409