Ocean
Loading...
Searching...
No Matches
SumSquareDifferencesSSE.h
Go to the documentation of this file.
1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 *
4 * This source code is licensed under the MIT license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7
8#ifndef META_OCEAN_CV_SUM_SQUARE_DIFFERENCES_SSE_H
9#define META_OCEAN_CV_SUM_SQUARE_DIFFERENCES_SSE_H
10
11#include "ocean/cv/CV.h"
12
14
15#if defined(OCEAN_HARDWARE_SSE_VERSION) && OCEAN_HARDWARE_SSE_VERSION >= 41
16
17#include "ocean/cv/SSE.h"
18
19namespace Ocean
20{
21
22namespace CV
23{
24
25/**
26 * This class implements function to calculate sum square differences using SSE instructions.
27 * @ingroup cv
28 */
30{
31 public:
32
33 /**
34 * Returns the sum of square differences between two memory buffers.
35 * @param buffer0 The first memory buffer, must be valid
36 * @param buffer1 The second memory buffer, must be valid
37 * @return The resulting sum of square differences
38 * @tparam tSize The size of the buffers in elements, with range [1, infinity)
39 */
40 template <unsigned int tSize>
41 static inline uint32_t buffer8BitPerChannel(const uint8_t* buffer0, const uint8_t* buffer1);
42
43 /**
44 * Returns the sum of square differences between two patches within an image.
45 * @param patch0 The top left start position of the first image patch, must be valid
46 * @param patch1 The top left start position of the second image patch, must be valid
47 * @param patch0StrideElements The number of elements between two rows for the first patch, in elements, with range [tChannels, tPatchSize, infinity)
48 * @param patch1StrideElements The number of elements between two rows for the second patch, in elements, with range [tChannels, tPatchSize, infinity)
49 * @return The resulting sum of square differences
50 * @tparam tChannels The number of channels for the given frames, with range [1, infinity)
51 * @tparam tPatchSize The size of the square patch (the edge length) in pixel, with range [1, infinity), must be odd
52 */
53 template <unsigned int tChannels, unsigned int tPatchSize>
54 static inline uint32_t patch8BitPerChannel(const uint8_t* patch0, const uint8_t* patch1, const unsigned int patch0StrideElements, const unsigned int patch1StrideElements);
55
56 /**
57 * Returns the sum of square differences between an image patch and a buffer.
58 * @param patch0 The top left start position of the image patch, must be valid
59 * @param buffer1 The memory buffer, must be valid
60 * @param patch0StrideElements The number of elements between two rows for the image patch, in elements, with range [tChannels, tPatchSize, infinity)
61 * @return The resulting sum of square differences
62 * @tparam tChannels The number of channels for the given frames, with range [1, infinity)
63 * @tparam tPatchSize The size of the square patch (the edge length) in pixel, with range [1, infinity), must be odd
64 */
65 template <unsigned int tChannels, unsigned int tPatchSize>
66 static inline uint32_t patchBuffer8BitPerChannel(const uint8_t* patch0, const uint8_t* buffer1, const unsigned int patch0StrideElements);
67
68 /**
69 * Returns the sum of square differences between two patches within an image, patch pixels outside the image will be mirrored back into the image.
70 * @param image0 The image in which the first patch is located, must be valid
71 * @param image1 The image in which the second patch is located, must be valid
72 * @param width0 The width of the first image, in pixels, with range [tPatchSize, infinity)
73 * @param height0 The height of the first image, in pixels, with range [tPatchSize, infinity)
74 * @param width1 The width of the second image, in pixels, with range [tPatchSize, infinity)
75 * @param height1 The height of the second image, in pixels, with range [tPatchSize, infinity)
76 * @param centerX0 Horizontal center position of the (tPatchSize x tPatchSize) block in the first frame, with range [tPatchSize/2, width - tPatchSize/2 - 1]
77 * @param centerY0 Vertical center position of the (tPatchSize x tPatchSize) block in the first frame, with range [tPatchSize/2, height - tPatchSize/2 - 1]
78 * @param centerX1 Horizontal center position of the (tPatchSize x tPatchSize) block in the second frame, with range [tPatchSize/2, width - tPatchSize/2 - 1]
79 * @param centerY1 Vertical center position of the (tPatchSize x tPatchSize) block in the second frame, with range [tPatchSize/2, height - tPatchSize/2 - 1]
80 * @param image0PaddingElements The number of padding elements at the end of each row of the first image, in elements, with range [0, infinity)
81 * @param image1PaddingElements The number of padding elements at the end of each row of the second image, in elements, with range [0, infinity)
82 * @return The resulting sum of square differences, with range [0, infinity)
83 * @tparam tChannels The number of frame channels, with range [1, infinity)
84 * @tparam tPatchSize The size of the square patch (the edge length) in pixel, with range [1, infinity), must be odd
85 */
86 template <unsigned int tChannels, unsigned int tPatchSize>
87 static uint32_t patchMirroredBorder8BitPerChannel(const uint8_t* image0, const uint8_t* image1, const unsigned int width0, const unsigned int height0, const unsigned int width1, const unsigned int height1, const unsigned int centerX0, const unsigned int centerY0, const unsigned int centerX1, const unsigned int centerY1, const unsigned int image0PaddingElements, const unsigned int image1PaddingElements);
88
89 protected:
90
91 /**
92 * Returns the mirrored element index for a given element index.
93 * The mirrored index is calculated as follows:
94 * <pre>
95 * |<----------------------- valid value range -------------------------->|
96 *
97 * elementIndex: -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, ... elements - 3, elements - 2, elements - 1, elements + 0, elements + 1
98 * result: 2 1 0 0 1 2 3 4 5 6 7 ... elements - 3 elements - 2 elements - 1 elements - 1 elements - 2
99 * </pre>
100 * The resulting mirrored index is adjusted to support several channels.
101 * @param elementIndex The index for which the mirrored index will be returned, with range [-elements/2, elements + elements/2]
102 * @param elements The number of maximal elements, with range [1, infinity)
103 * @return The mirrored index, with range [0, elements)
104 * @tparam tChannels The number of channels the elements have, with range [1, infinity)
105 */
106 template <unsigned int tChannels>
107 static OCEAN_FORCE_INLINE unsigned int mirrorIndex(const int elementIndex, const unsigned int elements);
108
109 /**
110 * Loads up to 8 uint8_t values from a row with mirroring pixels if necessary.
111 * @param row The row from which the values will be loaded, must be valid
112 * @param elementIndex The index of the first elements to load, with range [-elements/2, elements + elements/2]
113 * @param elements The number of elements in the row, with range [4, infinity)
114 * @param intermediateBuffer An intermediate buffer with 8 elements, must be valid
115 * @return The __m128i object with the loaded values (in the lower 64 bits)
116 * @tparam tChannels The number of channels the row has, with range [1, infinity)
117 * @tparam tFront True, if the uint8_t values will be placed at the front of the resulting __m128i object; False, to placed the uint8_t values at the end
118 * @tparam tSize The number of uint8_t values to be read, with range [1, 8]
119 */
120 template <unsigned int tChannels, bool tFront, unsigned int tSize>
121 static OCEAN_FORCE_INLINE __m128i loadMirrored_u_8x8(const uint8_t* const row, const int elementIndex, const unsigned int elements, uint8_t* const intermediateBuffer);
122
123 /**
124 * Loads up to 16 uint8_t values from a row with mirroring pixels if necessary.
125 * @param row The row from which the values will be loaded, must be valid
126 * @param elementIndex The index of the first elements to load, with range [-elements/2, elements + elements/2]
127 * @param elements The number of elements in the row, with range [8, infinity)
128 * @param intermediateBuffer An intermediate buffer with 16 elements, must be valid
129 * @return The __m128i object with the loaded values
130 * @tparam tChannels The number of channels the row has, with range [1, infinity)
131 * @tparam tFront True, if the uint8_t values will be placed at the front of the resulting __m128i object; False, to placed the uint8_t values at the end
132 * @tparam tSize The number of uint8_t values to be read, with range [1, 16]
133 */
134 template <unsigned int tChannels, bool tFront, unsigned int tSize>
135 static OCEAN_FORCE_INLINE __m128i loadMirrored_u_8x16(const uint8_t* const row, const int elementIndex, const unsigned int elements, uint8_t* const intermediateBuffer);
136};
137
138template <unsigned int tSize>
139inline uint32_t SumSquareDifferencesSSE::buffer8BitPerChannel(const uint8_t* buffer0, const uint8_t* buffer1)
140{
141 static_assert(tSize >= 1u, "Invalid buffer size!");
142
143 static_assert(std::is_same<short, int16_t>::value, "Invalid data type!");
144
145 const __m128i constant_signs_m128i = _mm_set1_epi16(short(0x1FF)); // -1, 1, -1, 1, -1, 1, -1, 1
146
147 __m128i sumLow_128i = _mm_setzero_si128();
148 __m128i sumHigh_128i = _mm_setzero_si128();
149
150 // first, we handle blocks with 16 elements
151
152 constexpr unsigned int blocks16 = tSize / 16u;
153
154 for (unsigned int n = 0u; n < blocks16; ++n)
155 {
156 const __m128i buffer0_128i = _mm_lddqu_si128((const __m128i*)buffer0);
157 const __m128i buffer1_128i = _mm_lddqu_si128((const __m128i*)buffer1);
158
159 const __m128i absDifferencesLow_128i = _mm_maddubs_epi16(_mm_unpacklo_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
160 const __m128i absDifferencesHigh_128i = _mm_maddubs_epi16(_mm_unpackhi_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
161
162 sumLow_128i = _mm_add_epi32(sumLow_128i, _mm_madd_epi16(absDifferencesLow_128i, absDifferencesLow_128i));
163 sumHigh_128i = _mm_add_epi32(sumHigh_128i, _mm_madd_epi16(absDifferencesHigh_128i, absDifferencesHigh_128i));
164
165 buffer0 += 16;
166 buffer1 += 16;
167 }
168
169 if constexpr (blocks16 >= 1u && (tSize % 16u) >= 10u)
170 {
171 constexpr unsigned int remainingElements = tSize % 16u;
172 constexpr unsigned int overlappingElements = 16u - remainingElements;
173
174 const __m128i buffer0_128i = _mm_srli_si128(_mm_lddqu_si128((const __m128i*)(buffer0 - overlappingElements)), overlappingElements);
175 const __m128i buffer1_128i = _mm_srli_si128(_mm_lddqu_si128((const __m128i*)(buffer1 - overlappingElements)), overlappingElements);
176
177 const __m128i absDifferencesLow_128i = _mm_maddubs_epi16(_mm_unpacklo_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
178 const __m128i absDifferencesHigh_128i = _mm_maddubs_epi16(_mm_unpackhi_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
179
180 sumLow_128i = _mm_add_epi32(sumLow_128i, _mm_madd_epi16(absDifferencesLow_128i, absDifferencesLow_128i));
181 sumHigh_128i = _mm_add_epi32(sumHigh_128i, _mm_madd_epi16(absDifferencesHigh_128i, absDifferencesHigh_128i));
182
183 const __m128i sum_128i = _mm_add_epi32(sumLow_128i, sumHigh_128i);
184
185 return SSE::sum_u32_4(sum_128i);
186 }
187 else
188 {
189 // we may handle at most one block with 8 elements
190
191 constexpr unsigned int blocks8 = (tSize % 16u) / 8u;
192 static_assert(blocks8 <= 1u, "Invalid number of blocks!");
193
194 if constexpr (blocks8 == 1u)
195 {
196 const __m128i buffer0_128i = _mm_loadl_epi64((const __m128i*)buffer0); // load for unaligned 64 bit memory
197 const __m128i buffer1_128i = _mm_loadl_epi64((const __m128i*)buffer1);
198
199 const __m128i absDifferencesLow_128i = _mm_maddubs_epi16(_mm_unpacklo_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
200
201 sumLow_128i = _mm_add_epi32(sumLow_128i, _mm_madd_epi16(absDifferencesLow_128i, absDifferencesLow_128i));
202
203 buffer0 += 8;
204 buffer1 += 8;
205 }
206
207 const __m128i sum_128i = _mm_add_epi32(sumLow_128i, sumHigh_128i);
208
209 constexpr unsigned int remainingElements = tSize - blocks16 * 16u - blocks8 * 8u;
210 static_assert(remainingElements < 8u, "Invalid number of remaining elements!");
211
212 uint32_t result = SSE::sum_u32_4(sum_128i);
213
214 // we apply the remaining elements (at most 7)
215
216 for (unsigned int n = 0u; n < remainingElements; ++n)
217 {
218 result += sqrDistance(buffer0[n], buffer1[n]);
219 }
220
221 return result;
222 }
223}
224
225template <unsigned int tChannels, unsigned int tPatchSize>
226inline uint32_t SumSquareDifferencesSSE::patch8BitPerChannel(const uint8_t* patch0, const uint8_t* patch1, const unsigned int patch0StrideElements, const unsigned int patch1StrideElements)
227{
228 static_assert(tChannels >= 1u, "Invalid channel number!");
229 static_assert(tPatchSize >= 1u, "Invalid buffer size!");
230
231 ocean_assert(patch0 != nullptr && patch1 != nullptr);
232
233 ocean_assert(patch0StrideElements >= tChannels * tPatchSize);
234 ocean_assert(patch1StrideElements >= tChannels * tPatchSize);
235
236 constexpr unsigned int patchWidthElements = tChannels * tPatchSize;
237
238 constexpr unsigned int blocks16 = patchWidthElements / 16u;
239 constexpr unsigned int remainingAfterBlocks16 = patchWidthElements % 16u;
240
241 constexpr bool partialBlock16 = remainingAfterBlocks16 > 8u;
242
243 constexpr bool fullBlock8 = !partialBlock16 && remainingAfterBlocks16 == 8u;
244
245 constexpr bool partialBlock8 = !partialBlock16 && !fullBlock8 && remainingAfterBlocks16 >= 3u;
246
247 constexpr unsigned int blocks1 = (!partialBlock16 && !fullBlock8 && !partialBlock8) ? remainingAfterBlocks16 : 0u;
248
249 static_assert(blocks1 <= 2u, "Invalid block size!");
250
251 static_assert(std::is_same<short, int16_t>::value, "Invalid data type!");
252
253 const __m128i constant_signs_m128i = _mm_set1_epi16(short(0x1FF)); // -1, 1, -1, 1, -1, 1, -1, 1
254
255 __m128i sumLow_128i = _mm_setzero_si128();
256 __m128i sumHigh_128i = _mm_setzero_si128();
257
258 uint32_t sumIndividual = 0u;
259
260 for (unsigned int y = 0u; y < tPatchSize; ++y)
261 {
262 SSE::prefetchT0(patch0 + patch0StrideElements);
263 SSE::prefetchT0(patch1 + patch1StrideElements);
264
265 for (unsigned int n = 0u; n < blocks16; ++n)
266 {
267 const __m128i buffer0_128i = _mm_lddqu_si128((const __m128i*)patch0);
268 const __m128i buffer1_128i = _mm_lddqu_si128((const __m128i*)patch1);
269
270 const __m128i absDifferencesLow_128i = _mm_maddubs_epi16(_mm_unpacklo_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
271 const __m128i absDifferencesHigh_128i = _mm_maddubs_epi16(_mm_unpackhi_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
272
273 sumLow_128i = _mm_add_epi32(sumLow_128i, _mm_madd_epi16(absDifferencesLow_128i, absDifferencesLow_128i));
274 sumHigh_128i = _mm_add_epi32(sumHigh_128i, _mm_madd_epi16(absDifferencesHigh_128i, absDifferencesHigh_128i));
275
276 patch0 += 16;
277 patch1 += 16;
278 }
279
280 if constexpr (fullBlock8)
281 {
282 const __m128i buffer0_128i = _mm_loadl_epi64((const __m128i*)patch0); // load for unaligned 64 bit memory
283 const __m128i buffer1_128i = _mm_loadl_epi64((const __m128i*)patch1);
284
285 const __m128i absDifferencesLow_128i = _mm_maddubs_epi16(_mm_unpacklo_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
286
287 sumLow_128i = _mm_add_epi32(sumLow_128i, _mm_madd_epi16(absDifferencesLow_128i, absDifferencesLow_128i));
288
289 patch0 += 8;
290 patch1 += 8;
291 }
292
293 if constexpr (partialBlock16)
294 {
295 constexpr unsigned int overlapElements = partialBlock16 ? 16u - remainingAfterBlocks16 : 0u;
296
297 static_assert(overlapElements < 8u, "Invalid value!");
298
299 if (y < tPatchSize - 1u)
300 {
301 const __m128i buffer0_128i = _mm_slli_si128(_mm_lddqu_si128((const __m128i*)patch0), overlapElements); // loading 16 elements, but shifting `overlapElements` zeros to the left
302 const __m128i buffer1_128i = _mm_slli_si128(_mm_lddqu_si128((const __m128i*)patch1), overlapElements);
303
304 const __m128i absDifferencesLow_128i = _mm_maddubs_epi16(_mm_unpacklo_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
305 const __m128i absDifferencesHigh_128i = _mm_maddubs_epi16(_mm_unpackhi_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
306
307 sumLow_128i = _mm_add_epi32(sumLow_128i, _mm_madd_epi16(absDifferencesLow_128i, absDifferencesLow_128i));
308 sumHigh_128i = _mm_add_epi32(sumHigh_128i, _mm_madd_epi16(absDifferencesHigh_128i, absDifferencesHigh_128i));
309 }
310 else
311 {
312 const __m128i buffer0_128i = _mm_srli_si128(_mm_lddqu_si128((const __m128i*)(patch0 - overlapElements)), overlapElements); // loading 16 elements, but shifting `overlapElements` zeros to the right
313 const __m128i buffer1_128i = _mm_srli_si128(_mm_lddqu_si128((const __m128i*)(patch1 - overlapElements)), overlapElements);
314
315 const __m128i absDifferencesLow_128i = _mm_maddubs_epi16(_mm_unpacklo_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
316 const __m128i absDifferencesHigh_128i = _mm_maddubs_epi16(_mm_unpackhi_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
317
318 sumLow_128i = _mm_add_epi32(sumLow_128i, _mm_madd_epi16(absDifferencesLow_128i, absDifferencesLow_128i));
319 sumHigh_128i = _mm_add_epi32(sumHigh_128i, _mm_madd_epi16(absDifferencesHigh_128i, absDifferencesHigh_128i));
320 }
321
322 patch0 += remainingAfterBlocks16;
323 patch1 += remainingAfterBlocks16;
324 }
325
326 if constexpr (partialBlock8)
327 {
328 constexpr unsigned int overlapElements = partialBlock8 ? 8u - remainingAfterBlocks16 : 0u;
329
330 static_assert(overlapElements < 8u, "Invalid value!");
331
332 if (y < tPatchSize - 1u)
333 {
334 const __m128i buffer0_128i = _mm_slli_si128(_mm_loadl_epi64((const __m128i*)patch0), overlapElements + 8); // loading 8 elements, but shifting `overlapElements` zeros to the left
335 const __m128i buffer1_128i = _mm_slli_si128(_mm_loadl_epi64((const __m128i*)patch1), overlapElements + 8);
336
337 const __m128i absDifferencesHigh_128i = _mm_maddubs_epi16(_mm_unpackhi_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
338
339 sumLow_128i = _mm_add_epi32(sumLow_128i, _mm_madd_epi16(absDifferencesHigh_128i, absDifferencesHigh_128i));
340 }
341 else
342 {
343 const __m128i buffer0_128i = _mm_srli_si128(_mm_loadl_epi64((const __m128i*)(patch0 - overlapElements)), overlapElements); // loading 8 elements, but shifting `overlapElements` zeros to the right
344 const __m128i buffer1_128i = _mm_srli_si128(_mm_loadl_epi64((const __m128i*)(patch1 - overlapElements)), overlapElements);
345
346 const __m128i absDifferencesLow_128i = _mm_maddubs_epi16(_mm_unpacklo_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
347
348 sumLow_128i = _mm_add_epi32(sumLow_128i, _mm_madd_epi16(absDifferencesLow_128i, absDifferencesLow_128i));
349 }
350
351 patch0 += remainingAfterBlocks16;
352 patch1 += remainingAfterBlocks16;
353 }
354
355 if constexpr (blocks1 != 0u)
356 {
357 for (unsigned int n = 0u; n < blocks1; ++n)
358 {
359 sumIndividual += sqrDistance(patch0[n], patch1[n]);
360 }
361
362 patch0 += blocks1;
363 patch1 += blocks1;
364 }
365
366 patch0 += patch0StrideElements - patchWidthElements;
367 patch1 += patch1StrideElements - patchWidthElements;
368 }
369
370 const __m128i sum_128i = _mm_add_epi32(sumLow_128i, sumHigh_128i);
371
372 return SSE::sum_u32_4(sum_128i) + sumIndividual;
373}
374
375template <unsigned int tChannels, unsigned int tPatchSize>
376inline uint32_t SumSquareDifferencesSSE::patchBuffer8BitPerChannel(const uint8_t* patch0, const uint8_t* buffer1, const unsigned int patch0StrideElements)
377{
378 return patch8BitPerChannel<tChannels, tPatchSize>(patch0, buffer1, patch0StrideElements, tChannels * tPatchSize);
379}
380
381template <unsigned int tChannels, unsigned int tPatchSize>
382uint32_t SumSquareDifferencesSSE::patchMirroredBorder8BitPerChannel(const uint8_t* image0, const uint8_t* image1, const unsigned int width0, const unsigned int height0, const unsigned int width1, const unsigned int height1, const unsigned int centerX0, const unsigned int centerY0, const unsigned int centerX1, const unsigned int centerY1, const unsigned int image0PaddingElements, const unsigned int image1PaddingElements)
383{
384 static_assert(tChannels >= 1u, "Invalid channel number!");
385 static_assert(tPatchSize % 2u == 1u, "Invalid patch size!");
386
387 ocean_assert(image0 != nullptr && image1 != nullptr);
388
389 ocean_assert(centerX0 < width0 && centerY0 < height0);
390 ocean_assert(centerX1 < width1 && centerY1 < height1);
391
392 constexpr unsigned int tPatchSize_2 = tPatchSize / 2u;
393
394 const unsigned int width0Elements = width0 * tChannels;
395 const unsigned int width1Elements = width1 * tChannels;
396
397 const unsigned int image0StrideElements = width0Elements + image0PaddingElements;
398 const unsigned int image1StrideElements = width1Elements + image1PaddingElements;
399
400 constexpr unsigned int patchWidthElements = tChannels * tPatchSize;
401
402 constexpr unsigned int blocks16 = patchWidthElements / 16u;
403 constexpr unsigned int remainingAfterBlocks16 = patchWidthElements % 16u;
404
405 constexpr bool partialBlock16 = remainingAfterBlocks16 > 8u;
406 constexpr unsigned int remainingAfterPartialBlock16 = partialBlock16 ? 0u : remainingAfterBlocks16;
407
408 constexpr unsigned int blocks8 = remainingAfterPartialBlock16 / 8u;
409 constexpr unsigned int remainingAfterBlocks8 = remainingAfterPartialBlock16 % 8u;
410
411 constexpr bool partialBlock8 = remainingAfterBlocks8 >= 3u;
412 constexpr unsigned int remainingAfterPartialBlock8 = partialBlock8 ? 0u : remainingAfterBlocks8;
413
414 constexpr unsigned int blocks1 = remainingAfterPartialBlock8;
415
416 static_assert(blocks1 <= 7u, "Invalid block size!");
417
418 static_assert(std::is_same<short, int16_t>::value, "Invalid data type!");
419
420 const __m128i constant_signs_m128i = _mm_set1_epi16(short(0x1FF)); // -1, 1, -1, 1, -1, 1, -1, 1
421
422 __m128i sumLow_128i = _mm_setzero_si128();
423 __m128i sumHigh_128i = _mm_setzero_si128();
424
425 uint32_t sumIndividual = 0u;
426
427 uint8_t intermediate[16];
428
429 int y1 = int(centerY1) - int(tPatchSize_2);
430 for (int y0 = int(centerY0) - int(tPatchSize_2); y0 <= int(centerY0) + int(tPatchSize_2); ++y0)
431 {
432 const uint8_t* const mirroredRow0 = image0 + (unsigned int)(y0 + CVUtilities::mirrorOffset(y0, height0)) * image0StrideElements;
433 const uint8_t* const mirroredRow1 = image1 + (unsigned int)(y1 + CVUtilities::mirrorOffset(y1, height1)) * image1StrideElements;
434
435 int x0 = (int(centerX0) - int(tPatchSize_2)) * int(tChannels);
436 int x1 = (int(centerX1) - int(tPatchSize_2)) * int(tChannels);
437
438 for (unsigned int n = 0u; n < blocks16; ++n)
439 {
440 const __m128i buffer0_128i = loadMirrored_u_8x16<tChannels, true, 16u>(mirroredRow0, x0, width0Elements, intermediate);
441 const __m128i buffer1_128i = loadMirrored_u_8x16<tChannels, true, 16u>(mirroredRow1, x1, width1Elements, intermediate);
442
443 const __m128i absDifferencesLow_128i = _mm_maddubs_epi16(_mm_unpacklo_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
444 const __m128i absDifferencesHigh_128i = _mm_maddubs_epi16(_mm_unpackhi_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
445
446 sumLow_128i = _mm_add_epi32(sumLow_128i, _mm_madd_epi16(absDifferencesLow_128i, absDifferencesLow_128i));
447 sumHigh_128i = _mm_add_epi32(sumHigh_128i, _mm_madd_epi16(absDifferencesHigh_128i, absDifferencesHigh_128i));
448
449 x0 += 16;
450 x1 += 16;
451 }
452
453 if constexpr (partialBlock16)
454 {
455 if (y0 < int(centerY0) + int(tPatchSize_2))
456 {
457 const __m128i buffer0_128i = loadMirrored_u_8x16<tChannels, true, remainingAfterBlocks16>(mirroredRow0, x0, width0Elements, intermediate);
458 const __m128i buffer1_128i = loadMirrored_u_8x16<tChannels, true, remainingAfterBlocks16>(mirroredRow1, x1, width1Elements, intermediate);
459
460 const __m128i absDifferencesLow_128i = _mm_maddubs_epi16(_mm_unpacklo_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
461 const __m128i absDifferencesHigh_128i = _mm_maddubs_epi16(_mm_unpackhi_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
462
463 sumLow_128i = _mm_add_epi32(sumLow_128i, _mm_madd_epi16(absDifferencesLow_128i, absDifferencesLow_128i));
464 sumHigh_128i = _mm_add_epi32(sumHigh_128i, _mm_madd_epi16(absDifferencesHigh_128i, absDifferencesHigh_128i));
465 }
466 else
467 {
468 const __m128i buffer0_128i = loadMirrored_u_8x16<tChannels, false, remainingAfterBlocks16>(mirroredRow0, x0, width0Elements, intermediate);
469 const __m128i buffer1_128i = loadMirrored_u_8x16<tChannels, false, remainingAfterBlocks16>(mirroredRow1, x1, width1Elements, intermediate);
470
471 const __m128i absDifferencesLow_128i = _mm_maddubs_epi16(_mm_unpacklo_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
472 const __m128i absDifferencesHigh_128i = _mm_maddubs_epi16(_mm_unpackhi_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
473
474 sumLow_128i = _mm_add_epi32(sumLow_128i, _mm_madd_epi16(absDifferencesLow_128i, absDifferencesLow_128i));
475 sumHigh_128i = _mm_add_epi32(sumHigh_128i, _mm_madd_epi16(absDifferencesHigh_128i, absDifferencesHigh_128i));
476 }
477
478 x0 += remainingAfterBlocks16;
479 x1 += remainingAfterBlocks16;
480 }
481
482 for (unsigned int n = 0u; n < blocks8; ++n)
483 {
484 const __m128i buffer0_128i = loadMirrored_u_8x8<tChannels, true, 8u>(mirroredRow0, x0, width0Elements, intermediate);
485 const __m128i buffer1_128i = loadMirrored_u_8x8<tChannels, true, 8u>(mirroredRow1, x1, width1Elements, intermediate);
486
487 const __m128i absDifferencesLow_128i = _mm_maddubs_epi16(_mm_unpacklo_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
488
489 sumLow_128i = _mm_add_epi32(sumLow_128i, _mm_madd_epi16(absDifferencesLow_128i, absDifferencesLow_128i));
490
491 x0 += 8;
492 x1 += 8;
493 }
494
495 if constexpr (partialBlock8)
496 {
497 // we have enough elements left so that using SSE is still faster than handling each element individually
498
499 if (y0 < int(centerY0) + int(tPatchSize_2))
500 {
501 // Shift data to high bytes and use unpack_hi to ensure correct pairing for madd
502 const __m128i loaded0_128i = loadMirrored_u_8x8<tChannels, true, remainingAfterBlocks8>(mirroredRow0, x0, width0Elements, intermediate);
503 const __m128i loaded1_128i = loadMirrored_u_8x8<tChannels, true, remainingAfterBlocks8>(mirroredRow1, x1, width1Elements, intermediate);
504
505 constexpr unsigned int shift = 8u + (8u - remainingAfterBlocks8);
506 const __m128i remaining0_128i = _mm_slli_si128(loaded0_128i, shift);
507 const __m128i remaining1_128i = _mm_slli_si128(loaded1_128i, shift);
508
509 const __m128i absDifferencesHigh_128i = _mm_maddubs_epi16(_mm_unpackhi_epi8(remaining0_128i, remaining1_128i), constant_signs_m128i);
510
511 sumLow_128i = _mm_add_epi32(sumLow_128i, _mm_madd_epi16(absDifferencesHigh_128i, absDifferencesHigh_128i));
512 }
513 else
514 {
515 // Data is in low bytes (after zero-padding at front), use unpack_lo
516 const __m128i remaining0_128i = loadMirrored_u_8x8<tChannels, false, remainingAfterBlocks8>(mirroredRow0, x0, width0Elements, intermediate);
517 const __m128i remaining1_128i = loadMirrored_u_8x8<tChannels, false, remainingAfterBlocks8>(mirroredRow1, x1, width1Elements, intermediate);
518
519 const __m128i absDifferencesLow_128i = _mm_maddubs_epi16(_mm_unpacklo_epi8(remaining0_128i, remaining1_128i), constant_signs_m128i);
520
521 sumLow_128i = _mm_add_epi32(sumLow_128i, _mm_madd_epi16(absDifferencesLow_128i, absDifferencesLow_128i));
522 }
523
524 x0 += remainingAfterBlocks8;
525 x1 += remainingAfterBlocks8;
526 }
527
528 if constexpr (blocks1 != 0u)
529 {
530 for (unsigned int n = 0u; n < blocks1; ++n)
531 {
532 sumIndividual += sqrDistance(mirroredRow0[mirrorIndex<tChannels>(x0 + int(n), width0Elements)], mirroredRow1[mirrorIndex<tChannels>(x1 + int(n), width1Elements)]);
533 }
534 }
535
536 ++y1;
537 }
538
539 const __m128i sum_128i = _mm_add_epi32(sumLow_128i, sumHigh_128i);
540
541 return SSE::sum_u32_4(sum_128i) + sumIndividual;
542}
543
544template <unsigned int tChannels>
545inline unsigned int SumSquareDifferencesSSE::mirrorIndex(const int elementIndex, const unsigned int elements)
546{
547 static_assert(tChannels >= 1u, "Invalid channel number!");
548
549 if ((unsigned int)(elementIndex) < elements)
550 {
551 return elementIndex;
552 }
553
554 if (elementIndex < 0)
555 {
556 const unsigned int leftElements = (unsigned int)(-elementIndex) - 1u;
557
558 const unsigned int pixelIndex = leftElements / tChannels;
559 const unsigned int channelIndex = tChannels - (leftElements % tChannels) - 1u;
560 ocean_assert(channelIndex < tChannels);
561
562 ocean_assert(pixelIndex * tChannels + channelIndex < elements);
563 return pixelIndex * tChannels + channelIndex;
564 }
565 else
566 {
567 ocean_assert((unsigned int)(elementIndex) >= elements);
568
569 const unsigned int rightElements = elementIndex - elements;
570
571 const unsigned int rightPixels = rightElements / tChannels;
572 const unsigned int channelIndex = rightElements % tChannels;
573 ocean_assert(channelIndex < tChannels);
574
575 ocean_assert(elements - (rightPixels + 1u) * tChannels + channelIndex < elements);
576 return elements - (rightPixels + 1u) * tChannels + channelIndex;
577 }
578}
579
580template <unsigned int tChannels, bool tFront, unsigned int tSize>
581OCEAN_FORCE_INLINE __m128i SumSquareDifferencesSSE::loadMirrored_u_8x8(const uint8_t* const row, const int elementIndex, const unsigned int elements, uint8_t* const intermediateBuffer)
582{
583 static_assert(tChannels >= 1u, "Invalid channel number!");
584
585 ocean_assert(tSize >= 1u && tSize <= 8u);
586
587 ocean_assert(row != nullptr && intermediateBuffer != nullptr);
588
589 constexpr unsigned int tOverlappingElements = 8u - tSize;
590
591 if (elementIndex >= 0 && elementIndex <= int(elements) - int(tSize))
592 {
593 if constexpr (tSize == 8u)
594 {
595 return _mm_loadl_epi64((const __m128i*)(row + elementIndex));
596 }
597 else
598 {
599 if constexpr (tFront)
600 {
601 // For tFront=true, keep data at the front (low bytes), zero the high bytes
602 // We load tSize bytes, they stay in the low bytes naturally
603 for (unsigned int n = 0u; n < tSize; ++n)
604 {
605 intermediateBuffer[n] = row[elementIndex + n];
606 }
607 for (unsigned int n = tSize; n < 8u; ++n)
608 {
609 intermediateBuffer[n] = 0u;
610 }
611 return _mm_loadl_epi64((const __m128i*)intermediateBuffer);
612 }
613 else
614 {
615 // For tFront=false, put zeros at the front (low bytes), data at the back (high bytes)
616 for (unsigned int n = 0u; n < tOverlappingElements; ++n)
617 {
618 intermediateBuffer[n] = 0u;
619 }
620 for (unsigned int n = 0u; n < tSize; ++n)
621 {
622 intermediateBuffer[tOverlappingElements + n] = row[elementIndex + n];
623 }
624 return _mm_loadl_epi64((const __m128i*)intermediateBuffer);
625 }
626 }
627 }
628
629 if constexpr (tFront)
630 {
631 for (unsigned int n = 0u; n < tSize; ++n)
632 {
633 const unsigned int index = mirrorIndex<tChannels>(elementIndex + int(n), elements);
634 ocean_assert(index < elements);
635
636 intermediateBuffer[n] = row[index];
637 }
638
639 for (unsigned int n = tSize; n < 8u; ++n)
640 {
641 intermediateBuffer[n] = 0u;
642 }
643 }
644 else
645 {
646 for (unsigned int n = 0u; n < tOverlappingElements; ++n)
647 {
648 intermediateBuffer[n] = 0u;
649 }
650
651 for (unsigned int n = 0u; n < tSize; ++n)
652 {
653 const unsigned int index = mirrorIndex<tChannels>(elementIndex + int(n), elements);
654 ocean_assert(index < elements);
655
656 intermediateBuffer[tOverlappingElements + n] = row[index];
657 }
658 }
659
660 return _mm_loadl_epi64((const __m128i*)intermediateBuffer);
661}
662
663template <unsigned int tChannels, bool tFront, unsigned int tSize>
664OCEAN_FORCE_INLINE __m128i SumSquareDifferencesSSE::loadMirrored_u_8x16(const uint8_t* const row, const int elementIndex, const unsigned int elements, uint8_t* const intermediateBuffer)
665{
666 static_assert(tChannels >= 1u, "Invalid channel number!");
667
668 ocean_assert(tSize > 8u && tSize <= 16u);
669
670 ocean_assert(row != nullptr && intermediateBuffer != nullptr);
671
672 constexpr unsigned int tOverlappingElements = 16u - tSize;
673
674 if (elementIndex >= 0 && elementIndex <= int(elements) - int(tSize))
675 {
676 if constexpr (tSize == 16u)
677 {
678 return _mm_lddqu_si128((const __m128i*)(row + elementIndex));
679 }
680 else
681 {
682 if constexpr (tFront)
683 {
684 // For tFront=true, keep data at the front (low bytes), zero the high bytes
685 for (unsigned int n = 0u; n < tSize; ++n)
686 {
687 intermediateBuffer[n] = row[elementIndex + n];
688 }
689 for (unsigned int n = tSize; n < 16u; ++n)
690 {
691 intermediateBuffer[n] = 0u;
692 }
693 return _mm_lddqu_si128((const __m128i*)intermediateBuffer);
694 }
695 else
696 {
697 // For tFront=false, put zeros at the front (low bytes), data at the back (high bytes)
698 for (unsigned int n = 0u; n < tOverlappingElements; ++n)
699 {
700 intermediateBuffer[n] = 0u;
701 }
702 for (unsigned int n = 0u; n < tSize; ++n)
703 {
704 intermediateBuffer[tOverlappingElements + n] = row[elementIndex + n];
705 }
706 return _mm_lddqu_si128((const __m128i*)intermediateBuffer);
707 }
708 }
709 }
710
711 if constexpr (tFront)
712 {
713 for (unsigned int n = 0u; n < tSize; ++n)
714 {
715 const unsigned int index = mirrorIndex<tChannels>(elementIndex + int(n), elements);
716 ocean_assert(index < elements);
717
718 intermediateBuffer[n] = row[index];
719 }
720
721 for (unsigned int n = tSize; n < 16u; ++n)
722 {
723 intermediateBuffer[n] = 0u;
724 }
725 }
726 else
727 {
728 for (unsigned int n = 0u; n < tOverlappingElements; ++n)
729 {
730 intermediateBuffer[n] = 0u;
731 }
732
733 for (unsigned int n = 0u; n < tSize; ++n)
734 {
735 const unsigned int index = mirrorIndex<tChannels>(elementIndex + int(n), elements);
736 ocean_assert(index < elements);
737
738 intermediateBuffer[tOverlappingElements + n] = row[index];
739 }
740 }
741
742 return _mm_lddqu_si128((const __m128i*)intermediateBuffer);
743}
744
745}
746
747}
748
749#endif // OCEAN_HARDWARE_SSE_VERSION >= 41
750
751#endif // META_OCEAN_CV_SUM_SQUARE_DIFFERENCES_SSE_H
static int mirrorOffset(const unsigned int index, const unsigned int elements)
Deprecated.
Definition CVUtilities.h:449
static void prefetchT0(const void *const data)
Prefetches a block of temporal memory into all cache levels.
Definition SSE.h:1293
static OCEAN_FORCE_INLINE unsigned int sum_u32_4(const __m128i &value)
Adds the four (all four) individual 32 bit unsigned integer values of a m128i value and returns the r...
Definition SSE.h:1360
This class implements function to calculate sum square differences using SSE instructions.
Definition SumSquareDifferencesSSE.h:30
static uint32_t patchBuffer8BitPerChannel(const uint8_t *patch0, const uint8_t *buffer1, const unsigned int patch0StrideElements)
Returns the sum of square differences between an image patch and a buffer.
Definition SumSquareDifferencesSSE.h:376
static OCEAN_FORCE_INLINE __m128i loadMirrored_u_8x8(const uint8_t *const row, const int elementIndex, const unsigned int elements, uint8_t *const intermediateBuffer)
Loads up to 8 uint8_t values from a row with mirroring pixels if necessary.
Definition SumSquareDifferencesSSE.h:581
static OCEAN_FORCE_INLINE unsigned int mirrorIndex(const int elementIndex, const unsigned int elements)
Returns the mirrored element index for a given element index.
static OCEAN_FORCE_INLINE __m128i loadMirrored_u_8x16(const uint8_t *const row, const int elementIndex, const unsigned int elements, uint8_t *const intermediateBuffer)
Loads up to 16 uint8_t values from a row with mirroring pixels if necessary.
Definition SumSquareDifferencesSSE.h:664
static uint32_t buffer8BitPerChannel(const uint8_t *buffer0, const uint8_t *buffer1)
Returns the sum of square differences between two memory buffers.
Definition SumSquareDifferencesSSE.h:139
static uint32_t patchMirroredBorder8BitPerChannel(const uint8_t *image0, const uint8_t *image1, const unsigned int width0, const unsigned int height0, const unsigned int width1, const unsigned int height1, const unsigned int centerX0, const unsigned int centerY0, const unsigned int centerX1, const unsigned int centerY1, const unsigned int image0PaddingElements, const unsigned int image1PaddingElements)
Returns the sum of square differences between two patches within an image, patch pixels outside the i...
Definition SumSquareDifferencesSSE.h:382
static uint32_t patch8BitPerChannel(const uint8_t *patch0, const uint8_t *patch1, const unsigned int patch0StrideElements, const unsigned int patch1StrideElements)
Returns the sum of square differences between two patches within an image.
Definition SumSquareDifferencesSSE.h:226
unsigned int sqrDistance(const char first, const char second)
Returns the square distance between two values.
Definition base/Utilities.h:1159
The namespace covering the entire Ocean framework.
Definition Accessor.h:15