Grok  10.0.3
dot-inl.h
Go to the documentation of this file.
1 // Copyright 2021 Google LLC
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 
16 // Include guard (still compiled once per target)
17 #include <cmath>
18 
19 #if defined(HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_) == \
20  defined(HWY_TARGET_TOGGLE)
21 #ifdef HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_
22 #undef HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_
23 #else
24 #define HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_
25 #endif
26 
27 #include "hwy/highway.h"
28 
30 namespace hwy {
31 namespace HWY_NAMESPACE {
32 
33 struct Dot {
34  // Specify zero or more of these, ORed together, as the kAssumptions template
35  // argument to Compute. Each one may improve performance or reduce code size,
36  // at the cost of additional requirements on the arguments.
37  enum Assumptions {
38  // num_elements is at least N, which may be up to HWY_MAX_BYTES / sizeof(T).
40  // num_elements is divisible by N (a power of two, so this can be used if
41  // the problem size is known to be a power of two >= HWY_MAX_BYTES /
42  // sizeof(T)).
44  // RoundUpTo(num_elements, N) elements are accessible; their value does not
45  // matter (will be treated as if they were zero).
47  };
48 
49  // Returns sum{pa[i] * pb[i]} for float or double inputs. Aligning the
50  // pointers to a multiple of N elements is helpful but not required.
51  template <int kAssumptions, class D, typename T = TFromD<D>,
52  HWY_IF_NOT_LANE_SIZE_D(D, 2)>
53  static HWY_INLINE T Compute(const D d, const T* const HWY_RESTRICT pa,
54  const T* const HWY_RESTRICT pb,
55  const size_t num_elements) {
56  static_assert(IsFloat<T>(), "MulAdd requires float type");
57  using V = decltype(Zero(d));
58 
59  const size_t N = Lanes(d);
60  size_t i = 0;
61 
62  constexpr bool kIsAtLeastOneVector =
63  (kAssumptions & kAtLeastOneVector) != 0;
64  constexpr bool kIsMultipleOfVector =
65  (kAssumptions & kMultipleOfVector) != 0;
66  constexpr bool kIsPaddedToVector = (kAssumptions & kPaddedToVector) != 0;
67 
68  // Won't be able to do a full vector load without padding => scalar loop.
69  if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector &&
70  HWY_UNLIKELY(num_elements < N)) {
71  // Only 2x unroll to avoid excessive code size.
72  T sum0 = T(0);
73  T sum1 = T(0);
74  for (; i + 2 <= num_elements; i += 2) {
75  sum0 += pa[i + 0] * pb[i + 0];
76  sum1 += pa[i + 1] * pb[i + 1];
77  }
78  if (i < num_elements) {
79  sum1 += pa[i] * pb[i];
80  }
81  return sum0 + sum1;
82  }
83 
84  // Compiler doesn't make independent sum* accumulators, so unroll manually.
85  // 2 FMA ports * 4 cycle latency = up to 8 in-flight, but that is excessive
86  // for unaligned inputs (each unaligned pointer halves the throughput
87  // because it occupies both L1 load ports for a cycle). We cannot have
88  // arrays of vectors on RVV/SVE, so always unroll 4x.
89  V sum0 = Zero(d);
90  V sum1 = Zero(d);
91  V sum2 = Zero(d);
92  V sum3 = Zero(d);
93 
94  // Main loop: unrolled
95  for (; i + 4 * N <= num_elements; /* i += 4 * N */) { // incr in loop
96  const auto a0 = LoadU(d, pa + i);
97  const auto b0 = LoadU(d, pb + i);
98  i += N;
99  sum0 = MulAdd(a0, b0, sum0);
100  const auto a1 = LoadU(d, pa + i);
101  const auto b1 = LoadU(d, pb + i);
102  i += N;
103  sum1 = MulAdd(a1, b1, sum1);
104  const auto a2 = LoadU(d, pa + i);
105  const auto b2 = LoadU(d, pb + i);
106  i += N;
107  sum2 = MulAdd(a2, b2, sum2);
108  const auto a3 = LoadU(d, pa + i);
109  const auto b3 = LoadU(d, pb + i);
110  i += N;
111  sum3 = MulAdd(a3, b3, sum3);
112  }
113 
114  // Up to 3 iterations of whole vectors
115  for (; i + N <= num_elements; i += N) {
116  const auto a = LoadU(d, pa + i);
117  const auto b = LoadU(d, pb + i);
118  sum0 = MulAdd(a, b, sum0);
119  }
120 
121  if (!kIsMultipleOfVector) {
122  const size_t remaining = num_elements - i;
123  if (remaining != 0) {
124  if (kIsPaddedToVector) {
125  const auto mask = FirstN(d, remaining);
126  const auto a = LoadU(d, pa + i);
127  const auto b = LoadU(d, pb + i);
128  sum1 = MulAdd(IfThenElseZero(mask, a), IfThenElseZero(mask, b), sum1);
129  } else {
130  // Unaligned load such that the last element is in the highest lane -
131  // ensures we do not touch any elements outside the valid range.
132  // If we get here, then num_elements >= N.
133  HWY_DASSERT(i >= N);
134  i += remaining - N;
135  const auto skip = FirstN(d, N - remaining);
136  const auto a = LoadU(d, pa + i); // always unaligned
137  const auto b = LoadU(d, pb + i);
138  sum1 = MulAdd(IfThenZeroElse(skip, a), IfThenZeroElse(skip, b), sum1);
139  }
140  }
141  } // kMultipleOfVector
142 
143  // Reduction tree: sum of all accumulators by pairs, then across lanes.
144  sum0 = Add(sum0, sum1);
145  sum2 = Add(sum2, sum3);
146  sum0 = Add(sum0, sum2);
147  return GetLane(SumOfLanes(d, sum0));
148  }
149 
150  // Returns sum{pa[i] * pb[i]} for bfloat16 inputs. Aligning the pointers to a
151  // multiple of N elements is helpful but not required.
152  template <int kAssumptions, class D>
153  static HWY_INLINE float Compute(const D d,
154  const bfloat16_t* const HWY_RESTRICT pa,
155  const bfloat16_t* const HWY_RESTRICT pb,
156  const size_t num_elements) {
157  const RebindToUnsigned<D> du16;
158  const Repartition<float, D> df32;
159 
160  using V = decltype(Zero(df32));
161  const size_t N = Lanes(d);
162  size_t i = 0;
163 
164  constexpr bool kIsAtLeastOneVector =
165  (kAssumptions & kAtLeastOneVector) != 0;
166  constexpr bool kIsMultipleOfVector =
167  (kAssumptions & kMultipleOfVector) != 0;
168  constexpr bool kIsPaddedToVector = (kAssumptions & kPaddedToVector) != 0;
169 
170  // Won't be able to do a full vector load without padding => scalar loop.
171  if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector &&
172  HWY_UNLIKELY(num_elements < N)) {
173  float sum0 = 0.0f; // Only 2x unroll to avoid excessive code size for..
174  float sum1 = 0.0f; // this unlikely(?) case.
175  for (; i + 2 <= num_elements; i += 2) {
176  sum0 += F32FromBF16(pa[i + 0]) * F32FromBF16(pb[i + 0]);
177  sum1 += F32FromBF16(pa[i + 1]) * F32FromBF16(pb[i + 1]);
178  }
179  if (i < num_elements) {
180  sum1 += F32FromBF16(pa[i]) * F32FromBF16(pb[i]);
181  }
182  return sum0 + sum1;
183  }
184 
185  // See comment in the other Compute() overload. Unroll 2x, but we need
186  // twice as many sums for ReorderWidenMulAccumulate.
187  V sum0 = Zero(df32);
188  V sum1 = Zero(df32);
189  V sum2 = Zero(df32);
190  V sum3 = Zero(df32);
191 
192  // Main loop: unrolled
193  for (; i + 2 * N <= num_elements; /* i += 2 * N */) { // incr in loop
194  const auto a0 = LoadU(d, pa + i);
195  const auto b0 = LoadU(d, pb + i);
196  i += N;
197  sum0 = ReorderWidenMulAccumulate(df32, a0, b0, sum0, sum1);
198  const auto a1 = LoadU(d, pa + i);
199  const auto b1 = LoadU(d, pb + i);
200  i += N;
201  sum2 = ReorderWidenMulAccumulate(df32, a1, b1, sum2, sum3);
202  }
203 
204  // Possibly one more iteration of whole vectors
205  if (i + N <= num_elements) {
206  const auto a0 = LoadU(d, pa + i);
207  const auto b0 = LoadU(d, pb + i);
208  i += N;
209  sum0 = ReorderWidenMulAccumulate(df32, a0, b0, sum0, sum1);
210  }
211 
212  if (!kIsMultipleOfVector) {
213  const size_t remaining = num_elements - i;
214  if (remaining != 0) {
215  if (kIsPaddedToVector) {
216  const auto mask = FirstN(du16, remaining);
217  const auto va = LoadU(d, pa + i);
218  const auto vb = LoadU(d, pb + i);
219  const auto a16 = BitCast(d, IfThenElseZero(mask, BitCast(du16, va)));
220  const auto b16 = BitCast(d, IfThenElseZero(mask, BitCast(du16, vb)));
221  sum2 = ReorderWidenMulAccumulate(df32, a16, b16, sum2, sum3);
222 
223  } else {
224  // Unaligned load such that the last element is in the highest lane -
225  // ensures we do not touch any elements outside the valid range.
226  // If we get here, then num_elements >= N.
227  HWY_DASSERT(i >= N);
228  i += remaining - N;
229  const auto skip = FirstN(du16, N - remaining);
230  const auto va = LoadU(d, pa + i); // always unaligned
231  const auto vb = LoadU(d, pb + i);
232  const auto a16 = BitCast(d, IfThenZeroElse(skip, BitCast(du16, va)));
233  const auto b16 = BitCast(d, IfThenZeroElse(skip, BitCast(du16, vb)));
234  sum2 = ReorderWidenMulAccumulate(df32, a16, b16, sum2, sum3);
235  }
236  }
237  } // kMultipleOfVector
238 
239  // Reduction tree: sum of all accumulators by pairs, then across lanes.
240  sum0 = Add(sum0, sum1);
241  sum2 = Add(sum2, sum3);
242  sum0 = Add(sum0, sum2);
243  return GetLane(SumOfLanes(df32, sum0));
244  }
245 };
246 
247 // NOLINTNEXTLINE(google-readability-namespace-comments)
248 } // namespace HWY_NAMESPACE
249 } // namespace hwy
251 
252 #endif // HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_
#define HWY_RESTRICT
Definition: base.h:61
#define HWY_INLINE
Definition: base.h:62
#define HWY_DASSERT(condition)
Definition: base.h:191
#define HWY_UNLIKELY(expr)
Definition: base.h:67
HWY_AFTER_NAMESPACE()
HWY_BEFORE_NAMESPACE()
d
Definition: rvv-inl.h:1742
HWY_API Mask128< T, N > FirstN(const Simd< T, N, 0 > d, size_t num)
Definition: arm_neon-inl.h:2409
HWY_API Vec128< float, N > MulAdd(const Vec128< float, N > mul, const Vec128< float, N > x, const Vec128< float, N > add)
Definition: arm_neon-inl.h:1784
constexpr HWY_API size_t Lanes(Simd< T, N, kPow2 >)
Definition: arm_sve-inl.h:236
HWY_API Vec128< T, N > SumOfLanes(Simd< T, N, 0 >, const Vec128< T, N > v)
Definition: arm_neon-inl.h:4932
Rebind< MakeUnsigned< TFromD< D > >, D > RebindToUnsigned
Definition: ops/shared-inl.h:200
HWY_API Vec128< float, N > ReorderWidenMulAccumulate(Simd< float, N, 0 > df32, Vec128< bfloat16_t, 2 *N > a, Vec128< bfloat16_t, 2 *N > b, const Vec128< float, N > sum0, Vec128< float, N > &sum1)
Definition: arm_neon-inl.h:4203
HWY_API Vec128< T, N > IfThenElseZero(const Mask128< T, N > mask, const Vec128< T, N > yes)
Definition: arm_neon-inl.h:2212
HWY_API V Add(V a, V b)
Definition: arm_neon-inl.h:6274
HWY_API Vec128< uint8_t > LoadU(Full128< uint8_t >, const uint8_t *HWY_RESTRICT unaligned)
Definition: arm_neon-inl.h:2544
HWY_API Vec128< T, N > BitCast(Simd< T, N, 0 > d, Vec128< FromT, N *sizeof(T)/sizeof(FromT)> v)
Definition: arm_neon-inl.h:988
HWY_API Vec128< T, N > Zero(Simd< T, N, 0 > d)
Definition: arm_neon-inl.h:1011
HWY_API Vec128< T, N > IfThenZeroElse(const Mask128< T, N > mask, const Vec128< T, N > no)
Definition: arm_neon-inl.h:2219
HWY_API TFromV< V > GetLane(const V v)
Definition: arm_neon-inl.h:1061
typename D::template Repartition< T > Repartition
Definition: ops/shared-inl.h:206
N
Definition: rvv-inl.h:1742
Definition: aligned_allocator.h:27
HWY_API float F32FromBF16(bfloat16_t bf)
Definition: base.h:831
#define HWY_NAMESPACE
Definition: set_macros-inl.h:82
Definition: dot-inl.h:33
static HWY_INLINE T Compute(const D d, const T *const HWY_RESTRICT pa, const T *const HWY_RESTRICT pb, const size_t num_elements)
Definition: dot-inl.h:53
static HWY_INLINE float Compute(const D d, const bfloat16_t *const HWY_RESTRICT pa, const bfloat16_t *const HWY_RESTRICT pb, const size_t num_elements)
Definition: dot-inl.h:153
Assumptions
Definition: dot-inl.h:37
@ kMultipleOfVector
Definition: dot-inl.h:43
@ kPaddedToVector
Definition: dot-inl.h:46
@ kAtLeastOneVector
Definition: dot-inl.h:39
Definition: base.h:251