Ocean
SumAbsoluteDifferencesSSE.h
Go to the documentation of this file.
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #ifndef META_OCEAN_CV_SUM_ABSOLUTE_DIFFERENCES_SSE_H
9 #define META_OCEAN_CV_SUM_ABSOLUTE_DIFFERENCES_SSE_H
10 
11 #include "ocean/cv/CV.h"
12 
13 #if defined(OCEAN_HARDWARE_SSE_VERSION) && OCEAN_HARDWARE_SSE_VERSION >= 41
14 
15 #include "ocean/cv/SSE.h"
16 
17 namespace Ocean
18 {
19 
20 namespace CV
21 {
22 
23 /**
24  * This class implements functions calculation the sum of absolute differences.
25  * @ingroup cv
26  */
28 {
29  public:
30 
31  /**
32  * Returns the sum of absolute differences between two memory buffers.
33  * @param buffer0 The first memory buffer, must be valid
34  * @param buffer1 The second memory buffer, must be valid
35  * @return The resulting sum of square differences
36  * @tparam tSize The size of the buffers in elements, with range [1, infinity)
37  */
38  template <unsigned int tSize>
39  static inline uint32_t buffer8BitPerChannel(const uint8_t* buffer0, const uint8_t* buffer1);
40 
41  /**
42  * Returns the sum of absolute differences between two patches within an image.
43  * @param patch0 The top left start position of the first image patch, must be valid
44  * @param patch1 The top left start position of the second image patch, must be valid
45  * @param patch0StrideElements The number of elements between two rows for the first patch, in elements, with range [tChannels, tPatchSize, infinity)
46  * @param patch1StrideElements The number of elements between two rows for the second patch, in elements, with range [tChannels, tPatchSize, infinity)
47  * @return The resulting sum of square differences
48  * @tparam tChannels The number of channels for the given frames, with range [1, infinity)
49  * @tparam tPatchSize The size of the square patch (the edge length) in pixel, with range [1, infinity), must be odd
50  */
51  template <unsigned int tChannels, unsigned int tPatchSize>
52  static inline uint32_t patch8BitPerChannel(const uint8_t* patch0, const uint8_t* patch1, const unsigned int patch0StrideElements, const unsigned int patch1StrideElements);
53 
54  /**
55  * Returns the sum of absolute differences between an image patch and a buffer.
56  * @param patch0 The top left start position of the image patch, must be valid
57  * @param buffer1 The memory buffer, must be valid
58  * @param patch0StrideElements The number of elements between two rows for the image patch, in elements, with range [tChannels, tPatchSize, infinity)
59  * @return The resulting sum of square differences
60  * @tparam tChannels The number of channels for the given frames, with range [1, infinity)
61  * @tparam tPatchSize The size of the square patch (the edge length) in pixel, with range [1, infinity), must be odd
62  */
63  template <unsigned int tChannels, unsigned int tPatchSize>
64  static inline uint32_t patchBuffer8BitPerChannel(const uint8_t* patch0, const uint8_t* buffer1, const unsigned int patch0StrideElements);
65 };
66 
67 template <unsigned int tSize>
68 inline uint32_t SumAbsoluteDifferencesSSE::buffer8BitPerChannel(const uint8_t* buffer0, const uint8_t* buffer1)
69 {
70  static_assert(tSize >= 1u, "Invalid buffer size!");
71 
72  __m128i sum_128i = _mm_setzero_si128();
73 
74  // first, we handle blocks with 16 elements
75 
76  constexpr unsigned int blocks16 = tSize / 16u;
77 
78  for (unsigned int n = 0u; n < blocks16; ++n)
79  {
80  const __m128i buffer0_128i = _mm_lddqu_si128((const __m128i*)buffer0);
81  const __m128i buffer1_128i = _mm_lddqu_si128((const __m128i*)buffer1);
82 
83  sum_128i = _mm_add_epi32(sum_128i, _mm_sad_epu8(buffer0_128i, buffer1_128i));
84 
85  buffer0 += 16;
86  buffer1 += 16;
87  }
88 
89  if constexpr (blocks16 >= 1u && (tSize % 16u) >= 10u)
90  {
91  constexpr unsigned int remainingElements = tSize % 16u;
92  constexpr unsigned int overlappingElements = 16u - remainingElements;
93 
94  const __m128i buffer0_128i = _mm_srli_si128(_mm_lddqu_si128((const __m128i*)(buffer0 - overlappingElements)), overlappingElements);
95  const __m128i buffer1_128i = _mm_srli_si128(_mm_lddqu_si128((const __m128i*)(buffer1 - overlappingElements)), overlappingElements);
96 
97  sum_128i = _mm_add_epi32(sum_128i, _mm_sad_epu8(buffer0_128i, buffer1_128i));
98 
99  return SSE::sum_u32_first_third(sum_128i);
100  }
101  else
102  {
103  // we may handle at most one block with 8 elements
104 
105  constexpr unsigned int blocks8 = (tSize % 16u) / 8u;
106  static_assert(blocks8 <= 1u, "Invalid number of blocks!");
107 
108  if constexpr (blocks8 == 1u)
109  {
110  const __m128i buffer0_128i = _mm_loadl_epi64((const __m128i*)buffer0); // load for unaligned 64 bit memory
111  const __m128i buffer1_128i = _mm_loadl_epi64((const __m128i*)buffer1);
112 
113  sum_128i = _mm_add_epi32(sum_128i, _mm_sad_epu8(buffer0_128i, buffer1_128i));
114 
115  buffer0 += 8;
116  buffer1 += 8;
117  }
118 
119  constexpr unsigned int remainingElements = tSize - blocks16 * 16u - blocks8 * 8u;
120  static_assert(remainingElements < 8u, "Invalid number of remaining elements!");
121 
122  uint32_t result = SSE::sum_u32_first_third(sum_128i);
123 
124  // we apply the remaining elements (at most 7)
125 
126  for (unsigned int n = 0u; n < remainingElements; ++n)
127  {
128  result += uint32_t(abs(buffer0[n] - buffer1[n]));
129  }
130 
131  return result;
132  }
133 }
134 
135 template <unsigned int tChannels, unsigned int tPatchSize>
136 inline uint32_t SumAbsoluteDifferencesSSE::patch8BitPerChannel(const uint8_t* patch0, const uint8_t* patch1, const unsigned int patch0StrideElements, const unsigned int patch1StrideElements)
137 {
138  static_assert(tChannels >= 1u, "Invalid channel number!");
139  static_assert(tPatchSize >= 5u, "Invalid patch size!");
140 
141  ocean_assert(patch0 != nullptr && patch1 != nullptr);
142 
143  ocean_assert(patch0StrideElements >= tChannels * tPatchSize);
144  ocean_assert(patch1StrideElements >= tChannels * tPatchSize);
145 
146  constexpr unsigned int patchWidthElements = tChannels * tPatchSize;
147 
148  constexpr unsigned int blocks16 = patchWidthElements / 16u;
149  constexpr unsigned int remainingAfterBlocks16 = patchWidthElements % 16u;
150 
151  constexpr bool partialBlock16 = remainingAfterBlocks16 > 8u;
152 
153  constexpr bool fullBlock8 = !partialBlock16 && remainingAfterBlocks16 == 8u;
154 
155  constexpr bool partialBlock8 = !partialBlock16 && !fullBlock8 && remainingAfterBlocks16 >= 3u;
156 
157  constexpr unsigned int blocks1 = (!partialBlock16 && !fullBlock8 && !partialBlock8) ? remainingAfterBlocks16 : 0u;
158 
159  static_assert(blocks1 <= 2u, "Invalid block size!");
160 
161  __m128i sum_128i = _mm_setzero_si128();
162 
163  uint32_t sumIndividual = 0u;
164 
165  for (unsigned int y = 0u; y < tPatchSize; ++y)
166  {
167  SSE::prefetchT0(patch0 + patch0StrideElements);
168  SSE::prefetchT0(patch1 + patch1StrideElements);
169 
170  for (unsigned int n = 0u; n < blocks16; ++n)
171  {
172  const __m128i buffer0_128i = _mm_lddqu_si128((const __m128i*)patch0);
173  const __m128i buffer1_128i = _mm_lddqu_si128((const __m128i*)patch1);
174 
175  sum_128i = _mm_add_epi32(sum_128i, _mm_sad_epu8(buffer0_128i, buffer1_128i));
176 
177  patch0 += 16;
178  patch1 += 16;
179  }
180 
181  if constexpr (fullBlock8)
182  {
183  const __m128i buffer0_128i = _mm_loadl_epi64((const __m128i*)patch0); // load for unaligned 64 bit memory
184  const __m128i buffer1_128i = _mm_loadl_epi64((const __m128i*)patch1);
185 
186  sum_128i = _mm_add_epi32(sum_128i, _mm_sad_epu8(buffer0_128i, buffer1_128i));
187 
188  patch0 += 8;
189  patch1 += 8;
190  }
191 
192  if constexpr (partialBlock16)
193  {
194  constexpr unsigned int overlapElements = partialBlock16 ? 16u - remainingAfterBlocks16 : 0u;
195 
196  static_assert(overlapElements < 8u, "Invalid value!");
197 
198  if (y < tPatchSize - 1u)
199  {
200  const __m128i buffer0_128i = _mm_slli_si128(_mm_lddqu_si128((const __m128i*)patch0), overlapElements); // loading 16 elements, but shifting `overlapElements` zeros to the left
201  const __m128i buffer1_128i = _mm_slli_si128(_mm_lddqu_si128((const __m128i*)patch1), overlapElements);
202 
203  sum_128i = _mm_add_epi32(sum_128i, _mm_sad_epu8(buffer0_128i, buffer1_128i));
204  }
205  else
206  {
207  const __m128i buffer0_128i = _mm_srli_si128(_mm_lddqu_si128((const __m128i*)(patch0 - overlapElements)), overlapElements); // loading 16 elements, but shifting `overlapElements` zeros to the right
208  const __m128i buffer1_128i = _mm_srli_si128(_mm_lddqu_si128((const __m128i*)(patch1 - overlapElements)), overlapElements);
209 
210  sum_128i = _mm_add_epi32(sum_128i, _mm_sad_epu8(buffer0_128i, buffer1_128i));
211  }
212 
213  patch0 += remainingAfterBlocks16;
214  patch1 += remainingAfterBlocks16;
215  }
216 
217  if constexpr (partialBlock8)
218  {
219  constexpr unsigned int overlapElements = partialBlock8 ? 8u - remainingAfterBlocks16 : 0u;
220 
221  static_assert(overlapElements < 8u, "Invalid value!");
222 
223  if (y < tPatchSize - 1u)
224  {
225  const __m128i buffer0_128i = _mm_slli_si128(_mm_loadl_epi64((const __m128i*)patch0), overlapElements + 8); // loading 8 elements, but shifting `overlapElements` zeros to the left
226  const __m128i buffer1_128i = _mm_slli_si128(_mm_loadl_epi64((const __m128i*)patch1), overlapElements + 8);
227 
228  sum_128i = _mm_add_epi32(sum_128i, _mm_sad_epu8(buffer0_128i, buffer1_128i));
229  }
230  else
231  {
232  const __m128i buffer0_128i = _mm_srli_si128(_mm_loadl_epi64((const __m128i*)(patch0 - overlapElements)), overlapElements); // loading 8 elements, but shifting `overlapElements` zeros to the right
233  const __m128i buffer1_128i = _mm_srli_si128(_mm_loadl_epi64((const __m128i*)(patch1 - overlapElements)), overlapElements);
234 
235  sum_128i = _mm_add_epi32(sum_128i, _mm_sad_epu8(buffer0_128i, buffer1_128i));
236  }
237 
238  patch0 += remainingAfterBlocks16;
239  patch1 += remainingAfterBlocks16;
240  }
241 
242  if constexpr (blocks1 != 0u)
243  {
244  for (unsigned int n = 0u; n < blocks1; ++n)
245  {
246  sumIndividual += uint32_t(abs(patch0[n] - patch1[n]));
247  }
248 
249  patch0 += blocks1;
250  patch1 += blocks1;
251  }
252 
253  patch0 += patch0StrideElements - patchWidthElements;
254  patch1 += patch1StrideElements - patchWidthElements;
255  }
256 
257  return SSE::sum_u32_first_third(sum_128i) + sumIndividual;
258 }
259 
260 template <unsigned int tChannels, unsigned int tPatchSize>
261 inline uint32_t SumAbsoluteDifferencesSSE::patchBuffer8BitPerChannel(const uint8_t* patch0, const uint8_t* buffer1, const unsigned int patch0StrideElements)
262 {
263  return patch8BitPerChannel<tChannels, tPatchSize>(patch0, buffer1, patch0StrideElements, tChannels * tPatchSize);
264 }
265 
266 }
267 
268 }
269 
270 #endif // OCEAN_HARDWARE_SSE_VERSION >= 41
271 
272 #endif // META_OCEAN_CV_SUM_ABSOLUTE_DIFFERENCES_SSE_H
static unsigned int sum_u32_first_third(const __m128i &value)
Adds the first and the second 32 bit unsigned integer values of a m128i value and returns the result.
Definition: SSE.h:1340
static void prefetchT0(const void *const data)
Prefetches a block of temporal memory into all cache levels.
Definition: SSE.h:1255
This class implements functions calculation the sum of absolute differences.
Definition: SumAbsoluteDifferencesSSE.h:28
static uint32_t patch8BitPerChannel(const uint8_t *patch0, const uint8_t *patch1, const unsigned int patch0StrideElements, const unsigned int patch1StrideElements)
Returns the sum of absolute differences between two patches within an image.
Definition: SumAbsoluteDifferencesSSE.h:136
static uint32_t patchBuffer8BitPerChannel(const uint8_t *patch0, const uint8_t *buffer1, const unsigned int patch0StrideElements)
Returns the sum of absolute differences between an image patch and a buffer.
Definition: SumAbsoluteDifferencesSSE.h:261
static uint32_t buffer8BitPerChannel(const uint8_t *buffer0, const uint8_t *buffer1)
Returns the sum of absolute differences between two memory buffers.
Definition: SumAbsoluteDifferencesSSE.h:68
The namespace covering the entire Ocean framework.
Definition: Accessor.h:15