Halide  14.0.0
Halide compiler and libraries
gpu_context_common.h
Go to the documentation of this file.
1 #include "printer.h"
2 #include "scoped_mutex_lock.h"
3 
4 namespace Halide {
5 namespace Internal {
6 
7 template<typename ContextT, typename ModuleStateT>
9  struct CachedCompilation {
10  ContextT context{};
11  ModuleStateT module_state{};
12  uint32_t kernel_id{};
13  uint32_t use_count{0};
14 
15  CachedCompilation(ContextT context, ModuleStateT module_state,
16  uint32_t kernel_id, uint32_t use_count)
17  : context(context), module_state(module_state),
18  kernel_id(kernel_id), use_count(use_count) {
19  }
20  };
21 
22  halide_mutex mutex;
23 
24  static constexpr float kLoadFactor{.5f};
25  static constexpr int kInitialTableBits{7};
26  int log2_compilations_size{0}; // number of bits in index into compilations table.
27  CachedCompilation *compilations{nullptr};
28  int count{0};
29 
30  static constexpr uint32_t kInvalidId{0};
31  static constexpr uint32_t kDeletedId{1};
32 
33  uint32_t unique_id{2}; // zero is an invalid id
34 
35 public:
36  static ALWAYS_INLINE uintptr_t kernel_hash(ContextT context, uint32_t id, uint32_t bits) {
37  uintptr_t addr = (uintptr_t)context + id;
38  // Fibonacci hashing. The golden ratio is 1.9E3779B97F4A7C15F39...
39  // in hexadecimal.
40  if (sizeof(uintptr_t) >= 8) {
41  return (addr * (uintptr_t)0x9E3779B97F4A7C15) >> (64 - bits);
42  } else {
43  return (addr * (uintptr_t)0x9E3779B9) >> (32 - bits);
44  }
45  }
46 
47  HALIDE_MUST_USE_RESULT bool insert(const CachedCompilation &entry) {
48  if (log2_compilations_size == 0) {
49  if (!resize_table(kInitialTableBits)) {
50  return false;
51  }
52  }
53  if ((count + 1) > (1 << log2_compilations_size) * kLoadFactor) {
54  if (!resize_table(log2_compilations_size + 1)) {
55  return false;
56  }
57  }
58  count += 1;
59  uintptr_t index = kernel_hash(entry.context, entry.kernel_id, log2_compilations_size);
60  for (int i = 0; i < (1 << log2_compilations_size); i++) {
61  uintptr_t effective_index = (index + i) & ((1 << log2_compilations_size) - 1);
62  if (compilations[effective_index].kernel_id <= kDeletedId) {
63  compilations[effective_index] = entry;
64  return true;
65  }
66  }
67  // This is a logic error that should never occur. It means the table is
68  // full, but it should have been resized.
69  halide_debug_assert(nullptr, false);
70  return false;
71  }
72 
73  HALIDE_MUST_USE_RESULT bool find_internal(ContextT context, uint32_t id,
74  ModuleStateT *&module_state, int increment) {
75  if (log2_compilations_size == 0) {
76  return false;
77  }
78  uintptr_t index = kernel_hash(context, id, log2_compilations_size);
79  for (int i = 0; i < (1 << log2_compilations_size); i++) {
80  uintptr_t effective_index = (index + i) & ((1 << log2_compilations_size) - 1);
81 
82  if (compilations[effective_index].kernel_id == kInvalidId) {
83  return false;
84  }
85  if (compilations[effective_index].context == context &&
86  compilations[effective_index].kernel_id == id) {
87  module_state = &compilations[effective_index].module_state;
88  if (increment != 0) {
89  compilations[effective_index].use_count += increment;
90  }
91  return true;
92  }
93  }
94  return false;
95  }
96 
97  HALIDE_MUST_USE_RESULT bool lookup(ContextT context, void *state_ptr, ModuleStateT &module_state) {
98  ScopedMutexLock lock_guard(&mutex);
99  uint32_t id = (uint32_t)(uintptr_t)state_ptr;
100  ModuleStateT *mod_ptr;
101  if (find_internal(context, id, mod_ptr, 0)) {
102  module_state = *mod_ptr;
103  return true;
104  }
105  return false;
106  }
107 
108  HALIDE_MUST_USE_RESULT bool resize_table(int size_bits) {
109  if (size_bits != log2_compilations_size) {
110  int new_size = (1 << size_bits);
111  int old_size = (1 << log2_compilations_size);
112  CachedCompilation *new_table = (CachedCompilation *)malloc(new_size * sizeof(CachedCompilation));
113  if (new_table == nullptr) {
114  // signal error.
115  return false;
116  }
117  memset(new_table, 0, new_size * sizeof(CachedCompilation));
118  CachedCompilation *old_table = compilations;
119  compilations = new_table;
120  log2_compilations_size = size_bits;
121 
122  if (count > 0) { // Mainly to catch empty initial table case
123  for (int32_t i = 0; i < old_size; i++) {
124  if (old_table[i].kernel_id != kInvalidId &&
125  old_table[i].kernel_id != kDeletedId) {
126  bool result = insert(old_table[i]);
127  halide_debug_assert(nullptr, result); // Resizing the table while resizing the table is a logic error.
128  (void)result;
129  }
130  }
131  }
132  free(old_table);
133  }
134  return true;
135  }
136 
137  template<typename FreeModuleT>
138  void release_context(void *user_context, bool all, ContextT context, FreeModuleT &f) {
139  if (count == 0) {
140  return;
141  }
142 
143  for (int i = 0; i < (1 << log2_compilations_size); i++) {
144  if (compilations[i].kernel_id > kInvalidId &&
145  (all || (compilations[i].context == context)) &&
146  compilations[i].use_count == 0) {
147  debug(user_context) << "Releasing cached compilation: " << compilations[i].module_state
148  << " id " << compilations[i].kernel_id
149  << " context " << compilations[i].context << "\n";
150  f(compilations[i].module_state);
151  compilations[i].module_state = nullptr;
152  compilations[i].kernel_id = kDeletedId;
153  count--;
154  }
155  }
156  }
157 
158  template<typename FreeModuleT>
159  void delete_context(void *user_context, ContextT context, FreeModuleT &f) {
160  ScopedMutexLock lock_guard(&mutex);
161 
162  release_context(user_context, false, context, f);
163  }
164 
165  template<typename FreeModuleT>
166  void release_all(void *user_context, FreeModuleT &f) {
167  ScopedMutexLock lock_guard(&mutex);
168 
169  release_context(user_context, true, nullptr, f);
170  // Some items may have been in use, so can't free.
171  if (count == 0) {
172  free(compilations);
173  compilations = nullptr;
174  log2_compilations_size = 0;
175  }
176  }
177 
178  template<typename CompileModuleT, typename... Args>
179  HALIDE_MUST_USE_RESULT bool kernel_state_setup(void *user_context, void **state_ptr,
180  ContextT context, ModuleStateT &result,
181  CompileModuleT f,
182  Args... args) {
183  ScopedMutexLock lock_guard(&mutex);
184 
185  uint32_t *id_ptr = (uint32_t *)state_ptr;
186  if (*id_ptr == 0) {
187  *id_ptr = unique_id++;
188  }
189 
190  ModuleStateT *mod;
191  if (find_internal(context, *id_ptr, mod, 1)) {
192  result = *mod;
193  return true;
194  }
195 
196  // TODO(zvookin): figure out the calling signature here...
197  ModuleStateT compiled_module = f(args...);
198  debug(user_context) << "Caching compiled kernel: " << compiled_module
199  << " id " << *id_ptr << " context " << context << "\n";
200  if (compiled_module == nullptr) {
201  return false;
202  }
203 
204  if (!insert({context, compiled_module, *id_ptr, 1})) {
205  return false;
206  }
207  result = compiled_module;
208 
209  return true;
210  }
211 
212  void release_hold(void *user_context, ContextT context, void *state_ptr) {
213  ModuleStateT *mod;
214  uint32_t id = (uint32_t)(uintptr_t)state_ptr;
215  bool result = find_internal(context, id, mod, -1);
216  halide_debug_assert(user_context, result); // Value must be in cache to be released
217  (void)result;
218  }
219 };
220 
221 } // namespace Internal
222 } // namespace Halide
#define HALIDE_MUST_USE_RESULT
Definition: HalideRuntime.h:54
HALIDE_MUST_USE_RESULT bool lookup(ContextT context, void *state_ptr, ModuleStateT &module_state)
static ALWAYS_INLINE uintptr_t kernel_hash(ContextT context, uint32_t id, uint32_t bits)
void release_hold(void *user_context, ContextT context, void *state_ptr)
HALIDE_MUST_USE_RESULT bool insert(const CachedCompilation &entry)
HALIDE_MUST_USE_RESULT bool kernel_state_setup(void *user_context, void **state_ptr, ContextT context, ModuleStateT &result, CompileModuleT f, Args... args)
void release_all(void *user_context, FreeModuleT &f)
void delete_context(void *user_context, ContextT context, FreeModuleT &f)
HALIDE_MUST_USE_RESULT bool resize_table(int size_bits)
void release_context(void *user_context, bool all, ContextT context, FreeModuleT &f)
HALIDE_MUST_USE_RESULT bool find_internal(ContextT context, uint32_t id, ModuleStateT *&module_state, int increment)
For optional debugging during codegen, use the debug class as follows:
Definition: Debug.h:49
HALIDE_ALWAYS_INLINE auto mod(A &&a, B &&b) -> decltype(IRMatcher::operator%(a, b))
Definition: IRMatch.h:1066
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
@ Internal
Not visible externally, similar to 'static' linkage in C.
void * malloc(size_t)
#define halide_debug_assert(user_context, cond)
halide_debug_assert() is like halide_assert(), but only expands into a check when DEBUG_RUNTIME is de...
signed __INT32_TYPE__ int32_t
void * memset(void *s, int val, size_t n)
#define ALWAYS_INLINE
unsigned __INT32_TYPE__ uint32_t
void free(void *)
Cross-platform mutex.