Ocean
Loading...
Searching...
No Matches
SumAbsoluteDifferencesSSE.h
Go to the documentation of this file.
1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 *
4 * This source code is licensed under the MIT license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7
8#ifndef META_OCEAN_CV_SUM_ABSOLUTE_DIFFERENCES_SSE_H
9#define META_OCEAN_CV_SUM_ABSOLUTE_DIFFERENCES_SSE_H
10
11#include "ocean/cv/CV.h"
12
13#if defined(OCEAN_HARDWARE_SSE_VERSION) && OCEAN_HARDWARE_SSE_VERSION >= 41
14
15#include "ocean/cv/SSE.h"
16
17namespace Ocean
18{
19
20namespace CV
21{
22
23/**
24 * This class implements functions calculation the sum of absolute differences.
25 * @ingroup cv
26 */
28{
29 public:
30
31 /**
32 * Returns the sum of absolute differences between two memory buffers.
33 * @param buffer0 The first memory buffer, must be valid
34 * @param buffer1 The second memory buffer, must be valid
35 * @return The resulting sum of square differences
36 * @tparam tSize The size of the buffers in elements, with range [1, infinity)
37 */
38 template <unsigned int tSize>
39 static inline uint32_t buffer8BitPerChannel(const uint8_t* buffer0, const uint8_t* buffer1);
40
41 /**
42 * Returns the sum of absolute differences between two patches within an image.
43 * @param patch0 The top left start position of the first image patch, must be valid
44 * @param patch1 The top left start position of the second image patch, must be valid
45 * @param patch0StrideElements The number of elements between two rows for the first patch, in elements, with range [tChannels, tPatchSize, infinity)
46 * @param patch1StrideElements The number of elements between two rows for the second patch, in elements, with range [tChannels, tPatchSize, infinity)
47 * @return The resulting sum of square differences
48 * @tparam tChannels The number of channels for the given frames, with range [1, infinity)
49 * @tparam tPatchSize The size of the square patch (the edge length) in pixel, with range [1, infinity), must be odd
50 */
51 template <unsigned int tChannels, unsigned int tPatchSize>
52 static inline uint32_t patch8BitPerChannel(const uint8_t* patch0, const uint8_t* patch1, const unsigned int patch0StrideElements, const unsigned int patch1StrideElements);
53
54 /**
55 * Returns the sum of absolute differences between an image patch and a buffer.
56 * @param patch0 The top left start position of the image patch, must be valid
57 * @param buffer1 The memory buffer, must be valid
58 * @param patch0StrideElements The number of elements between two rows for the image patch, in elements, with range [tChannels, tPatchSize, infinity)
59 * @return The resulting sum of square differences
60 * @tparam tChannels The number of channels for the given frames, with range [1, infinity)
61 * @tparam tPatchSize The size of the square patch (the edge length) in pixel, with range [1, infinity), must be odd
62 */
63 template <unsigned int tChannels, unsigned int tPatchSize>
64 static inline uint32_t patchBuffer8BitPerChannel(const uint8_t* patch0, const uint8_t* buffer1, const unsigned int patch0StrideElements);
65};
66
67template <unsigned int tSize>
68inline uint32_t SumAbsoluteDifferencesSSE::buffer8BitPerChannel(const uint8_t* buffer0, const uint8_t* buffer1)
69{
70 static_assert(tSize >= 1u, "Invalid buffer size!");
71
72 __m128i sum_128i = _mm_setzero_si128();
73
74 // first, we handle blocks with 16 elements
75
76 constexpr unsigned int blocks16 = tSize / 16u;
77
78 for (unsigned int n = 0u; n < blocks16; ++n)
79 {
80 const __m128i buffer0_128i = _mm_lddqu_si128((const __m128i*)buffer0);
81 const __m128i buffer1_128i = _mm_lddqu_si128((const __m128i*)buffer1);
82
83 sum_128i = _mm_add_epi32(sum_128i, _mm_sad_epu8(buffer0_128i, buffer1_128i));
84
85 buffer0 += 16;
86 buffer1 += 16;
87 }
88
89 if constexpr (blocks16 >= 1u && (tSize % 16u) >= 10u)
90 {
91 constexpr unsigned int remainingElements = tSize % 16u;
92 constexpr unsigned int overlappingElements = 16u - remainingElements;
93
94 const __m128i buffer0_128i = _mm_srli_si128(_mm_lddqu_si128((const __m128i*)(buffer0 - overlappingElements)), overlappingElements);
95 const __m128i buffer1_128i = _mm_srli_si128(_mm_lddqu_si128((const __m128i*)(buffer1 - overlappingElements)), overlappingElements);
96
97 sum_128i = _mm_add_epi32(sum_128i, _mm_sad_epu8(buffer0_128i, buffer1_128i));
98
99 return SSE::sum_u32_first_third(sum_128i);
100 }
101 else
102 {
103 // we may handle at most one block with 8 elements
104
105 constexpr unsigned int blocks8 = (tSize % 16u) / 8u;
106 static_assert(blocks8 <= 1u, "Invalid number of blocks!");
107
108 if constexpr (blocks8 == 1u)
109 {
110 const __m128i buffer0_128i = _mm_loadl_epi64((const __m128i*)buffer0); // load for unaligned 64 bit memory
111 const __m128i buffer1_128i = _mm_loadl_epi64((const __m128i*)buffer1);
112
113 sum_128i = _mm_add_epi32(sum_128i, _mm_sad_epu8(buffer0_128i, buffer1_128i));
114
115 buffer0 += 8;
116 buffer1 += 8;
117 }
118
119 constexpr unsigned int remainingElements = tSize - blocks16 * 16u - blocks8 * 8u;
120 static_assert(remainingElements < 8u, "Invalid number of remaining elements!");
121
122 uint32_t result = SSE::sum_u32_first_third(sum_128i);
123
124 // we apply the remaining elements (at most 7)
125
126 for (unsigned int n = 0u; n < remainingElements; ++n)
127 {
128 result += uint32_t(abs(buffer0[n] - buffer1[n]));
129 }
130
131 return result;
132 }
133}
134
135template <unsigned int tChannels, unsigned int tPatchSize>
136inline uint32_t SumAbsoluteDifferencesSSE::patch8BitPerChannel(const uint8_t* patch0, const uint8_t* patch1, const unsigned int patch0StrideElements, const unsigned int patch1StrideElements)
137{
138 static_assert(tChannels >= 1u, "Invalid channel number!");
139 static_assert(tPatchSize >= 5u, "Invalid patch size!");
140
141 ocean_assert(patch0 != nullptr && patch1 != nullptr);
142
143 ocean_assert(patch0StrideElements >= tChannels * tPatchSize);
144 ocean_assert(patch1StrideElements >= tChannels * tPatchSize);
145
146 constexpr unsigned int patchWidthElements = tChannels * tPatchSize;
147
148 constexpr unsigned int blocks16 = patchWidthElements / 16u;
149 constexpr unsigned int remainingAfterBlocks16 = patchWidthElements % 16u;
150
151 constexpr bool partialBlock16 = remainingAfterBlocks16 > 8u;
152
153 constexpr bool fullBlock8 = !partialBlock16 && remainingAfterBlocks16 == 8u;
154
155 constexpr bool partialBlock8 = !partialBlock16 && !fullBlock8 && remainingAfterBlocks16 >= 3u;
156
157 constexpr unsigned int blocks1 = (!partialBlock16 && !fullBlock8 && !partialBlock8) ? remainingAfterBlocks16 : 0u;
158
159 static_assert(blocks1 <= 2u, "Invalid block size!");
160
161 __m128i sum_128i = _mm_setzero_si128();
162
163 uint32_t sumIndividual = 0u;
164
165 for (unsigned int y = 0u; y < tPatchSize; ++y)
166 {
167 SSE::prefetchT0(patch0 + patch0StrideElements);
168 SSE::prefetchT0(patch1 + patch1StrideElements);
169
170 for (unsigned int n = 0u; n < blocks16; ++n)
171 {
172 const __m128i buffer0_128i = _mm_lddqu_si128((const __m128i*)patch0);
173 const __m128i buffer1_128i = _mm_lddqu_si128((const __m128i*)patch1);
174
175 sum_128i = _mm_add_epi32(sum_128i, _mm_sad_epu8(buffer0_128i, buffer1_128i));
176
177 patch0 += 16;
178 patch1 += 16;
179 }
180
181 if constexpr (fullBlock8)
182 {
183 const __m128i buffer0_128i = _mm_loadl_epi64((const __m128i*)patch0); // load for unaligned 64 bit memory
184 const __m128i buffer1_128i = _mm_loadl_epi64((const __m128i*)patch1);
185
186 sum_128i = _mm_add_epi32(sum_128i, _mm_sad_epu8(buffer0_128i, buffer1_128i));
187
188 patch0 += 8;
189 patch1 += 8;
190 }
191
192 if constexpr (partialBlock16)
193 {
194 constexpr unsigned int overlapElements = partialBlock16 ? 16u - remainingAfterBlocks16 : 0u;
195
196 static_assert(overlapElements < 8u, "Invalid value!");
197
198 if (y < tPatchSize - 1u)
199 {
200 const __m128i buffer0_128i = _mm_slli_si128(_mm_lddqu_si128((const __m128i*)patch0), overlapElements); // loading 16 elements, but shifting `overlapElements` zeros to the left
201 const __m128i buffer1_128i = _mm_slli_si128(_mm_lddqu_si128((const __m128i*)patch1), overlapElements);
202
203 sum_128i = _mm_add_epi32(sum_128i, _mm_sad_epu8(buffer0_128i, buffer1_128i));
204 }
205 else
206 {
207 const __m128i buffer0_128i = _mm_srli_si128(_mm_lddqu_si128((const __m128i*)(patch0 - overlapElements)), overlapElements); // loading 16 elements, but shifting `overlapElements` zeros to the right
208 const __m128i buffer1_128i = _mm_srli_si128(_mm_lddqu_si128((const __m128i*)(patch1 - overlapElements)), overlapElements);
209
210 sum_128i = _mm_add_epi32(sum_128i, _mm_sad_epu8(buffer0_128i, buffer1_128i));
211 }
212
213 patch0 += remainingAfterBlocks16;
214 patch1 += remainingAfterBlocks16;
215 }
216
217 if constexpr (partialBlock8)
218 {
219 constexpr unsigned int overlapElements = partialBlock8 ? 8u - remainingAfterBlocks16 : 0u;
220
221 static_assert(overlapElements < 8u, "Invalid value!");
222
223 if (y < tPatchSize - 1u)
224 {
225 const __m128i buffer0_128i = _mm_slli_si128(_mm_loadl_epi64((const __m128i*)patch0), overlapElements + 8); // loading 8 elements, but shifting `overlapElements` zeros to the left
226 const __m128i buffer1_128i = _mm_slli_si128(_mm_loadl_epi64((const __m128i*)patch1), overlapElements + 8);
227
228 sum_128i = _mm_add_epi32(sum_128i, _mm_sad_epu8(buffer0_128i, buffer1_128i));
229 }
230 else
231 {
232 const __m128i buffer0_128i = _mm_srli_si128(_mm_loadl_epi64((const __m128i*)(patch0 - overlapElements)), overlapElements); // loading 8 elements, but shifting `overlapElements` zeros to the right
233 const __m128i buffer1_128i = _mm_srli_si128(_mm_loadl_epi64((const __m128i*)(patch1 - overlapElements)), overlapElements);
234
235 sum_128i = _mm_add_epi32(sum_128i, _mm_sad_epu8(buffer0_128i, buffer1_128i));
236 }
237
238 patch0 += remainingAfterBlocks16;
239 patch1 += remainingAfterBlocks16;
240 }
241
242 if constexpr (blocks1 != 0u)
243 {
244 for (unsigned int n = 0u; n < blocks1; ++n)
245 {
246 sumIndividual += uint32_t(abs(patch0[n] - patch1[n]));
247 }
248
249 patch0 += blocks1;
250 patch1 += blocks1;
251 }
252
253 patch0 += patch0StrideElements - patchWidthElements;
254 patch1 += patch1StrideElements - patchWidthElements;
255 }
256
257 return SSE::sum_u32_first_third(sum_128i) + sumIndividual;
258}
259
260template <unsigned int tChannels, unsigned int tPatchSize>
261inline uint32_t SumAbsoluteDifferencesSSE::patchBuffer8BitPerChannel(const uint8_t* patch0, const uint8_t* buffer1, const unsigned int patch0StrideElements)
262{
263 return patch8BitPerChannel<tChannels, tPatchSize>(patch0, buffer1, patch0StrideElements, tChannels * tPatchSize);
264}
265
266}
267
268}
269
270#endif // OCEAN_HARDWARE_SSE_VERSION >= 41
271
272#endif // META_OCEAN_CV_SUM_ABSOLUTE_DIFFERENCES_SSE_H
static unsigned int sum_u32_first_third(const __m128i &value)
Adds the first and the second 32 bit unsigned integer values of a m128i value and returns the result.
Definition SSE.h:1340
static void prefetchT0(const void *const data)
Prefetches a block of temporal memory into all cache levels.
Definition SSE.h:1255
This class implements functions calculation the sum of absolute differences.
Definition SumAbsoluteDifferencesSSE.h:28
static uint32_t patch8BitPerChannel(const uint8_t *patch0, const uint8_t *patch1, const unsigned int patch0StrideElements, const unsigned int patch1StrideElements)
Returns the sum of absolute differences between two patches within an image.
Definition SumAbsoluteDifferencesSSE.h:136
static uint32_t patchBuffer8BitPerChannel(const uint8_t *patch0, const uint8_t *buffer1, const unsigned int patch0StrideElements)
Returns the sum of absolute differences between an image patch and a buffer.
Definition SumAbsoluteDifferencesSSE.h:261
static uint32_t buffer8BitPerChannel(const uint8_t *buffer0, const uint8_t *buffer1)
Returns the sum of absolute differences between two memory buffers.
Definition SumAbsoluteDifferencesSSE.h:68
The namespace covering the entire Ocean framework.
Definition Accessor.h:15