Grok  10.0.3
algo-inl.h
Go to the documentation of this file.
1 // Copyright 2021 Google LLC
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 
16 // Normal include guard for target-independent parts
17 #ifndef HIGHWAY_HWY_CONTRIB_SORT_ALGO_INL_H_
18 #define HIGHWAY_HWY_CONTRIB_SORT_ALGO_INL_H_
19 
20 #include <stdint.h>
21 #include <string.h> // memcpy
22 
23 #include <algorithm>
24 #include <cmath> // std::abs
25 #include <vector>
26 
27 #include "hwy/base.h"
29 
30 // Third-party algorithms
31 #define HAVE_AVX2SORT 0
32 #define HAVE_IPS4O 0
33 // When enabling, consider changing max_threads (required for Table 1a)
34 #define HAVE_PARALLEL_IPS4O (HAVE_IPS4O && 1)
35 #define HAVE_PDQSORT 0
36 #define HAVE_SORT512 0
37 #define HAVE_VXSORT 0
38 
39 #if HAVE_AVX2SORT
40 HWY_PUSH_ATTRIBUTES("avx2,avx")
41 #include "avx2sort.h"
43 #endif
44 #if HAVE_IPS4O || HAVE_PARALLEL_IPS4O
45 #include "third_party/ips4o/include/ips4o.hpp"
46 #include "third_party/ips4o/include/ips4o/thread_pool.hpp"
47 #endif
48 #if HAVE_PDQSORT
49 #include "third_party/boost/allowed/sort/sort.hpp"
50 #endif
51 #if HAVE_SORT512
52 #include "sort512.h"
53 #endif
54 
55 // vxsort is difficult to compile for multiple targets because it also uses
56 // .cpp files, and we'd also have to #undef its include guards. Instead, compile
57 // only for AVX2 or AVX3 depending on this macro.
58 #define VXSORT_AVX3 1
59 #if HAVE_VXSORT
60 // inlined from vxsort_targets_enable_avx512 (must close before end of header)
61 #ifdef __GNUC__
62 #ifdef __clang__
63 #if VXSORT_AVX3
64 #pragma clang attribute push(__attribute__((target("avx512f,avx512dq"))), \
65  apply_to = any(function))
66 #else
67 #pragma clang attribute push(__attribute__((target("avx2"))), \
68  apply_to = any(function))
69 #endif // VXSORT_AVX3
70 
71 #else
72 #pragma GCC push_options
73 #if VXSORT_AVX3
74 #pragma GCC target("avx512f,avx512dq")
75 #else
76 #pragma GCC target("avx2")
77 #endif // VXSORT_AVX3
78 #endif
79 #endif
80 
81 #if VXSORT_AVX3
82 #include "vxsort/machine_traits.avx512.h"
83 #else
84 #include "vxsort/machine_traits.avx2.h"
85 #endif // VXSORT_AVX3
86 #include "vxsort/vxsort.h"
87 #ifdef __GNUC__
88 #ifdef __clang__
89 #pragma clang attribute pop
90 #else
91 #pragma GCC pop_options
92 #endif
93 #endif
94 #endif // HAVE_VXSORT
95 
96 namespace hwy {
97 
99 
100 static inline std::vector<Dist> AllDist() {
101  return {/*Dist::kUniform8, Dist::kUniform16,*/ Dist::kUniform32};
102 }
103 
104 static inline const char* DistName(Dist dist) {
105  switch (dist) {
106  case Dist::kUniform8:
107  return "uniform8";
108  case Dist::kUniform16:
109  return "uniform16";
110  case Dist::kUniform32:
111  return "uniform32";
112  }
113  return "unreachable";
114 }
115 
116 template <typename T>
117 class InputStats {
118  public:
119  void Notify(T value) {
120  min_ = std::min(min_, value);
121  max_ = std::max(max_, value);
122  // Converting to integer would truncate floats, multiplying to save digits
123  // risks overflow especially when casting, so instead take the sum of the
124  // bit representations as the checksum.
125  uint64_t bits = 0;
126  static_assert(sizeof(T) <= 8, "Expected a built-in type");
127  CopyBytes<sizeof(T)>(&value, &bits);
128  sum_ += bits;
129  count_ += 1;
130  }
131 
132  bool operator==(const InputStats& other) const {
133  if (count_ != other.count_) {
134  HWY_ABORT("count %d vs %d\n", static_cast<int>(count_),
135  static_cast<int>(other.count_));
136  }
137 
138  if (min_ != other.min_ || max_ != other.max_) {
139  HWY_ABORT("minmax %f/%f vs %f/%f\n", static_cast<double>(min_),
140  static_cast<double>(max_), static_cast<double>(other.min_),
141  static_cast<double>(other.max_));
142  }
143 
144  // Sum helps detect duplicated/lost values
145  if (sum_ != other.sum_) {
146  HWY_ABORT("Sum mismatch %g %g; min %g max %g\n",
147  static_cast<double>(sum_), static_cast<double>(other.sum_),
148  static_cast<double>(min_), static_cast<double>(max_));
149  }
150 
151  return true;
152  }
153 
154  private:
155  T min_ = hwy::HighestValue<T>();
156  T max_ = hwy::LowestValue<T>();
157  uint64_t sum_ = 0;
158  size_t count_ = 0;
159 };
160 
161 enum class Algo {
162 #if HAVE_AVX2SORT
163  kSEA,
164 #endif
165 #if HAVE_IPS4O
166  kIPS4O,
167 #endif
168 #if HAVE_PARALLEL_IPS4O
169  kParallelIPS4O,
170 #endif
171 #if HAVE_PDQSORT
172  kPDQ,
173 #endif
174 #if HAVE_SORT512
175  kSort512,
176 #endif
177 #if HAVE_VXSORT
178  kVXSort,
179 #endif
180  kStd,
181  kVQSort,
182  kHeap,
183 };
184 
185 const char* AlgoName(Algo algo) {
186  switch (algo) {
187 #if HAVE_AVX2SORT
188  case Algo::kSEA:
189  return "sea";
190 #endif
191 #if HAVE_IPS4O
192  case Algo::kIPS4O:
193  return "ips4o";
194 #endif
195 #if HAVE_PARALLEL_IPS4O
196  case Algo::kParallelIPS4O:
197  return "par_ips4o";
198 #endif
199 #if HAVE_PDQSORT
200  case Algo::kPDQ:
201  return "pdq";
202 #endif
203 #if HAVE_SORT512
204  case Algo::kSort512:
205  return "sort512";
206 #endif
207 #if HAVE_VXSORT
208  case Algo::kVXSort:
209  return "vxsort";
210 #endif
211  case Algo::kStd:
212  return "std";
213  case Algo::kVQSort:
214  return "vq";
215  case Algo::kHeap:
216  return "heap";
217  }
218  return "unreachable";
219 }
220 
221 } // namespace hwy
222 #endif // HIGHWAY_HWY_CONTRIB_SORT_ALGO_INL_H_
223 
224 // Per-target
225 #if defined(HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE) == \
226  defined(HWY_TARGET_TOGGLE)
227 #ifdef HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE
228 #undef HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE
229 #else
230 #define HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE
231 #endif
232 
235 #include "hwy/contrib/sort/vqsort-inl.h" // HeapSort
236 #include "hwy/tests/test_util-inl.h"
237 
239 namespace hwy {
240 namespace HWY_NAMESPACE {
241 
243  static HWY_INLINE uint64_t SplitMix64(uint64_t z) {
244  z = (z ^ (z >> 30)) * 0xBF58476D1CE4E5B9ull;
245  z = (z ^ (z >> 27)) * 0x94D049BB133111EBull;
246  return z ^ (z >> 31);
247  }
248 
249  public:
250  // Generates two vectors of 64-bit seeds via SplitMix64 and stores into
251  // `seeds`. Generating these afresh in each ChoosePivot is too expensive.
252  template <class DU64>
253  static void GenerateSeeds(DU64 du64, TFromD<DU64>* HWY_RESTRICT seeds) {
254  seeds[0] = SplitMix64(0x9E3779B97F4A7C15ull);
255  for (size_t i = 1; i < 2 * Lanes(du64); ++i) {
256  seeds[i] = SplitMix64(seeds[i - 1]);
257  }
258  }
259 
260  // Need to pass in the state because vector cannot be class members.
261  template <class DU64>
262  static Vec<DU64> RandomBits(DU64 /* tag */, Vec<DU64>& state0,
263  Vec<DU64>& state1) {
264  Vec<DU64> s1 = state0;
265  Vec<DU64> s0 = state1;
266  const Vec<DU64> bits = Add(s1, s0);
267  state0 = s0;
268  s1 = Xor(s1, ShiftLeft<23>(s1));
269  state1 = Xor(s1, Xor(s0, Xor(ShiftRight<18>(s1), ShiftRight<5>(s0))));
270  return bits;
271  }
272 };
273 
274 template <typename T, class DU64, HWY_IF_NOT_FLOAT(T)>
276  const Vec<DU64> mask) {
277  const Vec<DU64> bits = Xorshift128Plus::RandomBits(du64, s0, s1);
278  return And(bits, mask);
279 }
280 
281 // Important to avoid denormals, which are flushed to zero by SIMD but not
282 // scalar sorts, and NaN, which may be ordered differently in scalar vs. SIMD.
283 template <typename T, class DU64, HWY_IF_FLOAT(T)>
284 Vec<DU64> RandomValues(DU64 du64, Vec<DU64>& s0, Vec<DU64>& s1,
285  const Vec<DU64> mask) {
286  const Vec<DU64> bits = Xorshift128Plus::RandomBits(du64, s0, s1);
287  const Vec<DU64> values = And(bits, mask);
288 #if HWY_TARGET == HWY_SCALAR // Cannot repartition u64 to i32
289  const RebindToSigned<DU64> di;
290 #else
291  const Repartition<MakeSigned<T>, DU64> di;
292 #endif
293  const RebindToFloat<decltype(di)> df;
294  const RebindToUnsigned<decltype(di)> du;
295  const auto k1 = BitCast(du64, Set(df, T{1.0}));
296  const auto mantissa = BitCast(du64, Set(du, MantissaMask<T>()));
297  // Avoid NaN/denormal by converting from (range-limited) integer.
298  const Vec<DU64> no_nan = OrAnd(k1, values, mantissa);
299  return BitCast(du64, ConvertTo(df, BitCast(di, no_nan)));
300 }
301 
302 template <class DU64>
303 Vec<DU64> MaskForDist(DU64 du64, const Dist dist, size_t sizeof_t) {
304  switch (sizeof_t) {
305  case 2:
306  return Set(du64, (dist == Dist::kUniform8) ? 0x00FF00FF00FF00FFull
307  : 0xFFFFFFFFFFFFFFFFull);
308  case 4:
309  return Set(du64, (dist == Dist::kUniform8) ? 0x000000FF000000FFull
310  : (dist == Dist::kUniform16) ? 0x0000FFFF0000FFFFull
311  : 0xFFFFFFFFFFFFFFFFull);
312  case 8:
313  return Set(du64, (dist == Dist::kUniform8) ? 0x00000000000000FFull
314  : (dist == Dist::kUniform16) ? 0x000000000000FFFFull
315  : 0x00000000FFFFFFFFull);
316  default:
317  HWY_ABORT("Logic error");
318  return Zero(du64);
319  }
320 }
321 
322 template <typename T>
323 InputStats<T> GenerateInput(const Dist dist, T* v, size_t num) {
324  SortTag<uint64_t> du64;
325  using VU64 = Vec<decltype(du64)>;
326  const size_t N64 = Lanes(du64);
327  auto buf = hwy::AllocateAligned<uint64_t>(2 * N64);
328  Xorshift128Plus::GenerateSeeds(du64, buf.get());
329  auto s0 = Load(du64, buf.get());
330  auto s1 = Load(du64, buf.get() + N64);
331 
332  const VU64 mask = MaskForDist(du64, dist, sizeof(T));
333 
334  const Repartition<T, decltype(du64)> d;
335  const size_t N = Lanes(d);
336  size_t i = 0;
337  for (; i + N <= num; i += N) {
338  const VU64 bits = RandomValues<T>(du64, s0, s1, mask);
339 #if HWY_ARCH_RVV || (HWY_TARGET == HWY_NEON && HWY_ARCH_ARM_V7)
340  // v may not be 64-bit aligned
341  StoreU(bits, du64, buf.get());
342  memcpy(v + i, buf.get(), N64 * sizeof(uint64_t));
343 #else
344  StoreU(bits, du64, reinterpret_cast<uint64_t*>(v + i));
345 #endif
346  }
347  if (i < num) {
348  const VU64 bits = RandomValues<T>(du64, s0, s1, mask);
349  StoreU(bits, du64, buf.get());
350  memcpy(v + i, buf.get(), (num - i) * sizeof(T));
351  }
352 
353  InputStats<T> input_stats;
354  for (size_t i = 0; i < num; ++i) {
355  input_stats.Notify(v[i]);
356  }
357  return input_stats;
358 }
359 
360 struct ThreadLocal {
362 };
363 
364 struct SharedState {
365 #if HAVE_PARALLEL_IPS4O
366  const unsigned max_threads = hwy::LimitsMax<unsigned>(); // 16 for Table 1a
367  ips4o::StdThreadPool pool{static_cast<int>(
368  HWY_MIN(max_threads, std::thread::hardware_concurrency() / 2))};
369 #endif
370  std::vector<ThreadLocal> tls{1};
371 };
372 
373 // Bridge from keys (passed to Run) to lanes as expected by HeapSort. For
374 // non-128-bit keys they are the same:
375 template <class Order, typename KeyType, HWY_IF_NOT_LANE_SIZE(KeyType, 16)>
376 void CallHeapSort(KeyType* HWY_RESTRICT keys, const size_t num_keys) {
377  using detail::TraitsLane;
378  using detail::SharedTraits;
379  if (Order().IsAscending()) {
380  const SharedTraits<TraitsLane<detail::OrderAscending<KeyType>>> st;
381  return detail::HeapSort(st, keys, num_keys);
382  } else {
383  const SharedTraits<TraitsLane<detail::OrderDescending<KeyType>>> st;
384  return detail::HeapSort(st, keys, num_keys);
385  }
386 }
387 
388 #if VQSORT_ENABLED
389 template <class Order>
390 void CallHeapSort(hwy::uint128_t* HWY_RESTRICT keys, const size_t num_keys) {
391  using detail::SharedTraits;
392  using detail::Traits128;
393  uint64_t* lanes = reinterpret_cast<uint64_t*>(keys);
394  const size_t num_lanes = num_keys * 2;
395  if (Order().IsAscending()) {
396  const SharedTraits<Traits128<detail::OrderAscending128>> st;
397  return detail::HeapSort(st, lanes, num_lanes);
398  } else {
399  const SharedTraits<Traits128<detail::OrderDescending128>> st;
400  return detail::HeapSort(st, lanes, num_lanes);
401  }
402 }
403 
404 template <class Order>
405 void CallHeapSort(K64V64* HWY_RESTRICT keys, const size_t num_keys) {
406  using detail::SharedTraits;
407  using detail::Traits128;
408  uint64_t* lanes = reinterpret_cast<uint64_t*>(keys);
409  const size_t num_lanes = num_keys * 2;
410  if (Order().IsAscending()) {
411  const SharedTraits<Traits128<detail::OrderAscendingKV128>> st;
412  return detail::HeapSort(st, lanes, num_lanes);
413  } else {
414  const SharedTraits<Traits128<detail::OrderDescendingKV128>> st;
415  return detail::HeapSort(st, lanes, num_lanes);
416  }
417 }
418 #endif // VQSORT_ENABLED
419 
420 template <class Order, typename KeyType>
421 void Run(Algo algo, KeyType* HWY_RESTRICT inout, size_t num,
422  SharedState& shared, size_t thread) {
423  const std::less<KeyType> less;
424  const std::greater<KeyType> greater;
425 
426  switch (algo) {
427 #if HAVE_AVX2SORT
428  case Algo::kSEA:
429  return avx2::quicksort(inout, static_cast<int>(num));
430 #endif
431 
432 #if HAVE_IPS4O
433  case Algo::kIPS4O:
434  if (Order().IsAscending()) {
435  return ips4o::sort(inout, inout + num, less);
436  } else {
437  return ips4o::sort(inout, inout + num, greater);
438  }
439 #endif
440 
441 #if HAVE_PARALLEL_IPS4O
442  case Algo::kParallelIPS4O:
443  if (Order().IsAscending()) {
444  return ips4o::parallel::sort(inout, inout + num, less, shared.pool);
445  } else {
446  return ips4o::parallel::sort(inout, inout + num, greater, shared.pool);
447  }
448 #endif
449 
450 #if HAVE_SORT512
451  case Algo::kSort512:
452  HWY_ABORT("not supported");
453  // return Sort512::Sort(inout, num);
454 #endif
455 
456 #if HAVE_PDQSORT
457  case Algo::kPDQ:
458  if (Order().IsAscending()) {
459  return boost::sort::pdqsort_branchless(inout, inout + num, less);
460  } else {
461  return boost::sort::pdqsort_branchless(inout, inout + num, greater);
462  }
463 #endif
464 
465 #if HAVE_VXSORT
466  case Algo::kVXSort: {
467 #if (VXSORT_AVX3 && HWY_TARGET != HWY_AVX3) || \
468  (!VXSORT_AVX3 && HWY_TARGET != HWY_AVX2)
469  fprintf(stderr, "Do not call for target %s\n",
471  return;
472 #else
473 #if VXSORT_AVX3
474  vxsort::vxsort<KeyType, vxsort::AVX512> vx;
475 #else
476  vxsort::vxsort<KeyType, vxsort::AVX2> vx;
477 #endif
478  if (Order().IsAscending()) {
479  return vx.sort(inout, inout + num - 1);
480  } else {
481  fprintf(stderr, "Skipping VX - does not support descending order\n");
482  return;
483  }
484 #endif // enabled for this target
485  }
486 #endif // HAVE_VXSORT
487 
488  case Algo::kStd:
489  if (Order().IsAscending()) {
490  return std::sort(inout, inout + num, less);
491  } else {
492  return std::sort(inout, inout + num, greater);
493  }
494 
495  case Algo::kVQSort:
496  return shared.tls[thread].sorter(inout, num, Order());
497 
498  case Algo::kHeap:
499  return CallHeapSort<Order>(inout, num);
500 
501  default:
502  HWY_ABORT("Not implemented");
503  }
504 }
505 
506 // NOLINTNEXTLINE(google-readability-namespace-comments)
507 } // namespace HWY_NAMESPACE
508 } // namespace hwy
510 
511 #endif // HIGHWAY_HWY_CONTRIB_SORT_ALGO_TOGGLE
HWY_AFTER_NAMESPACE()
HWY_BEFORE_NAMESPACE()
#define HWY_RESTRICT
Definition: base.h:61
#define HWY_POP_ATTRIBUTES
Definition: base.h:114
#define HWY_MIN(a, b)
Definition: base.h:125
#define HWY_ABORT(format,...)
Definition: base.h:141
#define HWY_INLINE
Definition: base.h:62
#define HWY_PUSH_ATTRIBUTES(targets_str)
Definition: base.h:113
Definition: algo-inl.h:242
static void GenerateSeeds(DU64 du64, TFromD< DU64 > *HWY_RESTRICT seeds)
Definition: algo-inl.h:253
static Vec< DU64 > RandomBits(DU64, Vec< DU64 > &state0, Vec< DU64 > &state1)
Definition: algo-inl.h:262
static HWY_INLINE uint64_t SplitMix64(uint64_t z)
Definition: algo-inl.h:243
Definition: algo-inl.h:117
T min_
Definition: algo-inl.h:155
size_t count_
Definition: algo-inl.h:158
T max_
Definition: algo-inl.h:156
bool operator==(const InputStats &other) const
Definition: algo-inl.h:132
void Notify(T value)
Definition: algo-inl.h:119
uint64_t sum_
Definition: algo-inl.h:157
Definition: vqsort.h:41
#define HWY_TARGET
Definition: detect_targets.h:341
void HeapSort(Traits st, T *HWY_RESTRICT lanes, const size_t num_lanes)
Definition: vqsort-inl.h:92
d
Definition: rvv-inl.h:1742
void CallHeapSort(KeyType *HWY_RESTRICT keys, const size_t num_keys)
Definition: algo-inl.h:376
constexpr HWY_API size_t Lanes(Simd< T, N, kPow2 >)
Definition: arm_sve-inl.h:236
void Run(Algo algo, KeyType *HWY_RESTRICT inout, size_t num, SharedState &shared, size_t thread)
Definition: algo-inl.h:421
Rebind< MakeUnsigned< TFromD< D > >, D > RebindToUnsigned
Definition: ops/shared-inl.h:200
Rebind< MakeFloat< TFromD< D > >, D > RebindToFloat
Definition: ops/shared-inl.h:202
Vec< DU64 > MaskForDist(DU64 du64, const Dist dist, size_t sizeof_t)
Definition: algo-inl.h:303
HWY_API Vec128< float > ConvertTo(Full128< float >, const Vec128< int32_t > v)
Definition: arm_neon-inl.h:3273
HWY_API V Add(V a, V b)
Definition: arm_neon-inl.h:6274
HWY_API Vec128< T, N > Load(Simd< T, N, 0 > d, const T *HWY_RESTRICT p)
Definition: arm_neon-inl.h:2706
HWY_API void StoreU(const Vec128< uint8_t > v, Full128< uint8_t >, uint8_t *HWY_RESTRICT unaligned)
Definition: arm_neon-inl.h:2725
svuint16_t Set(Simd< bfloat16_t, N, kPow2 > d, bfloat16_t arg)
Definition: arm_sve-inl.h:312
HWY_API Vec128< T, N > OrAnd(Vec128< T, N > o, Vec128< T, N > a1, Vec128< T, N > a2)
Definition: arm_neon-inl.h:1999
HWY_API Vec128< T, N > And(const Vec128< T, N > a, const Vec128< T, N > b)
Definition: arm_neon-inl.h:1934
HWY_API Vec128< T, N > BitCast(Simd< T, N, 0 > d, Vec128< FromT, N *sizeof(T)/sizeof(FromT)> v)
Definition: arm_neon-inl.h:988
HWY_API Vec128< T, N > Zero(Simd< T, N, 0 > d)
Definition: arm_neon-inl.h:1011
HWY_API Vec128< T, N > Xor(const Vec128< T, N > a, const Vec128< T, N > b)
Definition: arm_neon-inl.h:1983
InputStats< T > GenerateInput(const Dist dist, T *v, size_t num)
Definition: algo-inl.h:323
typename D::template Repartition< T > Repartition
Definition: ops/shared-inl.h:206
N
Definition: rvv-inl.h:1742
ScalableTag< T, -1 > SortTag
Definition: contrib/sort/shared-inl.h:123
Vec< DU64 > RandomValues(DU64 du64, Vec< DU64 > &s0, Vec< DU64 > &s1, const Vec< DU64 > mask)
Definition: algo-inl.h:275
const vfloat64m1_t v
Definition: rvv-inl.h:1742
typename D::T TFromD
Definition: ops/shared-inl.h:191
decltype(Zero(D())) Vec
Definition: generic_ops-inl.h:32
Definition: aligned_allocator.h:27
const char * AlgoName(Algo algo)
Definition: algo-inl.h:185
Dist
Definition: algo-inl.h:98
static const char * DistName(Dist dist)
Definition: algo-inl.h:104
static HWY_MAYBE_UNUSED const char * TargetName(uint32_t target)
Definition: targets.h:77
Algo
Definition: algo-inl.h:161
static std::vector< Dist > AllDist()
Definition: algo-inl.h:100
#define HWY_NAMESPACE
Definition: set_macros-inl.h:82
Definition: algo-inl.h:364
std::vector< ThreadLocal > tls
Definition: algo-inl.h:370
Definition: algo-inl.h:360
Sorter sorter
Definition: algo-inl.h:361
Definition: sorting_networks-inl.h:686
Definition: traits-inl.h:381
Definition: base.h:264