Ocean
Loading...
Searching...
No Matches
SumSquareDifferencesSSE.h
Go to the documentation of this file.
1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 *
4 * This source code is licensed under the MIT license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7
8#ifndef META_OCEAN_CV_SUM_SQUARE_DIFFERENCES_SSE_H
9#define META_OCEAN_CV_SUM_SQUARE_DIFFERENCES_SSE_H
10
11#include "ocean/cv/CV.h"
12
14
15#if defined(OCEAN_HARDWARE_SSE_VERSION) && OCEAN_HARDWARE_SSE_VERSION >= 41
16
17#include "ocean/cv/SSE.h"
18
19namespace Ocean
20{
21
22namespace CV
23{
24
25/**
26 * This class implements function to calculate sum square differences using SSE instructions.
27 * @ingroup cv
28 */
30{
31 public:
32
33 /**
34 * Returns the sum of square differences between two memory buffers.
35 * @param buffer0 The first memory buffer, must be valid
36 * @param buffer1 The second memory buffer, must be valid
37 * @return The resulting sum of square differences
38 * @tparam tSize The size of the buffers in elements, with range [1, infinity)
39 */
40 template <unsigned int tSize>
41 static inline uint32_t buffer8BitPerChannel(const uint8_t* buffer0, const uint8_t* buffer1);
42
43 /**
44 * Returns the sum of square differences between two patches within an image.
45 * @param patch0 The top left start position of the first image patch, must be valid
46 * @param patch1 The top left start position of the second image patch, must be valid
47 * @param patch0StrideElements The number of elements between two rows for the first patch, in elements, with range [tChannels, tPatchSize, infinity)
48 * @param patch1StrideElements The number of elements between two rows for the second patch, in elements, with range [tChannels, tPatchSize, infinity)
49 * @return The resulting sum of square differences
50 * @tparam tChannels The number of channels for the given frames, with range [1, infinity)
51 * @tparam tPatchSize The size of the square patch (the edge length) in pixel, with range [1, infinity), must be odd
52 */
53 template <unsigned int tChannels, unsigned int tPatchSize>
54 static inline uint32_t patch8BitPerChannel(const uint8_t* patch0, const uint8_t* patch1, const unsigned int patch0StrideElements, const unsigned int patch1StrideElements);
55
56 /**
57 * Returns the sum of square differences between an image patch and a buffer.
58 * @param patch0 The top left start position of the image patch, must be valid
59 * @param buffer1 The memory buffer, must be valid
60 * @param patch0StrideElements The number of elements between two rows for the image patch, in elements, with range [tChannels, tPatchSize, infinity)
61 * @return The resulting sum of square differences
62 * @tparam tChannels The number of channels for the given frames, with range [1, infinity)
63 * @tparam tPatchSize The size of the square patch (the edge length) in pixel, with range [1, infinity), must be odd
64 */
65 template <unsigned int tChannels, unsigned int tPatchSize>
66 static inline uint32_t patchBuffer8BitPerChannel(const uint8_t* patch0, const uint8_t* buffer1, const unsigned int patch0StrideElements);
67};
68
69template <unsigned int tSize>
70inline uint32_t SumSquareDifferencesSSE::buffer8BitPerChannel(const uint8_t* buffer0, const uint8_t* buffer1)
71{
72 static_assert(tSize >= 1u, "Invalid buffer size!");
73
74 static_assert(std::is_same<short, int16_t>::value, "Invalid data type!");
75
76 const __m128i constant_signs_m128i = _mm_set1_epi16(short(0x1FF)); // -1, 1, -1, 1, -1, 1, -1, 1
77
78 __m128i sumLow_128i = _mm_setzero_si128();
79 __m128i sumHigh_128i = _mm_setzero_si128();
80
81 // first, we handle blocks with 16 elements
82
83 constexpr unsigned int blocks16 = tSize / 16u;
84
85 for (unsigned int n = 0u; n < blocks16; ++n)
86 {
87 const __m128i buffer0_128i = _mm_lddqu_si128((const __m128i*)buffer0);
88 const __m128i buffer1_128i = _mm_lddqu_si128((const __m128i*)buffer1);
89
90 const __m128i absDifferencesLow_128i = _mm_maddubs_epi16(_mm_unpacklo_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
91 const __m128i absDifferencesHigh_128i = _mm_maddubs_epi16(_mm_unpackhi_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
92
93 sumLow_128i = _mm_add_epi32(sumLow_128i, _mm_madd_epi16(absDifferencesLow_128i, absDifferencesLow_128i));
94 sumHigh_128i = _mm_add_epi32(sumHigh_128i, _mm_madd_epi16(absDifferencesHigh_128i, absDifferencesHigh_128i));
95
96 buffer0 += 16;
97 buffer1 += 16;
98 }
99
100 if constexpr (blocks16 >= 1u && (tSize % 16u) >= 10u)
101 {
102 constexpr unsigned int remainingElements = tSize % 16u;
103 constexpr unsigned int overlappingElements = 16u - remainingElements;
104
105 const __m128i buffer0_128i = _mm_srli_si128(_mm_lddqu_si128((const __m128i*)(buffer0 - overlappingElements)), overlappingElements);
106 const __m128i buffer1_128i = _mm_srli_si128(_mm_lddqu_si128((const __m128i*)(buffer1 - overlappingElements)), overlappingElements);
107
108 const __m128i absDifferencesLow_128i = _mm_maddubs_epi16(_mm_unpacklo_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
109 const __m128i absDifferencesHigh_128i = _mm_maddubs_epi16(_mm_unpackhi_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
110
111 sumLow_128i = _mm_add_epi32(sumLow_128i, _mm_madd_epi16(absDifferencesLow_128i, absDifferencesLow_128i));
112 sumHigh_128i = _mm_add_epi32(sumHigh_128i, _mm_madd_epi16(absDifferencesHigh_128i, absDifferencesHigh_128i));
113
114 const __m128i sum_128i = _mm_add_epi32(sumLow_128i, sumHigh_128i);
115
116 return SSE::sum_u32_4(sum_128i);
117 }
118 else
119 {
120 // we may handle at most one block with 8 elements
121
122 constexpr unsigned int blocks8 = (tSize % 16u) / 8u;
123 static_assert(blocks8 <= 1u, "Invalid number of blocks!");
124
125 if constexpr (blocks8 == 1u)
126 {
127 const __m128i buffer0_128i = _mm_loadl_epi64((const __m128i*)buffer0); // load for unaligned 64 bit memory
128 const __m128i buffer1_128i = _mm_loadl_epi64((const __m128i*)buffer1);
129
130 const __m128i absDifferencesLow_128i = _mm_maddubs_epi16(_mm_unpacklo_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
131
132 sumLow_128i = _mm_add_epi32(sumLow_128i, _mm_madd_epi16(absDifferencesLow_128i, absDifferencesLow_128i));
133
134 buffer0 += 8;
135 buffer1 += 8;
136 }
137
138 const __m128i sum_128i = _mm_add_epi32(sumLow_128i, sumHigh_128i);
139
140 constexpr unsigned int remainingElements = tSize - blocks16 * 16u - blocks8 * 8u;
141 static_assert(remainingElements < 8u, "Invalid number of remaining elements!");
142
143 uint32_t result = SSE::sum_u32_4(sum_128i);
144
145 // we apply the remaining elements (at most 7)
146
147 for (unsigned int n = 0u; n < remainingElements; ++n)
148 {
149 result += sqrDistance(buffer0[n], buffer1[n]);
150 }
151
152 return result;
153 }
154}
155
156template <unsigned int tChannels, unsigned int tPatchSize>
157inline uint32_t SumSquareDifferencesSSE::patch8BitPerChannel(const uint8_t* patch0, const uint8_t* patch1, const unsigned int patch0StrideElements, const unsigned int patch1StrideElements)
158{
159 static_assert(tChannels >= 1u, "Invalid channel number!");
160 static_assert(tPatchSize >= 1u, "Invalid buffer size!");
161
162 ocean_assert(patch0 != nullptr && patch1 != nullptr);
163
164 ocean_assert(patch0StrideElements >= tChannels * tPatchSize);
165 ocean_assert(patch1StrideElements >= tChannels * tPatchSize);
166
167 constexpr unsigned int patchWidthElements = tChannels * tPatchSize;
168
169 constexpr unsigned int blocks16 = patchWidthElements / 16u;
170 constexpr unsigned int remainingAfterBlocks16 = patchWidthElements % 16u;
171
172 constexpr bool partialBlock16 = remainingAfterBlocks16 > 8u;
173
174 constexpr bool fullBlock8 = !partialBlock16 && remainingAfterBlocks16 == 8u;
175
176 constexpr bool partialBlock8 = !partialBlock16 && !fullBlock8 && remainingAfterBlocks16 >= 3u;
177
178 constexpr unsigned int blocks1 = (!partialBlock16 && !fullBlock8 && !partialBlock8) ? remainingAfterBlocks16 : 0u;
179
180 static_assert(blocks1 <= 2u, "Invalid block size!");
181
182 static_assert(std::is_same<short, int16_t>::value, "Invalid data type!");
183
184 const __m128i constant_signs_m128i = _mm_set1_epi16(short(0x1FF)); // -1, 1, -1, 1, -1, 1, -1, 1
185
186 __m128i sumLow_128i = _mm_setzero_si128();
187 __m128i sumHigh_128i = _mm_setzero_si128();
188
189 uint32_t sumIndividual = 0u;
190
191 for (unsigned int y = 0u; y < tPatchSize; ++y)
192 {
193 SSE::prefetchT0(patch0 + patch0StrideElements);
194 SSE::prefetchT0(patch1 + patch1StrideElements);
195
196 for (unsigned int n = 0u; n < blocks16; ++n)
197 {
198 const __m128i buffer0_128i = _mm_lddqu_si128((const __m128i*)patch0);
199 const __m128i buffer1_128i = _mm_lddqu_si128((const __m128i*)patch1);
200
201 const __m128i absDifferencesLow_128i = _mm_maddubs_epi16(_mm_unpacklo_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
202 const __m128i absDifferencesHigh_128i = _mm_maddubs_epi16(_mm_unpackhi_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
203
204 sumLow_128i = _mm_add_epi32(sumLow_128i, _mm_madd_epi16(absDifferencesLow_128i, absDifferencesLow_128i));
205 sumHigh_128i = _mm_add_epi32(sumHigh_128i, _mm_madd_epi16(absDifferencesHigh_128i, absDifferencesHigh_128i));
206
207 patch0 += 16;
208 patch1 += 16;
209 }
210
211 if constexpr (fullBlock8)
212 {
213 const __m128i buffer0_128i = _mm_loadl_epi64((const __m128i*)patch0); // load for unaligned 64 bit memory
214 const __m128i buffer1_128i = _mm_loadl_epi64((const __m128i*)patch1);
215
216 const __m128i absDifferencesLow_128i = _mm_maddubs_epi16(_mm_unpacklo_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
217
218 sumLow_128i = _mm_add_epi32(sumLow_128i, _mm_madd_epi16(absDifferencesLow_128i, absDifferencesLow_128i));
219
220 patch0 += 8;
221 patch1 += 8;
222 }
223
224 if constexpr (partialBlock16)
225 {
226 constexpr unsigned int overlapElements = partialBlock16 ? 16u - remainingAfterBlocks16 : 0u;
227
228 static_assert(overlapElements < 8u, "Invalid value!");
229
230 if (y < tPatchSize - 1u)
231 {
232 const __m128i buffer0_128i = _mm_slli_si128(_mm_lddqu_si128((const __m128i*)patch0), overlapElements); // loading 16 elements, but shifting `overlapElements` zeros to the left
233 const __m128i buffer1_128i = _mm_slli_si128(_mm_lddqu_si128((const __m128i*)patch1), overlapElements);
234
235 const __m128i absDifferencesLow_128i = _mm_maddubs_epi16(_mm_unpacklo_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
236 const __m128i absDifferencesHigh_128i = _mm_maddubs_epi16(_mm_unpackhi_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
237
238 sumLow_128i = _mm_add_epi32(sumLow_128i, _mm_madd_epi16(absDifferencesLow_128i, absDifferencesLow_128i));
239 sumHigh_128i = _mm_add_epi32(sumHigh_128i, _mm_madd_epi16(absDifferencesHigh_128i, absDifferencesHigh_128i));
240 }
241 else
242 {
243 const __m128i buffer0_128i = _mm_srli_si128(_mm_lddqu_si128((const __m128i*)(patch0 - overlapElements)), overlapElements); // loading 16 elements, but shifting `overlapElements` zeros to the right
244 const __m128i buffer1_128i = _mm_srli_si128(_mm_lddqu_si128((const __m128i*)(patch1 - overlapElements)), overlapElements);
245
246 const __m128i absDifferencesLow_128i = _mm_maddubs_epi16(_mm_unpacklo_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
247 const __m128i absDifferencesHigh_128i = _mm_maddubs_epi16(_mm_unpackhi_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
248
249 sumLow_128i = _mm_add_epi32(sumLow_128i, _mm_madd_epi16(absDifferencesLow_128i, absDifferencesLow_128i));
250 sumHigh_128i = _mm_add_epi32(sumHigh_128i, _mm_madd_epi16(absDifferencesHigh_128i, absDifferencesHigh_128i));
251 }
252
253 patch0 += remainingAfterBlocks16;
254 patch1 += remainingAfterBlocks16;
255 }
256
257 if constexpr (partialBlock8)
258 {
259 constexpr unsigned int overlapElements = partialBlock8 ? 8u - remainingAfterBlocks16 : 0u;
260
261 static_assert(overlapElements < 8u, "Invalid value!");
262
263 if (y < tPatchSize - 1u)
264 {
265 const __m128i buffer0_128i = _mm_slli_si128(_mm_loadl_epi64((const __m128i*)patch0), overlapElements + 8); // loading 8 elements, but shifting `overlapElements` zeros to the left
266 const __m128i buffer1_128i = _mm_slli_si128(_mm_loadl_epi64((const __m128i*)patch1), overlapElements + 8);
267
268 const __m128i absDifferencesHigh_128i = _mm_maddubs_epi16(_mm_unpackhi_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
269
270 sumLow_128i = _mm_add_epi32(sumLow_128i, _mm_madd_epi16(absDifferencesHigh_128i, absDifferencesHigh_128i));
271 }
272 else
273 {
274 const __m128i buffer0_128i = _mm_srli_si128(_mm_loadl_epi64((const __m128i*)(patch0 - overlapElements)), overlapElements); // loading 8 elements, but shifting `overlapElements` zeros to the right
275 const __m128i buffer1_128i = _mm_srli_si128(_mm_loadl_epi64((const __m128i*)(patch1 - overlapElements)), overlapElements);
276
277 const __m128i absDifferencesLow_128i = _mm_maddubs_epi16(_mm_unpacklo_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
278
279 sumLow_128i = _mm_add_epi32(sumLow_128i, _mm_madd_epi16(absDifferencesLow_128i, absDifferencesLow_128i));
280 }
281
282 patch0 += remainingAfterBlocks16;
283 patch1 += remainingAfterBlocks16;
284 }
285
286 if constexpr (blocks1 != 0u)
287 {
288 for (unsigned int n = 0u; n < blocks1; ++n)
289 {
290 sumIndividual += sqrDistance(patch0[n], patch1[n]);
291 }
292
293 patch0 += blocks1;
294 patch1 += blocks1;
295 }
296
297 patch0 += patch0StrideElements - patchWidthElements;
298 patch1 += patch1StrideElements - patchWidthElements;
299 }
300
301 const __m128i sum_128i = _mm_add_epi32(sumLow_128i, sumHigh_128i);
302
303 return SSE::sum_u32_4(sum_128i) + sumIndividual;
304}
305
306template <unsigned int tChannels, unsigned int tPatchSize>
307inline uint32_t SumSquareDifferencesSSE::patchBuffer8BitPerChannel(const uint8_t* patch0, const uint8_t* buffer1, const unsigned int patch0StrideElements)
308{
309 return patch8BitPerChannel<tChannels, tPatchSize>(patch0, buffer1, patch0StrideElements, tChannels * tPatchSize);
310}
311
312}
313
314}
315
316#endif // OCEAN_HARDWARE_SSE_VERSION >= 41
317
318#endif // META_OCEAN_CV_SUM_SQUARE_DIFFERENCES_SSE_H
static void prefetchT0(const void *const data)
Prefetches a block of temporal memory into all cache levels.
Definition SSE.h:1255
static OCEAN_FORCE_INLINE unsigned int sum_u32_4(const __m128i &value)
Adds the four (all four) individual 32 bit unsigned integer values of a m128i value and returns the r...
Definition SSE.h:1322
This class implements function to calculate sum square differences using SSE instructions.
Definition SumSquareDifferencesSSE.h:30
static uint32_t patchBuffer8BitPerChannel(const uint8_t *patch0, const uint8_t *buffer1, const unsigned int patch0StrideElements)
Returns the sum of square differences between an image patch and a buffer.
Definition SumSquareDifferencesSSE.h:307
static uint32_t buffer8BitPerChannel(const uint8_t *buffer0, const uint8_t *buffer1)
Returns the sum of square differences between two memory buffers.
Definition SumSquareDifferencesSSE.h:70
static uint32_t patch8BitPerChannel(const uint8_t *patch0, const uint8_t *patch1, const unsigned int patch0StrideElements, const unsigned int patch1StrideElements)
Returns the sum of square differences between two patches within an image.
Definition SumSquareDifferencesSSE.h:157
unsigned int sqrDistance(const char first, const char second)
Returns the square distance between two values.
Definition base/Utilities.h:1089
The namespace covering the entire Ocean framework.
Definition Accessor.h:15