Grok  10.0.3
traits128-inl.h
Go to the documentation of this file.
1 // Copyright 2021 Google LLC
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 
16 // Per-target
17 #if defined(HIGHWAY_HWY_CONTRIB_SORT_TRAITS128_TOGGLE) == \
18  defined(HWY_TARGET_TOGGLE)
19 #ifdef HIGHWAY_HWY_CONTRIB_SORT_TRAITS128_TOGGLE
20 #undef HIGHWAY_HWY_CONTRIB_SORT_TRAITS128_TOGGLE
21 #else
22 #define HIGHWAY_HWY_CONTRIB_SORT_TRAITS128_TOGGLE
23 #endif
24 
25 #include <string>
26 
28 #include "hwy/contrib/sort/vqsort.h" // SortDescending
29 #include "hwy/highway.h"
30 
32 namespace hwy {
33 namespace HWY_NAMESPACE {
34 namespace detail {
35 
36 #if VQSORT_ENABLED || HWY_IDE
37 
38 // Highway does not provide a lane type for 128-bit keys, so we use uint64_t
39 // along with an abstraction layer for single-lane vs. lane-pair, which is
40 // independent of the order.
41 struct KeyAny128 {
42  constexpr bool Is128() const { return true; }
43  constexpr size_t LanesPerKey() const { return 2; }
44 
45  // What type bench_sort should allocate for generating inputs.
46  using LaneType = uint64_t;
47  // KeyType and KeyString are defined by derived classes.
48 
49  HWY_INLINE void Swap(LaneType* a, LaneType* b) const {
50  const FixedTag<LaneType, 2> d;
51  const auto temp = LoadU(d, a);
52  StoreU(LoadU(d, b), d, a);
53  StoreU(temp, d, b);
54  }
55 
56  template <class V, class M>
57  HWY_INLINE V CompressKeys(V keys, M mask) const {
58  return CompressBlocksNot(keys, mask);
59  }
60 
61  template <class D>
62  HWY_INLINE Vec<D> SetKey(D d, const TFromD<D>* key) const {
63  return LoadDup128(d, key);
64  }
65 
66  template <class D>
67  HWY_INLINE Vec<D> ReverseKeys(D d, Vec<D> v) const {
68  return ReverseBlocks(d, v);
69  }
70 
71  template <class D>
72  HWY_INLINE Vec<D> ReverseKeys2(D /* tag */, const Vec<D> v) const {
73  return SwapAdjacentBlocks(v);
74  }
75 
76  // Only called for 4 keys because we do not support >512-bit vectors.
77  template <class D>
78  HWY_INLINE Vec<D> ReverseKeys4(D d, const Vec<D> v) const {
79  HWY_DASSERT(Lanes(d) <= 64 / sizeof(TFromD<D>));
80  return ReverseKeys(d, v);
81  }
82 
83  // Only called for 4 keys because we do not support >512-bit vectors.
84  template <class D>
85  HWY_INLINE Vec<D> OddEvenPairs(D d, const Vec<D> odd,
86  const Vec<D> even) const {
87  HWY_DASSERT(Lanes(d) <= 64 / sizeof(TFromD<D>));
88  return ConcatUpperLower(d, odd, even);
89  }
90 
91  template <class V>
92  HWY_INLINE V OddEvenKeys(const V odd, const V even) const {
93  return OddEvenBlocks(odd, even);
94  }
95 
96  template <class D>
97  HWY_INLINE Vec<D> ReverseKeys8(D, Vec<D>) const {
98  HWY_ASSERT(0); // not supported: would require 1024-bit vectors
99  }
100 
101  template <class D>
102  HWY_INLINE Vec<D> ReverseKeys16(D, Vec<D>) const {
103  HWY_ASSERT(0); // not supported: would require 2048-bit vectors
104  }
105 
106  // This is only called for 8/16 col networks (not supported).
107  template <class D>
108  HWY_INLINE Vec<D> SwapAdjacentPairs(D, Vec<D>) const {
109  HWY_ASSERT(0);
110  }
111 
112  // This is only called for 16 col networks (not supported).
113  template <class D>
114  HWY_INLINE Vec<D> SwapAdjacentQuads(D, Vec<D>) const {
115  HWY_ASSERT(0);
116  }
117 
118  // This is only called for 8 col networks (not supported).
119  template <class D>
120  HWY_INLINE Vec<D> OddEvenQuads(D, Vec<D>, Vec<D>) const {
121  HWY_ASSERT(0);
122  }
123 };
124 
125 // Base class shared between OrderAscending128, OrderDescending128.
126 struct Key128 : public KeyAny128 {
127  // What type to pass to Sorter::operator().
128  using KeyType = hwy::uint128_t;
129 
130  std::string KeyString() const { return "U128"; }
131 };
132 
133 // Anything order-related depends on the key traits *and* the order (see
134 // FirstOfLanes). We cannot implement just one Compare function because Lt128
135 // only compiles if the lane type is u64. Thus we need either overloaded
136 // functions with a tag type, class specializations, or separate classes.
137 // We avoid overloaded functions because we want all functions to be callable
138 // from a SortTraits without per-function wrappers. Specializing would work, but
139 // we are anyway going to specialize at a higher level.
140 struct OrderAscending128 : public Key128 {
141  using Order = SortAscending;
142 
143  HWY_INLINE bool Compare1(const LaneType* a, const LaneType* b) {
144  return (a[1] == b[1]) ? a[0] < b[0] : a[1] < b[1];
145  }
146 
147  template <class D>
148  HWY_INLINE Mask<D> Compare(D d, Vec<D> a, Vec<D> b) const {
149  return Lt128(d, a, b);
150  }
151 
152  // Used by CompareTop
153  template <class V>
154  HWY_INLINE Mask<DFromV<V> > CompareLanes(V a, V b) const {
155  return Lt(a, b);
156  }
157 
158  template <class D>
159  HWY_INLINE Vec<D> First(D d, const Vec<D> a, const Vec<D> b) const {
160  return Min128(d, a, b);
161  }
162 
163  template <class D>
164  HWY_INLINE Vec<D> Last(D d, const Vec<D> a, const Vec<D> b) const {
165  return Max128(d, a, b);
166  }
167 
168  // Same as for regular lanes because 128-bit lanes are u64.
169  template <class D>
170  HWY_INLINE Vec<D> FirstValue(D d) const {
171  return Set(d, hwy::LowestValue<TFromD<D> >());
172  }
173 
174  template <class D>
175  HWY_INLINE Vec<D> LastValue(D d) const {
176  return Set(d, hwy::HighestValue<TFromD<D> >());
177  }
178 };
179 
180 struct OrderDescending128 : public Key128 {
181  using Order = SortDescending;
182 
183  HWY_INLINE bool Compare1(const LaneType* a, const LaneType* b) {
184  return (a[1] == b[1]) ? b[0] < a[0] : b[1] < a[1];
185  }
186 
187  template <class D>
188  HWY_INLINE Mask<D> Compare(D d, Vec<D> a, Vec<D> b) const {
189  return Lt128(d, b, a);
190  }
191 
192  // Used by CompareTop
193  template <class V>
194  HWY_INLINE Mask<DFromV<V> > CompareLanes(V a, V b) const {
195  return Lt(b, a);
196  }
197 
198  template <class D>
199  HWY_INLINE Vec<D> First(D d, const Vec<D> a, const Vec<D> b) const {
200  return Max128(d, a, b);
201  }
202 
203  template <class D>
204  HWY_INLINE Vec<D> Last(D d, const Vec<D> a, const Vec<D> b) const {
205  return Min128(d, a, b);
206  }
207 
208  // Same as for regular lanes because 128-bit lanes are u64.
209  template <class D>
210  HWY_INLINE Vec<D> FirstValue(D d) const {
211  return Set(d, hwy::HighestValue<TFromD<D> >());
212  }
213 
214  template <class D>
215  HWY_INLINE Vec<D> LastValue(D d) const {
216  return Set(d, hwy::LowestValue<TFromD<D> >());
217  }
218 };
219 
220 // Base class shared between OrderAscendingKV128, OrderDescendingKV128.
221 struct KeyValue128 : public KeyAny128 {
222  // What type to pass to Sorter::operator().
223  using KeyType = K64V64;
224 
225  std::string KeyString() const { return "KV128"; }
226 };
227 
228 struct OrderAscendingKV128 : public KeyValue128 {
229  using Order = SortAscending;
230 
231  HWY_INLINE bool Compare1(const LaneType* a, const LaneType* b) {
232  return a[1] < b[1];
233  }
234 
235  template <class D>
236  HWY_INLINE Mask<D> Compare(D d, Vec<D> a, Vec<D> b) const {
237  return Lt128Upper(d, a, b);
238  }
239 
240  // Used by CompareTop
241  template <class V>
242  HWY_INLINE Mask<DFromV<V> > CompareLanes(V a, V b) const {
243  return Lt(a, b);
244  }
245 
246  template <class D>
247  HWY_INLINE Vec<D> First(D d, const Vec<D> a, const Vec<D> b) const {
248  return Min128Upper(d, a, b);
249  }
250 
251  template <class D>
252  HWY_INLINE Vec<D> Last(D d, const Vec<D> a, const Vec<D> b) const {
253  return Max128Upper(d, a, b);
254  }
255 
256  // Same as for regular lanes because 128-bit lanes are u64.
257  template <class D>
258  HWY_INLINE Vec<D> FirstValue(D d) const {
259  return Set(d, hwy::LowestValue<TFromD<D> >());
260  }
261 
262  template <class D>
263  HWY_INLINE Vec<D> LastValue(D d) const {
264  return Set(d, hwy::HighestValue<TFromD<D> >());
265  }
266 };
267 
268 struct OrderDescendingKV128 : public KeyValue128 {
269  using Order = SortDescending;
270 
271  HWY_INLINE bool Compare1(const LaneType* a, const LaneType* b) {
272  return b[1] < a[1];
273  }
274 
275  template <class D>
276  HWY_INLINE Mask<D> Compare(D d, Vec<D> a, Vec<D> b) const {
277  return Lt128Upper(d, b, a);
278  }
279 
280  // Used by CompareTop
281  template <class V>
282  HWY_INLINE Mask<DFromV<V> > CompareLanes(V a, V b) const {
283  return Lt(b, a);
284  }
285 
286  template <class D>
287  HWY_INLINE Vec<D> First(D d, const Vec<D> a, const Vec<D> b) const {
288  return Max128Upper(d, a, b);
289  }
290 
291  template <class D>
292  HWY_INLINE Vec<D> Last(D d, const Vec<D> a, const Vec<D> b) const {
293  return Min128Upper(d, a, b);
294  }
295 
296  // Same as for regular lanes because 128-bit lanes are u64.
297  template <class D>
298  HWY_INLINE Vec<D> FirstValue(D d) const {
299  return Set(d, hwy::HighestValue<TFromD<D> >());
300  }
301 
302  template <class D>
303  HWY_INLINE Vec<D> LastValue(D d) const {
304  return Set(d, hwy::LowestValue<TFromD<D> >());
305  }
306 };
307 
308 // Shared code that depends on Order.
309 template <class Base>
310 class Traits128 : public Base {
311  // Special case for >= 256 bit vectors
312 #if HWY_TARGET <= HWY_AVX2 || HWY_TARGET == HWY_SVE_256
313  // Returns vector with only the top u64 lane valid. Useful when the next step
314  // is to replicate the mask anyway.
315  template <class D>
316  HWY_INLINE HWY_MAYBE_UNUSED Vec<D> CompareTop(D d, Vec<D> a, Vec<D> b) const {
317  const Base* base = static_cast<const Base*>(this);
318  const Mask<D> eqHL = Eq(a, b);
319  const Vec<D> ltHL = VecFromMask(d, base->CompareLanes(a, b));
320 #if HWY_TARGET == HWY_SVE_256
321  return IfThenElse(eqHL, DupEven(ltHL), ltHL);
322 #else
323  const Vec<D> ltLX = ShiftLeftLanes<1>(ltHL);
324  return OrAnd(ltHL, VecFromMask(d, eqHL), ltLX);
325 #endif
326  }
327 
328  // We want to swap 2 u128, i.e. 4 u64 lanes, based on the 0 or FF..FF mask in
329  // the most-significant of those lanes (the result of CompareTop), so
330  // replicate it 4x. Only called for >= 256-bit vectors.
331  template <class V>
332  HWY_INLINE V ReplicateTop4x(V v) const {
333 #if HWY_TARGET == HWY_SVE_256
334  return svdup_lane_u64(v, 3);
335 #elif HWY_TARGET <= HWY_AVX3
336  return V{_mm512_permutex_epi64(v.raw, _MM_SHUFFLE(3, 3, 3, 3))};
337 #else // AVX2
338  return V{_mm256_permute4x64_epi64(v.raw, _MM_SHUFFLE(3, 3, 3, 3))};
339 #endif
340  }
341 #endif // HWY_TARGET
342 
343  public:
344  template <class D>
345  HWY_INLINE Vec<D> FirstOfLanes(D d, Vec<D> v,
346  TFromD<D>* HWY_RESTRICT buf) const {
347  const Base* base = static_cast<const Base*>(this);
348  const size_t N = Lanes(d);
349  Store(v, d, buf);
350  v = base->SetKey(d, buf + 0); // result must be broadcasted
351  for (size_t i = base->LanesPerKey(); i < N; i += base->LanesPerKey()) {
352  v = base->First(d, v, base->SetKey(d, buf + i));
353  }
354  return v;
355  }
356 
357  template <class D>
358  HWY_INLINE Vec<D> LastOfLanes(D d, Vec<D> v,
359  TFromD<D>* HWY_RESTRICT buf) const {
360  const Base* base = static_cast<const Base*>(this);
361  const size_t N = Lanes(d);
362  Store(v, d, buf);
363  v = base->SetKey(d, buf + 0); // result must be broadcasted
364  for (size_t i = base->LanesPerKey(); i < N; i += base->LanesPerKey()) {
365  v = base->Last(d, v, base->SetKey(d, buf + i));
366  }
367  return v;
368  }
369 
370  template <class D>
371  HWY_INLINE void Sort2(D d, Vec<D>& a, Vec<D>& b) const {
372  const Base* base = static_cast<const Base*>(this);
373 
374  const Vec<D> a_copy = a;
375  const auto lt = base->Compare(d, a, b);
376  a = IfThenElse(lt, a, b);
377  b = IfThenElse(lt, b, a_copy);
378  }
379 
380  // Conditionally swaps even-numbered lanes with their odd-numbered neighbor.
381  template <class D>
382  HWY_INLINE Vec<D> SortPairsDistance1(D d, Vec<D> v) const {
383  const Base* base = static_cast<const Base*>(this);
384  Vec<D> swapped = base->ReverseKeys2(d, v);
385 
386 #if HWY_TARGET <= HWY_AVX2 || HWY_TARGET == HWY_SVE_256
387  const Vec<D> select = ReplicateTop4x(CompareTop(d, v, swapped));
388  return IfVecThenElse(select, swapped, v);
389 #else
390  Sort2(d, v, swapped);
391  return base->OddEvenKeys(swapped, v);
392 #endif
393  }
394 
395  // Swaps with the vector formed by reversing contiguous groups of 4 keys.
396  template <class D>
397  HWY_INLINE Vec<D> SortPairsReverse4(D d, Vec<D> v) const {
398  const Base* base = static_cast<const Base*>(this);
399  Vec<D> swapped = base->ReverseKeys4(d, v);
400 
401  // Only specialize for AVX3 because this requires 512-bit vectors.
402 #if HWY_TARGET <= HWY_AVX3
403  const Vec512<uint64_t> outHx = CompareTop(d, v, swapped);
404  // Similar to ReplicateTop4x, we want to gang together 2 comparison results
405  // (4 lanes). They are not contiguous, so use permute to replicate 4x.
406  alignas(64) uint64_t kIndices[8] = {7, 7, 5, 5, 5, 5, 7, 7};
407  const Vec512<uint64_t> select =
408  TableLookupLanes(outHx, SetTableIndices(d, kIndices));
409  return IfVecThenElse(select, swapped, v);
410 #else
411  Sort2(d, v, swapped);
412  return base->OddEvenPairs(d, swapped, v);
413 #endif
414  }
415 
416  // Conditionally swaps lane 0 with 4, 1 with 5 etc.
417  template <class D>
418  HWY_INLINE Vec<D> SortPairsDistance4(D, Vec<D>) const {
419  // Only used by Merge16, which would require 2048 bit vectors (unsupported).
420  HWY_ASSERT(0);
421  }
422 };
423 
424 #endif // VQSORT_ENABLED
425 
426 } // namespace detail
427 // NOLINTNEXTLINE(google-readability-namespace-comments)
428 } // namespace HWY_NAMESPACE
429 } // namespace hwy
431 
432 #endif // HIGHWAY_HWY_CONTRIB_SORT_TRAITS128_TOGGLE
#define HWY_RESTRICT
Definition: base.h:61
#define HWY_INLINE
Definition: base.h:62
#define HWY_DASSERT(condition)
Definition: base.h:191
#define HWY_MAYBE_UNUSED
Definition: base.h:73
#define HWY_ASSERT(condition)
Definition: base.h:145
HWY_INLINE Vec128< T, N > IfThenElse(hwy::SizeTag< 1 >, Mask128< T, N > mask, Vec128< T, N > yes, Vec128< T, N > no)
Definition: x86_128-inl.h:673
d
Definition: rvv-inl.h:1742
HWY_API Vec128< T, N > OddEvenBlocks(Vec128< T, N >, Vec128< T, N > even)
Definition: arm_neon-inl.h:4533
HWY_API auto Lt(V a, V b) -> decltype(a==b)
Definition: arm_neon-inl.h:6309
HWY_API auto Eq(V a, V b) -> decltype(a==b)
Definition: arm_neon-inl.h:6301
constexpr HWY_API size_t Lanes(Simd< T, N, kPow2 >)
Definition: arm_sve-inl.h:236
HWY_API Vec128< uint64_t > CompressBlocksNot(Vec128< uint64_t > v, Mask128< uint64_t >)
Definition: arm_neon-inl.h:5815
HWY_API Vec128< T, N > IfVecThenElse(Vec128< T, N > mask, Vec128< T, N > yes, Vec128< T, N > no)
Definition: arm_neon-inl.h:2006
HWY_API Vec128< T, N > VecFromMask(Simd< T, N, 0 > d, const Mask128< T, N > v)
Definition: arm_neon-inl.h:2182
HWY_API Vec128< T, N > DupEven(Vec128< T, N > v)
Definition: arm_neon-inl.h:4482
HWY_API Vec128< T, N > TableLookupLanes(Vec128< T, N > v, Indices128< T, N > idx)
Definition: arm_neon-inl.h:3934
HWY_API void StoreU(const Vec128< uint8_t > v, Full128< uint8_t >, uint8_t *HWY_RESTRICT unaligned)
Definition: arm_neon-inl.h:2725
HWY_INLINE VFromD< D > Min128Upper(D d, const VFromD< D > a, const VFromD< D > b)
Definition: arm_neon-inl.h:6260
HWY_API Vec128< T, N > SwapAdjacentBlocks(Vec128< T, N > v)
Definition: arm_neon-inl.h:4540
HWY_INLINE VFromD< D > Min128(D d, const VFromD< D > a, const VFromD< D > b)
Definition: arm_neon-inl.h:6250
svuint16_t Set(Simd< bfloat16_t, N, kPow2 > d, bfloat16_t arg)
Definition: arm_sve-inl.h:312
HWY_API Vec128< uint8_t > LoadU(Full128< uint8_t >, const uint8_t *HWY_RESTRICT unaligned)
Definition: arm_neon-inl.h:2544
HWY_INLINE VFromD< D > Max128Upper(D d, const VFromD< D > a, const VFromD< D > b)
Definition: arm_neon-inl.h:6265
HWY_INLINE Mask128< T, N > Lt128(Simd< T, N, 0 > d, Vec128< T, N > a, Vec128< T, N > b)
Definition: arm_neon-inl.h:6212
decltype(GetLane(V())) LaneType
Definition: generic_ops-inl.h:25
HWY_API Vec128< T, N > OrAnd(Vec128< T, N > o, Vec128< T, N > a1, Vec128< T, N > a2)
Definition: arm_neon-inl.h:1999
HWY_API Vec128< T, N > ConcatUpperLower(Simd< T, N, 0 > d, Vec128< T, N > hi, Vec128< T, N > lo)
Definition: arm_neon-inl.h:4406
HWY_INLINE VFromD< D > Max128(D d, const VFromD< D > a, const VFromD< D > b)
Definition: arm_neon-inl.h:6255
HWY_API Indices128< T, N > SetTableIndices(Simd< T, N, 0 > d, const TI *idx)
Definition: arm_neon-inl.h:3928
HWY_API Vec128< T, N > LoadDup128(Simd< T, N, 0 > d, const T *const HWY_RESTRICT p)
Definition: arm_neon-inl.h:2718
N
Definition: rvv-inl.h:1742
HWY_API Vec128< T > ReverseBlocks(Full128< T >, const Vec128< T > v)
Definition: arm_neon-inl.h:4548
HWY_API void Store(Vec128< T, N > v, Simd< T, N, 0 > d, T *HWY_RESTRICT aligned)
Definition: arm_neon-inl.h:2882
HWY_INLINE Mask128< T, N > Lt128Upper(Simd< T, N, 0 > d, Vec128< T, N > a, Vec128< T, N > b)
Definition: arm_neon-inl.h:6240
const vfloat64m1_t v
Definition: rvv-inl.h:1742
Definition: aligned_allocator.h:27
constexpr HWY_API T LowestValue()
Definition: base.h:563
constexpr HWY_API T HighestValue()
Definition: base.h:576
#define HWY_NAMESPACE
Definition: set_macros-inl.h:82
Definition: base.h:264
HWY_AFTER_NAMESPACE()
HWY_BEFORE_NAMESPACE()