Halide  14.0.0
Halide compiler and libraries
IRMatch.h
Go to the documentation of this file.
1 #ifndef HALIDE_IR_MATCH_H
2 #define HALIDE_IR_MATCH_H
3 
4 /** \file
5  * Defines a method to match a fragment of IR against a pattern containing wildcards
6  */
7 
8 #include <map>
9 #include <random>
10 #include <set>
11 #include <vector>
12 
13 #include "IR.h"
14 #include "IREquality.h"
15 #include "IROperator.h"
16 
17 namespace Halide {
18 namespace Internal {
19 
20 /** Does the first expression have the same structure as the second?
21  * Variables in the first expression with the name * are interpreted
22  * as wildcards, and their matching equivalent in the second
23  * expression is placed in the vector give as the third argument.
24  * Wildcards require the types to match. For the type bits and width,
25  * a 0 indicates "match anything". So an Int(8, 0) will match 8-bit
26  * integer vectors of any width (including scalars), and a UInt(0, 0)
27  * will match any unsigned integer type.
28  *
29  * For example:
30  \code
31  Expr x = Variable::make(Int(32), "*");
32  match(x + x, 3 + (2*k), result)
33  \endcode
34  * should return true, and set result[0] to 3 and
35  * result[1] to 2*k.
36  */
37 bool expr_match(const Expr &pattern, const Expr &expr, std::vector<Expr> &result);
38 
39 /** Does the first expression have the same structure as the second?
40  * Variables are matched consistently. The first time a variable is
41  * matched, it assumes the value of the matching part of the second
42  * expression. Subsequent matches must be equal to the first match.
43  *
44  * For example:
45  \code
46  Var x("x"), y("y");
47  match(x*(x + y), a*(a + b), result)
48  \endcode
49  * should return true, and set result["x"] = a, and result["y"] = b.
50  */
51 bool expr_match(const Expr &pattern, const Expr &expr, std::map<std::string, Expr> &result);
52 
53 /** Rewrite the expression x to have `lanes` lanes. This is useful
54  * for substituting the results of expr_match into a pattern expression. */
55 Expr with_lanes(const Expr &x, int lanes);
56 
58 
59 /** An alternative template-metaprogramming approach to expression
60  * matching. Potentially more efficient. We lift the expression
61  * pattern into a type, and then use force-inlined functions to
62  * generate efficient matching and reconstruction code for any
63  * pattern. Pattern elements are either one of the classes in the
64  * namespace IRMatcher, or are non-null Exprs (represented as
65  * BaseExprNode &).
66  *
67  * Pattern elements that are fully specified by their pattern can be
68  * built into an expression using the make method. Some patterns,
69  * such as a broadcast that matches any number of lanes, don't have
70  * enough information to recreate an Expr.
71  */
72 namespace IRMatcher {
73 
74 constexpr int max_wild = 6;
75 
76 static const halide_type_t i64_type = {halide_type_int, 64, 1};
77 
78 /** To save stack space, the matcher objects are largely stateless and
79  * immutable. This state object is built up during matching and then
80  * consumed when constructing a replacement Expr.
81  */
82 struct MatcherState {
85 
86  // values of the lanes field with special meaning.
87  static constexpr uint16_t signed_integer_overflow = 0x8000;
88  static constexpr uint16_t special_values_mask = 0x8000; // currently only one
89 
91 
93  void set_binding(int i, const BaseExprNode &n) noexcept {
94  bindings[i] = &n;
95  }
96 
98  const BaseExprNode *get_binding(int i) const noexcept {
99  return bindings[i];
100  }
101 
103  void set_bound_const(int i, int64_t s, halide_type_t t) noexcept {
104  bound_const[i].u.i64 = s;
105  bound_const_type[i] = t;
106  }
107 
109  void set_bound_const(int i, uint64_t u, halide_type_t t) noexcept {
110  bound_const[i].u.u64 = u;
111  bound_const_type[i] = t;
112  }
113 
115  void set_bound_const(int i, double f, halide_type_t t) noexcept {
116  bound_const[i].u.f64 = f;
117  bound_const_type[i] = t;
118  }
119 
121  void set_bound_const(int i, halide_scalar_value_t val, halide_type_t t) noexcept {
122  bound_const[i] = val;
123  bound_const_type[i] = t;
124  }
125 
127  void get_bound_const(int i, halide_scalar_value_t &val, halide_type_t &type) const noexcept {
128  val = bound_const[i];
129  type = bound_const_type[i];
130  }
131 
133  // NOLINTNEXTLINE(modernize-use-equals-default): Can't use `= default`; clang-tidy complains about noexcept mismatch
134  MatcherState() noexcept {
135  }
136 };
137 
138 template<typename T,
139  typename = typename std::remove_reference<T>::type::pattern_tag>
141  struct type {};
142 };
143 
144 template<typename T>
145 struct bindings {
146  constexpr static uint32_t mask = std::remove_reference<T>::type::binds;
147 };
148 
151  ty.lanes &= ~MatcherState::special_values_mask;
153  return make_signed_integer_overflow(ty);
154  }
155  // unreachable
156  return Expr();
157 }
158 
161  halide_type_t scalar_type = ty;
162  if (scalar_type.lanes & MatcherState::special_values_mask) {
163  return make_const_special_expr(scalar_type);
164  }
165 
166  const int lanes = scalar_type.lanes;
167  scalar_type.lanes = 1;
168 
169  Expr e;
170  switch (scalar_type.code) {
171  case halide_type_int:
172  e = IntImm::make(scalar_type, val.u.i64);
173  break;
174  case halide_type_uint:
175  e = UIntImm::make(scalar_type, val.u.u64);
176  break;
177  case halide_type_float:
178  case halide_type_bfloat:
179  e = FloatImm::make(scalar_type, val.u.f64);
180  break;
181  default:
182  // Unreachable
183  return Expr();
184  }
185  if (lanes > 1) {
186  e = Broadcast::make(e, lanes);
187  }
188  return e;
189 }
190 
191 bool equal_helper(const BaseExprNode &a, const BaseExprNode &b) noexcept;
192 
193 // A fast version of expression equality that assumes a well-typed non-null expression tree.
195 bool equal(const BaseExprNode &a, const BaseExprNode &b) noexcept {
196  // Early out
197  return (&a == &b) ||
198  ((a.type == b.type) &&
199  (a.node_type == b.node_type) &&
200  equal_helper(a, b));
201 }
202 
203 // A pattern that matches a specific expression
204 struct SpecificExpr {
205  struct pattern_tag {};
206 
207  constexpr static uint32_t binds = 0;
208 
209  // What is the weakest and strongest IR node this could possibly be
212  constexpr static bool canonical = true;
213 
215 
216  template<uint32_t bound>
217  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
218  return equal(expr, e);
219  }
220 
222  Expr make(MatcherState &state, halide_type_t type_hint) const {
223  return Expr(&expr);
224  }
225 
226  constexpr static bool foldable = false;
227 };
228 
229 inline std::ostream &operator<<(std::ostream &s, const SpecificExpr &e) {
230  s << Expr(&e.expr);
231  return s;
232 }
233 
234 template<int i>
235 struct WildConstInt {
236  struct pattern_tag {};
237 
238  constexpr static uint32_t binds = 1 << i;
239 
242  constexpr static bool canonical = true;
243 
244  template<uint32_t bound>
245  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
246  static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
247  const BaseExprNode *op = &e;
248  if (op->node_type == IRNodeType::Broadcast) {
249  op = ((const Broadcast *)op)->value.get();
250  }
251  if (op->node_type != IRNodeType::IntImm) {
252  return false;
253  }
254  int64_t value = ((const IntImm *)op)->value;
255  if (bound & binds) {
257  halide_type_t type;
258  state.get_bound_const(i, val, type);
259  return (halide_type_t)e.type == type && value == val.u.i64;
260  }
261  state.set_bound_const(i, value, e.type);
262  return true;
263  }
264 
265  template<uint32_t bound>
266  HALIDE_ALWAYS_INLINE bool match(int64_t value, MatcherState &state) const noexcept {
267  static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
268  if (bound & binds) {
270  halide_type_t type;
271  state.get_bound_const(i, val, type);
272  return type == i64_type && value == val.u.i64;
273  }
274  state.set_bound_const(i, value, i64_type);
275  return true;
276  }
277 
279  Expr make(MatcherState &state, halide_type_t type_hint) const {
281  halide_type_t type;
282  state.get_bound_const(i, val, type);
283  return make_const_expr(val, type);
284  }
285 
286  constexpr static bool foldable = true;
287 
290  state.get_bound_const(i, val, ty);
291  }
292 };
293 
294 template<int i>
295 std::ostream &operator<<(std::ostream &s, const WildConstInt<i> &c) {
296  s << "ci" << i;
297  return s;
298 }
299 
300 template<int i>
302  struct pattern_tag {};
303 
304  constexpr static uint32_t binds = 1 << i;
305 
308  constexpr static bool canonical = true;
309 
310  template<uint32_t bound>
311  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
312  static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
313  const BaseExprNode *op = &e;
314  if (op->node_type == IRNodeType::Broadcast) {
315  op = ((const Broadcast *)op)->value.get();
316  }
317  if (op->node_type != IRNodeType::UIntImm) {
318  return false;
319  }
320  uint64_t value = ((const UIntImm *)op)->value;
321  if (bound & binds) {
323  halide_type_t type;
324  state.get_bound_const(i, val, type);
325  return (halide_type_t)e.type == type && value == val.u.u64;
326  }
327  state.set_bound_const(i, value, e.type);
328  return true;
329  }
330 
332  Expr make(MatcherState &state, halide_type_t type_hint) const {
334  halide_type_t type;
335  state.get_bound_const(i, val, type);
336  return make_const_expr(val, type);
337  }
338 
339  constexpr static bool foldable = true;
340 
342  void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
343  state.get_bound_const(i, val, ty);
344  }
345 };
346 
347 template<int i>
348 std::ostream &operator<<(std::ostream &s, const WildConstUInt<i> &c) {
349  s << "cu" << i;
350  return s;
351 }
352 
353 template<int i>
355  struct pattern_tag {};
356 
357  constexpr static uint32_t binds = 1 << i;
358 
361  constexpr static bool canonical = true;
362 
363  template<uint32_t bound>
364  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
365  static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
366  const BaseExprNode *op = &e;
367  if (op->node_type == IRNodeType::Broadcast) {
368  op = ((const Broadcast *)op)->value.get();
369  }
370  if (op->node_type != IRNodeType::FloatImm) {
371  return false;
372  }
373  double value = ((const FloatImm *)op)->value;
374  if (bound & binds) {
376  halide_type_t type;
377  state.get_bound_const(i, val, type);
378  return (halide_type_t)e.type == type && value == val.u.f64;
379  }
380  state.set_bound_const(i, value, e.type);
381  return true;
382  }
383 
385  Expr make(MatcherState &state, halide_type_t type_hint) const {
387  halide_type_t type;
388  state.get_bound_const(i, val, type);
389  return make_const_expr(val, type);
390  }
391 
392  constexpr static bool foldable = true;
393 
395  void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
396  state.get_bound_const(i, val, ty);
397  }
398 };
399 
400 template<int i>
401 std::ostream &operator<<(std::ostream &s, const WildConstFloat<i> &c) {
402  s << "cf" << i;
403  return s;
404 }
405 
406 // Matches and binds to any constant Expr. Does not support constant-folding.
407 template<int i>
408 struct WildConst {
409  struct pattern_tag {};
410 
411  constexpr static uint32_t binds = 1 << i;
412 
415  constexpr static bool canonical = true;
416 
417  template<uint32_t bound>
418  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
419  static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
420  const BaseExprNode *op = &e;
421  if (op->node_type == IRNodeType::Broadcast) {
422  op = ((const Broadcast *)op)->value.get();
423  }
424  switch (op->node_type) {
425  case IRNodeType::IntImm:
426  return WildConstInt<i>().template match<bound>(e, state);
427  case IRNodeType::UIntImm:
428  return WildConstUInt<i>().template match<bound>(e, state);
430  return WildConstFloat<i>().template match<bound>(e, state);
431  default:
432  return false;
433  }
434  }
435 
436  template<uint32_t bound>
437  HALIDE_ALWAYS_INLINE bool match(int64_t e, MatcherState &state) const noexcept {
438  static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
439  return WildConstInt<i>().template match<bound>(e, state);
440  }
441 
443  Expr make(MatcherState &state, halide_type_t type_hint) const {
445  halide_type_t type;
446  state.get_bound_const(i, val, type);
447  return make_const_expr(val, type);
448  }
449 
450  constexpr static bool foldable = true;
451 
453  void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
454  state.get_bound_const(i, val, ty);
455  }
456 };
457 
458 template<int i>
459 std::ostream &operator<<(std::ostream &s, const WildConst<i> &c) {
460  s << "c" << i;
461  return s;
462 }
463 
464 // Matches and binds to any Expr
465 template<int i>
466 struct Wild {
467  struct pattern_tag {};
468 
469  constexpr static uint32_t binds = 1 << (i + 16);
470 
473  constexpr static bool canonical = true;
474 
475  template<uint32_t bound>
476  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
477  if (bound & binds) {
478  return equal(*state.get_binding(i), e);
479  }
480  state.set_binding(i, e);
481  return true;
482  }
483 
485  Expr make(MatcherState &state, halide_type_t type_hint) const {
486  return state.get_binding(i);
487  }
488 
489  constexpr static bool foldable = false;
490 };
491 
492 template<int i>
493 std::ostream &operator<<(std::ostream &s, const Wild<i> &op) {
494  s << "_" << i;
495  return s;
496 }
497 
498 // Matches a specific constant or broadcast of that constant. The
499 // constant must be representable as an int64_t.
500 struct IntLiteral {
501  struct pattern_tag {};
503 
504  constexpr static uint32_t binds = 0;
505 
508  constexpr static bool canonical = true;
509 
511  explicit IntLiteral(int64_t v)
512  : v(v) {
513  }
514 
515  template<uint32_t bound>
516  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
517  const BaseExprNode *op = &e;
518  if (e.node_type == IRNodeType::Broadcast) {
519  op = ((const Broadcast *)op)->value.get();
520  }
521  switch (op->node_type) {
522  case IRNodeType::IntImm:
523  return ((const IntImm *)op)->value == (int64_t)v;
524  case IRNodeType::UIntImm:
525  return ((const UIntImm *)op)->value == (uint64_t)v;
527  return ((const FloatImm *)op)->value == (double)v;
528  default:
529  return false;
530  }
531  }
532 
533  template<uint32_t bound>
534  HALIDE_ALWAYS_INLINE bool match(int64_t val, MatcherState &state) const noexcept {
535  return v == val;
536  }
537 
538  template<uint32_t bound>
539  HALIDE_ALWAYS_INLINE bool match(const IntLiteral &b, MatcherState &state) const noexcept {
540  return v == b.v;
541  }
542 
544  Expr make(MatcherState &state, halide_type_t type_hint) const {
545  return make_const(type_hint, v);
546  }
547 
548  constexpr static bool foldable = true;
549 
551  void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
552  // Assume type is already correct
553  switch (ty.code) {
554  case halide_type_int:
555  val.u.i64 = v;
556  break;
557  case halide_type_uint:
558  val.u.u64 = (uint64_t)v;
559  break;
560  case halide_type_float:
561  case halide_type_bfloat:
562  val.u.f64 = (double)v;
563  break;
564  default:
565  // Unreachable
566  ;
567  }
568  }
569 };
570 
572  return t.v;
573 }
574 
575 // Convert a provided pattern, expr, or constant int into the internal
576 // representation we use in the matcher trees.
577 template<typename T,
578  typename = typename std::decay<T>::type::pattern_tag>
580  return t;
581 }
584  return IntLiteral{x};
585 }
586 
587 template<typename T>
589  static_assert(!std::is_same<typename std::decay<T>::type, Expr>::value || std::is_lvalue_reference<T>::value,
590  "Exprs are captured by reference by IRMatcher objects and so must be lvalues");
591 }
592 
594  return {*e.get()};
595 }
596 
597 // Helpers to deref SpecificExprs to const BaseExprNode & rather than
598 // passing them by value anywhere (incurring lots of refcounting)
599 template<typename T,
600  // T must be a pattern node
601  typename = typename std::decay<T>::type::pattern_tag,
602  // But T may not be SpecificExpr
603  typename = typename std::enable_if<!std::is_same<typename std::decay<T>::type, SpecificExpr>::value>::type>
605  return t;
606 }
607 
609 const BaseExprNode &unwrap(const SpecificExpr &e) {
610  return e.expr;
611 }
612 
613 inline std::ostream &operator<<(std::ostream &s, const IntLiteral &op) {
614  s << op.v;
615  return s;
616 }
617 
618 template<typename Op>
620 
621 template<typename Op>
623 
624 template<typename Op>
625 double constant_fold_bin_op(halide_type_t &, double, double) noexcept;
626 
627 constexpr bool commutative(IRNodeType t) {
628  return (t == IRNodeType::Add ||
629  t == IRNodeType::Mul ||
630  t == IRNodeType::And ||
631  t == IRNodeType::Or ||
632  t == IRNodeType::Min ||
633  t == IRNodeType::Max ||
634  t == IRNodeType::EQ ||
635  t == IRNodeType::NE);
636 }
637 
638 // Matches one of the binary operators
639 template<typename Op, typename A, typename B>
640 struct BinOp {
641  struct pattern_tag {};
642  A a;
643  B b;
644 
646 
647  constexpr static IRNodeType min_node_type = Op::_node_type;
648  constexpr static IRNodeType max_node_type = Op::_node_type;
649 
650  // For commutative bin ops, we expect the weaker IR node type on
651  // the right. That is, for the rule to be canonical it must be
652  // possible that A is at least as strong as B.
653  constexpr static bool canonical =
654  A::canonical && B::canonical && (!commutative(Op::_node_type) || (A::max_node_type >= B::min_node_type));
655 
656  template<uint32_t bound>
657  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
658  if (e.node_type != Op::_node_type) {
659  return false;
660  }
661  const Op &op = (const Op &)e;
662  return (a.template match<bound>(*op.a.get(), state) &&
663  b.template match<bound | bindings<A>::mask>(*op.b.get(), state));
664  }
665 
666  template<uint32_t bound, typename Op2, typename A2, typename B2>
667  HALIDE_ALWAYS_INLINE bool match(const BinOp<Op2, A2, B2> &op, MatcherState &state) const noexcept {
668  return (std::is_same<Op, Op2>::value &&
669  a.template match<bound>(unwrap(op.a), state) &&
670  b.template match<bound | bindings<A>::mask>(unwrap(op.b), state));
671  }
672 
673  constexpr static bool foldable = A::foldable && B::foldable;
674 
676  void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
677  halide_scalar_value_t val_a, val_b;
678  if (std::is_same<A, IntLiteral>::value) {
679  b.make_folded_const(val_b, ty, state);
680  if ((std::is_same<Op, And>::value && val_b.u.u64 == 0) ||
681  (std::is_same<Op, Or>::value && val_b.u.u64 == 1)) {
682  // Short circuit
683  val = val_b;
684  return;
685  }
686  const uint16_t l = ty.lanes;
687  a.make_folded_const(val_a, ty, state);
688  ty.lanes |= l; // Make sure the overflow bits are sticky
689  } else {
690  a.make_folded_const(val_a, ty, state);
691  if ((std::is_same<Op, And>::value && val_a.u.u64 == 0) ||
692  (std::is_same<Op, Or>::value && val_a.u.u64 == 1)) {
693  // Short circuit
694  val = val_a;
695  return;
696  }
697  const uint16_t l = ty.lanes;
698  b.make_folded_const(val_b, ty, state);
699  ty.lanes |= l;
700  }
701  switch (ty.code) {
702  case halide_type_int:
703  val.u.i64 = constant_fold_bin_op<Op>(ty, val_a.u.i64, val_b.u.i64);
704  break;
705  case halide_type_uint:
706  val.u.u64 = constant_fold_bin_op<Op>(ty, val_a.u.u64, val_b.u.u64);
707  break;
708  case halide_type_float:
709  case halide_type_bfloat:
710  val.u.f64 = constant_fold_bin_op<Op>(ty, val_a.u.f64, val_b.u.f64);
711  break;
712  default:
713  // unreachable
714  ;
715  }
716  }
717 
719  Expr make(MatcherState &state, halide_type_t type_hint) const noexcept {
720  Expr ea, eb;
721  if (std::is_same<A, IntLiteral>::value) {
722  eb = b.make(state, type_hint);
723  ea = a.make(state, eb.type());
724  } else {
725  ea = a.make(state, type_hint);
726  eb = b.make(state, ea.type());
727  }
728  // We sometimes mix vectors and scalars in the rewrite rules,
729  // so insert a broadcast if necessary.
730  if (ea.type().is_vector() && !eb.type().is_vector()) {
731  eb = Broadcast::make(eb, ea.type().lanes());
732  }
733  if (eb.type().is_vector() && !ea.type().is_vector()) {
734  ea = Broadcast::make(ea, eb.type().lanes());
735  }
736  return Op::make(std::move(ea), std::move(eb));
737  }
738 };
739 
740 template<typename Op>
742 
743 template<typename Op>
745 
746 template<typename Op>
747 uint64_t constant_fold_cmp_op(double, double) noexcept;
748 
749 // Matches one of the comparison operators
750 template<typename Op, typename A, typename B>
751 struct CmpOp {
752  struct pattern_tag {};
753  A a;
754  B b;
755 
757 
758  constexpr static IRNodeType min_node_type = Op::_node_type;
759  constexpr static IRNodeType max_node_type = Op::_node_type;
760  constexpr static bool canonical = (A::canonical &&
761  B::canonical &&
762  (!commutative(Op::_node_type) || A::max_node_type >= B::min_node_type) &&
763  (Op::_node_type != IRNodeType::GE) &&
764  (Op::_node_type != IRNodeType::GT));
765 
766  template<uint32_t bound>
767  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
768  if (e.node_type != Op::_node_type) {
769  return false;
770  }
771  const Op &op = (const Op &)e;
772  return (a.template match<bound>(*op.a.get(), state) &&
773  b.template match<bound | bindings<A>::mask>(*op.b.get(), state));
774  }
775 
776  template<uint32_t bound, typename Op2, typename A2, typename B2>
777  HALIDE_ALWAYS_INLINE bool match(const CmpOp<Op2, A2, B2> &op, MatcherState &state) const noexcept {
778  return (std::is_same<Op, Op2>::value &&
779  a.template match<bound>(unwrap(op.a), state) &&
780  b.template match<bound | bindings<A>::mask>(unwrap(op.b), state));
781  }
782 
783  constexpr static bool foldable = A::foldable && B::foldable;
784 
786  void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
787  halide_scalar_value_t val_a, val_b;
788  // If one side is an untyped const, evaluate the other side first to get a type hint.
789  if (std::is_same<A, IntLiteral>::value) {
790  b.make_folded_const(val_b, ty, state);
791  const uint16_t l = ty.lanes;
792  a.make_folded_const(val_a, ty, state);
793  ty.lanes |= l;
794  } else {
795  a.make_folded_const(val_a, ty, state);
796  const uint16_t l = ty.lanes;
797  b.make_folded_const(val_b, ty, state);
798  ty.lanes |= l;
799  }
800  switch (ty.code) {
801  case halide_type_int:
802  val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.i64, val_b.u.i64);
803  break;
804  case halide_type_uint:
805  val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.u64, val_b.u.u64);
806  break;
807  case halide_type_float:
808  case halide_type_bfloat:
809  val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.f64, val_b.u.f64);
810  break;
811  default:
812  // unreachable
813  ;
814  }
815  ty.code = halide_type_uint;
816  ty.bits = 1;
817  }
818 
820  Expr make(MatcherState &state, halide_type_t type_hint) const {
821  // If one side is an untyped const, evaluate the other side first to get a type hint.
822  Expr ea, eb;
823  if (std::is_same<A, IntLiteral>::value) {
824  eb = b.make(state, {});
825  ea = a.make(state, eb.type());
826  } else {
827  ea = a.make(state, {});
828  eb = b.make(state, ea.type());
829  }
830  // We sometimes mix vectors and scalars in the rewrite rules,
831  // so insert a broadcast if necessary.
832  if (ea.type().is_vector() && !eb.type().is_vector()) {
833  eb = Broadcast::make(eb, ea.type().lanes());
834  }
835  if (eb.type().is_vector() && !ea.type().is_vector()) {
836  ea = Broadcast::make(ea, eb.type().lanes());
837  }
838  return Op::make(std::move(ea), std::move(eb));
839  }
840 };
841 
842 template<typename A, typename B>
843 std::ostream &operator<<(std::ostream &s, const BinOp<Add, A, B> &op) {
844  s << "(" << op.a << " + " << op.b << ")";
845  return s;
846 }
847 
848 template<typename A, typename B>
849 std::ostream &operator<<(std::ostream &s, const BinOp<Sub, A, B> &op) {
850  s << "(" << op.a << " - " << op.b << ")";
851  return s;
852 }
853 
854 template<typename A, typename B>
855 std::ostream &operator<<(std::ostream &s, const BinOp<Mul, A, B> &op) {
856  s << "(" << op.a << " * " << op.b << ")";
857  return s;
858 }
859 
860 template<typename A, typename B>
861 std::ostream &operator<<(std::ostream &s, const BinOp<Div, A, B> &op) {
862  s << "(" << op.a << " / " << op.b << ")";
863  return s;
864 }
865 
866 template<typename A, typename B>
867 std::ostream &operator<<(std::ostream &s, const BinOp<And, A, B> &op) {
868  s << "(" << op.a << " && " << op.b << ")";
869  return s;
870 }
871 
872 template<typename A, typename B>
873 std::ostream &operator<<(std::ostream &s, const BinOp<Or, A, B> &op) {
874  s << "(" << op.a << " || " << op.b << ")";
875  return s;
876 }
877 
878 template<typename A, typename B>
879 std::ostream &operator<<(std::ostream &s, const BinOp<Min, A, B> &op) {
880  s << "min(" << op.a << ", " << op.b << ")";
881  return s;
882 }
883 
884 template<typename A, typename B>
885 std::ostream &operator<<(std::ostream &s, const BinOp<Max, A, B> &op) {
886  s << "max(" << op.a << ", " << op.b << ")";
887  return s;
888 }
889 
890 template<typename A, typename B>
891 std::ostream &operator<<(std::ostream &s, const CmpOp<LE, A, B> &op) {
892  s << "(" << op.a << " <= " << op.b << ")";
893  return s;
894 }
895 
896 template<typename A, typename B>
897 std::ostream &operator<<(std::ostream &s, const CmpOp<LT, A, B> &op) {
898  s << "(" << op.a << " < " << op.b << ")";
899  return s;
900 }
901 
902 template<typename A, typename B>
903 std::ostream &operator<<(std::ostream &s, const CmpOp<GE, A, B> &op) {
904  s << "(" << op.a << " >= " << op.b << ")";
905  return s;
906 }
907 
908 template<typename A, typename B>
909 std::ostream &operator<<(std::ostream &s, const CmpOp<GT, A, B> &op) {
910  s << "(" << op.a << " > " << op.b << ")";
911  return s;
912 }
913 
914 template<typename A, typename B>
915 std::ostream &operator<<(std::ostream &s, const CmpOp<EQ, A, B> &op) {
916  s << "(" << op.a << " == " << op.b << ")";
917  return s;
918 }
919 
920 template<typename A, typename B>
921 std::ostream &operator<<(std::ostream &s, const CmpOp<NE, A, B> &op) {
922  s << "(" << op.a << " != " << op.b << ")";
923  return s;
924 }
925 
926 template<typename A, typename B>
927 std::ostream &operator<<(std::ostream &s, const BinOp<Mod, A, B> &op) {
928  s << "(" << op.a << " % " << op.b << ")";
929  return s;
930 }
931 
932 template<typename A, typename B>
933 HALIDE_ALWAYS_INLINE auto operator+(A &&a, B &&b) noexcept -> BinOp<Add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
934  assert_is_lvalue_if_expr<A>();
935  assert_is_lvalue_if_expr<B>();
936  return {pattern_arg(a), pattern_arg(b)};
937 }
938 
939 template<typename A, typename B>
940 HALIDE_ALWAYS_INLINE auto add(A &&a, B &&b) -> decltype(IRMatcher::operator+(a, b)) {
941  assert_is_lvalue_if_expr<A>();
942  assert_is_lvalue_if_expr<B>();
943  return IRMatcher::operator+(a, b);
944 }
945 
946 template<>
948  t.lanes |= ((t.bits >= 32) && add_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
949  int dead_bits = 64 - t.bits;
950  // Drop the high bits then sign-extend them back
951  return int64_t((uint64_t(a) + uint64_t(b)) << dead_bits) >> dead_bits;
952 }
953 
954 template<>
956  uint64_t ones = (uint64_t)(-1);
957  return (a + b) & (ones >> (64 - t.bits));
958 }
959 
960 template<>
961 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Add>(halide_type_t &t, double a, double b) noexcept {
962  return a + b;
963 }
964 
965 template<typename A, typename B>
966 HALIDE_ALWAYS_INLINE auto operator-(A &&a, B &&b) noexcept -> BinOp<Sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
967  assert_is_lvalue_if_expr<A>();
968  assert_is_lvalue_if_expr<B>();
969  return {pattern_arg(a), pattern_arg(b)};
970 }
971 
972 template<typename A, typename B>
973 HALIDE_ALWAYS_INLINE auto sub(A &&a, B &&b) -> decltype(IRMatcher::operator-(a, b)) {
974  assert_is_lvalue_if_expr<A>();
975  assert_is_lvalue_if_expr<B>();
976  return IRMatcher::operator-(a, b);
977 }
978 
979 template<>
981  t.lanes |= ((t.bits >= 32) && sub_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
982  // Drop the high bits then sign-extend them back
983  int dead_bits = 64 - t.bits;
984  return int64_t((uint64_t(a) - uint64_t(b)) << dead_bits) >> dead_bits;
985 }
986 
987 template<>
989  uint64_t ones = (uint64_t)(-1);
990  return (a - b) & (ones >> (64 - t.bits));
991 }
992 
993 template<>
994 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Sub>(halide_type_t &t, double a, double b) noexcept {
995  return a - b;
996 }
997 
998 template<typename A, typename B>
999 HALIDE_ALWAYS_INLINE auto operator*(A &&a, B &&b) noexcept -> BinOp<Mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1000  assert_is_lvalue_if_expr<A>();
1001  assert_is_lvalue_if_expr<B>();
1002  return {pattern_arg(a), pattern_arg(b)};
1003 }
1004 
1005 template<typename A, typename B>
1006 HALIDE_ALWAYS_INLINE auto mul(A &&a, B &&b) -> decltype(IRMatcher::operator*(a, b)) {
1007  assert_is_lvalue_if_expr<A>();
1008  assert_is_lvalue_if_expr<B>();
1009  return IRMatcher::operator*(a, b);
1010 }
1011 
1012 template<>
1014  t.lanes |= ((t.bits >= 32) && mul_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
1015  int dead_bits = 64 - t.bits;
1016  // Drop the high bits then sign-extend them back
1017  return int64_t((uint64_t(a) * uint64_t(b)) << dead_bits) >> dead_bits;
1018 }
1019 
1020 template<>
1022  uint64_t ones = (uint64_t)(-1);
1023  return (a * b) & (ones >> (64 - t.bits));
1024 }
1025 
1026 template<>
1027 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Mul>(halide_type_t &t, double a, double b) noexcept {
1028  return a * b;
1029 }
1030 
1031 template<typename A, typename B>
1032 HALIDE_ALWAYS_INLINE auto operator/(A &&a, B &&b) noexcept -> BinOp<Div, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1033  assert_is_lvalue_if_expr<A>();
1034  assert_is_lvalue_if_expr<B>();
1035  return {pattern_arg(a), pattern_arg(b)};
1036 }
1037 
1038 template<typename A, typename B>
1039 HALIDE_ALWAYS_INLINE auto div(A &&a, B &&b) -> decltype(IRMatcher::operator/(a, b)) {
1040  return IRMatcher::operator/(a, b);
1041 }
1042 
1043 template<>
1045  return div_imp(a, b);
1046 }
1047 
1048 template<>
1050  return div_imp(a, b);
1051 }
1052 
1053 template<>
1054 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Div>(halide_type_t &t, double a, double b) noexcept {
1055  return div_imp(a, b);
1056 }
1057 
1058 template<typename A, typename B>
1059 HALIDE_ALWAYS_INLINE auto operator%(A &&a, B &&b) noexcept -> BinOp<Mod, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1060  assert_is_lvalue_if_expr<A>();
1061  assert_is_lvalue_if_expr<B>();
1062  return {pattern_arg(a), pattern_arg(b)};
1063 }
1064 
1065 template<typename A, typename B>
1066 HALIDE_ALWAYS_INLINE auto mod(A &&a, B &&b) -> decltype(IRMatcher::operator%(a, b)) {
1067  assert_is_lvalue_if_expr<A>();
1068  assert_is_lvalue_if_expr<B>();
1069  return IRMatcher::operator%(a, b);
1070 }
1071 
1072 template<>
1074  return mod_imp(a, b);
1075 }
1076 
1077 template<>
1079  return mod_imp(a, b);
1080 }
1081 
1082 template<>
1083 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Mod>(halide_type_t &t, double a, double b) noexcept {
1084  return mod_imp(a, b);
1085 }
1086 
1087 template<typename A, typename B>
1088 HALIDE_ALWAYS_INLINE auto min(A &&a, B &&b) noexcept -> BinOp<Min, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1089  assert_is_lvalue_if_expr<A>();
1090  assert_is_lvalue_if_expr<B>();
1091  return {pattern_arg(a), pattern_arg(b)};
1092 }
1093 
1094 template<>
1096  return std::min(a, b);
1097 }
1098 
1099 template<>
1101  return std::min(a, b);
1102 }
1103 
1104 template<>
1105 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Min>(halide_type_t &t, double a, double b) noexcept {
1106  return std::min(a, b);
1107 }
1108 
1109 template<typename A, typename B>
1110 HALIDE_ALWAYS_INLINE auto max(A &&a, B &&b) noexcept -> BinOp<Max, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1111  assert_is_lvalue_if_expr<A>();
1112  assert_is_lvalue_if_expr<B>();
1113  return {pattern_arg(std::forward<A>(a)), pattern_arg(std::forward<B>(b))};
1114 }
1115 
1116 template<>
1118  return std::max(a, b);
1119 }
1120 
1121 template<>
1123  return std::max(a, b);
1124 }
1125 
1126 template<>
1127 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Max>(halide_type_t &t, double a, double b) noexcept {
1128  return std::max(a, b);
1129 }
1130 
1131 template<typename A, typename B>
1132 HALIDE_ALWAYS_INLINE auto operator<(A &&a, B &&b) noexcept -> CmpOp<LT, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1133  return {pattern_arg(a), pattern_arg(b)};
1134 }
1135 
1136 template<typename A, typename B>
1137 HALIDE_ALWAYS_INLINE auto lt(A &&a, B &&b) -> decltype(IRMatcher::operator<(a, b)) {
1138  return IRMatcher::operator<(a, b);
1139 }
1140 
1141 template<>
1143  return a < b;
1144 }
1145 
1146 template<>
1148  return a < b;
1149 }
1150 
1151 template<>
1153  return a < b;
1154 }
1155 
1156 template<typename A, typename B>
1157 HALIDE_ALWAYS_INLINE auto operator>(A &&a, B &&b) noexcept -> CmpOp<GT, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1158  return {pattern_arg(a), pattern_arg(b)};
1159 }
1160 
1161 template<typename A, typename B>
1162 HALIDE_ALWAYS_INLINE auto gt(A &&a, B &&b) -> decltype(IRMatcher::operator>(a, b)) {
1163  return IRMatcher::operator>(a, b);
1164 }
1165 
1166 template<>
1168  return a > b;
1169 }
1170 
1171 template<>
1173  return a > b;
1174 }
1175 
1176 template<>
1178  return a > b;
1179 }
1180 
1181 template<typename A, typename B>
1182 HALIDE_ALWAYS_INLINE auto operator<=(A &&a, B &&b) noexcept -> CmpOp<LE, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1183  return {pattern_arg(a), pattern_arg(b)};
1184 }
1185 
1186 template<typename A, typename B>
1187 HALIDE_ALWAYS_INLINE auto le(A &&a, B &&b) -> decltype(IRMatcher::operator<=(a, b)) {
1188  return IRMatcher::operator<=(a, b);
1189 }
1190 
1191 template<>
1193  return a <= b;
1194 }
1195 
1196 template<>
1198  return a <= b;
1199 }
1200 
1201 template<>
1203  return a <= b;
1204 }
1205 
1206 template<typename A, typename B>
1207 HALIDE_ALWAYS_INLINE auto operator>=(A &&a, B &&b) noexcept -> CmpOp<GE, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1208  return {pattern_arg(a), pattern_arg(b)};
1209 }
1210 
1211 template<typename A, typename B>
1212 HALIDE_ALWAYS_INLINE auto ge(A &&a, B &&b) -> decltype(IRMatcher::operator>=(a, b)) {
1213  return IRMatcher::operator>=(a, b);
1214 }
1215 
1216 template<>
1218  return a >= b;
1219 }
1220 
1221 template<>
1223  return a >= b;
1224 }
1225 
1226 template<>
1228  return a >= b;
1229 }
1230 
1231 template<typename A, typename B>
1232 HALIDE_ALWAYS_INLINE auto operator==(A &&a, B &&b) noexcept -> CmpOp<EQ, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1233  return {pattern_arg(a), pattern_arg(b)};
1234 }
1235 
1236 template<typename A, typename B>
1237 HALIDE_ALWAYS_INLINE auto eq(A &&a, B &&b) -> decltype(IRMatcher::operator==(a, b)) {
1238  return IRMatcher::operator==(a, b);
1239 }
1240 
1241 template<>
1243  return a == b;
1244 }
1245 
1246 template<>
1248  return a == b;
1249 }
1250 
1251 template<>
1253  return a == b;
1254 }
1255 
1256 template<typename A, typename B>
1257 HALIDE_ALWAYS_INLINE auto operator!=(A &&a, B &&b) noexcept -> CmpOp<NE, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1258  return {pattern_arg(a), pattern_arg(b)};
1259 }
1260 
1261 template<typename A, typename B>
1262 HALIDE_ALWAYS_INLINE auto ne(A &&a, B &&b) -> decltype(IRMatcher::operator!=(a, b)) {
1263  return IRMatcher::operator!=(a, b);
1264 }
1265 
1266 template<>
1268  return a != b;
1269 }
1270 
1271 template<>
1273  return a != b;
1274 }
1275 
1276 template<>
1278  return a != b;
1279 }
1280 
1281 template<typename A, typename B>
1282 HALIDE_ALWAYS_INLINE auto operator||(A &&a, B &&b) noexcept -> BinOp<Or, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1283  return {pattern_arg(a), pattern_arg(b)};
1284 }
1285 
1286 template<typename A, typename B>
1287 HALIDE_ALWAYS_INLINE auto or_op(A &&a, B &&b) -> decltype(IRMatcher::operator||(a, b)) {
1288  return IRMatcher::operator||(a, b);
1289 }
1290 
1291 template<>
1293  return (a | b) & 1;
1294 }
1295 
1296 template<>
1298  return (a | b) & 1;
1299 }
1300 
1301 template<>
1302 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Or>(halide_type_t &t, double a, double b) noexcept {
1303  // Unreachable, as it would be a type mismatch.
1304  return 0;
1305 }
1306 
1307 template<typename A, typename B>
1308 HALIDE_ALWAYS_INLINE auto operator&&(A &&a, B &&b) noexcept -> BinOp<And, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1309  return {pattern_arg(a), pattern_arg(b)};
1310 }
1311 
1312 template<typename A, typename B>
1313 HALIDE_ALWAYS_INLINE auto and_op(A &&a, B &&b) -> decltype(IRMatcher::operator&&(a, b)) {
1314  return IRMatcher::operator&&(a, b);
1315 }
1316 
1317 template<>
1319  return a & b & 1;
1320 }
1321 
1322 template<>
1324  return a & b & 1;
1325 }
1326 
1327 template<>
1328 HALIDE_ALWAYS_INLINE double constant_fold_bin_op<And>(halide_type_t &t, double a, double b) noexcept {
1329  // Unreachable
1330  return 0;
1331 }
1332 
1333 constexpr inline uint32_t bitwise_or_reduce() {
1334  return 0;
1335 }
1336 
1337 template<typename... Args>
1338 constexpr uint32_t bitwise_or_reduce(uint32_t first, Args... rest) {
1339  return first | bitwise_or_reduce(rest...);
1340 }
1341 
1342 constexpr inline bool and_reduce() {
1343  return true;
1344 }
1345 
1346 template<typename... Args>
1347 constexpr bool and_reduce(bool first, Args... rest) {
1348  return first && and_reduce(rest...);
1349 }
1350 
1351 // TODO: this can be replaced with std::min() once we require C++14 or later
1352 constexpr int const_min(int a, int b) {
1353  return a < b ? a : b;
1354 }
1355 
1356 template<typename... Args>
1357 struct Intrin {
1358  struct pattern_tag {};
1360  std::tuple<Args...> args;
1361 
1363 
1366  constexpr static bool canonical = and_reduce((Args::canonical)...);
1367 
1368  template<int i,
1369  uint32_t bound,
1370  typename = typename std::enable_if<(i < sizeof...(Args))>::type>
1371  HALIDE_ALWAYS_INLINE bool match_args(int, const Call &c, MatcherState &state) const noexcept {
1372  using T = decltype(std::get<i>(args));
1373  return (std::get<i>(args).template match<bound>(*c.args[i].get(), state) &&
1374  match_args<i + 1, bound | bindings<T>::mask>(0, c, state));
1375  }
1376 
1377  template<int i, uint32_t binds>
1378  HALIDE_ALWAYS_INLINE bool match_args(double, const Call &c, MatcherState &state) const noexcept {
1379  return true;
1380  }
1381 
1382  template<uint32_t bound>
1383  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1384  if (e.node_type != IRNodeType::Call) {
1385  return false;
1386  }
1387  const Call &c = (const Call &)e;
1388  return (c.is_intrinsic(intrin) && match_args<0, bound>(0, c, state));
1389  }
1390 
1391  template<int i,
1392  typename = typename std::enable_if<(i < sizeof...(Args))>::type>
1393  HALIDE_ALWAYS_INLINE void print_args(int, std::ostream &s) const {
1394  s << std::get<i>(args);
1395  if (i + 1 < sizeof...(Args)) {
1396  s << ", ";
1397  }
1398  print_args<i + 1>(0, s);
1399  }
1400 
1401  template<int i>
1402  HALIDE_ALWAYS_INLINE void print_args(double, std::ostream &s) const {
1403  }
1404 
1406  void print_args(std::ostream &s) const {
1407  print_args<0>(0, s);
1408  }
1409 
1411  Expr make(MatcherState &state, halide_type_t type_hint) const {
1412  Expr arg0 = std::get<0>(args).make(state, type_hint);
1413  if (intrin == Call::likely) {
1414  return likely(arg0);
1415  } else if (intrin == Call::likely_if_innermost) {
1416  return likely_if_innermost(arg0);
1417  } else if (intrin == Call::abs) {
1418  return abs(arg0);
1419  }
1420 
1421  Expr arg1 = std::get<const_min(1, sizeof...(Args) - 1)>(args).make(state, type_hint);
1422  if (intrin == Call::absd) {
1423  return absd(arg0, arg1);
1424  } else if (intrin == Call::widening_add) {
1425  return widening_add(arg0, arg1);
1426  } else if (intrin == Call::widening_sub) {
1427  return widening_sub(arg0, arg1);
1428  } else if (intrin == Call::widening_mul) {
1429  return widening_mul(arg0, arg1);
1430  } else if (intrin == Call::saturating_add) {
1431  return saturating_add(arg0, arg1);
1432  } else if (intrin == Call::saturating_sub) {
1433  return saturating_sub(arg0, arg1);
1434  } else if (intrin == Call::halving_add) {
1435  return halving_add(arg0, arg1);
1436  } else if (intrin == Call::halving_sub) {
1437  return halving_sub(arg0, arg1);
1438  } else if (intrin == Call::rounding_halving_add) {
1439  return rounding_halving_add(arg0, arg1);
1440  } else if (intrin == Call::rounding_halving_sub) {
1441  return rounding_halving_sub(arg0, arg1);
1442  } else if (intrin == Call::shift_left) {
1443  return arg0 << arg1;
1444  } else if (intrin == Call::shift_right) {
1445  return arg0 >> arg1;
1446  } else if (intrin == Call::rounding_shift_left) {
1447  return rounding_shift_left(arg0, arg1);
1448  } else if (intrin == Call::rounding_shift_right) {
1449  return rounding_shift_right(arg0, arg1);
1450  }
1451 
1452  Expr arg2 = std::get<const_min(2, sizeof...(Args) - 1)>(args).make(state, type_hint);
1453  if (intrin == Call::mul_shift_right) {
1454  return mul_shift_right(arg0, arg1, arg2);
1455  } else if (intrin == Call::rounding_mul_shift_right) {
1456  return rounding_mul_shift_right(arg0, arg1, arg2);
1457  }
1458 
1459  internal_error << "Unhandled intrinsic in IRMatcher: " << intrin;
1460  return Expr();
1461  }
1462 
1463  constexpr static bool foldable = true;
1464 
1466  halide_scalar_value_t arg1;
1467  // Assuming the args have the same type as the intrinsic is incorrect in
1468  // general. But for the intrinsics we can fold (just shifts), the LHS
1469  // has the same type as the intrinsic, and we can always treat the RHS
1470  // as a signed int, because we're using 64 bits for it.
1471  std::get<0>(args).make_folded_const(val, ty, state);
1472  halide_type_t signed_ty = ty;
1473  signed_ty.code = halide_type_int;
1474  // We can just directly get the second arg here, because we only want to
1475  // instantiate this method for shifts, which have two args.
1476  std::get<1>(args).make_folded_const(arg1, signed_ty, state);
1477 
1478  if (intrin == Call::shift_left) {
1479  if (arg1.u.i64 < 0) {
1480  if (ty.code == halide_type_int) {
1481  // Arithmetic shift
1482  val.u.i64 >>= -arg1.u.i64;
1483  } else {
1484  // Logical shift
1485  val.u.u64 >>= -arg1.u.i64;
1486  }
1487  } else {
1488  val.u.u64 <<= arg1.u.i64;
1489  }
1490  } else if (intrin == Call::shift_right) {
1491  if (arg1.u.i64 > 0) {
1492  if (ty.code == halide_type_int) {
1493  // Arithmetic shift
1494  val.u.i64 >>= arg1.u.i64;
1495  } else {
1496  // Logical shift
1497  val.u.u64 >>= arg1.u.i64;
1498  }
1499  } else {
1500  val.u.u64 <<= -arg1.u.i64;
1501  }
1502  } else {
1503  internal_error << "Folding not implemented for intrinsic: " << intrin;
1504  }
1505  }
1506 
1509  : intrin(intrin), args(args...) {
1510  }
1511 };
1512 
1513 template<typename... Args>
1514 std::ostream &operator<<(std::ostream &s, const Intrin<Args...> &op) {
1515  s << op.intrin << "(";
1516  op.print_args(s);
1517  s << ")";
1518  return s;
1519 }
1520 
1521 template<typename... Args>
1522 HALIDE_ALWAYS_INLINE auto intrin(Call::IntrinsicOp intrinsic_op, Args... args) noexcept -> Intrin<decltype(pattern_arg(args))...> {
1523  return {intrinsic_op, pattern_arg(args)...};
1524 }
1525 
1526 template<typename A, typename B>
1527 auto widening_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1528  return {Call::widening_add, pattern_arg(a), pattern_arg(b)};
1529 }
1530 template<typename A, typename B>
1531 auto widening_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1532  return {Call::widening_sub, pattern_arg(a), pattern_arg(b)};
1533 }
1534 template<typename A, typename B>
1535 auto widening_mul(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1536  return {Call::widening_mul, pattern_arg(a), pattern_arg(b)};
1537 }
1538 template<typename A, typename B>
1539 auto saturating_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1540  return {Call::saturating_add, pattern_arg(a), pattern_arg(b)};
1541 }
1542 template<typename A, typename B>
1543 auto saturating_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1544  return {Call::saturating_sub, pattern_arg(a), pattern_arg(b)};
1545 }
1546 template<typename A, typename B>
1547 auto halving_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1548  return {Call::halving_add, pattern_arg(a), pattern_arg(b)};
1549 }
1550 template<typename A, typename B>
1551 auto halving_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1552  return {Call::halving_sub, pattern_arg(a), pattern_arg(b)};
1553 }
1554 template<typename A, typename B>
1555 auto rounding_halving_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1557 }
1558 template<typename A, typename B>
1559 auto rounding_halving_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1561 }
1562 template<typename A, typename B>
1563 auto shift_left(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1564  return {Call::shift_left, pattern_arg(a), pattern_arg(b)};
1565 }
1566 template<typename A, typename B>
1567 auto shift_right(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1568  return {Call::shift_right, pattern_arg(a), pattern_arg(b)};
1569 }
1570 template<typename A, typename B>
1571 auto rounding_shift_left(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1573 }
1574 template<typename A, typename B>
1575 auto rounding_shift_right(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1577 }
1578 template<typename A, typename B, typename C>
1579 auto mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
1581 }
1582 template<typename A, typename B, typename C>
1583 auto rounding_mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
1585 }
1586 
1587 template<typename A>
1588 struct NotOp {
1589  struct pattern_tag {};
1590  A a;
1591 
1592  constexpr static uint32_t binds = bindings<A>::mask;
1593 
1596  constexpr static bool canonical = A::canonical;
1597 
1598  template<uint32_t bound>
1599  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1600  if (e.node_type != IRNodeType::Not) {
1601  return false;
1602  }
1603  const Not &op = (const Not &)e;
1604  return (a.template match<bound>(*op.a.get(), state));
1605  }
1606 
1607  template<uint32_t bound, typename A2>
1608  HALIDE_ALWAYS_INLINE bool match(const NotOp<A2> &op, MatcherState &state) const noexcept {
1609  return a.template match<bound>(unwrap(op.a), state);
1610  }
1611 
1613  Expr make(MatcherState &state, halide_type_t type_hint) const {
1614  return Not::make(a.make(state, type_hint));
1615  }
1616 
1617  constexpr static bool foldable = A::foldable;
1618 
1619  template<typename A1 = A>
1621  a.make_folded_const(val, ty, state);
1622  val.u.u64 = ~val.u.u64;
1623  val.u.u64 &= 1;
1624  }
1625 };
1626 
1627 template<typename A>
1628 HALIDE_ALWAYS_INLINE auto operator!(A &&a) noexcept -> NotOp<decltype(pattern_arg(a))> {
1629  assert_is_lvalue_if_expr<A>();
1630  return {pattern_arg(a)};
1631 }
1632 
1633 template<typename A>
1634 HALIDE_ALWAYS_INLINE auto not_op(A &&a) -> decltype(IRMatcher::operator!(a)) {
1635  assert_is_lvalue_if_expr<A>();
1636  return IRMatcher::operator!(a);
1637 }
1638 
1639 template<typename A>
1640 inline std::ostream &operator<<(std::ostream &s, const NotOp<A> &op) {
1641  s << "!(" << op.a << ")";
1642  return s;
1643 }
1644 
1645 template<typename C, typename T, typename F>
1646 struct SelectOp {
1647  struct pattern_tag {};
1648  C c;
1649  T t;
1650  F f;
1651 
1653 
1656 
1657  constexpr static bool canonical = C::canonical && T::canonical && F::canonical;
1658 
1659  template<uint32_t bound>
1660  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1661  if (e.node_type != Select::_node_type) {
1662  return false;
1663  }
1664  const Select &op = (const Select &)e;
1665  return (c.template match<bound>(*op.condition.get(), state) &&
1666  t.template match<bound | bindings<C>::mask>(*op.true_value.get(), state) &&
1667  f.template match<bound | bindings<C>::mask | bindings<T>::mask>(*op.false_value.get(), state));
1668  }
1669  template<uint32_t bound, typename C2, typename T2, typename F2>
1670  HALIDE_ALWAYS_INLINE bool match(const SelectOp<C2, T2, F2> &instance, MatcherState &state) const noexcept {
1671  return (c.template match<bound>(unwrap(instance.c), state) &&
1672  t.template match<bound | bindings<C>::mask>(unwrap(instance.t), state) &&
1673  f.template match<bound | bindings<C>::mask | bindings<T>::mask>(unwrap(instance.f), state));
1674  }
1675 
1677  Expr make(MatcherState &state, halide_type_t type_hint) const {
1678  return Select::make(c.make(state, {}), t.make(state, type_hint), f.make(state, type_hint));
1679  }
1680 
1681  constexpr static bool foldable = C::foldable && T::foldable && F::foldable;
1682 
1683  template<typename C1 = C>
1685  halide_scalar_value_t c_val, t_val, f_val;
1686  halide_type_t c_ty;
1687  c.make_folded_const(c_val, c_ty, state);
1688  if ((c_val.u.u64 & 1) == 1) {
1689  t.make_folded_const(val, ty, state);
1690  } else {
1691  f.make_folded_const(val, ty, state);
1692  }
1693  ty.lanes |= c_ty.lanes & MatcherState::special_values_mask;
1694  }
1695 };
1696 
1697 template<typename C, typename T, typename F>
1698 std::ostream &operator<<(std::ostream &s, const SelectOp<C, T, F> &op) {
1699  s << "select(" << op.c << ", " << op.t << ", " << op.f << ")";
1700  return s;
1701 }
1702 
1703 template<typename C, typename T, typename F>
1704 HALIDE_ALWAYS_INLINE auto select(C &&c, T &&t, F &&f) noexcept -> SelectOp<decltype(pattern_arg(c)), decltype(pattern_arg(t)), decltype(pattern_arg(f))> {
1705  assert_is_lvalue_if_expr<C>();
1706  assert_is_lvalue_if_expr<T>();
1707  assert_is_lvalue_if_expr<F>();
1708  return {pattern_arg(c), pattern_arg(t), pattern_arg(f)};
1709 }
1710 
1711 template<typename A, typename B>
1712 struct BroadcastOp {
1713  struct pattern_tag {};
1714  A a;
1716 
1718 
1721 
1722  constexpr static bool canonical = A::canonical && B::canonical;
1723 
1724  template<uint32_t bound>
1725  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1726  if (e.node_type == Broadcast::_node_type) {
1727  const Broadcast &op = (const Broadcast &)e;
1728  if (a.template match<bound>(*op.value.get(), state) &&
1729  lanes.template match<bound>(op.lanes, state)) {
1730  return true;
1731  }
1732  }
1733  return false;
1734  }
1735 
1736  template<uint32_t bound, typename A2, typename B2>
1737  HALIDE_ALWAYS_INLINE bool match(const BroadcastOp<A2, B2> &op, MatcherState &state) const noexcept {
1738  return (a.template match<bound>(unwrap(op.a), state) &&
1739  lanes.template match<bound | bindings<A>::mask>(unwrap(op.lanes), state));
1740  }
1741 
1743  Expr make(MatcherState &state, halide_type_t type_hint) const {
1744  halide_scalar_value_t lanes_val;
1745  halide_type_t ty;
1746  lanes.make_folded_const(lanes_val, ty, state);
1747  int32_t l = (int32_t)lanes_val.u.i64;
1748  type_hint.lanes /= l;
1749  Expr val = a.make(state, type_hint);
1750  if (l == 1) {
1751  return val;
1752  } else {
1753  return Broadcast::make(std::move(val), l);
1754  }
1755  }
1756 
1757  constexpr static bool foldable = false;
1758 
1759  template<typename A1 = A>
1761  halide_scalar_value_t lanes_val;
1762  halide_type_t lanes_ty;
1763  lanes.make_folded_const(lanes_val, lanes_ty, state);
1764  uint16_t l = (uint16_t)lanes_val.u.i64;
1765  a.make_folded_const(val, ty, state);
1766  ty.lanes = l | (ty.lanes & MatcherState::special_values_mask);
1767  }
1768 };
1769 
1770 template<typename A, typename B>
1771 inline std::ostream &operator<<(std::ostream &s, const BroadcastOp<A, B> &op) {
1772  s << "broadcast(" << op.a << ", " << op.lanes << ")";
1773  return s;
1774 }
1775 
1776 template<typename A, typename B>
1777 HALIDE_ALWAYS_INLINE auto broadcast(A &&a, B lanes) noexcept -> BroadcastOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes))> {
1778  assert_is_lvalue_if_expr<A>();
1779  return {pattern_arg(a), pattern_arg(lanes)};
1780 }
1781 
1782 template<typename A, typename B, typename C>
1783 struct RampOp {
1784  struct pattern_tag {};
1785  A a;
1786  B b;
1788 
1790 
1793 
1794  constexpr static bool canonical = A::canonical && B::canonical && C::canonical;
1795 
1796  template<uint32_t bound>
1797  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1798  if (e.node_type != Ramp::_node_type) {
1799  return false;
1800  }
1801  const Ramp &op = (const Ramp &)e;
1802  if (a.template match<bound>(*op.base.get(), state) &&
1803  b.template match<bound | bindings<A>::mask>(*op.stride.get(), state) &&
1804  lanes.template match<bound | bindings<A>::mask | bindings<B>::mask>(op.lanes, state)) {
1805  return true;
1806  } else {
1807  return false;
1808  }
1809  }
1810 
1811  template<uint32_t bound, typename A2, typename B2, typename C2>
1812  HALIDE_ALWAYS_INLINE bool match(const RampOp<A2, B2, C2> &op, MatcherState &state) const noexcept {
1813  return (a.template match<bound>(unwrap(op.a), state) &&
1814  b.template match<bound | bindings<A>::mask>(unwrap(op.b), state) &&
1815  lanes.template match<bound | bindings<A>::mask | bindings<B>::mask>(unwrap(op.lanes), state));
1816  }
1817 
1819  Expr make(MatcherState &state, halide_type_t type_hint) const {
1820  halide_scalar_value_t lanes_val;
1821  halide_type_t ty;
1822  lanes.make_folded_const(lanes_val, ty, state);
1823  int32_t l = (int32_t)lanes_val.u.i64;
1824  type_hint.lanes /= l;
1825  Expr ea, eb;
1826  eb = b.make(state, type_hint);
1827  ea = a.make(state, eb.type());
1828  return Ramp::make(ea, eb, l);
1829  }
1830 
1831  constexpr static bool foldable = false;
1832 };
1833 
1834 template<typename A, typename B, typename C>
1835 std::ostream &operator<<(std::ostream &s, const RampOp<A, B, C> &op) {
1836  s << "ramp(" << op.a << ", " << op.b << ", " << op.lanes << ")";
1837  return s;
1838 }
1839 
1840 template<typename A, typename B, typename C>
1841 HALIDE_ALWAYS_INLINE auto ramp(A &&a, B &&b, C &&c) noexcept -> RampOp<decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
1842  assert_is_lvalue_if_expr<A>();
1843  assert_is_lvalue_if_expr<B>();
1844  assert_is_lvalue_if_expr<C>();
1845  return {pattern_arg(a), pattern_arg(b), pattern_arg(c)};
1846 }
1847 
1848 template<typename A, typename B, VectorReduce::Operator reduce_op>
1850  struct pattern_tag {};
1851  A a;
1853 
1854  constexpr static uint32_t binds = bindings<A>::mask;
1855 
1858  constexpr static bool canonical = A::canonical;
1859 
1860  template<uint32_t bound>
1861  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1862  if (e.node_type == VectorReduce::_node_type) {
1863  const VectorReduce &op = (const VectorReduce &)e;
1864  if (op.op == reduce_op &&
1865  a.template match<bound>(*op.value.get(), state) &&
1866  lanes.template match<bound | bindings<A>::mask>(op.type.lanes(), state)) {
1867  return true;
1868  }
1869  }
1870  return false;
1871  }
1872 
1873  template<uint32_t bound, typename A2, typename B2, VectorReduce::Operator reduce_op_2>
1875  return (reduce_op == reduce_op_2 &&
1876  a.template match<bound>(unwrap(op.a), state) &&
1877  lanes.template match<bound | bindings<A>::mask>(unwrap(op.lanes), state));
1878  }
1879 
1881  Expr make(MatcherState &state, halide_type_t type_hint) const {
1882  halide_scalar_value_t lanes_val;
1883  halide_type_t ty;
1884  lanes.make_folded_const(lanes_val, ty, state);
1885  int l = (int)lanes_val.u.i64;
1886  return VectorReduce::make(reduce_op, a.make(state, type_hint), l);
1887  }
1888 
1889  constexpr static bool foldable = false;
1890 };
1891 
1892 template<typename A, typename B, VectorReduce::Operator reduce_op>
1893 inline std::ostream &operator<<(std::ostream &s, const VectorReduceOp<A, B, reduce_op> &op) {
1894  s << "vector_reduce(" << reduce_op << ", " << op.a << ", " << op.lanes << ")";
1895  return s;
1896 }
1897 
1898 template<typename A, typename B>
1899 HALIDE_ALWAYS_INLINE auto h_add(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Add> {
1900  assert_is_lvalue_if_expr<A>();
1901  return {pattern_arg(a), pattern_arg(lanes)};
1902 }
1903 
1904 template<typename A, typename B>
1905 HALIDE_ALWAYS_INLINE auto h_min(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Min> {
1906  assert_is_lvalue_if_expr<A>();
1907  return {pattern_arg(a), pattern_arg(lanes)};
1908 }
1909 
1910 template<typename A, typename B>
1911 HALIDE_ALWAYS_INLINE auto h_max(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Max> {
1912  assert_is_lvalue_if_expr<A>();
1913  return {pattern_arg(a), pattern_arg(lanes)};
1914 }
1915 
1916 template<typename A, typename B>
1917 HALIDE_ALWAYS_INLINE auto h_and(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::And> {
1918  assert_is_lvalue_if_expr<A>();
1919  return {pattern_arg(a), pattern_arg(lanes)};
1920 }
1921 
1922 template<typename A, typename B>
1923 HALIDE_ALWAYS_INLINE auto h_or(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Or> {
1924  assert_is_lvalue_if_expr<A>();
1925  return {pattern_arg(a), pattern_arg(lanes)};
1926 }
1927 
1928 template<typename A>
1929 struct NegateOp {
1930  struct pattern_tag {};
1931  A a;
1932 
1933  constexpr static uint32_t binds = bindings<A>::mask;
1934 
1937 
1938  constexpr static bool canonical = A::canonical;
1939 
1940  template<uint32_t bound>
1941  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1942  if (e.node_type != Sub::_node_type) {
1943  return false;
1944  }
1945  const Sub &op = (const Sub &)e;
1946  return (a.template match<bound>(*op.b.get(), state) &&
1947  is_const_zero(op.a));
1948  }
1949 
1950  template<uint32_t bound, typename A2>
1951  HALIDE_ALWAYS_INLINE bool match(NegateOp<A2> &&p, MatcherState &state) const noexcept {
1952  return a.template match<bound>(unwrap(p.a), state);
1953  }
1954 
1956  Expr make(MatcherState &state, halide_type_t type_hint) const {
1957  Expr ea = a.make(state, type_hint);
1958  Expr z = make_zero(ea.type());
1959  return Sub::make(std::move(z), std::move(ea));
1960  }
1961 
1962  constexpr static bool foldable = A::foldable;
1963 
1964  template<typename A1 = A>
1966  a.make_folded_const(val, ty, state);
1967  int dead_bits = 64 - ty.bits;
1968  switch (ty.code) {
1969  case halide_type_int:
1970  if (ty.bits >= 32 && val.u.u64 && (val.u.u64 << (65 - ty.bits)) == 0) {
1971  // Trying to negate the most negative signed int for a no-overflow type.
1973  } else {
1974  // Negate, drop the high bits, and then sign-extend them back
1975  val.u.i64 = int64_t(uint64_t(-val.u.i64) << dead_bits) >> dead_bits;
1976  }
1977  break;
1978  case halide_type_uint:
1979  val.u.u64 = ((-val.u.u64) << dead_bits) >> dead_bits;
1980  break;
1981  case halide_type_float:
1982  case halide_type_bfloat:
1983  val.u.f64 = -val.u.f64;
1984  break;
1985  default:
1986  // unreachable
1987  ;
1988  }
1989  }
1990 };
1991 
1992 template<typename A>
1993 std::ostream &operator<<(std::ostream &s, const NegateOp<A> &op) {
1994  s << "-" << op.a;
1995  return s;
1996 }
1997 
1998 template<typename A>
1999 HALIDE_ALWAYS_INLINE auto operator-(A &&a) noexcept -> NegateOp<decltype(pattern_arg(a))> {
2000  assert_is_lvalue_if_expr<A>();
2001  return {pattern_arg(a)};
2002 }
2003 
2004 template<typename A>
2005 HALIDE_ALWAYS_INLINE auto negate(A &&a) -> decltype(IRMatcher::operator-(a)) {
2006  assert_is_lvalue_if_expr<A>();
2007  return IRMatcher::operator-(a);
2008 }
2009 
2010 template<typename A>
2011 struct CastOp {
2012  struct pattern_tag {};
2014  A a;
2015 
2016  constexpr static uint32_t binds = bindings<A>::mask;
2017 
2020  constexpr static bool canonical = A::canonical;
2021 
2022  template<uint32_t bound>
2023  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2024  if (e.node_type != Cast::_node_type) {
2025  return false;
2026  }
2027  const Cast &op = (const Cast &)e;
2028  return (e.type == t &&
2029  a.template match<bound>(*op.value.get(), state));
2030  }
2031  template<uint32_t bound, typename A2>
2032  HALIDE_ALWAYS_INLINE bool match(const CastOp<A2> &op, MatcherState &state) const noexcept {
2033  return t == op.t && a.template match<bound>(unwrap(op.a), state);
2034  }
2035 
2037  Expr make(MatcherState &state, halide_type_t type_hint) const {
2038  return cast(t, a.make(state, {}));
2039  }
2040 
2041  constexpr static bool foldable = false;
2042 };
2043 
2044 template<typename A>
2045 std::ostream &operator<<(std::ostream &s, const CastOp<A> &op) {
2046  s << "cast(" << op.t << ", " << op.a << ")";
2047  return s;
2048 }
2049 
2050 template<typename A>
2051 HALIDE_ALWAYS_INLINE auto cast(halide_type_t t, A &&a) noexcept -> CastOp<decltype(pattern_arg(a))> {
2052  assert_is_lvalue_if_expr<A>();
2053  return {t, pattern_arg(a)};
2054 }
2055 
2056 template<typename A>
2057 struct Fold {
2058  struct pattern_tag {};
2059  A a;
2060 
2061  constexpr static uint32_t binds = bindings<A>::mask;
2062 
2065  constexpr static bool canonical = true;
2066 
2068  Expr make(MatcherState &state, halide_type_t type_hint) const noexcept {
2070  halide_type_t ty = type_hint;
2071  a.make_folded_const(c, ty, state);
2072 
2073  // The result of the fold may have an underspecified type
2074  // (e.g. because it's from an int literal). Make the type code
2075  // and bits match the required type, if there is one (we can
2076  // tell from the bits field).
2077  if (type_hint.bits) {
2078  if (((int)ty.code == (int)halide_type_int) &&
2079  ((int)type_hint.code == (int)halide_type_float)) {
2080  int64_t x = c.u.i64;
2081  c.u.f64 = (double)x;
2082  }
2083  ty.code = type_hint.code;
2084  ty.bits = type_hint.bits;
2085  }
2086 
2087  Expr e = make_const_expr(c, ty);
2088  return e;
2089  }
2090 
2091  constexpr static bool foldable = A::foldable;
2092 
2093  template<typename A1 = A>
2095  a.make_folded_const(val, ty, state);
2096  }
2097 };
2098 
2099 template<typename A>
2100 HALIDE_ALWAYS_INLINE auto fold(A &&a) noexcept -> Fold<decltype(pattern_arg(a))> {
2101  assert_is_lvalue_if_expr<A>();
2102  return {pattern_arg(a)};
2103 }
2104 
2105 template<typename A>
2106 std::ostream &operator<<(std::ostream &s, const Fold<A> &op) {
2107  s << "fold(" << op.a << ")";
2108  return s;
2109 }
2110 
2111 template<typename A>
2112 struct Overflows {
2113  struct pattern_tag {};
2114  A a;
2115 
2116  constexpr static uint32_t binds = bindings<A>::mask;
2117 
2118  // This rule is a predicate, so it always evaluates to a boolean,
2119  // which has IRNodeType UIntImm
2122  constexpr static bool canonical = true;
2123 
2124  constexpr static bool foldable = A::foldable;
2125 
2126  template<typename A1 = A>
2128  a.make_folded_const(val, ty, state);
2129  ty.code = halide_type_uint;
2130  ty.bits = 64;
2131  val.u.u64 = (ty.lanes & MatcherState::special_values_mask) != 0;
2132  ty.lanes = 1;
2133  }
2134 };
2135 
2136 template<typename A>
2137 HALIDE_ALWAYS_INLINE auto overflows(A &&a) noexcept -> Overflows<decltype(pattern_arg(a))> {
2138  assert_is_lvalue_if_expr<A>();
2139  return {pattern_arg(a)};
2140 }
2141 
2142 template<typename A>
2143 std::ostream &operator<<(std::ostream &s, const Overflows<A> &op) {
2144  s << "overflows(" << op.a << ")";
2145  return s;
2146 }
2147 
2148 struct Overflow {
2149  struct pattern_tag {};
2150 
2151  constexpr static uint32_t binds = 0;
2152 
2153  // Overflow is an intrinsic, represented as a Call node
2156  constexpr static bool canonical = true;
2157 
2158  template<uint32_t bound>
2159  HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2160  if (e.node_type != Call::_node_type) {
2161  return false;
2162  }
2163  const Call &op = (const Call &)e;
2165  }
2166 
2168  Expr make(MatcherState &state, halide_type_t type_hint) const {
2170  return make_const_special_expr(type_hint);
2171  }
2172 
2173  constexpr static bool foldable = true;
2174 
2176  void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
2177  val.u.u64 = 0;
2179  }
2180 };
2181 
2182 inline std::ostream &operator<<(std::ostream &s, const Overflow &op) {
2183  s << "overflow()";
2184  return s;
2185 }
2186 
2187 template<typename A>
2188 struct IsConst {
2189  struct pattern_tag {};
2190 
2191  constexpr static uint32_t binds = bindings<A>::mask;
2192 
2193  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2196  constexpr static bool canonical = true;
2197 
2198  A a;
2199  bool check_v;
2201 
2202  constexpr static bool foldable = true;
2203 
2204  template<typename A1 = A>
2206  Expr e = a.make(state, {});
2207  ty.code = halide_type_uint;
2208  ty.bits = 64;
2209  ty.lanes = 1;
2210  if (check_v) {
2211  val.u.u64 = ::Halide::Internal::is_const(e, v) ? 1 : 0;
2212  } else {
2213  val.u.u64 = ::Halide::Internal::is_const(e) ? 1 : 0;
2214  }
2215  }
2216 };
2217 
2218 template<typename A>
2219 HALIDE_ALWAYS_INLINE auto is_const(A &&a) noexcept -> IsConst<decltype(pattern_arg(a))> {
2220  assert_is_lvalue_if_expr<A>();
2221  return {pattern_arg(a), false, 0};
2222 }
2223 
2224 template<typename A>
2225 HALIDE_ALWAYS_INLINE auto is_const(A &&a, int64_t value) noexcept -> IsConst<decltype(pattern_arg(a))> {
2226  assert_is_lvalue_if_expr<A>();
2227  return {pattern_arg(a), true, value};
2228 }
2229 
2230 template<typename A>
2231 std::ostream &operator<<(std::ostream &s, const IsConst<A> &op) {
2232  if (op.check_v) {
2233  s << "is_const(" << op.a << ")";
2234  } else {
2235  s << "is_const(" << op.a << ", " << op.v << ")";
2236  }
2237  return s;
2238 }
2239 
2240 template<typename A, typename Prover>
2241 struct CanProve {
2242  struct pattern_tag {};
2243  A a;
2244  Prover *prover; // An existing simplifying mutator
2245 
2246  constexpr static uint32_t binds = bindings<A>::mask;
2247 
2248  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2251  constexpr static bool canonical = true;
2252 
2253  constexpr static bool foldable = true;
2254 
2255  // Includes a raw call to an inlined make method, so don't inline.
2257  Expr condition = a.make(state, {});
2258  condition = prover->mutate(condition, nullptr);
2259  val.u.u64 = is_const_one(condition);
2260  ty.code = halide_type_uint;
2261  ty.bits = 1;
2262  ty.lanes = condition.type().lanes();
2263  }
2264 };
2265 
2266 template<typename A, typename Prover>
2267 HALIDE_ALWAYS_INLINE auto can_prove(A &&a, Prover *p) noexcept -> CanProve<decltype(pattern_arg(a)), Prover> {
2268  assert_is_lvalue_if_expr<A>();
2269  return {pattern_arg(a), p};
2270 }
2271 
2272 template<typename A, typename Prover>
2273 std::ostream &operator<<(std::ostream &s, const CanProve<A, Prover> &op) {
2274  s << "can_prove(" << op.a << ")";
2275  return s;
2276 }
2277 
2278 template<typename A>
2279 struct IsFloat {
2280  struct pattern_tag {};
2281  A a;
2282 
2283  constexpr static uint32_t binds = bindings<A>::mask;
2284 
2285  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2288  constexpr static bool canonical = true;
2289 
2290  constexpr static bool foldable = true;
2291 
2294  // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2295  Type t = a.make(state, {}).type();
2296  val.u.u64 = t.is_float();
2297  ty.code = halide_type_uint;
2298  ty.bits = 1;
2299  ty.lanes = t.lanes();
2300  }
2301 };
2302 
2303 template<typename A>
2304 HALIDE_ALWAYS_INLINE auto is_float(A &&a) noexcept -> IsFloat<decltype(pattern_arg(a))> {
2305  assert_is_lvalue_if_expr<A>();
2306  return {pattern_arg(a)};
2307 }
2308 
2309 template<typename A>
2310 std::ostream &operator<<(std::ostream &s, const IsFloat<A> &op) {
2311  s << "is_float(" << op.a << ")";
2312  return s;
2313 }
2314 
2315 template<typename A>
2316 struct IsInt {
2317  struct pattern_tag {};
2318  A a;
2319  int bits;
2320 
2321  constexpr static uint32_t binds = bindings<A>::mask;
2322 
2323  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2326  constexpr static bool canonical = true;
2327 
2328  constexpr static bool foldable = true;
2329 
2332  // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2333  Type t = a.make(state, {}).type();
2334  val.u.u64 = t.is_int() && (bits == 0 || t.bits() == bits);
2335  ty.code = halide_type_uint;
2336  ty.bits = 1;
2337  ty.lanes = t.lanes();
2338  }
2339 };
2340 
2341 template<typename A>
2342 HALIDE_ALWAYS_INLINE auto is_int(A &&a, int bits = 0) noexcept -> IsInt<decltype(pattern_arg(a))> {
2343  assert_is_lvalue_if_expr<A>();
2344  return {pattern_arg(a), bits};
2345 }
2346 
2347 template<typename A>
2348 std::ostream &operator<<(std::ostream &s, const IsInt<A> &op) {
2349  s << "is_int(" << op.a;
2350  if (op.bits > 0) {
2351  s << ", " << op.bits;
2352  }
2353  s << ")";
2354  return s;
2355 }
2356 
2357 template<typename A>
2358 struct IsUInt {
2359  struct pattern_tag {};
2360  A a;
2361  int bits;
2362 
2363  constexpr static uint32_t binds = bindings<A>::mask;
2364 
2365  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2368  constexpr static bool canonical = true;
2369 
2370  constexpr static bool foldable = true;
2371 
2374  // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2375  Type t = a.make(state, {}).type();
2376  val.u.u64 = t.is_uint() && (bits == 0 || t.bits() == bits);
2377  ty.code = halide_type_uint;
2378  ty.bits = 1;
2379  ty.lanes = t.lanes();
2380  }
2381 };
2382 
2383 template<typename A>
2384 HALIDE_ALWAYS_INLINE auto is_uint(A &&a, int bits = 0) noexcept -> IsUInt<decltype(pattern_arg(a))> {
2385  assert_is_lvalue_if_expr<A>();
2386  return {pattern_arg(a), bits};
2387 }
2388 
2389 template<typename A>
2390 std::ostream &operator<<(std::ostream &s, const IsUInt<A> &op) {
2391  s << "is_uint(" << op.a;
2392  if (op.bits > 0) {
2393  s << ", " << op.bits;
2394  }
2395  s << ")";
2396  return s;
2397 }
2398 
2399 template<typename A>
2400 struct IsScalar {
2401  struct pattern_tag {};
2402  A a;
2403 
2404  constexpr static uint32_t binds = bindings<A>::mask;
2405 
2406  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2409  constexpr static bool canonical = true;
2410 
2411  constexpr static bool foldable = true;
2412 
2415  // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2416  Type t = a.make(state, {}).type();
2417  val.u.u64 = t.is_scalar();
2418  ty.code = halide_type_uint;
2419  ty.bits = 1;
2420  ty.lanes = t.lanes();
2421  }
2422 };
2423 
2424 template<typename A>
2425 HALIDE_ALWAYS_INLINE auto is_scalar(A &&a) noexcept -> IsScalar<decltype(pattern_arg(a))> {
2426  assert_is_lvalue_if_expr<A>();
2427  return {pattern_arg(a)};
2428 }
2429 
2430 template<typename A>
2431 std::ostream &operator<<(std::ostream &s, const IsScalar<A> &op) {
2432  s << "is_scalar(" << op.a << ")";
2433  return s;
2434 }
2435 
2436 template<typename A>
2437 struct IsMaxValue {
2438  struct pattern_tag {};
2439  A a;
2440 
2441  constexpr static uint32_t binds = bindings<A>::mask;
2442 
2443  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2446  constexpr static bool canonical = true;
2447 
2448  constexpr static bool foldable = true;
2449 
2452  // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2453  a.make_folded_const(val, ty, state);
2454  const uint64_t max_bits = (uint64_t)(-1) >> (64 - ty.bits + (ty.code == halide_type_int));
2455  if (ty.code == halide_type_uint || ty.code == halide_type_int) {
2456  val.u.u64 = (val.u.u64 == max_bits);
2457  } else {
2458  val.u.u64 = 0;
2459  }
2460  ty.code = halide_type_uint;
2461  ty.bits = 1;
2462  }
2463 };
2464 
2465 template<typename A>
2466 HALIDE_ALWAYS_INLINE auto is_max_value(A &&a) noexcept -> IsMaxValue<decltype(pattern_arg(a))> {
2467  assert_is_lvalue_if_expr<A>();
2468  return {pattern_arg(a)};
2469 }
2470 
2471 template<typename A>
2472 std::ostream &operator<<(std::ostream &s, const IsMaxValue<A> &op) {
2473  s << "is_max_value(" << op.a << ")";
2474  return s;
2475 }
2476 
2477 template<typename A>
2478 struct IsMinValue {
2479  struct pattern_tag {};
2480  A a;
2481 
2482  constexpr static uint32_t binds = bindings<A>::mask;
2483 
2484  // This rule is a boolean-valued predicate. Bools have type UIntImm.
2487  constexpr static bool canonical = true;
2488 
2489  constexpr static bool foldable = true;
2490 
2493  // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2494  a.make_folded_const(val, ty, state);
2495  if (ty.code == halide_type_int) {
2496  const uint64_t min_bits = (uint64_t)(-1) << (ty.bits - 1);
2497  val.u.u64 = (val.u.u64 == min_bits);
2498  } else if (ty.code == halide_type_uint) {
2499  val.u.u64 = (val.u.u64 == 0);
2500  } else {
2501  val.u.u64 = 0;
2502  }
2503  ty.code = halide_type_uint;
2504  ty.bits = 1;
2505  }
2506 };
2507 
2508 template<typename A>
2509 HALIDE_ALWAYS_INLINE auto is_min_value(A &&a) noexcept -> IsMinValue<decltype(pattern_arg(a))> {
2510  assert_is_lvalue_if_expr<A>();
2511  return {pattern_arg(a)};
2512 }
2513 
2514 template<typename A>
2515 std::ostream &operator<<(std::ostream &s, const IsMinValue<A> &op) {
2516  s << "is_min_value(" << op.a << ")";
2517  return s;
2518 }
2519 
2520 // Verify properties of each rewrite rule. Currently just fuzz tests them.
2521 template<typename Before,
2522  typename After,
2523  typename Predicate,
2524  typename = typename std::enable_if<std::decay<Before>::type::foldable &&
2525  std::decay<After>::type::foldable>::type>
2526 HALIDE_NEVER_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred,
2527  halide_type_t wildcard_type, halide_type_t output_type) noexcept {
2528 
2529  // We only validate the rules in the scalar case
2530  wildcard_type.lanes = output_type.lanes = 1;
2531 
2532  // Track which types this rule has been tested for before
2533  static std::set<uint32_t> tested;
2534 
2535  if (!tested.insert(reinterpret_bits<uint32_t>(wildcard_type)).second) {
2536  return;
2537  }
2538 
2539  // Print it in a form where it can be piped into a python/z3 validator
2540  debug(0) << "validate('" << before << "', '" << after << "', '" << pred << "', " << Type(wildcard_type) << ", " << Type(output_type) << ")\n";
2541 
2542  // Substitute some random constants into the before and after
2543  // expressions and see if the rule holds true. This should catch
2544  // silly errors, but not necessarily corner cases.
2545  static std::mt19937_64 rng(0);
2546  MatcherState state;
2547 
2548  Expr exprs[max_wild];
2549 
2550  for (int trials = 0; trials < 100; trials++) {
2551  // We want to test small constants more frequently than
2552  // large ones, otherwise we'll just get coverage of
2553  // overflow rules.
2554  int shift = (int)(rng() & (wildcard_type.bits - 1));
2555 
2556  for (int i = 0; i < max_wild; i++) {
2557  // Bind all the exprs and constants
2558  switch (wildcard_type.code) {
2559  case halide_type_uint: {
2560  // Normalize to the type's range by adding zero
2561  uint64_t val = constant_fold_bin_op<Add>(wildcard_type, (uint64_t)rng() >> shift, 0);
2562  state.set_bound_const(i, val, wildcard_type);
2563  val = constant_fold_bin_op<Add>(wildcard_type, (uint64_t)rng() >> shift, 0);
2564  exprs[i] = make_const(wildcard_type, val);
2565  state.set_binding(i, *exprs[i].get());
2566  } break;
2567  case halide_type_int: {
2568  int64_t val = constant_fold_bin_op<Add>(wildcard_type, (int64_t)rng() >> shift, 0);
2569  state.set_bound_const(i, val, wildcard_type);
2570  val = constant_fold_bin_op<Add>(wildcard_type, (int64_t)rng() >> shift, 0);
2571  exprs[i] = make_const(wildcard_type, val);
2572  } break;
2573  case halide_type_float:
2574  case halide_type_bfloat: {
2575  // Use a very narrow range of precise floats, so
2576  // that none of the rules a human is likely to
2577  // write have instabilities.
2578  double val = ((int64_t)(rng() & 15) - 8) / 2.0;
2579  state.set_bound_const(i, val, wildcard_type);
2580  val = ((int64_t)(rng() & 15) - 8) / 2.0;
2581  exprs[i] = make_const(wildcard_type, val);
2582  } break;
2583  default:
2584  return; // Don't care about handles
2585  }
2586  state.set_binding(i, *exprs[i].get());
2587  }
2588 
2589  halide_scalar_value_t val_pred, val_before, val_after;
2590  halide_type_t type = output_type;
2591  if (!evaluate_predicate(pred, state)) {
2592  continue;
2593  }
2594  before.make_folded_const(val_before, type, state);
2595  uint16_t lanes = type.lanes;
2596  after.make_folded_const(val_after, type, state);
2597  lanes |= type.lanes;
2598 
2599  if (lanes & MatcherState::special_values_mask) {
2600  continue;
2601  }
2602 
2603  bool ok = true;
2604  switch (output_type.code) {
2605  case halide_type_uint:
2606  // Compare normalized representations
2607  ok &= (constant_fold_bin_op<Add>(output_type, val_before.u.u64, 0) ==
2608  constant_fold_bin_op<Add>(output_type, val_after.u.u64, 0));
2609  break;
2610  case halide_type_int:
2611  ok &= (constant_fold_bin_op<Add>(output_type, val_before.u.i64, 0) ==
2612  constant_fold_bin_op<Add>(output_type, val_after.u.i64, 0));
2613  break;
2614  case halide_type_float:
2615  case halide_type_bfloat: {
2616  double error = std::abs(val_before.u.f64 - val_after.u.f64);
2617  // We accept an equal bit pattern (e.g. inf vs inf),
2618  // a small floating point difference, or turning a nan into not-a-nan.
2619  ok &= (error < 0.01 ||
2620  val_before.u.u64 == val_after.u.u64 ||
2621  std::isnan(val_before.u.f64));
2622  break;
2623  }
2624  default:
2625  return;
2626  }
2627 
2628  if (!ok) {
2629  debug(0) << "Fails with values:\n";
2630  for (int i = 0; i < max_wild; i++) {
2632  state.get_bound_const(i, val, wildcard_type);
2633  debug(0) << " c" << i << ": " << make_const_expr(val, wildcard_type) << "\n";
2634  }
2635  for (int i = 0; i < max_wild; i++) {
2636  debug(0) << " _" << i << ": " << Expr(state.get_binding(i)) << "\n";
2637  }
2638  debug(0) << " Before: " << make_const_expr(val_before, output_type) << "\n";
2639  debug(0) << " After: " << make_const_expr(val_after, output_type) << "\n";
2640  debug(0) << val_before.u.u64 << " " << val_after.u.u64 << "\n";
2642  }
2643  }
2644 }
2645 
2646 template<typename Before,
2647  typename After,
2648  typename Predicate,
2649  typename = typename std::enable_if<!(std::decay<Before>::type::foldable &&
2650  std::decay<After>::type::foldable)>::type>
2651 HALIDE_ALWAYS_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred,
2652  halide_type_t, halide_type_t, int dummy = 0) noexcept {
2653  // We can't verify rewrite rules that can't be constant-folded.
2654 }
2655 
2657 bool evaluate_predicate(bool x, MatcherState &) noexcept {
2658  return x;
2659 }
2660 
2661 template<typename Pattern,
2662  typename = typename enable_if_pattern<Pattern>::type>
2665  halide_type_t ty = halide_type_of<bool>();
2666  p.make_folded_const(c, ty, state);
2667  // Overflow counts as a failed predicate
2668  return (c.u.u64 != 0) && ((ty.lanes & MatcherState::special_values_mask) == 0);
2669 }
2670 
2671 // #defines for testing
2672 
2673 // Print all successful or failed matches
2674 #define HALIDE_DEBUG_MATCHED_RULES 0
2675 #define HALIDE_DEBUG_UNMATCHED_RULES 0
2676 
2677 // Set to true if you want to fuzz test every rewrite passed to
2678 // operator() to ensure the input and the output have the same value
2679 // for lots of random values of the wildcards. Run
2680 // correctness_simplify with this on.
2681 #define HALIDE_FUZZ_TEST_RULES 0
2682 
2683 template<typename Instance>
2684 struct Rewriter {
2685  Instance instance;
2689  bool validate;
2690 
2693  : instance(std::move(instance)), output_type(ot), wildcard_type(wt) {
2694  }
2695 
2696  template<typename After>
2698  result = after.make(state, output_type);
2699  }
2700 
2701  template<typename Before,
2702  typename After,
2703  typename = typename enable_if_pattern<Before>::type,
2704  typename = typename enable_if_pattern<After>::type>
2705  HALIDE_ALWAYS_INLINE bool operator()(Before before, After after) {
2706  static_assert((Before::binds & After::binds) == After::binds, "Rule result uses unbound values");
2707  static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2708  static_assert(After::canonical, "RHS of rewrite rule should be in canonical form");
2709 #if HALIDE_FUZZ_TEST_RULES
2710  fuzz_test_rule(before, after, true, wildcard_type, output_type);
2711 #endif
2712  if (before.template match<0>(unwrap(instance), state)) {
2713  build_replacement(after);
2714 #if HALIDE_DEBUG_MATCHED_RULES
2715  debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
2716 #endif
2717  return true;
2718  } else {
2719 #if HALIDE_DEBUG_UNMATCHED_RULES
2720  debug(0) << instance << " does not match " << before << "\n";
2721 #endif
2722  return false;
2723  }
2724  }
2725 
2726  template<typename Before,
2727  typename = typename enable_if_pattern<Before>::type>
2728  HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after) noexcept {
2729  static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2730  if (before.template match<0>(unwrap(instance), state)) {
2731  result = after;
2732 #if HALIDE_DEBUG_MATCHED_RULES
2733  debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
2734 #endif
2735  return true;
2736  } else {
2737 #if HALIDE_DEBUG_UNMATCHED_RULES
2738  debug(0) << instance << " does not match " << before << "\n";
2739 #endif
2740  return false;
2741  }
2742  }
2743 
2744  template<typename Before,
2745  typename = typename enable_if_pattern<Before>::type>
2746  HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after) noexcept {
2747  static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2748 #if HALIDE_FUZZ_TEST_RULES
2749  fuzz_test_rule(before, IntLiteral(after), true, wildcard_type, output_type);
2750 #endif
2751  if (before.template match<0>(unwrap(instance), state)) {
2752  result = make_const(output_type, after);
2753 #if HALIDE_DEBUG_MATCHED_RULES
2754  debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
2755 #endif
2756  return true;
2757  } else {
2758 #if HALIDE_DEBUG_UNMATCHED_RULES
2759  debug(0) << instance << " does not match " << before << "\n";
2760 #endif
2761  return false;
2762  }
2763  }
2764 
2765  template<typename Before,
2766  typename After,
2767  typename Predicate,
2768  typename = typename enable_if_pattern<Before>::type,
2769  typename = typename enable_if_pattern<After>::type,
2770  typename = typename enable_if_pattern<Predicate>::type>
2771  HALIDE_ALWAYS_INLINE bool operator()(Before before, After after, Predicate pred) {
2772  static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
2773  static_assert((Before::binds & After::binds) == After::binds, "Rule result uses unbound values");
2774  static_assert((Before::binds & Predicate::binds) == Predicate::binds, "Rule predicate uses unbound values");
2775  static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2776  static_assert(After::canonical, "RHS of rewrite rule should be in canonical form");
2777 
2778 #if HALIDE_FUZZ_TEST_RULES
2779  fuzz_test_rule(before, after, pred, wildcard_type, output_type);
2780 #endif
2781  if (before.template match<0>(unwrap(instance), state) &&
2782  evaluate_predicate(pred, state)) {
2783  build_replacement(after);
2784 #if HALIDE_DEBUG_MATCHED_RULES
2785  debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
2786 #endif
2787  return true;
2788  } else {
2789 #if HALIDE_DEBUG_UNMATCHED_RULES
2790  debug(0) << instance << " does not match " << before << "\n";
2791 #endif
2792  return false;
2793  }
2794  }
2795 
2796  template<typename Before,
2797  typename Predicate,
2798  typename = typename enable_if_pattern<Before>::type,
2799  typename = typename enable_if_pattern<Predicate>::type>
2800  HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after, Predicate pred) {
2801  static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
2802  static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2803 
2804  if (before.template match<0>(unwrap(instance), state) &&
2805  evaluate_predicate(pred, state)) {
2806  result = after;
2807 #if HALIDE_DEBUG_MATCHED_RULES
2808  debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
2809 #endif
2810  return true;
2811  } else {
2812 #if HALIDE_DEBUG_UNMATCHED_RULES
2813  debug(0) << instance << " does not match " << before << "\n";
2814 #endif
2815  return false;
2816  }
2817  }
2818 
2819  template<typename Before,
2820  typename Predicate,
2821  typename = typename enable_if_pattern<Before>::type,
2822  typename = typename enable_if_pattern<Predicate>::type>
2823  HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after, Predicate pred) {
2824  static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
2825  static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2826 #if HALIDE_FUZZ_TEST_RULES
2827  fuzz_test_rule(before, IntLiteral(after), pred, wildcard_type, output_type);
2828 #endif
2829  if (before.template match<0>(unwrap(instance), state) &&
2830  evaluate_predicate(pred, state)) {
2831  result = make_const(output_type, after);
2832 #if HALIDE_DEBUG_MATCHED_RULES
2833  debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
2834 #endif
2835  return true;
2836  } else {
2837 #if HALIDE_DEBUG_UNMATCHED_RULES
2838  debug(0) << instance << " does not match " << before << "\n";
2839 #endif
2840  return false;
2841  }
2842  }
2843 };
2844 
2845 /** Construct a rewriter for the given instance, which may be a pattern
2846  * with concrete expressions as leaves, or just an expression. The
2847  * second optional argument (wildcard_type) is a hint as to what the
2848  * type of the wildcards is likely to be. If omitted it uses the same
2849  * type as the expression itself. They are not required to be this
2850  * type, but the rule will only be tested for wildcards of that type
2851  * when testing is enabled.
2852  *
2853  * The rewriter can be used to check to see if the instance is one of
2854  * some number of patterns and if so rewrite it into another form,
2855  * using its operator() method. See Simplify.cpp for a bunch of
2856  * example usage.
2857  *
2858  * Important: Any Exprs in patterns are captured by reference, not by
2859  * value, so ensure they outlive the rewriter.
2860  */
2861 // @{
2862 template<typename Instance,
2863  typename = typename enable_if_pattern<Instance>::type>
2864 HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type, halide_type_t wildcard_type) noexcept -> Rewriter<decltype(pattern_arg(instance))> {
2865  return {pattern_arg(instance), output_type, wildcard_type};
2866 }
2867 
2868 template<typename Instance,
2869  typename = typename enable_if_pattern<Instance>::type>
2870 HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type) noexcept -> Rewriter<decltype(pattern_arg(instance))> {
2871  return {pattern_arg(instance), output_type, output_type};
2872 }
2873 
2875 auto rewriter(const Expr &e, halide_type_t wildcard_type) noexcept -> Rewriter<decltype(pattern_arg(e))> {
2876  return {pattern_arg(e), e.type(), wildcard_type};
2877 }
2878 
2880 auto rewriter(const Expr &e) noexcept -> Rewriter<decltype(pattern_arg(e))> {
2881  return {pattern_arg(e), e.type(), e.type()};
2882 }
2883 // @}
2884 
2885 } // namespace IRMatcher
2886 
2887 } // namespace Internal
2888 } // namespace Halide
2889 
2890 #endif
#define internal_error
Definition: Errors.h:23
@ halide_type_float
IEEE floating point numbers.
@ halide_type_bfloat
floating point numbers in the bfloat format
@ halide_type_int
signed integers
@ halide_type_uint
unsigned integers
#define HALIDE_NEVER_INLINE
Definition: HalideRuntime.h:39
#define HALIDE_ALWAYS_INLINE
Definition: HalideRuntime.h:38
Subtypes for Halide expressions (Halide::Expr) and statements (Halide::Internal::Stmt)
Methods to test Exprs and Stmts for equality of value.
Defines various operator overloads and utility functions that make it more pleasant to work with Hali...
For optional debugging during codegen, use the debug class as follows:
Definition: Debug.h:49
std::ostream & operator<<(std::ostream &s, const SpecificExpr &e)
Definition: IRMatch.h:229
auto rounding_shift_left(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1571
auto shift_left(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1563
HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type, halide_type_t wildcard_type) noexcept -> Rewriter< decltype(pattern_arg(instance))>
Construct a rewriter for the given instance, which may be a pattern with concrete expressions as leav...
Definition: IRMatch.h:2864
HALIDE_ALWAYS_INLINE T pattern_arg(T t)
Definition: IRMatch.h:579
HALIDE_ALWAYS_INLINE auto or_op(A &&a, B &&b) -> decltype(IRMatcher::operator||(a, b))
Definition: IRMatch.h:1287
HALIDE_ALWAYS_INLINE auto operator!(A &&a) noexcept -> NotOp< decltype(pattern_arg(a))>
Definition: IRMatch.h:1628
HALIDE_ALWAYS_INLINE auto min(A &&a, B &&b) noexcept -> BinOp< Min, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1088
HALIDE_ALWAYS_INLINE auto is_int(A &&a, int bits=0) noexcept -> IsInt< decltype(pattern_arg(a))>
Definition: IRMatch.h:2342
HALIDE_ALWAYS_INLINE bool evaluate_predicate(bool x, MatcherState &) noexcept
Definition: IRMatch.h:2657
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Div >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1044
HALIDE_ALWAYS_INLINE auto ne(A &&a, B &&b) -> decltype(IRMatcher::operator!=(a, b))
Definition: IRMatch.h:1262
HALIDE_ALWAYS_INLINE auto negate(A &&a) -> decltype(IRMatcher::operator-(a))
Definition: IRMatch.h:2005
uint64_t constant_fold_cmp_op(int64_t, int64_t) noexcept
HALIDE_ALWAYS_INLINE auto operator<=(A &&a, B &&b) noexcept -> CmpOp< LE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1182
HALIDE_ALWAYS_INLINE auto operator+(A &&a, B &&b) noexcept -> BinOp< Add, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:933
HALIDE_ALWAYS_INLINE auto is_max_value(A &&a) noexcept -> IsMaxValue< decltype(pattern_arg(a))>
Definition: IRMatch.h:2466
HALIDE_ALWAYS_INLINE auto and_op(A &&a, B &&b) -> decltype(IRMatcher::operator&&(a, b))
Definition: IRMatch.h:1313
HALIDE_ALWAYS_INLINE auto h_and(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::And >
Definition: IRMatch.h:1917
HALIDE_ALWAYS_INLINE auto gt(A &&a, B &&b) -> decltype(IRMatcher::operator>(a, b))
Definition: IRMatch.h:1162
HALIDE_ALWAYS_INLINE auto is_const(A &&a) noexcept -> IsConst< decltype(pattern_arg(a))>
Definition: IRMatch.h:2219
HALIDE_ALWAYS_INLINE auto intrin(Call::IntrinsicOp intrinsic_op, Args... args) noexcept -> Intrin< decltype(pattern_arg(args))... >
Definition: IRMatch.h:1522
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< LE >(int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1192
HALIDE_ALWAYS_INLINE auto operator*(A &&a, B &&b) noexcept -> BinOp< Mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:999
auto rounding_halving_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1555
auto rounding_shift_right(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1575
HALIDE_ALWAYS_INLINE auto add(A &&a, B &&b) -> decltype(IRMatcher::operator+(a, b))
Definition: IRMatch.h:940
HALIDE_ALWAYS_INLINE auto div(A &&a, B &&b) -> decltype(IRMatcher::operator/(a, b))
Definition: IRMatch.h:1039
auto saturating_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1539
HALIDE_ALWAYS_INLINE auto mul(A &&a, B &&b) -> decltype(IRMatcher::operator*(a, b))
Definition: IRMatch.h:1006
HALIDE_ALWAYS_INLINE auto max(A &&a, B &&b) noexcept -> BinOp< Max, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1110
HALIDE_ALWAYS_INLINE auto ramp(A &&a, B &&b, C &&c) noexcept -> RampOp< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
Definition: IRMatch.h:1841
HALIDE_ALWAYS_INLINE auto operator/(A &&a, B &&b) noexcept -> BinOp< Div, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1032
auto widening_mul(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1535
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Mod >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1073
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< And >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1318
HALIDE_ALWAYS_INLINE int64_t unwrap(IntLiteral t)
Definition: IRMatch.h:571
HALIDE_ALWAYS_INLINE auto operator>(A &&a, B &&b) noexcept -> CmpOp< GT, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1157
HALIDE_ALWAYS_INLINE auto cast(halide_type_t t, A &&a) noexcept -> CastOp< decltype(pattern_arg(a))>
Definition: IRMatch.h:2051
HALIDE_ALWAYS_INLINE auto overflows(A &&a) noexcept -> Overflows< decltype(pattern_arg(a))>
Definition: IRMatch.h:2137
auto widening_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1527
HALIDE_ALWAYS_INLINE void assert_is_lvalue_if_expr()
Definition: IRMatch.h:588
HALIDE_ALWAYS_INLINE auto operator%(A &&a, B &&b) noexcept -> BinOp< Mod, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1059
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Sub >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:980
HALIDE_ALWAYS_INLINE auto is_scalar(A &&a) noexcept -> IsScalar< decltype(pattern_arg(a))>
Definition: IRMatch.h:2425
HALIDE_ALWAYS_INLINE auto fold(A &&a) noexcept -> Fold< decltype(pattern_arg(a))>
Definition: IRMatch.h:2100
HALIDE_ALWAYS_INLINE auto not_op(A &&a) -> decltype(IRMatcher::operator!(a))
Definition: IRMatch.h:1634
auto halving_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1547
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Max >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1117
constexpr bool and_reduce()
Definition: IRMatch.h:1342
HALIDE_ALWAYS_INLINE auto operator||(A &&a, B &&b) noexcept -> BinOp< Or, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1282
auto widening_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1531
constexpr int max_wild
Definition: IRMatch.h:74
HALIDE_ALWAYS_INLINE auto operator!=(A &&a, B &&b) noexcept -> CmpOp< NE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1257
HALIDE_ALWAYS_INLINE bool equal(const BaseExprNode &a, const BaseExprNode &b) noexcept
Definition: IRMatch.h:195
HALIDE_ALWAYS_INLINE auto is_float(A &&a) noexcept -> IsFloat< decltype(pattern_arg(a))>
Definition: IRMatch.h:2304
HALIDE_ALWAYS_INLINE auto operator>=(A &&a, B &&b) noexcept -> CmpOp< GE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1207
bool equal_helper(const BaseExprNode &a, const BaseExprNode &b) noexcept
HALIDE_ALWAYS_INLINE auto operator<(A &&a, B &&b) noexcept -> CmpOp< LT, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1132
HALIDE_ALWAYS_INLINE auto operator&&(A &&a, B &&b) noexcept -> BinOp< And, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1308
HALIDE_ALWAYS_INLINE auto h_or(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Or >
Definition: IRMatch.h:1923
constexpr bool commutative(IRNodeType t)
Definition: IRMatch.h:627
HALIDE_ALWAYS_INLINE auto sub(A &&a, B &&b) -> decltype(IRMatcher::operator-(a, b))
Definition: IRMatch.h:973
HALIDE_ALWAYS_INLINE auto h_max(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Max >
Definition: IRMatch.h:1911
HALIDE_ALWAYS_INLINE auto broadcast(A &&a, B lanes) noexcept -> BroadcastOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes))>
Definition: IRMatch.h:1777
HALIDE_ALWAYS_INLINE auto select(C &&c, T &&t, F &&f) noexcept -> SelectOp< decltype(pattern_arg(c)), decltype(pattern_arg(t)), decltype(pattern_arg(f))>
Definition: IRMatch.h:1704
HALIDE_ALWAYS_INLINE auto is_min_value(A &&a) noexcept -> IsMinValue< decltype(pattern_arg(a))>
Definition: IRMatch.h:2509
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Min >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1095
HALIDE_NEVER_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred, halide_type_t wildcard_type, halide_type_t output_type) noexcept
Definition: IRMatch.h:2526
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< GT >(int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1167
auto halving_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1551
auto saturating_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1543
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Mul >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1013
HALIDE_ALWAYS_INLINE auto is_uint(A &&a, int bits=0) noexcept -> IsUInt< decltype(pattern_arg(a))>
Definition: IRMatch.h:2384
auto mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
Definition: IRMatch.h:1579
auto shift_right(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1567
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< GE >(int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1217
HALIDE_ALWAYS_INLINE auto operator-(A &&a, B &&b) noexcept -> BinOp< Sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:966
HALIDE_ALWAYS_INLINE auto le(A &&a, B &&b) -> decltype(IRMatcher::operator<=(a, b))
Definition: IRMatch.h:1187
HALIDE_ALWAYS_INLINE auto lt(A &&a, B &&b) -> decltype(IRMatcher::operator<(a, b))
Definition: IRMatch.h:1137
HALIDE_ALWAYS_INLINE auto is_const(A &&a, int64_t value) noexcept -> IsConst< decltype(pattern_arg(a))>
Definition: IRMatch.h:2225
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< LT >(int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1142
HALIDE_ALWAYS_INLINE auto h_min(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Min >
Definition: IRMatch.h:1905
HALIDE_ALWAYS_INLINE auto h_add(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Add >
Definition: IRMatch.h:1899
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Or >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1292
HALIDE_ALWAYS_INLINE Expr make_const_expr(halide_scalar_value_t val, halide_type_t ty)
Definition: IRMatch.h:160
constexpr uint32_t bitwise_or_reduce()
Definition: IRMatch.h:1333
auto rounding_mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
Definition: IRMatch.h:1583
int64_t constant_fold_bin_op(halide_type_t &, int64_t, int64_t) noexcept
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< EQ >(int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1242
HALIDE_NEVER_INLINE Expr make_const_special_expr(halide_type_t ty)
Definition: IRMatch.h:149
HALIDE_ALWAYS_INLINE auto ge(A &&a, B &&b) -> decltype(IRMatcher::operator>=(a, b))
Definition: IRMatch.h:1212
constexpr int const_min(int a, int b)
Definition: IRMatch.h:1352
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< NE >(int64_t a, int64_t b) noexcept
Definition: IRMatch.h:1267
auto rounding_halving_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1559
HALIDE_ALWAYS_INLINE auto mod(A &&a, B &&b) -> decltype(IRMatcher::operator%(a, b))
Definition: IRMatch.h:1066
HALIDE_ALWAYS_INLINE auto operator==(A &&a, B &&b) noexcept -> CmpOp< EQ, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition: IRMatch.h:1232
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Add >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition: IRMatch.h:947
HALIDE_ALWAYS_INLINE auto can_prove(A &&a, Prover *p) noexcept -> CanProve< decltype(pattern_arg(a)), Prover >
Definition: IRMatch.h:2267
HALIDE_ALWAYS_INLINE auto eq(A &&a, B &&b) -> decltype(IRMatcher::operator==(a, b))
Definition: IRMatch.h:1237
T div_imp(T a, T b)
Definition: IROperator.h:260
bool is_const_zero(const Expr &e)
Is the expression a const (as defined by is_const), and also equal to zero (in all lanes,...
Expr make_zero(Type t)
Construct the representation of zero in the given type.
void expr_match_test()
bool is_const_one(const Expr &e)
Is the expression a const (as defined by is_const), and also equal to one (in all lanes,...
constexpr IRNodeType StrongestExprNodeType
Definition: Expr.h:79
Expr make_const(Type t, int64_t val)
Construct an immediate of the given type from any numeric C++ type.
T mod_imp(T a, T b)
Implementations of division and mod that are specific to Halide.
Definition: IROperator.h:239
bool sub_would_overflow(int bits, int64_t a, int64_t b)
bool add_would_overflow(int bits, int64_t a, int64_t b)
Routines to test if math would overflow for signed integers with the given number of bits.
bool mul_would_overflow(int bits, int64_t a, int64_t b)
Expr with_lanes(const Expr &x, int lanes)
Rewrite the expression x to have lanes lanes.
bool expr_match(const Expr &pattern, const Expr &expr, std::vector< Expr > &result)
Does the first expression have the same structure as the second? Variables in the first expression wi...
Expr make_signed_integer_overflow(Type type)
Construct a unique signed_integer_overflow Expr.
IRNodeType
All our IR node types get unique IDs for the purposes of RTTI.
Definition: Expr.h:25
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
@ Internal
Not visible externally, similar to 'static' linkage in C.
@ Predicate
Guard the loads and stores in the loop with an if statement that prevents evaluation beyond the origi...
Expr absd(Expr a, Expr b)
Return the absolute difference between two values.
@ C
No name mangling.
Expr likely_if_innermost(Expr e)
Equivalent to likely, but only triggers a loop partitioning if found in an innermost loop.
Expr abs(Expr a)
Returns the absolute value of a signed integer or floating-point expression.
Expr likely(Expr e)
Expressions tagged with this intrinsic are considered to be part of the steady state of some loop wit...
unsigned __INT64_TYPE__ uint64_t
signed __INT64_TYPE__ int64_t
signed __INT32_TYPE__ int32_t
unsigned __INT16_TYPE__ uint16_t
unsigned __INT32_TYPE__ uint32_t
A fragment of Halide syntax.
Definition: Expr.h:256
HALIDE_ALWAYS_INLINE Type type() const
Get the type of this expression node.
Definition: Expr.h:320
HALIDE_ALWAYS_INLINE const Internal::BaseExprNode * get() const
Override get() to return a BaseExprNode * instead of an IRNode *.
Definition: Expr.h:314
The sum of two expressions.
Definition: IR.h:38
Logical and - are both expressions true.
Definition: IR.h:157
A base class for expression nodes.
Definition: Expr.h:141
A vector with 'lanes' elements, in which every element is 'value'.
Definition: IR.h:241
static Expr make(Expr value, int lanes)
static const IRNodeType _node_type
Definition: IR.h:247
A function call.
Definition: IR.h:466
@ signed_integer_overflow
Definition: IR.h:556
@ rounding_mul_shift_right
Definition: IR.h:547
bool is_intrinsic() const
Definition: IR.h:649
static const IRNodeType _node_type
Definition: IR.h:694
The actual IR nodes begin here.
Definition: IR.h:29
static const IRNodeType _node_type
Definition: IR.h:34
The ratio of two expressions.
Definition: IR.h:65
Is the first expression equal to the second.
Definition: IR.h:103
Floating point constants.
Definition: Expr.h:234
static const FloatImm * make(Type t, double value)
Is the first expression greater than or equal to the second.
Definition: IR.h:148
Is the first expression greater than the second.
Definition: IR.h:139
constexpr static uint32_t binds
Definition: IRMatch.h:645
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:648
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:676
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:657
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:647
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const noexcept
Definition: IRMatch.h:719
HALIDE_ALWAYS_INLINE bool match(const BinOp< Op2, A2, B2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:667
constexpr static bool canonical
Definition: IRMatch.h:653
constexpr static bool foldable
Definition: IRMatch.h:673
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1743
HALIDE_ALWAYS_INLINE bool match(const BroadcastOp< A2, B2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:1737
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1725
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:1719
constexpr static uint32_t binds
Definition: IRMatch.h:1717
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:1760
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:1720
HALIDE_NEVER_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2256
constexpr static bool foldable
Definition: IRMatch.h:2253
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2249
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2250
constexpr static uint32_t binds
Definition: IRMatch.h:2246
constexpr static bool canonical
Definition: IRMatch.h:2251
constexpr static bool canonical
Definition: IRMatch.h:2020
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2019
constexpr static bool foldable
Definition: IRMatch.h:2041
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:2023
constexpr static uint32_t binds
Definition: IRMatch.h:2016
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2018
HALIDE_ALWAYS_INLINE bool match(const CastOp< A2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:2032
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:2037
constexpr static bool canonical
Definition: IRMatch.h:760
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:820
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:758
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:759
constexpr static bool foldable
Definition: IRMatch.h:783
constexpr static uint32_t binds
Definition: IRMatch.h:756
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:767
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:786
HALIDE_ALWAYS_INLINE bool match(const CmpOp< Op2, A2, B2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:777
constexpr static bool foldable
Definition: IRMatch.h:2091
constexpr static uint32_t binds
Definition: IRMatch.h:2061
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2063
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2064
constexpr static bool canonical
Definition: IRMatch.h:2065
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const noexcept
Definition: IRMatch.h:2068
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:2094
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:516
constexpr static bool canonical
Definition: IRMatch.h:508
constexpr static uint32_t binds
Definition: IRMatch.h:504
HALIDE_ALWAYS_INLINE IntLiteral(int64_t v)
Definition: IRMatch.h:511
HALIDE_ALWAYS_INLINE bool match(const IntLiteral &b, MatcherState &state) const noexcept
Definition: IRMatch.h:539
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:551
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:506
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:507
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:544
HALIDE_ALWAYS_INLINE bool match(int64_t val, MatcherState &state) const noexcept
Definition: IRMatch.h:534
constexpr static bool foldable
Definition: IRMatch.h:548
HALIDE_ALWAYS_INLINE bool match_args(double, const Call &c, MatcherState &state) const noexcept
Definition: IRMatch.h:1378
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1411
constexpr static bool canonical
Definition: IRMatch.h:1366
HALIDE_ALWAYS_INLINE void print_args(std::ostream &s) const
Definition: IRMatch.h:1406
constexpr static bool foldable
Definition: IRMatch.h:1463
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:1465
static constexpr uint32_t binds
Definition: IRMatch.h:1362
HALIDE_ALWAYS_INLINE bool match_args(int, const Call &c, MatcherState &state) const noexcept
Definition: IRMatch.h:1371
HALIDE_ALWAYS_INLINE void print_args(int, std::ostream &s) const
Definition: IRMatch.h:1393
std::tuple< Args... > args
Definition: IRMatch.h:1360
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1383
HALIDE_ALWAYS_INLINE void print_args(double, std::ostream &s) const
Definition: IRMatch.h:1402
HALIDE_ALWAYS_INLINE Intrin(Call::IntrinsicOp intrin, Args... args) noexcept
Definition: IRMatch.h:1508
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:1365
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:1364
constexpr static bool canonical
Definition: IRMatch.h:2196
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:2205
constexpr static bool foldable
Definition: IRMatch.h:2202
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2195
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2194
constexpr static uint32_t binds
Definition: IRMatch.h:2191
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2286
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2293
constexpr static bool canonical
Definition: IRMatch.h:2288
constexpr static uint32_t binds
Definition: IRMatch.h:2283
constexpr static bool foldable
Definition: IRMatch.h:2290
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2287
constexpr static uint32_t binds
Definition: IRMatch.h:2321
constexpr static bool foldable
Definition: IRMatch.h:2328
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2324
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2331
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2325
constexpr static bool canonical
Definition: IRMatch.h:2326
constexpr static bool canonical
Definition: IRMatch.h:2446
constexpr static bool foldable
Definition: IRMatch.h:2448
constexpr static uint32_t binds
Definition: IRMatch.h:2441
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2444
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2445
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2451
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2485
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2486
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2492
constexpr static bool canonical
Definition: IRMatch.h:2487
constexpr static uint32_t binds
Definition: IRMatch.h:2482
constexpr static bool foldable
Definition: IRMatch.h:2489
constexpr static bool foldable
Definition: IRMatch.h:2411
constexpr static bool canonical
Definition: IRMatch.h:2409
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2414
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2408
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2407
constexpr static uint32_t binds
Definition: IRMatch.h:2404
constexpr static bool canonical
Definition: IRMatch.h:2368
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:2373
constexpr static uint32_t binds
Definition: IRMatch.h:2363
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2366
constexpr static bool foldable
Definition: IRMatch.h:2370
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2367
To save stack space, the matcher objects are largely stateless and immutable.
Definition: IRMatch.h:82
HALIDE_ALWAYS_INLINE void get_bound_const(int i, halide_scalar_value_t &val, halide_type_t &type) const noexcept
Definition: IRMatch.h:127
HALIDE_ALWAYS_INLINE void set_bound_const(int i, int64_t s, halide_type_t t) noexcept
Definition: IRMatch.h:103
HALIDE_ALWAYS_INLINE void set_bound_const(int i, double f, halide_type_t t) noexcept
Definition: IRMatch.h:115
static constexpr uint16_t special_values_mask
Definition: IRMatch.h:88
HALIDE_ALWAYS_INLINE void set_bound_const(int i, halide_scalar_value_t val, halide_type_t t) noexcept
Definition: IRMatch.h:121
halide_type_t bound_const_type[max_wild]
Definition: IRMatch.h:90
HALIDE_ALWAYS_INLINE void set_binding(int i, const BaseExprNode &n) noexcept
Definition: IRMatch.h:93
HALIDE_ALWAYS_INLINE MatcherState() noexcept
Definition: IRMatch.h:134
halide_scalar_value_t bound_const[max_wild]
Definition: IRMatch.h:84
HALIDE_ALWAYS_INLINE const BaseExprNode * get_binding(int i) const noexcept
Definition: IRMatch.h:98
HALIDE_ALWAYS_INLINE void set_bound_const(int i, uint64_t u, halide_type_t t) noexcept
Definition: IRMatch.h:109
static constexpr uint16_t signed_integer_overflow
Definition: IRMatch.h:87
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:1935
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:1936
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1941
constexpr static uint32_t binds
Definition: IRMatch.h:1933
constexpr static bool canonical
Definition: IRMatch.h:1938
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1956
HALIDE_ALWAYS_INLINE bool match(NegateOp< A2 > &&p, MatcherState &state) const noexcept
Definition: IRMatch.h:1951
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:1965
constexpr static bool foldable
Definition: IRMatch.h:1962
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:1594
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1599
HALIDE_ALWAYS_INLINE bool match(const NotOp< A2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:1608
constexpr static uint32_t binds
Definition: IRMatch.h:1592
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:1595
constexpr static bool foldable
Definition: IRMatch.h:1617
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1613
constexpr static bool canonical
Definition: IRMatch.h:1596
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:1620
constexpr static bool canonical
Definition: IRMatch.h:2156
constexpr static bool foldable
Definition: IRMatch.h:2173
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:2159
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:2168
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2155
constexpr static uint32_t binds
Definition: IRMatch.h:2151
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:2176
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2154
constexpr static bool foldable
Definition: IRMatch.h:2124
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:2127
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:2121
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:2120
constexpr static uint32_t binds
Definition: IRMatch.h:2116
constexpr static bool canonical
Definition: IRMatch.h:2122
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1819
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:1792
constexpr static bool canonical
Definition: IRMatch.h:1794
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:1791
constexpr static bool foldable
Definition: IRMatch.h:1831
HALIDE_ALWAYS_INLINE bool match(const RampOp< A2, B2, C2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:1812
constexpr static uint32_t binds
Definition: IRMatch.h:1789
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1797
HALIDE_NEVER_INLINE void build_replacement(After after)
Definition: IRMatch.h:2697
HALIDE_ALWAYS_INLINE bool operator()(Before before, After after, Predicate pred)
Definition: IRMatch.h:2771
HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after) noexcept
Definition: IRMatch.h:2746
HALIDE_ALWAYS_INLINE Rewriter(Instance instance, halide_type_t ot, halide_type_t wt)
Definition: IRMatch.h:2692
HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after, Predicate pred)
Definition: IRMatch.h:2800
HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after) noexcept
Definition: IRMatch.h:2728
HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after, Predicate pred)
Definition: IRMatch.h:2823
HALIDE_ALWAYS_INLINE bool operator()(Before before, After after)
Definition: IRMatch.h:2705
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:1655
constexpr static bool canonical
Definition: IRMatch.h:1657
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:1684
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:1654
constexpr static uint32_t binds
Definition: IRMatch.h:1652
constexpr static bool foldable
Definition: IRMatch.h:1681
HALIDE_ALWAYS_INLINE bool match(const SelectOp< C2, T2, F2 > &instance, MatcherState &state) const noexcept
Definition: IRMatch.h:1670
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1660
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1677
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:210
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:217
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:222
constexpr static uint32_t binds
Definition: IRMatch.h:207
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:211
HALIDE_ALWAYS_INLINE bool match(const VectorReduceOp< A2, B2, reduce_op_2 > &op, MatcherState &state) const noexcept
Definition: IRMatch.h:1874
constexpr static uint32_t binds
Definition: IRMatch.h:1854
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:1857
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:1861
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:1856
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:1881
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:364
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:360
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:385
constexpr static uint32_t binds
Definition: IRMatch.h:357
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:359
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:395
constexpr static uint32_t binds
Definition: IRMatch.h:411
constexpr static bool foldable
Definition: IRMatch.h:450
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:443
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:414
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:418
constexpr static bool canonical
Definition: IRMatch.h:415
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:453
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:413
HALIDE_ALWAYS_INLINE bool match(int64_t e, MatcherState &state) const noexcept
Definition: IRMatch.h:437
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:240
constexpr static uint32_t binds
Definition: IRMatch.h:238
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:279
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition: IRMatch.h:289
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:245
HALIDE_ALWAYS_INLINE bool match(int64_t value, MatcherState &state) const noexcept
Definition: IRMatch.h:266
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:241
constexpr static uint32_t binds
Definition: IRMatch.h:304
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:307
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:306
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:311
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition: IRMatch.h:342
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:332
constexpr static IRNodeType max_node_type
Definition: IRMatch.h:472
constexpr static bool foldable
Definition: IRMatch.h:489
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition: IRMatch.h:485
constexpr static bool canonical
Definition: IRMatch.h:473
constexpr static IRNodeType min_node_type
Definition: IRMatch.h:471
constexpr static uint32_t binds
Definition: IRMatch.h:469
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition: IRMatch.h:476
constexpr static uint32_t mask
Definition: IRMatch.h:146
IRNodeType node_type
Each IR node subclass has a unique identifier.
Definition: Expr.h:111
Integer constants.
Definition: Expr.h:216
static const IntImm * make(Type t, int64_t value)
Is the first expression less than or equal to the second.
Definition: IR.h:130
Is the first expression less than the second.
Definition: IR.h:121
The greater of two values.
Definition: IR.h:94
The lesser of two values.
Definition: IR.h:85
The remainder of a / b.
Definition: IR.h:76
The product of two expressions.
Definition: IR.h:56
Is the first expression not equal to the second.
Definition: IR.h:112
Logical not - true if the expression false.
Definition: IR.h:175
static Expr make(Expr a)
Logical or - is at least one of the expression true.
Definition: IR.h:166
A linear ramp vector node.
Definition: IR.h:229
static const IRNodeType _node_type
Definition: IR.h:235
static Expr make(Expr base, Expr stride, int lanes)
A ternary operator.
Definition: IR.h:186
static Expr make(Expr condition, Expr true_value, Expr false_value)
static const IRNodeType _node_type
Definition: IR.h:191
The difference of two expressions.
Definition: IR.h:47
static const IRNodeType _node_type
Definition: IR.h:52
static Expr make(Expr a, Expr b)
Unsigned integer constants.
Definition: Expr.h:225
static const UIntImm * make(Type t, uint64_t value)
Horizontally reduce a vector to a scalar or narrower vector using the given commutative and associati...
Definition: IR.h:888
static const IRNodeType _node_type
Definition: IR.h:907
static Expr make(Operator op, Expr vec, int lanes)
Types in the halide type system.
Definition: Type.h:266
HALIDE_ALWAYS_INLINE bool is_int() const
Is this type a signed integer type?
Definition: Type.h:414
HALIDE_ALWAYS_INLINE int lanes() const
Return the number of vector elements in this type.
Definition: Type.h:334
HALIDE_ALWAYS_INLINE bool is_uint() const
Is this type an unsigned integer type?
Definition: Type.h:420
HALIDE_ALWAYS_INLINE int bits() const
Return the bit size of a single element of this type.
Definition: Type.h:328
HALIDE_ALWAYS_INLINE bool is_vector() const
Is this type a vector type? (lanes() != 1).
Definition: Type.h:389
HALIDE_ALWAYS_INLINE bool is_scalar() const
Is this type a scalar type? (lanes() == 1).
Definition: Type.h:396
HALIDE_ALWAYS_INLINE bool is_float() const
Is this type a floating point type (float or double).
Definition: Type.h:402
halide_scalar_value_t is a simple union able to represent all the well-known scalar values in a filte...
union halide_scalar_value_t::@3 u
A runtime tag for a type in the halide type system.
uint8_t bits
The number of bits of precision of a single scalar value of this type.
uint16_t lanes
How many elements in a vector.
uint8_t code
The basic type code: signed integer, unsigned integer, or floating point.