Grok  10.0.3
vqsort-inl.h
Go to the documentation of this file.
1 // Copyright 2021 Google LLC
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 
16 // Normal include guard for target-independent parts
17 #ifndef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_INL_H_
18 #define HIGHWAY_HWY_CONTRIB_SORT_VQSORT_INL_H_
19 
20 // Makes it harder for adversaries to predict our sampling locations, at the
21 // cost of 1-2% increased runtime.
22 #ifndef VQSORT_SECURE_RNG
23 #define VQSORT_SECURE_RNG 0
24 #endif
25 
26 #if VQSORT_SECURE_RNG
27 #include "third_party/absl/random/random.h"
28 #endif
29 
30 #include <string.h> // memcpy
31 
32 #include "hwy/cache_control.h" // Prefetch
33 #include "hwy/contrib/sort/vqsort.h" // Fill24Bytes
34 
35 #if HWY_IS_MSAN
36 #include <sanitizer/msan_interface.h>
37 #endif
38 
39 #endif // HIGHWAY_HWY_CONTRIB_SORT_VQSORT_INL_H_
40 
41 // Per-target
42 #if defined(HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE) == \
43  defined(HWY_TARGET_TOGGLE)
44 #ifdef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE
45 #undef HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE
46 #else
47 #define HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE
48 #endif
49 
52 #include "hwy/highway.h"
53 
55 namespace hwy {
56 namespace HWY_NAMESPACE {
57 namespace detail {
58 
60 
61 // ------------------------------ HeapSort
62 
63 template <class Traits, typename T>
64 void SiftDown(Traits st, T* HWY_RESTRICT lanes, const size_t num_lanes,
65  size_t start) {
66  constexpr size_t N1 = st.LanesPerKey();
67  const FixedTag<T, N1> d;
68 
69  while (start < num_lanes) {
70  const size_t left = 2 * start + N1;
71  const size_t right = 2 * start + 2 * N1;
72  if (left >= num_lanes) break;
73  size_t idx_larger = start;
74  const auto key_j = st.SetKey(d, lanes + start);
75  if (AllTrue(d, st.Compare(d, key_j, st.SetKey(d, lanes + left)))) {
76  idx_larger = left;
77  }
78  if (right < num_lanes &&
79  AllTrue(d, st.Compare(d, st.SetKey(d, lanes + idx_larger),
80  st.SetKey(d, lanes + right)))) {
81  idx_larger = right;
82  }
83  if (idx_larger == start) break;
84  st.Swap(lanes + start, lanes + idx_larger);
85  start = idx_larger;
86  }
87 }
88 
89 // Heapsort: O(1) space, O(N*logN) worst-case comparisons.
90 // Based on LLVM sanitizer_common.h, licensed under Apache-2.0.
91 template <class Traits, typename T>
92 void HeapSort(Traits st, T* HWY_RESTRICT lanes, const size_t num_lanes) {
93  constexpr size_t N1 = st.LanesPerKey();
94 
95  if (num_lanes < 2 * N1) return;
96 
97  // Build heap.
98  for (size_t i = ((num_lanes - N1) / N1 / 2) * N1; i != (size_t)-N1; i -= N1) {
99  SiftDown(st, lanes, num_lanes, i);
100  }
101 
102  for (size_t i = num_lanes - N1; i != 0; i -= N1) {
103  // Swap root with last
104  st.Swap(lanes + 0, lanes + i);
105 
106  // Sift down the new root.
107  SiftDown(st, lanes, i, 0);
108  }
109 }
110 
111 #if VQSORT_ENABLED || HWY_IDE
112 
113 // ------------------------------ BaseCase
114 
115 // Sorts `keys` within the range [0, num) via sorting network.
116 template <class D, class Traits, typename T>
117 HWY_NOINLINE void BaseCase(D d, Traits st, T* HWY_RESTRICT keys,
118  T* HWY_RESTRICT keys_end, size_t num,
119  T* HWY_RESTRICT buf) {
120  const size_t N = Lanes(d);
121  using V = decltype(Zero(d));
122 
123  // _Nonzero32 requires num - 1 != 0.
124  if (HWY_UNLIKELY(num <= 1)) return;
125 
126  // Reshape into a matrix with kMaxRows rows, and columns limited by the
127  // 1D `num`, which is upper-bounded by the vector width (see BaseCaseNum).
128  const size_t num_pow2 = size_t{1}
130  static_cast<uint32_t>(num - 1)));
131  HWY_DASSERT(num <= num_pow2 && num_pow2 <= Constants::BaseCaseNum(N));
132  const size_t cols =
133  HWY_MAX(st.LanesPerKey(), num_pow2 >> Constants::kMaxRowsLog2);
134  HWY_DASSERT(cols <= N);
135 
136  // We can avoid padding and load/store directly to `keys` after checking the
137  // original input array has enough space. Except at the right border, it's OK
138  // to sort more than the current sub-array. Even if we sort across a previous
139  // partition point, we know that keys will not migrate across it. However, we
140  // must use the maximum size of the sorting network, because the StoreU of its
141  // last vector would otherwise write invalid data starting at kMaxRows * cols.
142  const size_t N_sn = Lanes(CappedTag<T, Constants::kMaxCols>());
143  if (HWY_LIKELY(keys + N_sn * Constants::kMaxRows <= keys_end)) {
144  SortingNetwork(st, keys, N_sn);
145  return;
146  }
147 
148  // Copy `keys` to `buf`.
149  size_t i;
150  for (i = 0; i + N <= num; i += N) {
151  Store(LoadU(d, keys + i), d, buf + i);
152  }
153  SafeCopyN(num - i, d, keys + i, buf + i);
154  i = num;
155 
156  // Fill with padding - last in sort order, not copied to keys.
157  const V kPadding = st.LastValue(d);
158  // Initialize an extra vector because SortingNetwork loads full vectors,
159  // which may exceed cols*kMaxRows.
160  for (; i < (cols * Constants::kMaxRows + N); i += N) {
161  StoreU(kPadding, d, buf + i);
162  }
163 
164  SortingNetwork(st, buf, cols);
165 
166  for (i = 0; i + N <= num; i += N) {
167  StoreU(Load(d, buf + i), d, keys + i);
168  }
169  SafeCopyN(num - i, d, buf + i, keys + i);
170 }
171 
172 // ------------------------------ Partition
173 
174 // Consumes from `left` until a multiple of kUnroll*N remains.
175 // Temporarily stores the right side into `buf`, then moves behind `right`.
176 template <class D, class Traits, class T>
177 HWY_NOINLINE void PartitionToMultipleOfUnroll(D d, Traits st,
178  T* HWY_RESTRICT keys,
179  size_t& left, size_t& right,
180  const Vec<D> pivot,
181  T* HWY_RESTRICT buf) {
182  constexpr size_t kUnroll = Constants::kPartitionUnroll;
183  const size_t N = Lanes(d);
184  size_t readL = left;
185  size_t bufR = 0;
186  const size_t num = right - left;
187  // Partition requires both a multiple of kUnroll*N and at least
188  // 2*kUnroll*N for the initial loads. If less, consume all here.
189  const size_t num_rem =
190  (num < 2 * kUnroll * N) ? num : (num & (kUnroll * N - 1));
191  size_t i = 0;
192  for (; i + N <= num_rem; i += N) {
193  const Vec<D> vL = LoadU(d, keys + readL);
194  readL += N;
195 
196  const auto comp = st.Compare(d, pivot, vL);
197  left += CompressBlendedStore(vL, Not(comp), d, keys + left);
198  bufR += CompressStore(vL, comp, d, buf + bufR);
199  }
200  // Last iteration: only use valid lanes.
201  if (HWY_LIKELY(i != num_rem)) {
202  const auto mask = FirstN(d, num_rem - i);
203  const Vec<D> vL = LoadU(d, keys + readL);
204 
205  const auto comp = st.Compare(d, pivot, vL);
206  left += CompressBlendedStore(vL, AndNot(comp, mask), d, keys + left);
207  bufR += CompressStore(vL, And(comp, mask), d, buf + bufR);
208  }
209 
210  // MSAN seems not to understand CompressStore. buf[0, bufR) are valid.
211 #if HWY_IS_MSAN
212  __msan_unpoison(buf, bufR * sizeof(T));
213 #endif
214 
215  // Everything we loaded was put into buf, or behind the new `left`, after
216  // which there is space for bufR items. First move items from `right` to
217  // `left` to free up space, then copy `buf` into the vacated `right`.
218  // A loop with masked loads from `buf` is insufficient - we would also need to
219  // mask from `right`. Combining a loop with memcpy for the remainders is
220  // slower than just memcpy, so we use that for simplicity.
221  right -= bufR;
222  memcpy(keys + left, keys + right, bufR * sizeof(T));
223  memcpy(keys + right, buf, bufR * sizeof(T));
224 }
225 
226 template <class D, class Traits, typename T>
227 HWY_INLINE void StoreLeftRight(D d, Traits st, const Vec<D> v,
228  const Vec<D> pivot, T* HWY_RESTRICT keys,
229  size_t& writeL, size_t& writeR) {
230  const size_t N = Lanes(d);
231 
232  const auto comp = st.Compare(d, pivot, v);
233 
235  (HWY_MAX_BYTES == 16 && st.Is128())) {
236  // Non-native Compress (e.g. AVX2): we are able to partition a vector using
237  // a single Compress+two StoreU instead of two Compress[Blended]Store. The
238  // latter are more expensive. Because we store entire vectors, the contents
239  // between the updated writeL and writeR are ignored and will be overwritten
240  // by subsequent calls. This works because writeL and writeR are at least
241  // two vectors apart.
242  const auto lr = st.CompressKeys(v, comp);
243  const size_t num_right = CountTrue(d, comp);
244  const size_t num_left = N - num_right;
245  StoreU(lr, d, keys + writeL);
246  writeL += num_left;
247  // Now write the right-side elements (if any), such that the previous writeR
248  // is one past the end of the newly written right elements, then advance.
249  StoreU(lr, d, keys + writeR - N);
250  writeR -= num_right;
251  } else {
252  // Native Compress[Store] (e.g. AVX3), which only keep the left or right
253  // side, not both, hence we require two calls.
254  const size_t num_left = CompressStore(v, Not(comp), d, keys + writeL);
255  writeL += num_left;
256 
257  writeR -= (N - num_left);
258  (void)CompressBlendedStore(v, comp, d, keys + writeR);
259  }
260 }
261 
262 template <class D, class Traits, typename T>
263 HWY_INLINE void StoreLeftRight4(D d, Traits st, const Vec<D> v0,
264  const Vec<D> v1, const Vec<D> v2,
265  const Vec<D> v3, const Vec<D> pivot,
266  T* HWY_RESTRICT keys, size_t& writeL,
267  size_t& writeR) {
268  StoreLeftRight(d, st, v0, pivot, keys, writeL, writeR);
269  StoreLeftRight(d, st, v1, pivot, keys, writeL, writeR);
270  StoreLeftRight(d, st, v2, pivot, keys, writeL, writeR);
271  StoreLeftRight(d, st, v3, pivot, keys, writeL, writeR);
272 }
273 
274 // Moves "<= pivot" keys to the front, and others to the back. pivot is
275 // broadcasted. Time-critical!
276 //
277 // Aligned loads do not seem to be worthwhile (not bottlenecked by load ports).
278 template <class D, class Traits, typename T>
279 HWY_NOINLINE size_t Partition(D d, Traits st, T* HWY_RESTRICT keys, size_t left,
280  size_t right, const Vec<D> pivot,
281  T* HWY_RESTRICT buf) {
282  using V = decltype(Zero(d));
283  const size_t N = Lanes(d);
284 
285  // StoreLeftRight will CompressBlendedStore ending at `writeR`. Unless all
286  // lanes happen to be in the right-side partition, this will overrun `keys`,
287  // which triggers asan errors. Avoid by special-casing the last vector.
288  HWY_DASSERT(right - left > 2 * N); // ensured by HandleSpecialCases
289  right -= N;
290  const size_t last = right;
291  const V vlast = LoadU(d, keys + last);
292 
293  PartitionToMultipleOfUnroll(d, st, keys, left, right, pivot, buf);
294  constexpr size_t kUnroll = Constants::kPartitionUnroll;
295 
296  // Invariant: [left, writeL) and [writeR, right) are already partitioned.
297  size_t writeL = left;
298  size_t writeR = right;
299 
300  const size_t num = right - left;
301  // Cannot load if there were fewer than 2 * kUnroll * N.
302  if (HWY_LIKELY(num != 0)) {
303  HWY_DASSERT(num >= 2 * kUnroll * N);
304  HWY_DASSERT((num & (kUnroll * N - 1)) == 0);
305 
306  // Make space for writing in-place by reading from left and right.
307  const V vL0 = LoadU(d, keys + left + 0 * N);
308  const V vL1 = LoadU(d, keys + left + 1 * N);
309  const V vL2 = LoadU(d, keys + left + 2 * N);
310  const V vL3 = LoadU(d, keys + left + 3 * N);
311  left += kUnroll * N;
312  right -= kUnroll * N;
313  const V vR0 = LoadU(d, keys + right + 0 * N);
314  const V vR1 = LoadU(d, keys + right + 1 * N);
315  const V vR2 = LoadU(d, keys + right + 2 * N);
316  const V vR3 = LoadU(d, keys + right + 3 * N);
317 
318  // The left/right updates may consume all inputs, so check before the loop.
319  while (left != right) {
320  V v0, v1, v2, v3;
321 
322  // Free up capacity for writing by loading from the side that has less.
323  // Data-dependent but branching is faster than forcing branch-free.
324  const size_t capacityL = left - writeL;
325  const size_t capacityR = writeR - right;
326  HWY_DASSERT(capacityL <= num && capacityR <= num); // >= 0
327  if (capacityR < capacityL) {
328  right -= kUnroll * N;
329  v0 = LoadU(d, keys + right + 0 * N);
330  v1 = LoadU(d, keys + right + 1 * N);
331  v2 = LoadU(d, keys + right + 2 * N);
332  v3 = LoadU(d, keys + right + 3 * N);
333  hwy::Prefetch(keys + right - 3 * kUnroll * N);
334  } else {
335  v0 = LoadU(d, keys + left + 0 * N);
336  v1 = LoadU(d, keys + left + 1 * N);
337  v2 = LoadU(d, keys + left + 2 * N);
338  v3 = LoadU(d, keys + left + 3 * N);
339  left += kUnroll * N;
340  hwy::Prefetch(keys + left + 3 * kUnroll * N);
341  }
342 
343  StoreLeftRight4(d, st, v0, v1, v2, v3, pivot, keys, writeL, writeR);
344  }
345 
346  // Now finish writing the initial left/right to the middle.
347  StoreLeftRight4(d, st, vL0, vL1, vL2, vL3, pivot, keys, writeL, writeR);
348  StoreLeftRight4(d, st, vR0, vR1, vR2, vR3, pivot, keys, writeL, writeR);
349  }
350 
351  // We have partitioned [left, right) such that writeL is the boundary.
352  HWY_DASSERT(writeL == writeR);
353  // Make space for inserting vlast: move up to N of the first right-side keys
354  // into the unused space starting at last. If we have fewer, ensure they are
355  // the last items in that vector by subtracting from the *load* address,
356  // which is safe because we have at least two vectors (checked above).
357  const size_t totalR = last - writeL;
358  const size_t startR = totalR < N ? writeL + totalR - N : writeL;
359  StoreU(LoadU(d, keys + startR), d, keys + last);
360 
361  // Partition vlast: write L, then R, into the single-vector gap at writeL.
362  const auto comp = st.Compare(d, pivot, vlast);
363  writeL += CompressBlendedStore(vlast, Not(comp), d, keys + writeL);
364  (void)CompressBlendedStore(vlast, comp, d, keys + writeL);
365 
366  return writeL;
367 }
368 
369 // ------------------------------ Pivot
370 
371 template <class Traits, class V>
372 HWY_INLINE V MedianOf3(Traits st, V v0, V v1, V v2) {
373  const DFromV<V> d;
374  // Slightly faster for 128-bit, apparently because not serially dependent.
375  if (st.Is128()) {
376  // Median = XOR-sum 'minus' the first and last. Calling First twice is
377  // slightly faster than Compare + 2 IfThenElse or even IfThenElse + XOR.
378  const auto sum = Xor(Xor(v0, v1), v2);
379  const auto first = st.First(d, st.First(d, v0, v1), v2);
380  const auto last = st.Last(d, st.Last(d, v0, v1), v2);
381  return Xor(Xor(sum, first), last);
382  }
383  st.Sort2(d, v0, v2);
384  v1 = st.Last(d, v0, v1);
385  v1 = st.First(d, v1, v2);
386  return v1;
387 }
388 
389 // Replaces triplets with their median and recurses until less than 3 keys
390 // remain. Ignores leftover values (non-whole triplets)!
391 template <class D, class Traits, typename T>
392 Vec<D> RecursiveMedianOf3(D d, Traits st, T* HWY_RESTRICT keys, size_t num,
393  T* HWY_RESTRICT buf) {
394  const size_t N = Lanes(d);
395  constexpr size_t N1 = st.LanesPerKey();
396 
397  if (num < 3 * N1) return st.SetKey(d, keys);
398 
399  size_t read = 0;
400  size_t written = 0;
401 
402  // Triplets of vectors
403  for (; read + 3 * N <= num; read += 3 * N) {
404  const auto v0 = Load(d, keys + read + 0 * N);
405  const auto v1 = Load(d, keys + read + 1 * N);
406  const auto v2 = Load(d, keys + read + 2 * N);
407  Store(MedianOf3(st, v0, v1, v2), d, buf + written);
408  written += N;
409  }
410 
411  // Triplets of keys
412  for (; read + 3 * N1 <= num; read += 3 * N1) {
413  const auto v0 = st.SetKey(d, keys + read + 0 * N1);
414  const auto v1 = st.SetKey(d, keys + read + 1 * N1);
415  const auto v2 = st.SetKey(d, keys + read + 2 * N1);
416  StoreU(MedianOf3(st, v0, v1, v2), d, buf + written);
417  written += N1;
418  }
419 
420  // Tail recursion; swap buffers
421  return RecursiveMedianOf3(d, st, buf, written, keys);
422 }
423 
424 #if VQSORT_SECURE_RNG
425 using Generator = absl::BitGen;
426 #else
427 // Based on https://github.com/numpy/numpy/issues/16313#issuecomment-641897028
428 #pragma pack(push, 1)
429 class Generator {
430  public:
431  Generator(const void* heap, size_t num) {
432  Sorter::Fill24Bytes(heap, num, &a_);
433  k_ = 1; // stream index: must be odd
434  }
435 
436  explicit Generator(uint64_t seed) {
437  a_ = b_ = w_ = seed;
438  k_ = 1;
439  }
440 
441  uint64_t operator()() {
442  const uint64_t b = b_;
443  w_ += k_;
444  const uint64_t next = a_ ^ w_;
445  a_ = (b + (b << 3)) ^ (b >> 11);
446  const uint64_t rot = (b << 24) | (b >> 40);
447  b_ = rot + next;
448  return next;
449  }
450 
451  private:
452  uint64_t a_;
453  uint64_t b_;
454  uint64_t w_;
455  uint64_t k_; // increment
456 };
457 #pragma pack(pop)
458 
459 #endif // !VQSORT_SECURE_RNG
460 
461 // Returns slightly biased random index of a chunk in [0, num_chunks).
462 // See https://www.pcg-random.org/posts/bounded-rands.html.
463 HWY_INLINE size_t RandomChunkIndex(const uint32_t num_chunks, uint32_t bits) {
464  const uint64_t chunk_index = (static_cast<uint64_t>(bits) * num_chunks) >> 32;
465  HWY_DASSERT(chunk_index < num_chunks);
466  return static_cast<size_t>(chunk_index);
467 }
468 
469 template <class D, class Traits, typename T>
470 HWY_NOINLINE Vec<D> ChoosePivot(D d, Traits st, T* HWY_RESTRICT keys,
471  const size_t begin, const size_t end,
472  T* HWY_RESTRICT buf, Generator& rng) {
473  using V = decltype(Zero(d));
474  const size_t N = Lanes(d);
475 
476  // Power of two
477  const size_t lanes_per_chunk = Constants::LanesPerChunk(sizeof(T), N);
478 
479  keys += begin;
480  size_t num = end - begin;
481 
482  // Align start of keys to chunks. We always have at least 2 chunks because the
483  // base case would have handled anything up to 16 vectors, i.e. >= 4 chunks.
484  HWY_DASSERT(num >= 2 * lanes_per_chunk);
485  const size_t misalign =
486  (reinterpret_cast<uintptr_t>(keys) / sizeof(T)) & (lanes_per_chunk - 1);
487  if (misalign != 0) {
488  const size_t consume = lanes_per_chunk - misalign;
489  keys += consume;
490  num -= consume;
491  }
492 
493  // Generate enough random bits for 9 uint32
494  uint64_t* bits64 = reinterpret_cast<uint64_t*>(buf);
495  for (size_t i = 0; i < 5; ++i) {
496  bits64[i] = rng();
497  }
498  const uint32_t* bits = reinterpret_cast<const uint32_t*>(buf);
499 
500  const uint32_t lpc32 = static_cast<uint32_t>(lanes_per_chunk);
501  // Avoid division
502  const size_t log2_lpc = Num0BitsBelowLS1Bit_Nonzero32(lpc32);
503  const size_t num_chunks64 = num >> log2_lpc;
504  // Clamp to uint32 for RandomChunkIndex
505  const uint32_t num_chunks =
506  static_cast<uint32_t>(HWY_MIN(num_chunks64, 0xFFFFFFFFull));
507 
508  const size_t offset0 = RandomChunkIndex(num_chunks, bits[0]) << log2_lpc;
509  const size_t offset1 = RandomChunkIndex(num_chunks, bits[1]) << log2_lpc;
510  const size_t offset2 = RandomChunkIndex(num_chunks, bits[2]) << log2_lpc;
511  const size_t offset3 = RandomChunkIndex(num_chunks, bits[3]) << log2_lpc;
512  const size_t offset4 = RandomChunkIndex(num_chunks, bits[4]) << log2_lpc;
513  const size_t offset5 = RandomChunkIndex(num_chunks, bits[5]) << log2_lpc;
514  const size_t offset6 = RandomChunkIndex(num_chunks, bits[6]) << log2_lpc;
515  const size_t offset7 = RandomChunkIndex(num_chunks, bits[7]) << log2_lpc;
516  const size_t offset8 = RandomChunkIndex(num_chunks, bits[8]) << log2_lpc;
517  for (size_t i = 0; i < lanes_per_chunk; i += N) {
518  const V v0 = Load(d, keys + offset0 + i);
519  const V v1 = Load(d, keys + offset1 + i);
520  const V v2 = Load(d, keys + offset2 + i);
521  const V medians0 = MedianOf3(st, v0, v1, v2);
522  Store(medians0, d, buf + i);
523 
524  const V v3 = Load(d, keys + offset3 + i);
525  const V v4 = Load(d, keys + offset4 + i);
526  const V v5 = Load(d, keys + offset5 + i);
527  const V medians1 = MedianOf3(st, v3, v4, v5);
528  Store(medians1, d, buf + i + lanes_per_chunk);
529 
530  const V v6 = Load(d, keys + offset6 + i);
531  const V v7 = Load(d, keys + offset7 + i);
532  const V v8 = Load(d, keys + offset8 + i);
533  const V medians2 = MedianOf3(st, v6, v7, v8);
534  Store(medians2, d, buf + i + lanes_per_chunk * 2);
535  }
536 
537  return RecursiveMedianOf3(d, st, buf, 3 * lanes_per_chunk,
538  buf + 3 * lanes_per_chunk);
539 }
540 
541 // Compute exact min/max to detect all-equal partitions. Only called after a
542 // degenerate Partition (none in the right partition).
543 template <class D, class Traits, typename T>
544 HWY_NOINLINE void ScanMinMax(D d, Traits st, const T* HWY_RESTRICT keys,
545  size_t num, T* HWY_RESTRICT buf, Vec<D>& first,
546  Vec<D>& last) {
547  const size_t N = Lanes(d);
548 
549  first = st.LastValue(d);
550  last = st.FirstValue(d);
551 
552  size_t i = 0;
553  for (; i + N <= num; i += N) {
554  const Vec<D> v = LoadU(d, keys + i);
555  first = st.First(d, v, first);
556  last = st.Last(d, v, last);
557  }
558  if (HWY_LIKELY(i != num)) {
559  HWY_DASSERT(num >= N); // See HandleSpecialCases
560  const Vec<D> v = LoadU(d, keys + num - N);
561  first = st.First(d, v, first);
562  last = st.Last(d, v, last);
563  }
564 
565  first = st.FirstOfLanes(d, first, buf);
566  last = st.LastOfLanes(d, last, buf);
567 }
568 
569 template <class D, class Traits, typename T>
570 void Recurse(D d, Traits st, T* HWY_RESTRICT keys, T* HWY_RESTRICT keys_end,
571  const size_t begin, const size_t end, const Vec<D> pivot,
572  T* HWY_RESTRICT buf, Generator& rng, size_t remaining_levels) {
573  HWY_DASSERT(begin + 1 < end);
574  const size_t num = end - begin; // >= 2
575 
576  // Too many degenerate partitions. This is extremely unlikely to happen
577  // because we select pivots from large (though still O(1)) samples.
578  if (HWY_UNLIKELY(remaining_levels == 0)) {
579  HeapSort(st, keys + begin, num); // Slow but N*logN.
580  return;
581  }
582 
583  const ptrdiff_t base_case_num =
584  static_cast<ptrdiff_t>(Constants::BaseCaseNum(Lanes(d)));
585  const size_t bound = Partition(d, st, keys, begin, end, pivot, buf);
586 
587  const ptrdiff_t num_left =
588  static_cast<ptrdiff_t>(bound) - static_cast<ptrdiff_t>(begin);
589  const ptrdiff_t num_right =
590  static_cast<ptrdiff_t>(end) - static_cast<ptrdiff_t>(bound);
591 
592  // Check for degenerate partitions (i.e. Partition did not move any keys):
593  if (HWY_UNLIKELY(num_right == 0)) {
594  // Because the pivot is one of the keys, it must have been equal to the
595  // first or last key in sort order. Scan for the actual min/max:
596  // passing the current pivot as the new bound is insufficient because one of
597  // the partitions might not actually include that key.
598  Vec<D> first, last;
599  ScanMinMax(d, st, keys + begin, num, buf, first, last);
600  if (AllTrue(d, Eq(first, last))) return;
601 
602  // Separate recursion to make sure that we don't pick `last` as the
603  // pivot - that would again lead to a degenerate partition.
604  Recurse(d, st, keys, keys_end, begin, end, first, buf, rng,
605  remaining_levels - 1);
606  return;
607  }
608 
609  if (HWY_UNLIKELY(num_left <= base_case_num)) {
610  BaseCase(d, st, keys + begin, keys_end, static_cast<size_t>(num_left), buf);
611  } else {
612  const Vec<D> next_pivot = ChoosePivot(d, st, keys, begin, bound, buf, rng);
613  Recurse(d, st, keys, keys_end, begin, bound, next_pivot, buf, rng,
614  remaining_levels - 1);
615  }
616  if (HWY_UNLIKELY(num_right <= base_case_num)) {
617  BaseCase(d, st, keys + bound, keys_end, static_cast<size_t>(num_right),
618  buf);
619  } else {
620  const Vec<D> next_pivot = ChoosePivot(d, st, keys, bound, end, buf, rng);
621  Recurse(d, st, keys, keys_end, bound, end, next_pivot, buf, rng,
622  remaining_levels - 1);
623  }
624 }
625 
626 // Returns true if sorting is finished.
627 template <class D, class Traits, typename T>
628 bool HandleSpecialCases(D d, Traits st, T* HWY_RESTRICT keys, size_t num,
629  T* HWY_RESTRICT buf) {
630  const size_t N = Lanes(d);
631  const size_t base_case_num = Constants::BaseCaseNum(N);
632 
633  // 128-bit keys require vectors with at least two u64 lanes, which is always
634  // the case unless `d` requests partial vectors (e.g. fraction = 1/2) AND the
635  // hardware vector width is less than 128bit / fraction.
636  const bool partial_128 = !IsFull(d) && N < 2 && st.Is128();
637  // Partition assumes its input is at least two vectors. If vectors are huge,
638  // base_case_num may actually be smaller. If so, which is only possible on
639  // RVV, pass a capped or partial d (LMUL < 1). Use HWY_MAX_BYTES instead of
640  // HWY_LANES to account for the largest possible LMUL.
641  constexpr bool kPotentiallyHuge =
643  const bool huge_vec = kPotentiallyHuge && (2 * N > base_case_num);
644  if (partial_128 || huge_vec) {
645  // PERFORMANCE WARNING: falling back to HeapSort.
646  HeapSort(st, keys, num);
647  return true;
648  }
649 
650  // Small arrays: use sorting network, no need for other checks.
651  if (HWY_UNLIKELY(num <= base_case_num)) {
652  BaseCase(d, st, keys, keys + num, num, buf);
653  return true;
654  }
655 
656  // We could also check for already sorted/reverse/equal, but that's probably
657  // counterproductive if vqsort is used as a base case.
658 
659  return false; // not finished sorting
660 }
661 
662 #endif // VQSORT_ENABLED
663 } // namespace detail
664 
665 // Sorts `keys[0..num-1]` according to the order defined by `st.Compare`.
666 // In-place i.e. O(1) additional storage. Worst-case N*logN comparisons.
667 // Non-stable (order of equal keys may change), except for the common case where
668 // the upper bits of T are the key, and the lower bits are a sequential or at
669 // least unique ID.
670 // There is no upper limit on `num`, but note that pivots may be chosen by
671 // sampling only from the first 256 GiB.
672 //
673 // `d` is typically SortTag<T> (chooses between full and partial vectors).
674 // `st` is SharedTraits<Traits*<Order*>>. This abstraction layer bridges
675 // differences in sort order and single-lane vs 128-bit keys.
676 template <class D, class Traits, typename T>
677 void Sort(D d, Traits st, T* HWY_RESTRICT keys, size_t num,
678  T* HWY_RESTRICT buf) {
679 #if VQSORT_ENABLED || HWY_IDE
680 #if !HWY_HAVE_SCALABLE
681  // On targets with fixed-size vectors, avoid _using_ the allocated memory.
682  // We avoid (potentially expensive for small input sizes) allocations on
683  // platforms where no targets are scalable. For 512-bit vectors, this fits on
684  // the stack (several KiB).
685  HWY_ALIGN T storage[SortConstants::BufNum<T>(HWY_LANES(T))] = {};
686  static_assert(sizeof(storage) <= 8192, "Unexpectedly large, check size");
687  buf = storage;
688 #endif // !HWY_HAVE_SCALABLE
689 
690  if (detail::HandleSpecialCases(d, st, keys, num, buf)) return;
691 
692 #if HWY_MAX_BYTES > 64
693  // sorting_networks-inl and traits assume no more than 512 bit vectors.
694  if (Lanes(d) > 64 / sizeof(T)) {
695  return Sort(CappedTag<T, 64 / sizeof(T)>(), st, keys, num, buf);
696  }
697 #endif // HWY_MAX_BYTES > 64
698 
699  // Pulled out of the recursion so we can special-case degenerate partitions.
700  detail::Generator rng(keys, num);
701  const Vec<D> pivot = detail::ChoosePivot(d, st, keys, 0, num, buf, rng);
702 
703  // Introspection: switch to worst-case N*logN heapsort after this many.
704  const size_t max_levels = 2 * hwy::CeilLog2(num) + 4;
705 
706  detail::Recurse(d, st, keys, keys + num, 0, num, pivot, buf, rng, max_levels);
707 #else
708  (void)d;
709  (void)buf;
710  // PERFORMANCE WARNING: vqsort is not enabled for the non-SIMD target
711  return detail::HeapSort(st, keys, num);
712 #endif // VQSORT_ENABLED
713 }
714 
715 // NOLINTNEXTLINE(google-readability-namespace-comments)
716 } // namespace HWY_NAMESPACE
717 } // namespace hwy
719 
720 #endif // HIGHWAY_HWY_CONTRIB_SORT_VQSORT_TOGGLE
#define HWY_MAX(a, b)
Definition: base.h:126
#define HWY_RESTRICT
Definition: base.h:61
#define HWY_NOINLINE
Definition: base.h:63
#define HWY_MIN(a, b)
Definition: base.h:125
#define HWY_INLINE
Definition: base.h:62
#define HWY_DASSERT(condition)
Definition: base.h:191
#define HWY_LIKELY(expr)
Definition: base.h:66
#define HWY_UNLIKELY(expr)
Definition: base.h:67
static void Fill24Bytes(const void *seed_heap, size_t seed_num, void *bytes)
void SiftDown(Traits st, T *HWY_RESTRICT lanes, const size_t num_lanes, size_t start)
Definition: vqsort-inl.h:64
HWY_INLINE bool AllTrue(hwy::SizeTag< 1 >, const Mask128< T > m)
Definition: wasm_128-inl.h:3578
void HeapSort(Traits st, T *HWY_RESTRICT lanes, const size_t num_lanes)
Definition: vqsort-inl.h:92
HWY_INLINE Mask128< T, N > Xor(hwy::SizeTag< 1 >, const Mask128< T, N > a, const Mask128< T, N > b)
Definition: x86_128-inl.h:929
HWY_INLINE size_t CountTrue(hwy::SizeTag< 1 >, const Mask128< T > mask)
Definition: arm_neon-inl.h:5207
HWY_INLINE Mask128< T, N > And(hwy::SizeTag< 1 >, const Mask128< T, N > a, const Mask128< T, N > b)
Definition: x86_128-inl.h:818
HWY_INLINE Mask512< T > Not(hwy::SizeTag< 1 >, const Mask512< T > m)
Definition: x86_512-inl.h:1574
constexpr bool IsFull(Simd< T, N, kPow2 >)
Definition: ops/shared-inl.h:103
HWY_INLINE Mask128< T, N > AndNot(hwy::SizeTag< 1 >, const Mask128< T, N > a, const Mask128< T, N > b)
Definition: x86_128-inl.h:855
d
Definition: rvv-inl.h:1742
HWY_API auto Eq(V a, V b) -> decltype(a==b)
Definition: arm_neon-inl.h:6301
HWY_API Mask128< T, N > FirstN(const Simd< T, N, 0 > d, size_t num)
Definition: arm_neon-inl.h:2409
constexpr HWY_API size_t Lanes(Simd< T, N, kPow2 >)
Definition: arm_sve-inl.h:236
typename detail::CappedTagChecker< T, kLimit >::type CappedTag
Definition: ops/shared-inl.h:172
HWY_API Vec128< T, N > Load(Simd< T, N, 0 > d, const T *HWY_RESTRICT p)
Definition: arm_neon-inl.h:2706
void Sort(D d, Traits st, T *HWY_RESTRICT keys, size_t num, T *HWY_RESTRICT buf)
Definition: vqsort-inl.h:677
HWY_API void StoreU(const Vec128< uint8_t > v, Full128< uint8_t >, uint8_t *HWY_RESTRICT unaligned)
Definition: arm_neon-inl.h:2725
HWY_API Vec128< uint8_t > LoadU(Full128< uint8_t >, const uint8_t *HWY_RESTRICT unaligned)
Definition: arm_neon-inl.h:2544
HWY_API Vec128< T, N > Zero(Simd< T, N, 0 > d)
Definition: arm_neon-inl.h:1011
HWY_API size_t CompressBlendedStore(Vec128< T, N > v, Mask128< T, N > m, Simd< T, N, 0 > d, T *HWY_RESTRICT unaligned)
Definition: arm_neon-inl.h:5846
typename detail::FixedTagChecker< T, kNumLanes >::type FixedTag
Definition: ops/shared-inl.h:188
HWY_API void SafeCopyN(const size_t num, D d, const T *HWY_RESTRICT from, T *HWY_RESTRICT to)
Definition: generic_ops-inl.h:103
N
Definition: rvv-inl.h:1742
HWY_API size_t CompressStore(Vec128< T, N > v, const Mask128< T, N > mask, Simd< T, N, 0 > d, T *HWY_RESTRICT unaligned)
Definition: arm_neon-inl.h:5837
HWY_API void Store(Vec128< T, N > v, Simd< T, N, 0 > d, T *HWY_RESTRICT aligned)
Definition: arm_neon-inl.h:2882
const vfloat64m1_t v
Definition: rvv-inl.h:1742
decltype(Zero(D())) Vec
Definition: generic_ops-inl.h:32
Definition: aligned_allocator.h:27
HWY_INLINE HWY_ATTR_CACHE void Prefetch(const T *p)
Definition: cache_control.h:77
HWY_API size_t Num0BitsAboveMS1Bit_Nonzero32(const uint32_t x)
Definition: base.h:709
HWY_API size_t Num0BitsBelowLS1Bit_Nonzero32(const uint32_t x)
Definition: base.h:674
constexpr size_t CeilLog2(TI x)
Definition: base.h:777
#define HWY_MAX_BYTES
Definition: set_macros-inl.h:84
#define HWY_LANES(T)
Definition: set_macros-inl.h:85
#define HWY_ALIGN
Definition: set_macros-inl.h:83
#define HWY_NAMESPACE
Definition: set_macros-inl.h:82
Definition: arm_neon-inl.h:5318
Definition: contrib/sort/shared-inl.h:28
static constexpr size_t kMaxCols
Definition: contrib/sort/shared-inl.h:34
static constexpr size_t kMaxRows
Definition: contrib/sort/shared-inl.h:43
static constexpr HWY_INLINE size_t BaseCaseNum(size_t N)
Definition: contrib/sort/shared-inl.h:45
static constexpr size_t kMaxRowsLog2
Definition: contrib/sort/shared-inl.h:42
static constexpr size_t kPartitionUnroll
Definition: contrib/sort/shared-inl.h:54
static constexpr HWY_INLINE size_t LanesPerChunk(size_t sizeof_t, size_t N)
Definition: contrib/sort/shared-inl.h:68
HWY_AFTER_NAMESPACE()
HWY_BEFORE_NAMESPACE()