nlib
SimdAlgorithm.h
[詳解]
1 
2 #pragma once
3 #ifndef INCLUDE_NN_NLIB_SIMD_SIMDALGORITHM_H_
4 #define INCLUDE_NN_NLIB_SIMD_SIMDALGORITHM_H_
5 
6 #include <algorithm>
7 #include <functional>
8 #include "nn/nlib/simd/SimdInt.h"
10 
11 #if defined(NLIB_SIMD)
12 
13 NLIB_NAMESPACE_BEGIN
14 namespace simd {
15 namespace detail {
16 
17 NLIB_VIS_PUBLIC void sortUint32A16(uint32_t* buf, uint32_t* src, size_t N) NLIB_NOEXCEPT;
18 NLIB_VIS_PUBLIC void sortUint32A16_1(uint32_t* buf, uint32_t* src, size_t N) NLIB_NOEXCEPT;
19 
20 template <size_t NumElem, bool flag>
21 struct MergeSortHelper {
22  static NLIB_ALWAYS_INLINE void Sort(uint32_t* data) NLIB_NOEXCEPT {
23  NLIB_ALIGNAS(16) uint32_t tmp[NumElem];
24  sortUint32A16(tmp, data, NumElem);
25  }
26 };
27 
28 template <size_t NumElem>
29 struct MergeSortHelper<NumElem, false> {
30  static NLIB_ALWAYS_INLINE void Sort(uint32_t* data) NLIB_NOEXCEPT {
31  NLIB_ALIGNAS(16) uint32_t tmp[NumElem];
32  sortUint32A16_1(tmp, data, NumElem);
33  }
34 };
35 
36 } // namespace detail
37 
38 template <size_t NumElem>
39 // data must be 16 bytes aligned, NumElem must be multiple of 16
40 NLIB_ALWAYS_INLINE void MergeSortUint32A16(uint32_t* data) NLIB_NOEXCEPT {
41  NLIB_STATIC_ASSERT((NumElem > 0) && (NumElem % 16 == 0));
42  NLIB_ASSERT(!(reinterpret_cast<uintptr_t>(data) & 15));
43  detail::MergeSortHelper<NumElem, ((NumElem & (NumElem - 1)) == 0)>::Sort(data);
44 }
45 
47 
48 template <class PRED>
49 // i128 PRED(i128 c) { mask(c.u8[i]) if matched }
50 const void* nlib_memchr_pred(const void* s, PRED pred, size_t n) NLIB_NOEXCEPT {
51  if (!s) return NULL;
52 
53  const unsigned char* p = reinterpret_cast<const unsigned char*>(s);
54  i128 a1, a2;
55  i128 cmp1, cmp2;
56  uint32_t mask;
57  if (reinterpret_cast<uintptr_t>(p) & 15) {
58  size_t r = reinterpret_cast<uintptr_t>(p) & 15;
59  a1 = I128::LoadA16(p - r);
60  cmp1 = pred(a1);
61  mask = I128::MoveMask8(cmp1);
62  mask = mask >> r;
63  size_t rr = 16 - r;
64  if (n < rr) {
65  mask &= (1 << n) - 1;
66  if (mask) return p + nlib_ctz(mask);
67  return NULL;
68  }
69  if (mask) return p + nlib_ctz(mask);
70  p += rr;
71  n -= rr;
72  }
73  if (n >= 16) {
74  if ((reinterpret_cast<uintptr_t>(p) & 32)) {
75  a1 = I128::LoadA16(p);
76  cmp1 = pred(a1);
77  if (!I128::IsZero(cmp1)) {
78  mask = I128::MoveMask8(cmp1);
79  return p + nlib_ctz(mask);
80  }
81  p += 16;
82  n -= 16;
83  }
84  while (n >= 32) {
85  a1 = I128::LoadA16(p);
86  a2 = I128::LoadA16(p + 16);
87  cmp1 = I128::SetZero();
88  cmp2 = I128::SetZero();
89  cmp1 = pred(a1);
90  cmp2 = pred(a2);
91  if (!I128::IsZero(I128::Or(cmp1, cmp2))) {
92  mask = I128::MoveMask8(cmp1) | (I128::MoveMask8(cmp2) << 16);
93  return p + nlib_ctz(mask);
94  }
95  p += 32;
96  n -= 32;
97  }
98  if (n >= 16) {
99  a1 = I128::LoadA16(p);
100  cmp1 = pred(a1);
101  if (!I128::IsZero(cmp1)) {
102  mask = I128::MoveMask8(cmp1);
103  return p + nlib_ctz(mask);
104  }
105  p += 16;
106  n -= 16;
107  }
108  }
109  if (n > 0) {
110  a1 = I128::LoadA16(p);
111  cmp1 = pred(a1);
112  mask = I128::MoveMask8(cmp1);
113  mask &= (1 << n) - 1;
114  if (mask) return p + nlib_ctz(mask);
115  }
116  return NULL;
117 }
118 
119 template <class PRED>
120 // i128 PRED(i128 c) { mask(c.u8[i]) if matched }
121 const void* nlib_memchr_pred_not(const void* s, PRED pred, size_t n) NLIB_NOEXCEPT {
122  if (!s) return NULL;
123 
124  const unsigned char* p = reinterpret_cast<const unsigned char*>(s);
125  i128 a1, a2;
126  i128 cmp1, cmp2;
127  uint32_t mask;
128  if (reinterpret_cast<uintptr_t>(p) & 15) {
129  size_t r = reinterpret_cast<uintptr_t>(p) & 15;
130  a1 = I128::LoadA16(p - r);
131  cmp1 = I128::Not(pred(a1));
132  mask = I128::MoveMask8(cmp1);
133  mask = mask >> r;
134  size_t rr = 16 - r;
135  if (n < rr) {
136  mask &= (1 << n) - 1;
137  if (mask) return p + nlib_ctz(mask);
138  return NULL;
139  }
140  if (mask) return p + nlib_ctz(mask);
141  p += rr;
142  n -= rr;
143  }
144  if (n >= 16) {
145  if ((reinterpret_cast<uintptr_t>(p) & 32)) {
146  a1 = I128::LoadA16(p);
147  cmp1 = pred(a1);
148  if (!I128::IsFull(cmp1)) {
149  mask = I128::MoveMask8(I128::Not(cmp1));
150  return p + nlib_ctz(mask);
151  }
152  p += 16;
153  n -= 16;
154  }
155  while (n >= 32) {
156  a1 = I128::LoadA16(p);
157  a2 = I128::LoadA16(p + 16);
158  cmp1 = pred(a1);
159  cmp2 = pred(a2);
160  if (!I128::IsFull(I128::And(cmp1, cmp2))) {
161  mask = I128::MoveMask8(I128::Not(cmp1)) | (I128::MoveMask8(I128::Not(cmp2)) << 16);
162  return p + nlib_ctz(mask);
163  }
164  p += 32;
165  n -= 32;
166  }
167  if (n >= 16) {
168  a1 = I128::LoadA16(p);
169  cmp1 = pred(a1);
170  if (!I128::IsFull(cmp1)) {
171  mask = I128::MoveMask8(I128::Not(cmp1));
172  return p + nlib_ctz(mask);
173  }
174  p += 16;
175  n -= 16;
176  }
177  }
178  if (n > 0) {
179  a1 = I128::LoadA16(p);
180  cmp1 = I128::Not(pred(a1));
181  mask = I128::MoveMask8(cmp1);
182  mask &= (1 << n) - 1;
183  if (mask) return p + nlib_ctz(mask);
184  }
185  return NULL;
186 }
187 
188 // mask if c.u8[i] is a-zA-Z
189 inline i128 __vectorcall IsAlpha(i128 c) NLIB_NOEXCEPT {
190  i128 result = I128::CmpLtInt8(c, I128::SetValue('{', each_int8));
191  result = I128::And(result, I128::CmpGtInt8(c, I128::SetValue('`', each_int8)));
192 
193  i128 tmp = I128::CmpLtInt8(c, I128::SetValue('[', each_int8));
194  tmp = I128::And(tmp, I128::CmpGtInt8(c, I128::SetValue('@', each_int8)));
195  result = I128::Or(result, tmp);
196 
197  return result;
198 }
199 
200 // mask if c.u8[i] is 0-9
201 inline i128 __vectorcall IsDigit(i128 c) NLIB_NOEXCEPT {
202  i128 result = I128::CmpLtInt8(c, I128::SetValue(':', each_int8));
203  result = I128::And(result, I128::CmpGtInt8(c, I128::SetValue('/', each_int8)));
204 
205  return result;
206 }
207 
208 // mask if c.u8[i] is a-zA-Z0-9
209 inline i128 __vectorcall IsAlnum(i128 c) NLIB_NOEXCEPT { return I128::Or(IsDigit(c), IsAlpha(c)); }
210 
211 // mask if c.u8[i] is space, CR, LF, or tab
212 inline i128 __vectorcall IsSpace(i128 c) NLIB_NOEXCEPT {
213  i128 result = I128::CmpEq8(c, I128::SetValue(' ', each_int8));
214  result = I128::Or(result, I128::CmpEq8(c, I128::SetValue('\r', each_int8)));
215  result = I128::Or(result, I128::CmpEq8(c, I128::SetValue('\n', each_int8)));
216  result = I128::Or(result, I128::CmpEq8(c, I128::SetValue('\t', each_int8)));
217 
218  return result;
219 }
220 
221 // mask if c.u8[i] is 0-9A-Fa-f
222 inline i128 __vectorcall IsXdigit(i128 c) NLIB_NOEXCEPT {
223  i128 tmp;
224  i128 result = I128::CmpLtInt8(c, I128::SetValue(':', each_int8));
225  result = I128::And(result, I128::CmpGtInt8(c, I128::SetValue('/', each_int8)));
226 
227  tmp = I128::CmpLtInt8(c, I128::SetValue('G', each_int8));
228  tmp = I128::And(tmp, I128::CmpGtInt8(c, I128::SetValue('@', each_int8)));
229  result = I128::Or(result, tmp);
230 
231  tmp = I128::CmpLtInt8(c, I128::SetValue('g', each_int8));
232  tmp = I128::And(tmp, I128::CmpGtInt8(c, I128::SetValue('`', each_int8)));
233  result = I128::Or(result, tmp);
234 
235  return result;
236 }
237 
238 namespace detail {
239 template<class T, class Compare>
240 class KeyIdxSortLess {
241  public:
242  KeyIdxSortLess(T* const* src, uint32_t mask,
243  Compare comp) NLIB_NOEXCEPT : src_(src), mask_(mask), comp_(comp) {
244  }
245  NLIB_ALWAYS_INLINE bool operator()(uint32_t lhs, uint32_t rhs) const NLIB_NOEXCEPT {
246  return comp_(*src_[lhs & mask_], *src_[rhs & mask_]);
247  }
248 
249  private:
250  T* const* src_;
251  uint32_t mask_;
252  Compare comp_;
253 };
254 
255 NLIB_ALIGNAS(16) extern NLIB_VIS_PUBLIC const uint32_t keyidxsort_0123[4];
256 } // namespace detail
257 
258 template<class T, class Compare>
259 errno_t KeyIdxSortN(T** dst, T* const* src, size_t n, Compare comp) NLIB_NOEXCEPT {
260  // T must have a member function such that
261  // uint32_t GetKey32() const;
262 #ifdef NLIB_64BIT
263  if (n == 0 || n > 0x7FFFFFFFU) return EINVAL;
264 #else
265  if (n == 0 || n > RSIZE_MAX) return EINVAL;
266 #endif
267  errno_t e;
268  size_t n16 = (n + 15) & ~15;
269 
271  e = mem_.Init(n16 * sizeof(uint32_t), 16); // NOLINT
272  if (NLIB_UNLIKELY(e != 0)) return e;
273  uint32_t* mem = reinterpret_cast<uint32_t*>(mem_.Get());
274  int idx_width = 32 - nlib_clz(static_cast<uint32_t>(n16));
275  uint32_t idx_mask = (1U << idx_width) - 1;
276  uint32_t idx_mask_inv = ~idx_mask;
277 
278  uint32_t max_key = 0U;
279  uint32_t min_key = 0xFFFFFFFFU;
280  for (size_t i = 0; i < n; ++i) {
281  uint32_t key = src[i]->GetKey32();
282  mem[i] = key;
283  if (min_key > key)
284  min_key = key;
285  else if (max_key < key)
286  max_key = key;
287  }
288  // several number of the bits could be omitted because they are the same.
289  int left_shift = nlib_clz(min_key ^ max_key);
290 
291  // embed indices
292 #if 1
293  i128 vecmask = I128::SetValue(idx_mask_inv, each_uint32);
294  i128 vecidx0 = I128::LoadA16(&detail::keyidxsort_0123[0]);
295  i128 vec_i = I128::SetZero();
296  i128 four = I128::SetValue(4, each_uint32);
297  i128 vecidx4 = I128::Add32(vecidx0, four);
298  i128 vecidx8 = I128::Add32(vecidx4, four);
299  i128 vecidx12 = I128::Add32(vecidx8, four);
300  i128 d16 = I128::Mult32(four, four);
301 
302  for (size_t i = 0; i < n16; i += 16) {
303  i128 m0 = I128::LoadA16(&mem[i]);
304  i128 m1 = I128::LoadA16(&mem[i + 4]);
305  i128 m2 = I128::LoadA16(&mem[i + 8]);
306  i128 m3 = I128::LoadA16(&mem[i + 12]);
307 
308  m0 = I128::And(I128::ShiftLeftLogical32(m0, left_shift), vecmask);
309  m1 = I128::And(I128::ShiftLeftLogical32(m1, left_shift), vecmask);
310  m2 = I128::And(I128::ShiftLeftLogical32(m2, left_shift), vecmask);
311  m3 = I128::And(I128::ShiftLeftLogical32(m3, left_shift), vecmask);
312  m0 = I128::Or(m0, I128::Add32(vec_i, vecidx0));
313  m1 = I128::Or(m1, I128::Add32(vec_i, vecidx4));
314  m2 = I128::Or(m2, I128::Add32(vec_i, vecidx8));
315  m3 = I128::Or(m3, I128::Add32(vec_i, vecidx12));
316 
317  I128::StoreA16(&mem[i], m0);
318  I128::StoreA16(&mem[i + 4], m1);
319  I128::StoreA16(&mem[i + 8], m2);
320  I128::StoreA16(&mem[i + 12], m3);
321  vec_i = I128::Add16(vec_i, d16);
322  }
323 #else
324  for (size_t i = 0; i < n; ++i) {
325  mem[i] = ((mem[i] << left_shift) & idx_mask_inv) | static_cast<uint32_t>(i);
326  }
327 #endif
328  for (size_t i = n; i < n16; ++i) {
329  // It is ok because MergeSortUint32A16 is a stable sort algorithm.
330  mem[i] = idx_mask_inv | static_cast<uint32_t>(i);
331  }
332 
333  e = MergeSortUint32A16(mem, n16);
334  if (NLIB_UNLIKELY(e != 0)) return e;
335 
336  // they must be sorted if the keys are the same.
337  uint32_t prev_key = mem[0] & idx_mask_inv;
338  size_t i = 1;
339  detail::KeyIdxSortLess<T, Compare> myless(src, idx_mask, comp);
340  while (i < n) {
341  uint32_t key = mem[i] & idx_mask_inv;
342  if (NLIB_UNLIKELY(key == prev_key)) {
343  size_t from = i - 1;
344  do {
345  ++i;
346  if (i == n) break;
347  key = mem[i] & idx_mask_inv;
348  } while (key == prev_key);
349  // if the sort algorithm is not stable, mem[i -> n16] might be swapped.
350  std::stable_sort(&mem[from], &mem[i], myless);
351  }
352  prev_key = key;
353  ++i;
354  }
355 
356  for (i = 0; i < n; ++i) {
357  dst[i] = src[mem[i] & idx_mask];
358  }
359  return 0;
360 }
361 
362 template<class T>
363 NLIB_ALWAYS_INLINE errno_t KeyIdxSortN(T** dst, T* const* src, size_t n) NLIB_NOEXCEPT {
364  return KeyIdxSortN(dst, src, n, std::less<T>());
365 }
366 
367 template<class T, class Compare>
368 inline errno_t KeyIdxSort(T** first, T** last, Compare comp) NLIB_NOEXCEPT {
369  size_t n = last - first;
370  T** tmp = reinterpret_cast<T**>(nlib_malloc(n * sizeof(*first)));
371  if (NLIB_UNLIKELY(!tmp)) return ENOMEM;
372  errno_t e = KeyIdxSortN(tmp, first, n, comp);
373  if (NLIB_LIKELY(e == 0)) {
374  nlib_memcpy(first, n * sizeof(*first), tmp, n * sizeof(*tmp));
375  }
376  nlib_free(tmp);
377  return e;
378 }
379 
380 template<class T>
381 NLIB_ALWAYS_INLINE errno_t KeyIdxSort(T** first, T** last) NLIB_NOEXCEPT {
382  return KeyIdxSort(first, last, std::less<T>());
383 }
384 
385 } // namespace simd
386 NLIB_NAMESPACE_END
387 
388 #endif
389 
390 #endif // INCLUDE_NN_NLIB_SIMD_SIMDALGORITHM_H_
i128 IsXdigit(i128 c) noexcept
c 内の16進数の文字をマスクします。
#define NLIB_ALWAYS_INLINE
コンパイラに関数をインライン展開するように強く示します。
Definition: Platform_unix.h:69
errno_t Init(size_t size, size_t align) noexcept
メモリの割り当てを行います。
i128 IsDigit(i128 c) noexcept
c 内の&#39;0&#39;-&#39;9&#39;の文字をマスクします。
errno_t KeyIdxSort(T **first, T **last) noexcept
KeyIdxSort(first, last, std::less<T>())を実行します。
整数のSIMD演算を行うためのクラスや関数が実装されています。
#define NLIB_UNLIKELY(x)
条件xが偽になる傾向が高いことをコンパイラに示します。
Definition: Platform_unix.h:72
アラインされたメモリを得るためのクラスです。
#define RSIZE_MAX
size_tの最大値よりいくらか小さい値が定義されています。
Definition: Platform.h:541
i128 IsAlpha(i128 c) noexcept
c 内のアルファベットをマスクします。
i128 IsSpace(i128 c) noexcept
c 内の空白文字(0x20, 0x09, 0x0A, 0x0D)をマスクします。
#define NLIB_VIS_PUBLIC
関数やクラス等のシンボルをライブラリの外部に公開します。
Definition: Platform_unix.h:61
i128 IsAlnum(i128 c) noexcept
c 内のアルファベットか&#39;0&#39;-&#39;9&#39;の文字をマスクします。
NLIB_CHECK_RESULT void * nlib_malloc(size_t size)
C標準関数のmalloc()を呼び出すweak関数です。nlibはこの関数を経由してmalloc()を呼び出します。 ...
nlib_i128_t i128
nlib_i128_tがtypedefされています。
Definition: SimdInt.h:63
#define NLIB_LIKELY(x)
条件xが真になる傾向が高いことをコンパイラに示します。
Definition: Platform_unix.h:71
static errno_t nlib_memcpy(void *s1, size_t s1max, const void *s2, size_t n)
N1078のmemcpy_sに相当する実装です。
Definition: Platform.h:3170
constexpr const each_uint32_tag each_uint32
each_uint32_tag型の定数オブジェクトで、32bitの符号なし整数を示すためのタグです。
Definition: SimdInt.h:40
#define NLIB_NOEXCEPT
環境に合わせてnoexcept 又は同等の定義がされます。
Definition: Config.h:86
void * Get() noexcept
割り当てられた領域へのポインタを返します。
errno_t KeyIdxSortN(T **dst, T *const *src, size_t n) noexcept
KeyIdxSortN(dst, src, n, std::less<T>())を実行します。
アラインされたメモリを得たい場合に利用します。
#define NLIB_ALIGNAS(x)
alignas(x)又は同等の定義がされます。
Definition: Config.h:221
constexpr const each_int8_tag each_int8
each_int8_tag型の定数オブジェクトで、8bitの符号付き整数を示すためのタグです。
Definition: SimdInt.h:34
const void * nlib_memchr_pred_not(const void *s, PRED pred, size_t n) noexcept
バイト列内のバイトの検査をSIMD命令を使って行うための関数テンプレートです。
#define NLIB_STATIC_ASSERT(exp)
静的アサートが定義されます。利用可能であればstatic_assertを利用します。
Definition: Config.h:136
void nlib_free(void *ptr)
C標準関数のfree()を呼び出すweak関数です。nlibはこの関数を経由してfree()を呼び出します。 ...
errno_t MergeSortUint32A16(uint32_t *data, size_t n) noexcept
SIMDを利用して32bit符号なし整数の並びを昇順にマージソートします。
const void * nlib_memchr_pred(const void *s, PRED pred, size_t n) noexcept
バイト列内のバイトの検査をSIMD命令を使って行うための関数テンプレートです。
Definition: SimdAlgorithm.h:50
int errno_t
intのtypedefで、戻り値としてPOSIXのエラー値を返すことを示します。
Definition: NMalloc.h:24