Ocean
SumSquareDifferencesSSE.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #ifndef META_OCEAN_CV_SUM_SQUARE_DIFFERENCES_SSE_H
9 #define META_OCEAN_CV_SUM_SQUARE_DIFFERENCES_SSE_H
10 
11 #include "ocean/cv/CV.h"
12 
13 #include "ocean/base/Utilities.h"
14 
15 #if defined(OCEAN_HARDWARE_SSE_VERSION) && OCEAN_HARDWARE_SSE_VERSION >= 41
16 
17 #include "ocean/cv/SSE.h"
18 
19 namespace Ocean
20 {
21 
22 namespace CV
23 {
24 
25 /**
26  * This class implements function to calculate sum square differences using SSE instructions.
27  * @ingroup cv
28  */
30 {
31  public:
32 
33  /**
34  * Returns the sum of square differences between two memory buffers.
35  * @param buffer0 The first memory buffer, must be valid
36  * @param buffer1 The second memory buffer, must be valid
37  * @return The resulting sum of square differences
38  * @tparam tSize The size of the buffers in elements, with range [1, infinity)
39  */
40  template <unsigned int tSize>
41  static inline uint32_t buffer8BitPerChannel(const uint8_t* buffer0, const uint8_t* buffer1);
42 
43  /**
44  * Returns the sum of square differences between two patches within an image.
45  * @param patch0 The top left start position of the first image patch, must be valid
46  * @param patch1 The top left start position of the second image patch, must be valid
47  * @param patch0StrideElements The number of elements between two rows for the first patch, in elements, with range [tChannels, tPatchSize, infinity)
48  * @param patch1StrideElements The number of elements between two rows for the second patch, in elements, with range [tChannels, tPatchSize, infinity)
49  * @return The resulting sum of square differences
50  * @tparam tChannels The number of channels for the given frames, with range [1, infinity)
51  * @tparam tPatchSize The size of the square patch (the edge length) in pixel, with range [1, infinity), must be odd
52  */
53  template <unsigned int tChannels, unsigned int tPatchSize>
54  static inline uint32_t patch8BitPerChannel(const uint8_t* patch0, const uint8_t* patch1, const unsigned int patch0StrideElements, const unsigned int patch1StrideElements);
55 
56  /**
57  * Returns the sum of square differences between an image patch and a buffer.
58  * @param patch0 The top left start position of the image patch, must be valid
59  * @param buffer1 The memory buffer, must be valid
60  * @param patch0StrideElements The number of elements between two rows for the image patch, in elements, with range [tChannels, tPatchSize, infinity)
61  * @return The resulting sum of square differences
62  * @tparam tChannels The number of channels for the given frames, with range [1, infinity)
63  * @tparam tPatchSize The size of the square patch (the edge length) in pixel, with range [1, infinity), must be odd
64  */
65  template <unsigned int tChannels, unsigned int tPatchSize>
66  static inline uint32_t patchBuffer8BitPerChannel(const uint8_t* patch0, const uint8_t* buffer1, const unsigned int patch0StrideElements);
67 };
68 
69 template <unsigned int tSize>
70 inline uint32_t SumSquareDifferencesSSE::buffer8BitPerChannel(const uint8_t* buffer0, const uint8_t* buffer1)
71 {
72  static_assert(tSize >= 1u, "Invalid buffer size!");
73 
74  static_assert(std::is_same<short, int16_t>::value, "Invalid data type!");
75 
76  const __m128i constant_signs_m128i = _mm_set1_epi16(short(0x1FF)); // -1, 1, -1, 1, -1, 1, -1, 1
77 
78  __m128i sumLow_128i = _mm_setzero_si128();
79  __m128i sumHigh_128i = _mm_setzero_si128();
80 
81  // first, we handle blocks with 16 elements
82 
83  constexpr unsigned int blocks16 = tSize / 16u;
84 
85  for (unsigned int n = 0u; n < blocks16; ++n)
86  {
87  const __m128i buffer0_128i = _mm_lddqu_si128((const __m128i*)buffer0);
88  const __m128i buffer1_128i = _mm_lddqu_si128((const __m128i*)buffer1);
89 
90  const __m128i absDifferencesLow_128i = _mm_maddubs_epi16(_mm_unpacklo_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
91  const __m128i absDifferencesHigh_128i = _mm_maddubs_epi16(_mm_unpackhi_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
92 
93  sumLow_128i = _mm_add_epi32(sumLow_128i, _mm_madd_epi16(absDifferencesLow_128i, absDifferencesLow_128i));
94  sumHigh_128i = _mm_add_epi32(sumHigh_128i, _mm_madd_epi16(absDifferencesHigh_128i, absDifferencesHigh_128i));
95 
96  buffer0 += 16;
97  buffer1 += 16;
98  }
99 
100  if constexpr (blocks16 >= 1u && (tSize % 16u) >= 10u)
101  {
102  constexpr unsigned int remainingElements = tSize % 16u;
103  constexpr unsigned int overlappingElements = 16u - remainingElements;
104 
105  const __m128i buffer0_128i = _mm_srli_si128(_mm_lddqu_si128((const __m128i*)(buffer0 - overlappingElements)), overlappingElements);
106  const __m128i buffer1_128i = _mm_srli_si128(_mm_lddqu_si128((const __m128i*)(buffer1 - overlappingElements)), overlappingElements);
107 
108  const __m128i absDifferencesLow_128i = _mm_maddubs_epi16(_mm_unpacklo_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
109  const __m128i absDifferencesHigh_128i = _mm_maddubs_epi16(_mm_unpackhi_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
110 
111  sumLow_128i = _mm_add_epi32(sumLow_128i, _mm_madd_epi16(absDifferencesLow_128i, absDifferencesLow_128i));
112  sumHigh_128i = _mm_add_epi32(sumHigh_128i, _mm_madd_epi16(absDifferencesHigh_128i, absDifferencesHigh_128i));
113 
114  const __m128i sum_128i = _mm_add_epi32(sumLow_128i, sumHigh_128i);
115 
116  return SSE::sum_u32_4(sum_128i);
117  }
118  else
119  {
120  // we may handle at most one block with 8 elements
121 
122  constexpr unsigned int blocks8 = (tSize % 16u) / 8u;
123  static_assert(blocks8 <= 1u, "Invalid number of blocks!");
124 
125  if constexpr (blocks8 == 1u)
126  {
127  const __m128i buffer0_128i = _mm_loadl_epi64((const __m128i*)buffer0); // load for unaligned 64 bit memory
128  const __m128i buffer1_128i = _mm_loadl_epi64((const __m128i*)buffer1);
129 
130  const __m128i absDifferencesLow_128i = _mm_maddubs_epi16(_mm_unpacklo_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
131 
132  sumLow_128i = _mm_add_epi32(sumLow_128i, _mm_madd_epi16(absDifferencesLow_128i, absDifferencesLow_128i));
133 
134  buffer0 += 8;
135  buffer1 += 8;
136  }
137 
138  const __m128i sum_128i = _mm_add_epi32(sumLow_128i, sumHigh_128i);
139 
140  constexpr unsigned int remainingElements = tSize - blocks16 * 16u - blocks8 * 8u;
141  static_assert(remainingElements < 8u, "Invalid number of remaining elements!");
142 
143  uint32_t result = SSE::sum_u32_4(sum_128i);
144 
145  // we apply the remaining elements (at most 7)
146 
147  for (unsigned int n = 0u; n < remainingElements; ++n)
148  {
149  result += sqrDistance(buffer0[n], buffer1[n]);
150  }
151 
152  return result;
153  }
154 }
155 
156 template <unsigned int tChannels, unsigned int tPatchSize>
157 inline uint32_t SumSquareDifferencesSSE::patch8BitPerChannel(const uint8_t* patch0, const uint8_t* patch1, const unsigned int patch0StrideElements, const unsigned int patch1StrideElements)
158 {
159  static_assert(tChannels >= 1u, "Invalid channel number!");
160  static_assert(tPatchSize >= 1u, "Invalid buffer size!");
161 
162  ocean_assert(patch0 != nullptr && patch1 != nullptr);
163 
164  ocean_assert(patch0StrideElements >= tChannels * tPatchSize);
165  ocean_assert(patch1StrideElements >= tChannels * tPatchSize);
166 
167  constexpr unsigned int patchWidthElements = tChannels * tPatchSize;
168 
169  constexpr unsigned int blocks16 = patchWidthElements / 16u;
170  constexpr unsigned int remainingAfterBlocks16 = patchWidthElements % 16u;
171 
172  constexpr bool partialBlock16 = remainingAfterBlocks16 > 8u;
173 
174  constexpr bool fullBlock8 = !partialBlock16 && remainingAfterBlocks16 == 8u;
175 
176  constexpr bool partialBlock8 = !partialBlock16 && !fullBlock8 && remainingAfterBlocks16 >= 3u;
177 
178  constexpr unsigned int blocks1 = (!partialBlock16 && !fullBlock8 && !partialBlock8) ? remainingAfterBlocks16 : 0u;
179 
180  static_assert(blocks1 <= 2u, "Invalid block size!");
181 
182  static_assert(std::is_same<short, int16_t>::value, "Invalid data type!");
183 
184  const __m128i constant_signs_m128i = _mm_set1_epi16(short(0x1FF)); // -1, 1, -1, 1, -1, 1, -1, 1
185 
186  __m128i sumLow_128i = _mm_setzero_si128();
187  __m128i sumHigh_128i = _mm_setzero_si128();
188 
189  uint32_t sumIndividual = 0u;
190 
191  for (unsigned int y = 0u; y < tPatchSize; ++y)
192  {
193  SSE::prefetchT0(patch0 + patch0StrideElements);
194  SSE::prefetchT0(patch1 + patch1StrideElements);
195 
196  for (unsigned int n = 0u; n < blocks16; ++n)
197  {
198  const __m128i buffer0_128i = _mm_lddqu_si128((const __m128i*)patch0);
199  const __m128i buffer1_128i = _mm_lddqu_si128((const __m128i*)patch1);
200 
201  const __m128i absDifferencesLow_128i = _mm_maddubs_epi16(_mm_unpacklo_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
202  const __m128i absDifferencesHigh_128i = _mm_maddubs_epi16(_mm_unpackhi_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
203 
204  sumLow_128i = _mm_add_epi32(sumLow_128i, _mm_madd_epi16(absDifferencesLow_128i, absDifferencesLow_128i));
205  sumHigh_128i = _mm_add_epi32(sumHigh_128i, _mm_madd_epi16(absDifferencesHigh_128i, absDifferencesHigh_128i));
206 
207  patch0 += 16;
208  patch1 += 16;
209  }
210 
211  if constexpr (fullBlock8)
212  {
213  const __m128i buffer0_128i = _mm_loadl_epi64((const __m128i*)patch0); // load for unaligned 64 bit memory
214  const __m128i buffer1_128i = _mm_loadl_epi64((const __m128i*)patch1);
215 
216  const __m128i absDifferencesLow_128i = _mm_maddubs_epi16(_mm_unpacklo_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
217 
218  sumLow_128i = _mm_add_epi32(sumLow_128i, _mm_madd_epi16(absDifferencesLow_128i, absDifferencesLow_128i));
219 
220  patch0 += 8;
221  patch1 += 8;
222  }
223 
224  if constexpr (partialBlock16)
225  {
226  constexpr unsigned int overlapElements = partialBlock16 ? 16u - remainingAfterBlocks16 : 0u;
227 
228  static_assert(overlapElements < 8u, "Invalid value!");
229 
230  if (y < tPatchSize - 1u)
231  {
232  const __m128i buffer0_128i = _mm_slli_si128(_mm_lddqu_si128((const __m128i*)patch0), overlapElements); // loading 16 elements, but shifting `overlapElements` zeros to the left
233  const __m128i buffer1_128i = _mm_slli_si128(_mm_lddqu_si128((const __m128i*)patch1), overlapElements);
234 
235  const __m128i absDifferencesLow_128i = _mm_maddubs_epi16(_mm_unpacklo_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
236  const __m128i absDifferencesHigh_128i = _mm_maddubs_epi16(_mm_unpackhi_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
237 
238  sumLow_128i = _mm_add_epi32(sumLow_128i, _mm_madd_epi16(absDifferencesLow_128i, absDifferencesLow_128i));
239  sumHigh_128i = _mm_add_epi32(sumHigh_128i, _mm_madd_epi16(absDifferencesHigh_128i, absDifferencesHigh_128i));
240  }
241  else
242  {
243  const __m128i buffer0_128i = _mm_srli_si128(_mm_lddqu_si128((const __m128i*)(patch0 - overlapElements)), overlapElements); // loading 16 elements, but shifting `overlapElements` zeros to the right
244  const __m128i buffer1_128i = _mm_srli_si128(_mm_lddqu_si128((const __m128i*)(patch1 - overlapElements)), overlapElements);
245 
246  const __m128i absDifferencesLow_128i = _mm_maddubs_epi16(_mm_unpacklo_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
247  const __m128i absDifferencesHigh_128i = _mm_maddubs_epi16(_mm_unpackhi_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
248 
249  sumLow_128i = _mm_add_epi32(sumLow_128i, _mm_madd_epi16(absDifferencesLow_128i, absDifferencesLow_128i));
250  sumHigh_128i = _mm_add_epi32(sumHigh_128i, _mm_madd_epi16(absDifferencesHigh_128i, absDifferencesHigh_128i));
251  }
252 
253  patch0 += remainingAfterBlocks16;
254  patch1 += remainingAfterBlocks16;
255  }
256 
257  if constexpr (partialBlock8)
258  {
259  constexpr unsigned int overlapElements = partialBlock8 ? 8u - remainingAfterBlocks16 : 0u;
260 
261  static_assert(overlapElements < 8u, "Invalid value!");
262 
263  if (y < tPatchSize - 1u)
264  {
265  const __m128i buffer0_128i = _mm_slli_si128(_mm_loadl_epi64((const __m128i*)patch0), overlapElements + 8); // loading 8 elements, but shifting `overlapElements` zeros to the left
266  const __m128i buffer1_128i = _mm_slli_si128(_mm_loadl_epi64((const __m128i*)patch1), overlapElements + 8);
267 
268  const __m128i absDifferencesHigh_128i = _mm_maddubs_epi16(_mm_unpackhi_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
269 
270  sumLow_128i = _mm_add_epi32(sumLow_128i, _mm_madd_epi16(absDifferencesHigh_128i, absDifferencesHigh_128i));
271  }
272  else
273  {
274  const __m128i buffer0_128i = _mm_srli_si128(_mm_loadl_epi64((const __m128i*)(patch0 - overlapElements)), overlapElements); // loading 8 elements, but shifting `overlapElements` zeros to the right
275  const __m128i buffer1_128i = _mm_srli_si128(_mm_loadl_epi64((const __m128i*)(patch1 - overlapElements)), overlapElements);
276 
277  const __m128i absDifferencesLow_128i = _mm_maddubs_epi16(_mm_unpacklo_epi8(buffer0_128i, buffer1_128i), constant_signs_m128i);
278 
279  sumLow_128i = _mm_add_epi32(sumLow_128i, _mm_madd_epi16(absDifferencesLow_128i, absDifferencesLow_128i));
280  }
281 
282  patch0 += remainingAfterBlocks16;
283  patch1 += remainingAfterBlocks16;
284  }
285 
286  if constexpr (blocks1 != 0u)
287  {
288  for (unsigned int n = 0u; n < blocks1; ++n)
289  {
290  sumIndividual += sqrDistance(patch0[n], patch1[n]);
291  }
292 
293  patch0 += blocks1;
294  patch1 += blocks1;
295  }
296 
297  patch0 += patch0StrideElements - patchWidthElements;
298  patch1 += patch1StrideElements - patchWidthElements;
299  }
300 
301  const __m128i sum_128i = _mm_add_epi32(sumLow_128i, sumHigh_128i);
302 
303  return SSE::sum_u32_4(sum_128i) + sumIndividual;
304 }
305 
306 template <unsigned int tChannels, unsigned int tPatchSize>
307 inline uint32_t SumSquareDifferencesSSE::patchBuffer8BitPerChannel(const uint8_t* patch0, const uint8_t* buffer1, const unsigned int patch0StrideElements)
308 {
309  return patch8BitPerChannel<tChannels, tPatchSize>(patch0, buffer1, patch0StrideElements, tChannels * tPatchSize);
310 }
311 
312 }
313 
314 }
315 
316 #endif // OCEAN_HARDWARE_SSE_VERSION >= 41
317 
318 #endif // META_OCEAN_CV_SUM_SQUARE_DIFFERENCES_SSE_H
static void prefetchT0(const void *const data)
Prefetches a block of temporal memory into all cache levels.
Definition: SSE.h:1255
static OCEAN_FORCE_INLINE unsigned int sum_u32_4(const __m128i &value)
Adds the four (all four) individual 32 bit unsigned integer values of a m128i value and returns the r...
Definition: SSE.h:1322
This class implements function to calculate sum square differences using SSE instructions.
Definition: SumSquareDifferencesSSE.h:30
static uint32_t patchBuffer8BitPerChannel(const uint8_t *patch0, const uint8_t *buffer1, const unsigned int patch0StrideElements)
Returns the sum of square differences between an image patch and a buffer.
Definition: SumSquareDifferencesSSE.h:307
static uint32_t buffer8BitPerChannel(const uint8_t *buffer0, const uint8_t *buffer1)
Returns the sum of square differences between two memory buffers.
Definition: SumSquareDifferencesSSE.h:70
static uint32_t patch8BitPerChannel(const uint8_t *patch0, const uint8_t *patch1, const unsigned int patch0StrideElements, const unsigned int patch1StrideElements)
Returns the sum of square differences between two patches within an image.
Definition: SumSquareDifferencesSSE.h:157
unsigned int sqrDistance(const char first, const char second)
Returns the square distance between two values.
Definition: base/Utilities.h:1089
The namespace covering the entire Ocean framework.
Definition: Accessor.h:15