OR-Tools  8.2
expressions.cc
Go to the documentation of this file.
1 // Copyright 2010-2018 Google LLC
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 // http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 
14 #include <algorithm>
15 #include <cmath>
16 #include <memory>
17 #include <string>
18 #include <utility>
19 #include <vector>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_format.h"
26 #include "ortools/base/logging.h"
27 #include "ortools/base/map_util.h"
28 #include "ortools/base/mathutil.h"
29 #include "ortools/base/stl_util.h"
32 #include "ortools/util/bitset.h"
35 
36 ABSL_FLAG(bool, cp_disable_expression_optimization, false,
37  "Disable special optimization when creating expressions.");
38 ABSL_FLAG(bool, cp_share_int_consts, true,
39  "Share IntConst's with the same value.");
40 
41 #if defined(_MSC_VER)
42 #pragma warning(disable : 4351 4355)
43 #endif
44 
45 namespace operations_research {
46 
47 // ---------- IntExpr ----------
48 
49 IntVar* IntExpr::VarWithName(const std::string& name) {
50  IntVar* const var = Var();
51  var->set_name(name);
52  return var;
53 }
54 
55 // ---------- IntVar ----------
56 
57 IntVar::IntVar(Solver* const s) : IntExpr(s), index_(s->GetNewIntVarIndex()) {}
58 
59 IntVar::IntVar(Solver* const s, const std::string& name)
60  : IntExpr(s), index_(s->GetNewIntVarIndex()) {
61  set_name(name);
62 }
63 
64 // ----- Boolean variable -----
65 
67 
69  if (m <= 0) return;
70  if (m > 1) solver()->Fail();
71  SetValue(1);
72 }
73 
75  if (m >= 1) return;
76  if (m < 0) solver()->Fail();
77  SetValue(0);
78 }
79 
81  if (mi > 1 || ma < 0 || mi > ma) {
82  solver()->Fail();
83  }
84  if (mi == 1) {
85  SetValue(1);
86  } else if (ma == 0) {
87  SetValue(0);
88  }
89 }
90 
93  if (v == 0) {
94  SetValue(1);
95  } else if (v == 1) {
96  SetValue(0);
97  }
98  } else if (v == value_) {
99  solver()->Fail();
100  }
101 }
102 
104  if (u < l) return;
105  if (l <= 0 && u >= 1) {
106  solver()->Fail();
107  } else if (l == 1) {
108  SetValue(0);
109  } else if (u == 0) {
110  SetValue(1);
111  }
112 }
113 
116  if (d->priority() == Solver::DELAYED_PRIORITY) {
117  delayed_bound_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
118  } else {
119  bound_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
120  }
121  }
122 }
123 
125  return (1 + (value_ == kUnboundBooleanVarValue));
126 }
127 
129  return ((v == 0 && value_ != 1) || (v == 1 && value_ != 0));
130 }
131 
133  if (constant > 1 || constant < 0) {
134  return solver()->MakeIntConst(0);
135  }
136  if (constant == 1) {
137  return this;
138  } else { // constant == 0.
139  return solver()->MakeDifference(1, this)->Var();
140  }
141 }
142 
144  if (constant > 1 || constant < 0) {
145  return solver()->MakeIntConst(1);
146  }
147  if (constant == 1) {
148  return solver()->MakeDifference(1, this)->Var();
149  } else { // constant == 0.
150  return this;
151  }
152 }
153 
155  if (constant > 1) {
156  return solver()->MakeIntConst(0);
157  } else if (constant <= 0) {
158  return solver()->MakeIntConst(1);
159  } else {
160  return this;
161  }
162 }
163 
165  if (constant < 0) {
166  return solver()->MakeIntConst(0);
167  } else if (constant >= 1) {
168  return solver()->MakeIntConst(1);
169  } else {
170  return IsEqual(0);
171  }
172 }
173 
174 std::string BooleanVar::DebugString() const {
175  std::string out;
176  const std::string& var_name = name();
177  if (!var_name.empty()) {
178  out = var_name + "(";
179  } else {
180  out = "BooleanVar(";
181  }
182  switch (value_) {
183  case 0:
184  out += "0";
185  break;
186  case 1:
187  out += "1";
188  break;
190  out += "0 .. 1";
191  break;
192  }
193  out += ")";
194  return out;
195 }
196 
197 namespace {
198 // ---------- Subclasses of IntVar ----------
199 
200 // ----- Domain Int Var: base class for variables -----
201 // It Contains bounds and a bitset representation of possible values.
202 class DomainIntVar : public IntVar {
203  public:
204  // Utility classes
205  class BitSetIterator : public BaseObject {
206  public:
207  BitSetIterator(uint64* const bitset, int64 omin)
208  : bitset_(bitset), omin_(omin), max_(kint64min), current_(kint64max) {}
209 
210  ~BitSetIterator() override {}
211 
212  void Init(int64 min, int64 max) {
213  max_ = max;
214  current_ = min;
215  }
216 
217  bool Ok() const { return current_ <= max_; }
218 
219  int64 Value() const { return current_; }
220 
221  void Next() {
222  if (++current_ <= max_) {
224  bitset_, current_ - omin_, max_ - omin_) +
225  omin_;
226  }
227  }
228 
229  std::string DebugString() const override { return "BitSetIterator"; }
230 
231  private:
232  uint64* const bitset_;
233  const int64 omin_;
234  int64 max_;
235  int64 current_;
236  };
237 
238  class BitSet : public BaseObject {
239  public:
240  explicit BitSet(Solver* const s) : solver_(s), holes_stamp_(0) {}
241  ~BitSet() override {}
242 
243  virtual int64 ComputeNewMin(int64 nmin, int64 cmin, int64 cmax) = 0;
244  virtual int64 ComputeNewMax(int64 nmax, int64 cmin, int64 cmax) = 0;
245  virtual bool Contains(int64 val) const = 0;
246  virtual bool SetValue(int64 val) = 0;
247  virtual bool RemoveValue(int64 val) = 0;
248  virtual uint64 Size() const = 0;
249  virtual void DelayRemoveValue(int64 val) = 0;
250  virtual void ApplyRemovedValues(DomainIntVar* var) = 0;
251  virtual void ClearRemovedValues() = 0;
252  virtual std::string pretty_DebugString(int64 min, int64 max) const = 0;
253  virtual BitSetIterator* MakeIterator() = 0;
254 
255  void InitHoles() {
256  const uint64 current_stamp = solver_->stamp();
257  if (holes_stamp_ < current_stamp) {
258  holes_.clear();
259  holes_stamp_ = current_stamp;
260  }
261  }
262 
263  virtual void ClearHoles() { holes_.clear(); }
264 
265  const std::vector<int64>& Holes() { return holes_; }
266 
267  void AddHole(int64 value) { holes_.push_back(value); }
268 
269  int NumHoles() const {
270  return holes_stamp_ < solver_->stamp() ? 0 : holes_.size();
271  }
272 
273  protected:
274  Solver* const solver_;
275 
276  private:
277  std::vector<int64> holes_;
278  uint64 holes_stamp_;
279  };
280 
281  class QueueHandler : public Demon {
282  public:
283  explicit QueueHandler(DomainIntVar* const var) : var_(var) {}
284  ~QueueHandler() override {}
285  void Run(Solver* const s) override {
286  s->GetPropagationMonitor()->StartProcessingIntegerVariable(var_);
287  var_->Process();
288  s->GetPropagationMonitor()->EndProcessingIntegerVariable(var_);
289  }
290  Solver::DemonPriority priority() const override {
291  return Solver::VAR_PRIORITY;
292  }
293  std::string DebugString() const override {
294  return absl::StrFormat("Handler(%s)", var_->DebugString());
295  }
296 
297  private:
298  DomainIntVar* const var_;
299  };
300 
301  // Bounds and Value watchers
302 
303  // This class stores the watchers variables attached to values. It is
304  // reversible and it helps maintaining the set of 'active' watchers
305  // (variables not bound to a single value).
306  template <class T>
307  class RevIntPtrMap {
308  public:
309  RevIntPtrMap(Solver* const solver, int64 rmin, int64 rmax)
310  : solver_(solver), range_min_(rmin), start_(0) {}
311 
312  ~RevIntPtrMap() {}
313 
314  bool Empty() const { return start_.Value() == elements_.size(); }
315 
316  void SortActive() { std::sort(elements_.begin(), elements_.end()); }
317 
318  // Access with value API.
319 
320  // Add the pointer to the map attached to the given value.
321  void UnsafeRevInsert(int64 value, T* elem) {
322  elements_.push_back(std::make_pair(value, elem));
323  if (solver_->state() != Solver::OUTSIDE_SEARCH) {
324  solver_->AddBacktrackAction(
325  [this, value](Solver* s) { Uninsert(value); }, false);
326  }
327  }
328 
329  T* FindPtrOrNull(int64 value, int* position) {
330  for (int pos = start_.Value(); pos < elements_.size(); ++pos) {
331  if (elements_[pos].first == value) {
332  if (position != nullptr) *position = pos;
333  return At(pos).second;
334  }
335  }
336  return nullptr;
337  }
338 
339  // Access map through the underlying vector.
340  void RemoveAt(int position) {
341  const int start = start_.Value();
342  DCHECK_GE(position, start);
343  DCHECK_LT(position, elements_.size());
344  if (position > start) {
345  // Swap the current element with the one at the start position, and
346  // increase start.
347  const std::pair<int64, T*> copy = elements_[start];
348  elements_[start] = elements_[position];
349  elements_[position] = copy;
350  }
351  start_.Incr(solver_);
352  }
353 
354  const std::pair<int64, T*>& At(int position) const {
355  DCHECK_GE(position, start_.Value());
356  DCHECK_LT(position, elements_.size());
357  return elements_[position];
358  }
359 
360  void RemoveAll() { start_.SetValue(solver_, elements_.size()); }
361 
362  int start() const { return start_.Value(); }
363  int end() const { return elements_.size(); }
364  // Number of active elements.
365  int Size() const { return elements_.size() - start_.Value(); }
366 
367  // Removes the object permanently from the map.
368  void Uninsert(int64 value) {
369  for (int pos = 0; pos < elements_.size(); ++pos) {
370  if (elements_[pos].first == value) {
371  DCHECK_GE(pos, start_.Value());
372  const int last = elements_.size() - 1;
373  if (pos != last) { // Swap the current with the last.
374  elements_[pos] = elements_.back();
375  }
376  elements_.pop_back();
377  return;
378  }
379  }
380  LOG(FATAL) << "The element should have been removed";
381  }
382 
383  private:
384  Solver* const solver_;
385  const int64 range_min_;
386  NumericalRev<int> start_;
387  std::vector<std::pair<int64, T*>> elements_;
388  };
389 
390  // Base class for value watchers
391  class BaseValueWatcher : public Constraint {
392  public:
393  explicit BaseValueWatcher(Solver* const solver) : Constraint(solver) {}
394 
395  ~BaseValueWatcher() override {}
396 
397  virtual IntVar* GetOrMakeValueWatcher(int64 value) = 0;
398 
399  virtual void SetValueWatcher(IntVar* const boolvar, int64 value) = 0;
400  };
401 
402  // This class monitors the domain of the variable and updates the
403  // IsEqual/IsDifferent boolean variables accordingly.
404  class ValueWatcher : public BaseValueWatcher {
405  public:
406  class WatchDemon : public Demon {
407  public:
408  WatchDemon(ValueWatcher* const watcher, int64 value, IntVar* var)
409  : value_watcher_(watcher), value_(value), var_(var) {}
410  ~WatchDemon() override {}
411 
412  void Run(Solver* const solver) override {
413  value_watcher_->ProcessValueWatcher(value_, var_);
414  }
415 
416  private:
417  ValueWatcher* const value_watcher_;
418  const int64 value_;
419  IntVar* const var_;
420  };
421 
422  class VarDemon : public Demon {
423  public:
424  explicit VarDemon(ValueWatcher* const watcher)
425  : value_watcher_(watcher) {}
426 
427  ~VarDemon() override {}
428 
429  void Run(Solver* const solver) override { value_watcher_->ProcessVar(); }
430 
431  private:
432  ValueWatcher* const value_watcher_;
433  };
434 
435  ValueWatcher(Solver* const solver, DomainIntVar* const variable)
436  : BaseValueWatcher(solver),
437  variable_(variable),
438  hole_iterator_(variable_->MakeHoleIterator(true)),
439  var_demon_(nullptr),
440  watchers_(solver, variable->Min(), variable->Max()) {}
441 
442  ~ValueWatcher() override {}
443 
444  IntVar* GetOrMakeValueWatcher(int64 value) override {
445  IntVar* const watcher = watchers_.FindPtrOrNull(value, nullptr);
446  if (watcher != nullptr) return watcher;
447  if (variable_->Contains(value)) {
448  if (variable_->Bound()) {
449  return solver()->MakeIntConst(1);
450  } else {
451  const std::string vname = variable_->HasName()
452  ? variable_->name()
453  : variable_->DebugString();
454  const std::string bname =
455  absl::StrFormat("Watch<%s == %d>", vname, value);
456  IntVar* const boolvar = solver()->MakeBoolVar(bname);
457  watchers_.UnsafeRevInsert(value, boolvar);
458  if (posted_.Switched()) {
459  boolvar->WhenBound(
460  solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
461  var_demon_->desinhibit(solver());
462  }
463  return boolvar;
464  }
465  } else {
466  return variable_->solver()->MakeIntConst(0);
467  }
468  }
469 
470  void SetValueWatcher(IntVar* const boolvar, int64 value) override {
471  CHECK(watchers_.FindPtrOrNull(value, nullptr) == nullptr);
472  if (!boolvar->Bound()) {
473  watchers_.UnsafeRevInsert(value, boolvar);
474  if (posted_.Switched() && !boolvar->Bound()) {
475  boolvar->WhenBound(
476  solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
477  var_demon_->desinhibit(solver());
478  }
479  }
480  }
481 
482  void Post() override {
483  var_demon_ = solver()->RevAlloc(new VarDemon(this));
484  variable_->WhenDomain(var_demon_);
485  for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
486  const std::pair<int64, IntVar*>& w = watchers_.At(pos);
487  const int64 value = w.first;
488  IntVar* const boolvar = w.second;
489  if (!boolvar->Bound() && variable_->Contains(value)) {
490  boolvar->WhenBound(
491  solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
492  }
493  }
494  posted_.Switch(solver());
495  }
496 
497  void InitialPropagate() override {
498  if (variable_->Bound()) {
499  VariableBound();
500  } else {
501  for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
502  const std::pair<int64, IntVar*>& w = watchers_.At(pos);
503  const int64 value = w.first;
504  IntVar* const boolvar = w.second;
505  if (!variable_->Contains(value)) {
506  boolvar->SetValue(0);
507  watchers_.RemoveAt(pos);
508  } else {
509  if (boolvar->Bound()) {
510  ProcessValueWatcher(value, boolvar);
511  watchers_.RemoveAt(pos);
512  }
513  }
514  }
515  CheckInhibit();
516  }
517  }
518 
519  void ProcessValueWatcher(int64 value, IntVar* boolvar) {
520  if (boolvar->Min() == 0) {
521  if (variable_->Size() < 0xFFFFFF) {
522  variable_->RemoveValue(value);
523  } else {
524  // Delay removal.
525  solver()->AddConstraint(solver()->MakeNonEquality(variable_, value));
526  }
527  } else {
528  variable_->SetValue(value);
529  }
530  }
531 
532  void ProcessVar() {
533  const int kSmallList = 16;
534  if (variable_->Bound()) {
535  VariableBound();
536  } else if (watchers_.Size() <= kSmallList ||
537  variable_->Min() != variable_->OldMin() ||
538  variable_->Max() != variable_->OldMax()) {
539  // Brute force loop for small numbers of watchers, or if the bounds have
540  // changed, which would have required a sort (n log(n)) anyway to take
541  // advantage of.
542  ScanWatchers();
543  CheckInhibit();
544  } else {
545  // If there is no bitset, then there are no holes.
546  // In that case, the two loops above should have performed all
547  // propagation. Otherwise, scan the remaining watchers.
548  BitSet* const bitset = variable_->bitset();
549  if (bitset != nullptr && !watchers_.Empty()) {
550  if (bitset->NumHoles() * 2 < watchers_.Size()) {
551  for (const int64 hole : InitAndGetValues(hole_iterator_)) {
552  int pos = 0;
553  IntVar* const boolvar = watchers_.FindPtrOrNull(hole, &pos);
554  if (boolvar != nullptr) {
555  boolvar->SetValue(0);
556  watchers_.RemoveAt(pos);
557  }
558  }
559  } else {
560  ScanWatchers();
561  }
562  }
563  CheckInhibit();
564  }
565  }
566 
567  // Optimized case if the variable is bound.
568  void VariableBound() {
569  DCHECK(variable_->Bound());
570  const int64 value = variable_->Min();
571  for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
572  const std::pair<int64, IntVar*>& w = watchers_.At(pos);
573  w.second->SetValue(w.first == value);
574  }
575  watchers_.RemoveAll();
576  var_demon_->inhibit(solver());
577  }
578 
579  // Scans all the watchers to check and assign them.
580  void ScanWatchers() {
581  for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
582  const std::pair<int64, IntVar*>& w = watchers_.At(pos);
583  if (!variable_->Contains(w.first)) {
584  IntVar* const boolvar = w.second;
585  boolvar->SetValue(0);
586  watchers_.RemoveAt(pos);
587  }
588  }
589  }
590 
591  // If the set of active watchers is empty, we can inhibit the demon on the
592  // main variable.
593  void CheckInhibit() {
594  if (watchers_.Empty()) {
595  var_demon_->inhibit(solver());
596  }
597  }
598 
599  void Accept(ModelVisitor* const visitor) const override {
600  visitor->BeginVisitConstraint(ModelVisitor::kVarValueWatcher, this);
601  visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
602  variable_);
603  std::vector<int64> all_coefficients;
604  std::vector<IntVar*> all_bool_vars;
605  for (int position = watchers_.start(); position < watchers_.end();
606  ++position) {
607  const std::pair<int64, IntVar*>& w = watchers_.At(position);
608  all_coefficients.push_back(w.first);
609  all_bool_vars.push_back(w.second);
610  }
611  visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
612  all_bool_vars);
613  visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
614  all_coefficients);
615  visitor->EndVisitConstraint(ModelVisitor::kVarValueWatcher, this);
616  }
617 
618  std::string DebugString() const override {
619  return absl::StrFormat("ValueWatcher(%s)", variable_->DebugString());
620  }
621 
622  private:
623  DomainIntVar* const variable_;
624  IntVarIterator* const hole_iterator_;
625  RevSwitch posted_;
626  Demon* var_demon_;
627  RevIntPtrMap<IntVar> watchers_;
628  };
629 
630  // Optimized case for small maps.
631  class DenseValueWatcher : public BaseValueWatcher {
632  public:
633  class WatchDemon : public Demon {
634  public:
635  WatchDemon(DenseValueWatcher* const watcher, int64 value, IntVar* var)
636  : value_watcher_(watcher), value_(value), var_(var) {}
637  ~WatchDemon() override {}
638 
639  void Run(Solver* const solver) override {
640  value_watcher_->ProcessValueWatcher(value_, var_);
641  }
642 
643  private:
644  DenseValueWatcher* const value_watcher_;
645  const int64 value_;
646  IntVar* const var_;
647  };
648 
649  class VarDemon : public Demon {
650  public:
651  explicit VarDemon(DenseValueWatcher* const watcher)
652  : value_watcher_(watcher) {}
653 
654  ~VarDemon() override {}
655 
656  void Run(Solver* const solver) override { value_watcher_->ProcessVar(); }
657 
658  private:
659  DenseValueWatcher* const value_watcher_;
660  };
661 
662  DenseValueWatcher(Solver* const solver, DomainIntVar* const variable)
663  : BaseValueWatcher(solver),
664  variable_(variable),
665  hole_iterator_(variable_->MakeHoleIterator(true)),
666  var_demon_(nullptr),
667  offset_(variable->Min()),
668  watchers_(variable->Max() - variable->Min() + 1, nullptr),
669  active_watchers_(0) {}
670 
671  ~DenseValueWatcher() override {}
672 
673  IntVar* GetOrMakeValueWatcher(int64 value) override {
674  const int64 var_max = offset_ + watchers_.size() - 1; // Bad cast.
675  if (value < offset_ || value > var_max) {
676  return solver()->MakeIntConst(0);
677  }
678  const int index = value - offset_;
679  IntVar* const watcher = watchers_[index];
680  if (watcher != nullptr) return watcher;
681  if (variable_->Contains(value)) {
682  if (variable_->Bound()) {
683  return solver()->MakeIntConst(1);
684  } else {
685  const std::string vname = variable_->HasName()
686  ? variable_->name()
687  : variable_->DebugString();
688  const std::string bname =
689  absl::StrFormat("Watch<%s == %d>", vname, value);
690  IntVar* const boolvar = solver()->MakeBoolVar(bname);
691  RevInsert(index, boolvar);
692  if (posted_.Switched()) {
693  boolvar->WhenBound(
694  solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
695  var_demon_->desinhibit(solver());
696  }
697  return boolvar;
698  }
699  } else {
700  return variable_->solver()->MakeIntConst(0);
701  }
702  }
703 
704  void SetValueWatcher(IntVar* const boolvar, int64 value) override {
705  const int index = value - offset_;
706  CHECK(watchers_[index] == nullptr);
707  if (!boolvar->Bound()) {
708  RevInsert(index, boolvar);
709  if (posted_.Switched() && !boolvar->Bound()) {
710  boolvar->WhenBound(
711  solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
712  var_demon_->desinhibit(solver());
713  }
714  }
715  }
716 
717  void Post() override {
718  var_demon_ = solver()->RevAlloc(new VarDemon(this));
719  variable_->WhenDomain(var_demon_);
720  for (int pos = 0; pos < watchers_.size(); ++pos) {
721  const int64 value = pos + offset_;
722  IntVar* const boolvar = watchers_[pos];
723  if (boolvar != nullptr && !boolvar->Bound() &&
724  variable_->Contains(value)) {
725  boolvar->WhenBound(
726  solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
727  }
728  }
729  posted_.Switch(solver());
730  }
731 
732  void InitialPropagate() override {
733  if (variable_->Bound()) {
734  VariableBound();
735  } else {
736  for (int pos = 0; pos < watchers_.size(); ++pos) {
737  IntVar* const boolvar = watchers_[pos];
738  if (boolvar == nullptr) continue;
739  const int64 value = pos + offset_;
740  if (!variable_->Contains(value)) {
741  boolvar->SetValue(0);
742  RevRemove(pos);
743  } else if (boolvar->Bound()) {
744  ProcessValueWatcher(value, boolvar);
745  RevRemove(pos);
746  }
747  }
748  if (active_watchers_.Value() == 0) {
749  var_demon_->inhibit(solver());
750  }
751  }
752  }
753 
754  void ProcessValueWatcher(int64 value, IntVar* boolvar) {
755  if (boolvar->Min() == 0) {
756  variable_->RemoveValue(value);
757  } else {
758  variable_->SetValue(value);
759  }
760  }
761 
762  void ProcessVar() {
763  if (variable_->Bound()) {
764  VariableBound();
765  } else {
766  // Brute force loop for small numbers of watchers.
767  ScanWatchers();
768  if (active_watchers_.Value() == 0) {
769  var_demon_->inhibit(solver());
770  }
771  }
772  }
773 
774  // Optimized case if the variable is bound.
775  void VariableBound() {
776  DCHECK(variable_->Bound());
777  const int64 value = variable_->Min();
778  for (int pos = 0; pos < watchers_.size(); ++pos) {
779  IntVar* const boolvar = watchers_[pos];
780  if (boolvar != nullptr) {
781  boolvar->SetValue(pos + offset_ == value);
782  RevRemove(pos);
783  }
784  }
785  var_demon_->inhibit(solver());
786  }
787 
788  // Scans all the watchers to check and assign them.
789  void ScanWatchers() {
790  const int64 old_min_index = variable_->OldMin() - offset_;
791  const int64 old_max_index = variable_->OldMax() - offset_;
792  const int64 min_index = variable_->Min() - offset_;
793  const int64 max_index = variable_->Max() - offset_;
794  for (int pos = old_min_index; pos < min_index; ++pos) {
795  IntVar* const boolvar = watchers_[pos];
796  if (boolvar != nullptr) {
797  boolvar->SetValue(0);
798  RevRemove(pos);
799  }
800  }
801  for (int pos = max_index + 1; pos <= old_max_index; ++pos) {
802  IntVar* const boolvar = watchers_[pos];
803  if (boolvar != nullptr) {
804  boolvar->SetValue(0);
805  RevRemove(pos);
806  }
807  }
808  BitSet* const bitset = variable_->bitset();
809  if (bitset != nullptr) {
810  if (bitset->NumHoles() * 2 < active_watchers_.Value()) {
811  for (const int64 hole : InitAndGetValues(hole_iterator_)) {
812  IntVar* const boolvar = watchers_[hole - offset_];
813  if (boolvar != nullptr) {
814  boolvar->SetValue(0);
815  RevRemove(hole - offset_);
816  }
817  }
818  } else {
819  for (int pos = min_index + 1; pos < max_index; ++pos) {
820  IntVar* const boolvar = watchers_[pos];
821  if (boolvar != nullptr && !variable_->Contains(offset_ + pos)) {
822  boolvar->SetValue(0);
823  RevRemove(pos);
824  }
825  }
826  }
827  }
828  }
829 
830  void RevRemove(int pos) {
831  solver()->SaveValue(reinterpret_cast<void**>(&watchers_[pos]));
832  watchers_[pos] = nullptr;
833  active_watchers_.Decr(solver());
834  }
835 
836  void RevInsert(int pos, IntVar* boolvar) {
837  solver()->SaveValue(reinterpret_cast<void**>(&watchers_[pos]));
838  watchers_[pos] = boolvar;
839  active_watchers_.Incr(solver());
840  }
841 
842  void Accept(ModelVisitor* const visitor) const override {
843  visitor->BeginVisitConstraint(ModelVisitor::kVarValueWatcher, this);
844  visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
845  variable_);
846  std::vector<int64> all_coefficients;
847  std::vector<IntVar*> all_bool_vars;
848  for (int position = 0; position < watchers_.size(); ++position) {
849  if (watchers_[position] != nullptr) {
850  all_coefficients.push_back(position + offset_);
851  all_bool_vars.push_back(watchers_[position]);
852  }
853  }
854  visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
855  all_bool_vars);
856  visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
857  all_coefficients);
858  visitor->EndVisitConstraint(ModelVisitor::kVarValueWatcher, this);
859  }
860 
861  std::string DebugString() const override {
862  return absl::StrFormat("DenseValueWatcher(%s)", variable_->DebugString());
863  }
864 
865  private:
866  DomainIntVar* const variable_;
867  IntVarIterator* const hole_iterator_;
868  RevSwitch posted_;
869  Demon* var_demon_;
870  const int64 offset_;
871  std::vector<IntVar*> watchers_;
872  NumericalRev<int> active_watchers_;
873  };
874 
875  class BaseUpperBoundWatcher : public Constraint {
876  public:
877  explicit BaseUpperBoundWatcher(Solver* const solver) : Constraint(solver) {}
878 
879  ~BaseUpperBoundWatcher() override {}
880 
881  virtual IntVar* GetOrMakeUpperBoundWatcher(int64 value) = 0;
882 
883  virtual void SetUpperBoundWatcher(IntVar* const boolvar, int64 value) = 0;
884  };
885 
886  // This class watches the bounds of the variable and updates the
887  // IsGreater/IsGreaterOrEqual/IsLess/IsLessOrEqual demons
888  // accordingly.
889  class UpperBoundWatcher : public BaseUpperBoundWatcher {
890  public:
891  class WatchDemon : public Demon {
892  public:
893  WatchDemon(UpperBoundWatcher* const watcher, int64 index,
894  IntVar* const var)
895  : value_watcher_(watcher), index_(index), var_(var) {}
896  ~WatchDemon() override {}
897 
898  void Run(Solver* const solver) override {
899  value_watcher_->ProcessUpperBoundWatcher(index_, var_);
900  }
901 
902  private:
903  UpperBoundWatcher* const value_watcher_;
904  const int64 index_;
905  IntVar* const var_;
906  };
907 
908  class VarDemon : public Demon {
909  public:
910  explicit VarDemon(UpperBoundWatcher* const watcher)
911  : value_watcher_(watcher) {}
912  ~VarDemon() override {}
913 
914  void Run(Solver* const solver) override { value_watcher_->ProcessVar(); }
915 
916  private:
917  UpperBoundWatcher* const value_watcher_;
918  };
919 
920  UpperBoundWatcher(Solver* const solver, DomainIntVar* const variable)
921  : BaseUpperBoundWatcher(solver),
922  variable_(variable),
923  var_demon_(nullptr),
924  watchers_(solver, variable->Min(), variable->Max()),
925  start_(0),
926  end_(0),
927  sorted_(false) {}
928 
929  ~UpperBoundWatcher() override {}
930 
931  IntVar* GetOrMakeUpperBoundWatcher(int64 value) override {
932  IntVar* const watcher = watchers_.FindPtrOrNull(value, nullptr);
933  if (watcher != nullptr) {
934  return watcher;
935  }
936  if (variable_->Max() >= value) {
937  if (variable_->Min() >= value) {
938  return solver()->MakeIntConst(1);
939  } else {
940  const std::string vname = variable_->HasName()
941  ? variable_->name()
942  : variable_->DebugString();
943  const std::string bname =
944  absl::StrFormat("Watch<%s >= %d>", vname, value);
945  IntVar* const boolvar = solver()->MakeBoolVar(bname);
946  watchers_.UnsafeRevInsert(value, boolvar);
947  if (posted_.Switched()) {
948  boolvar->WhenBound(
949  solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
950  var_demon_->desinhibit(solver());
951  sorted_ = false;
952  }
953  return boolvar;
954  }
955  } else {
956  return variable_->solver()->MakeIntConst(0);
957  }
958  }
959 
960  void SetUpperBoundWatcher(IntVar* const boolvar, int64 value) override {
961  CHECK(watchers_.FindPtrOrNull(value, nullptr) == nullptr);
962  watchers_.UnsafeRevInsert(value, boolvar);
963  if (posted_.Switched() && !boolvar->Bound()) {
964  boolvar->WhenBound(
965  solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
966  var_demon_->desinhibit(solver());
967  sorted_ = false;
968  }
969  }
970 
971  void Post() override {
972  const int kTooSmallToSort = 8;
973  var_demon_ = solver()->RevAlloc(new VarDemon(this));
974  variable_->WhenRange(var_demon_);
975 
976  if (watchers_.Size() > kTooSmallToSort) {
977  watchers_.SortActive();
978  sorted_ = true;
979  start_.SetValue(solver(), watchers_.start());
980  end_.SetValue(solver(), watchers_.end() - 1);
981  }
982 
983  for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
984  const std::pair<int64, IntVar*>& w = watchers_.At(pos);
985  IntVar* const boolvar = w.second;
986  const int64 value = w.first;
987  if (!boolvar->Bound() && value > variable_->Min() &&
988  value <= variable_->Max()) {
989  boolvar->WhenBound(
990  solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
991  }
992  }
993  posted_.Switch(solver());
994  }
995 
996  void InitialPropagate() override {
997  const int64 var_min = variable_->Min();
998  const int64 var_max = variable_->Max();
999  if (sorted_) {
1000  while (start_.Value() <= end_.Value()) {
1001  const std::pair<int64, IntVar*>& w = watchers_.At(start_.Value());
1002  if (w.first <= var_min) {
1003  w.second->SetValue(1);
1004  start_.Incr(solver());
1005  } else {
1006  break;
1007  }
1008  }
1009  while (end_.Value() >= start_.Value()) {
1010  const std::pair<int64, IntVar*>& w = watchers_.At(end_.Value());
1011  if (w.first > var_max) {
1012  w.second->SetValue(0);
1013  end_.Decr(solver());
1014  } else {
1015  break;
1016  }
1017  }
1018  for (int i = start_.Value(); i <= end_.Value(); ++i) {
1019  const std::pair<int64, IntVar*>& w = watchers_.At(i);
1020  if (w.second->Bound()) {
1021  ProcessUpperBoundWatcher(w.first, w.second);
1022  }
1023  }
1024  if (start_.Value() > end_.Value()) {
1025  var_demon_->inhibit(solver());
1026  }
1027  } else {
1028  for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
1029  const std::pair<int64, IntVar*>& w = watchers_.At(pos);
1030  const int64 value = w.first;
1031  IntVar* const boolvar = w.second;
1032 
1033  if (value <= var_min) {
1034  boolvar->SetValue(1);
1035  watchers_.RemoveAt(pos);
1036  } else if (value > var_max) {
1037  boolvar->SetValue(0);
1038  watchers_.RemoveAt(pos);
1039  } else if (boolvar->Bound()) {
1040  ProcessUpperBoundWatcher(value, boolvar);
1041  watchers_.RemoveAt(pos);
1042  }
1043  }
1044  }
1045  }
1046 
1047  void Accept(ModelVisitor* const visitor) const override {
1048  visitor->BeginVisitConstraint(ModelVisitor::kVarBoundWatcher, this);
1049  visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
1050  variable_);
1051  std::vector<int64> all_coefficients;
1052  std::vector<IntVar*> all_bool_vars;
1053  for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
1054  const std::pair<int64, IntVar*>& w = watchers_.At(pos);
1055  all_coefficients.push_back(w.first);
1056  all_bool_vars.push_back(w.second);
1057  }
1058  visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
1059  all_bool_vars);
1060  visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
1061  all_coefficients);
1062  visitor->EndVisitConstraint(ModelVisitor::kVarBoundWatcher, this);
1063  }
1064 
1065  std::string DebugString() const override {
1066  return absl::StrFormat("UpperBoundWatcher(%s)", variable_->DebugString());
1067  }
1068 
1069  private:
1070  void ProcessUpperBoundWatcher(int64 value, IntVar* const boolvar) {
1071  if (boolvar->Min() == 0) {
1072  variable_->SetMax(value - 1);
1073  } else {
1074  variable_->SetMin(value);
1075  }
1076  }
1077 
1078  void ProcessVar() {
1079  const int64 var_min = variable_->Min();
1080  const int64 var_max = variable_->Max();
1081  if (sorted_) {
1082  while (start_.Value() <= end_.Value()) {
1083  const std::pair<int64, IntVar*>& w = watchers_.At(start_.Value());
1084  if (w.first <= var_min) {
1085  w.second->SetValue(1);
1086  start_.Incr(solver());
1087  } else {
1088  break;
1089  }
1090  }
1091  while (end_.Value() >= start_.Value()) {
1092  const std::pair<int64, IntVar*>& w = watchers_.At(end_.Value());
1093  if (w.first > var_max) {
1094  w.second->SetValue(0);
1095  end_.Decr(solver());
1096  } else {
1097  break;
1098  }
1099  }
1100  if (start_.Value() > end_.Value()) {
1101  var_demon_->inhibit(solver());
1102  }
1103  } else {
1104  for (int pos = watchers_.start(); pos < watchers_.end(); ++pos) {
1105  const std::pair<int64, IntVar*>& w = watchers_.At(pos);
1106  const int64 value = w.first;
1107  IntVar* const boolvar = w.second;
1108 
1109  if (value <= var_min) {
1110  boolvar->SetValue(1);
1111  watchers_.RemoveAt(pos);
1112  } else if (value > var_max) {
1113  boolvar->SetValue(0);
1114  watchers_.RemoveAt(pos);
1115  }
1116  }
1117  if (watchers_.Empty()) {
1118  var_demon_->inhibit(solver());
1119  }
1120  }
1121  }
1122 
1123  DomainIntVar* const variable_;
1124  RevSwitch posted_;
1125  Demon* var_demon_;
1126  RevIntPtrMap<IntVar> watchers_;
1127  NumericalRev<int> start_;
1128  NumericalRev<int> end_;
1129  bool sorted_;
1130  };
1131 
1132  // Optimized case for small maps.
1133  class DenseUpperBoundWatcher : public BaseUpperBoundWatcher {
1134  public:
1135  class WatchDemon : public Demon {
1136  public:
1137  WatchDemon(DenseUpperBoundWatcher* const watcher, int64 value,
1138  IntVar* var)
1139  : value_watcher_(watcher), value_(value), var_(var) {}
1140  ~WatchDemon() override {}
1141 
1142  void Run(Solver* const solver) override {
1143  value_watcher_->ProcessUpperBoundWatcher(value_, var_);
1144  }
1145 
1146  private:
1147  DenseUpperBoundWatcher* const value_watcher_;
1148  const int64 value_;
1149  IntVar* const var_;
1150  };
1151 
1152  class VarDemon : public Demon {
1153  public:
1154  explicit VarDemon(DenseUpperBoundWatcher* const watcher)
1155  : value_watcher_(watcher) {}
1156 
1157  ~VarDemon() override {}
1158 
1159  void Run(Solver* const solver) override { value_watcher_->ProcessVar(); }
1160 
1161  private:
1162  DenseUpperBoundWatcher* const value_watcher_;
1163  };
1164 
1165  DenseUpperBoundWatcher(Solver* const solver, DomainIntVar* const variable)
1166  : BaseUpperBoundWatcher(solver),
1167  variable_(variable),
1168  var_demon_(nullptr),
1169  offset_(variable->Min()),
1170  watchers_(variable->Max() - variable->Min() + 1, nullptr),
1171  active_watchers_(0) {}
1172 
1173  ~DenseUpperBoundWatcher() override {}
1174 
1175  IntVar* GetOrMakeUpperBoundWatcher(int64 value) override {
1176  if (variable_->Max() >= value) {
1177  if (variable_->Min() >= value) {
1178  return solver()->MakeIntConst(1);
1179  } else {
1180  const std::string vname = variable_->HasName()
1181  ? variable_->name()
1182  : variable_->DebugString();
1183  const std::string bname =
1184  absl::StrFormat("Watch<%s >= %d>", vname, value);
1185  IntVar* const boolvar = solver()->MakeBoolVar(bname);
1186  RevInsert(value - offset_, boolvar);
1187  if (posted_.Switched()) {
1188  boolvar->WhenBound(
1189  solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
1190  var_demon_->desinhibit(solver());
1191  }
1192  return boolvar;
1193  }
1194  } else {
1195  return variable_->solver()->MakeIntConst(0);
1196  }
1197  }
1198 
1199  void SetUpperBoundWatcher(IntVar* const boolvar, int64 value) override {
1200  const int index = value - offset_;
1201  CHECK(watchers_[index] == nullptr);
1202  if (!boolvar->Bound()) {
1203  RevInsert(index, boolvar);
1204  if (posted_.Switched() && !boolvar->Bound()) {
1205  boolvar->WhenBound(
1206  solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
1207  var_demon_->desinhibit(solver());
1208  }
1209  }
1210  }
1211 
1212  void Post() override {
1213  var_demon_ = solver()->RevAlloc(new VarDemon(this));
1214  variable_->WhenRange(var_demon_);
1215  for (int pos = 0; pos < watchers_.size(); ++pos) {
1216  const int64 value = pos + offset_;
1217  IntVar* const boolvar = watchers_[pos];
1218  if (boolvar != nullptr && !boolvar->Bound() &&
1219  value > variable_->Min() && value <= variable_->Max()) {
1220  boolvar->WhenBound(
1221  solver()->RevAlloc(new WatchDemon(this, value, boolvar)));
1222  }
1223  }
1224  posted_.Switch(solver());
1225  }
1226 
1227  void InitialPropagate() override {
1228  for (int pos = 0; pos < watchers_.size(); ++pos) {
1229  IntVar* const boolvar = watchers_[pos];
1230  if (boolvar == nullptr) continue;
1231  const int64 value = pos + offset_;
1232  if (value <= variable_->Min()) {
1233  boolvar->SetValue(1);
1234  RevRemove(pos);
1235  } else if (value > variable_->Max()) {
1236  boolvar->SetValue(0);
1237  RevRemove(pos);
1238  } else if (boolvar->Bound()) {
1239  ProcessUpperBoundWatcher(value, boolvar);
1240  RevRemove(pos);
1241  }
1242  }
1243  if (active_watchers_.Value() == 0) {
1244  var_demon_->inhibit(solver());
1245  }
1246  }
1247 
1248  void ProcessUpperBoundWatcher(int64 value, IntVar* boolvar) {
1249  if (boolvar->Min() == 0) {
1250  variable_->SetMax(value - 1);
1251  } else {
1252  variable_->SetMin(value);
1253  }
1254  }
1255 
1256  void ProcessVar() {
1257  const int64 old_min_index = variable_->OldMin() - offset_;
1258  const int64 old_max_index = variable_->OldMax() - offset_;
1259  const int64 min_index = variable_->Min() - offset_;
1260  const int64 max_index = variable_->Max() - offset_;
1261  for (int pos = old_min_index; pos <= min_index; ++pos) {
1262  IntVar* const boolvar = watchers_[pos];
1263  if (boolvar != nullptr) {
1264  boolvar->SetValue(1);
1265  RevRemove(pos);
1266  }
1267  }
1268 
1269  for (int pos = max_index + 1; pos <= old_max_index; ++pos) {
1270  IntVar* const boolvar = watchers_[pos];
1271  if (boolvar != nullptr) {
1272  boolvar->SetValue(0);
1273  RevRemove(pos);
1274  }
1275  }
1276  if (active_watchers_.Value() == 0) {
1277  var_demon_->inhibit(solver());
1278  }
1279  }
1280 
1281  void RevRemove(int pos) {
1282  solver()->SaveValue(reinterpret_cast<void**>(&watchers_[pos]));
1283  watchers_[pos] = nullptr;
1284  active_watchers_.Decr(solver());
1285  }
1286 
1287  void RevInsert(int pos, IntVar* boolvar) {
1288  solver()->SaveValue(reinterpret_cast<void**>(&watchers_[pos]));
1289  watchers_[pos] = boolvar;
1290  active_watchers_.Incr(solver());
1291  }
1292 
1293  void Accept(ModelVisitor* const visitor) const override {
1294  visitor->BeginVisitConstraint(ModelVisitor::kVarBoundWatcher, this);
1295  visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
1296  variable_);
1297  std::vector<int64> all_coefficients;
1298  std::vector<IntVar*> all_bool_vars;
1299  for (int position = 0; position < watchers_.size(); ++position) {
1300  if (watchers_[position] != nullptr) {
1301  all_coefficients.push_back(position + offset_);
1302  all_bool_vars.push_back(watchers_[position]);
1303  }
1304  }
1305  visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,
1306  all_bool_vars);
1307  visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,
1308  all_coefficients);
1309  visitor->EndVisitConstraint(ModelVisitor::kVarBoundWatcher, this);
1310  }
1311 
1312  std::string DebugString() const override {
1313  return absl::StrFormat("DenseUpperBoundWatcher(%s)",
1314  variable_->DebugString());
1315  }
1316 
1317  private:
1318  DomainIntVar* const variable_;
1319  RevSwitch posted_;
1320  Demon* var_demon_;
1321  const int64 offset_;
1322  std::vector<IntVar*> watchers_;
1323  NumericalRev<int> active_watchers_;
1324  };
1325 
1326  // ----- Main Class -----
1327  DomainIntVar(Solver* const s, int64 vmin, int64 vmax,
1328  const std::string& name);
1329  DomainIntVar(Solver* const s, const std::vector<int64>& sorted_values,
1330  const std::string& name);
1331  ~DomainIntVar() override;
1332 
1333  int64 Min() const override { return min_.Value(); }
1334  void SetMin(int64 m) override;
1335  int64 Max() const override { return max_.Value(); }
1336  void SetMax(int64 m) override;
1337  void SetRange(int64 mi, int64 ma) override;
1338  void SetValue(int64 v) override;
1339  bool Bound() const override { return (min_.Value() == max_.Value()); }
1340  int64 Value() const override {
1341  CHECK_EQ(min_.Value(), max_.Value())
1342  << " variable " << DebugString() << " is not bound.";
1343  return min_.Value();
1344  }
1345  void RemoveValue(int64 v) override;
1346  void RemoveInterval(int64 l, int64 u) override;
1347  void CreateBits();
1348  void WhenBound(Demon* d) override {
1349  if (min_.Value() != max_.Value()) {
1350  if (d->priority() == Solver::DELAYED_PRIORITY) {
1351  delayed_bound_demons_.PushIfNotTop(solver(),
1352  solver()->RegisterDemon(d));
1353  } else {
1354  bound_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
1355  }
1356  }
1357  }
1358  void WhenRange(Demon* d) override {
1359  if (min_.Value() != max_.Value()) {
1360  if (d->priority() == Solver::DELAYED_PRIORITY) {
1361  delayed_range_demons_.PushIfNotTop(solver(),
1362  solver()->RegisterDemon(d));
1363  } else {
1364  range_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
1365  }
1366  }
1367  }
1368  void WhenDomain(Demon* d) override {
1369  if (min_.Value() != max_.Value()) {
1370  if (d->priority() == Solver::DELAYED_PRIORITY) {
1371  delayed_domain_demons_.PushIfNotTop(solver(),
1372  solver()->RegisterDemon(d));
1373  } else {
1374  domain_demons_.PushIfNotTop(solver(), solver()->RegisterDemon(d));
1375  }
1376  }
1377  }
1378 
1379  IntVar* IsEqual(int64 constant) override {
1380  Solver* const s = solver();
1381  if (constant == min_.Value() && value_watcher_ == nullptr) {
1382  return s->MakeIsLessOrEqualCstVar(this, constant);
1383  }
1384  if (constant == max_.Value() && value_watcher_ == nullptr) {
1385  return s->MakeIsGreaterOrEqualCstVar(this, constant);
1386  }
1387  if (!Contains(constant)) {
1388  return s->MakeIntConst(int64{0});
1389  }
1390  if (Bound() && min_.Value() == constant) {
1391  return s->MakeIntConst(int64{1});
1392  }
1393  IntExpr* const cache = s->Cache()->FindExprConstantExpression(
1394  this, constant, ModelCache::EXPR_CONSTANT_IS_EQUAL);
1395  if (cache != nullptr) {
1396  return cache->Var();
1397  } else {
1398  if (value_watcher_ == nullptr) {
1399  if (CapSub(Max(), Min()) <= 256) {
1400  solver()->SaveAndSetValue(
1401  reinterpret_cast<void**>(&value_watcher_),
1402  reinterpret_cast<void*>(
1403  solver()->RevAlloc(new DenseValueWatcher(solver(), this))));
1404 
1405  } else {
1406  solver()->SaveAndSetValue(reinterpret_cast<void**>(&value_watcher_),
1407  reinterpret_cast<void*>(solver()->RevAlloc(
1408  new ValueWatcher(solver(), this))));
1409  }
1410  solver()->AddConstraint(value_watcher_);
1411  }
1412  IntVar* const boolvar = value_watcher_->GetOrMakeValueWatcher(constant);
1413  s->Cache()->InsertExprConstantExpression(
1414  boolvar, this, constant, ModelCache::EXPR_CONSTANT_IS_EQUAL);
1415  return boolvar;
1416  }
1417  }
1418 
1419  Constraint* SetIsEqual(const std::vector<int64>& values,
1420  const std::vector<IntVar*>& vars) {
1421  if (value_watcher_ == nullptr) {
1422  solver()->SaveAndSetValue(reinterpret_cast<void**>(&value_watcher_),
1423  reinterpret_cast<void*>(solver()->RevAlloc(
1424  new ValueWatcher(solver(), this))));
1425  for (int i = 0; i < vars.size(); ++i) {
1426  value_watcher_->SetValueWatcher(vars[i], values[i]);
1427  }
1428  }
1429  return value_watcher_;
1430  }
1431 
1432  IntVar* IsDifferent(int64 constant) override {
1433  Solver* const s = solver();
1434  if (constant == min_.Value() && value_watcher_ == nullptr) {
1435  return s->MakeIsGreaterOrEqualCstVar(this, constant + 1);
1436  }
1437  if (constant == max_.Value() && value_watcher_ == nullptr) {
1438  return s->MakeIsLessOrEqualCstVar(this, constant - 1);
1439  }
1440  if (!Contains(constant)) {
1441  return s->MakeIntConst(int64{1});
1442  }
1443  if (Bound() && min_.Value() == constant) {
1444  return s->MakeIntConst(int64{0});
1445  }
1446  IntExpr* const cache = s->Cache()->FindExprConstantExpression(
1447  this, constant, ModelCache::EXPR_CONSTANT_IS_NOT_EQUAL);
1448  if (cache != nullptr) {
1449  return cache->Var();
1450  } else {
1451  IntVar* const boolvar = s->MakeDifference(1, IsEqual(constant))->Var();
1452  s->Cache()->InsertExprConstantExpression(
1453  boolvar, this, constant, ModelCache::EXPR_CONSTANT_IS_NOT_EQUAL);
1454  return boolvar;
1455  }
1456  }
1457 
1458  IntVar* IsGreaterOrEqual(int64 constant) override {
1459  Solver* const s = solver();
1460  if (max_.Value() < constant) {
1461  return s->MakeIntConst(int64{0});
1462  }
1463  if (min_.Value() >= constant) {
1464  return s->MakeIntConst(int64{1});
1465  }
1466  IntExpr* const cache = s->Cache()->FindExprConstantExpression(
1468  if (cache != nullptr) {
1469  return cache->Var();
1470  } else {
1471  if (bound_watcher_ == nullptr) {
1472  if (CapSub(Max(), Min()) <= 256) {
1473  solver()->SaveAndSetValue(
1474  reinterpret_cast<void**>(&bound_watcher_),
1475  reinterpret_cast<void*>(solver()->RevAlloc(
1476  new DenseUpperBoundWatcher(solver(), this))));
1477  solver()->AddConstraint(bound_watcher_);
1478  } else {
1479  solver()->SaveAndSetValue(
1480  reinterpret_cast<void**>(&bound_watcher_),
1481  reinterpret_cast<void*>(
1482  solver()->RevAlloc(new UpperBoundWatcher(solver(), this))));
1483  solver()->AddConstraint(bound_watcher_);
1484  }
1485  }
1486  IntVar* const boolvar =
1487  bound_watcher_->GetOrMakeUpperBoundWatcher(constant);
1488  s->Cache()->InsertExprConstantExpression(
1489  boolvar, this, constant,
1491  return boolvar;
1492  }
1493  }
1494 
1495  Constraint* SetIsGreaterOrEqual(const std::vector<int64>& values,
1496  const std::vector<IntVar*>& vars) {
1497  if (bound_watcher_ == nullptr) {
1498  if (CapSub(Max(), Min()) <= 256) {
1499  solver()->SaveAndSetValue(
1500  reinterpret_cast<void**>(&bound_watcher_),
1501  reinterpret_cast<void*>(solver()->RevAlloc(
1502  new DenseUpperBoundWatcher(solver(), this))));
1503  solver()->AddConstraint(bound_watcher_);
1504  } else {
1505  solver()->SaveAndSetValue(reinterpret_cast<void**>(&bound_watcher_),
1506  reinterpret_cast<void*>(solver()->RevAlloc(
1507  new UpperBoundWatcher(solver(), this))));
1508  solver()->AddConstraint(bound_watcher_);
1509  }
1510  for (int i = 0; i < values.size(); ++i) {
1511  bound_watcher_->SetUpperBoundWatcher(vars[i], values[i]);
1512  }
1513  }
1514  return bound_watcher_;
1515  }
1516 
1517  IntVar* IsLessOrEqual(int64 constant) override {
1518  Solver* const s = solver();
1519  IntExpr* const cache = s->Cache()->FindExprConstantExpression(
1521  if (cache != nullptr) {
1522  return cache->Var();
1523  } else {
1524  IntVar* const boolvar =
1525  s->MakeDifference(1, IsGreaterOrEqual(constant + 1))->Var();
1526  s->Cache()->InsertExprConstantExpression(
1527  boolvar, this, constant, ModelCache::EXPR_CONSTANT_IS_LESS_OR_EQUAL);
1528  return boolvar;
1529  }
1530  }
1531 
1532  void Process();
1533  void Push();
1534  void CleanInProcess();
1535  uint64 Size() const override {
1536  if (bits_ != nullptr) return bits_->Size();
1537  return (static_cast<uint64>(max_.Value()) -
1538  static_cast<uint64>(min_.Value()) + 1);
1539  }
1540  bool Contains(int64 v) const override {
1541  if (v < min_.Value() || v > max_.Value()) return false;
1542  return (bits_ == nullptr ? true : bits_->Contains(v));
1543  }
1544  IntVarIterator* MakeHoleIterator(bool reversible) const override;
1545  IntVarIterator* MakeDomainIterator(bool reversible) const override;
1546  int64 OldMin() const override { return std::min(old_min_, min_.Value()); }
1547  int64 OldMax() const override { return std::max(old_max_, max_.Value()); }
1548 
1549  std::string DebugString() const override;
1550  BitSet* bitset() const { return bits_; }
1551  int VarType() const override { return DOMAIN_INT_VAR; }
1552  std::string BaseName() const override { return "IntegerVar"; }
1553 
1554  friend class PlusCstDomainIntVar;
1555  friend class LinkExprAndDomainIntVar;
1556 
1557  private:
1558  void CheckOldMin() {
1559  if (old_min_ > min_.Value()) {
1560  old_min_ = min_.Value();
1561  }
1562  }
1563  void CheckOldMax() {
1564  if (old_max_ < max_.Value()) {
1565  old_max_ = max_.Value();
1566  }
1567  }
1568  Rev<int64> min_;
1569  Rev<int64> max_;
1570  int64 old_min_;
1571  int64 old_max_;
1572  int64 new_min_;
1573  int64 new_max_;
1574  SimpleRevFIFO<Demon*> bound_demons_;
1575  SimpleRevFIFO<Demon*> range_demons_;
1576  SimpleRevFIFO<Demon*> domain_demons_;
1577  SimpleRevFIFO<Demon*> delayed_bound_demons_;
1578  SimpleRevFIFO<Demon*> delayed_range_demons_;
1579  SimpleRevFIFO<Demon*> delayed_domain_demons_;
1580  QueueHandler handler_;
1581  bool in_process_;
1582  BitSet* bits_;
1583  BaseValueWatcher* value_watcher_;
1584  BaseUpperBoundWatcher* bound_watcher_;
1585 };
1586 
1587 // ----- BitSet -----
1588 
1589 // Return whether an integer interval [a..b] (inclusive) contains at most
1590 // K values, i.e. b - a < K, in a way that's robust to overflows.
1591 // For performance reasons, in opt mode it doesn't check that [a, b] is a
1592 // valid interval, nor that K is nonnegative.
1593 inline bool ClosedIntervalNoLargerThan(int64 a, int64 b, int64 K) {
1594  DCHECK_LE(a, b);
1595  DCHECK_GE(K, 0);
1596  if (a > 0) {
1597  return a > b - K;
1598  } else {
1599  return a + K > b;
1600  }
1601 }
1602 
1603 class SimpleBitSet : public DomainIntVar::BitSet {
1604  public:
1605  SimpleBitSet(Solver* const s, int64 vmin, int64 vmax)
1606  : BitSet(s),
1607  bits_(nullptr),
1608  stamps_(nullptr),
1609  omin_(vmin),
1610  omax_(vmax),
1611  size_(vmax - vmin + 1),
1612  bsize_(BitLength64(size_.Value())) {
1613  CHECK(ClosedIntervalNoLargerThan(vmin, vmax, 0xFFFFFFFF))
1614  << "Bitset too large: [" << vmin << ", " << vmax << "]";
1615  bits_ = new uint64[bsize_];
1616  stamps_ = new uint64[bsize_];
1617  for (int i = 0; i < bsize_; ++i) {
1618  const int bs =
1619  (i == size_.Value() - 1) ? 63 - BitPos64(size_.Value()) : 0;
1620  bits_[i] = kAllBits64 >> bs;
1621  stamps_[i] = s->stamp() - 1;
1622  }
1623  }
1624 
1625  SimpleBitSet(Solver* const s, const std::vector<int64>& sorted_values,
1626  int64 vmin, int64 vmax)
1627  : BitSet(s),
1628  bits_(nullptr),
1629  stamps_(nullptr),
1630  omin_(vmin),
1631  omax_(vmax),
1632  size_(sorted_values.size()),
1633  bsize_(BitLength64(vmax - vmin + 1)) {
1634  CHECK(ClosedIntervalNoLargerThan(vmin, vmax, 0xFFFFFFFF))
1635  << "Bitset too large: [" << vmin << ", " << vmax << "]";
1636  bits_ = new uint64[bsize_];
1637  stamps_ = new uint64[bsize_];
1638  for (int i = 0; i < bsize_; ++i) {
1639  bits_[i] = uint64_t{0};
1640  stamps_[i] = s->stamp() - 1;
1641  }
1642  for (int i = 0; i < sorted_values.size(); ++i) {
1643  const int64 val = sorted_values[i];
1644  DCHECK(!bit(val));
1645  const int offset = BitOffset64(val - omin_);
1646  const int pos = BitPos64(val - omin_);
1647  bits_[offset] |= OneBit64(pos);
1648  }
1649  }
1650 
1651  ~SimpleBitSet() override {
1652  delete[] bits_;
1653  delete[] stamps_;
1654  }
1655 
1656  bool bit(int64 val) const { return IsBitSet64(bits_, val - omin_); }
1657 
1658  int64 ComputeNewMin(int64 nmin, int64 cmin, int64 cmax) override {
1659  DCHECK_GE(nmin, cmin);
1660  DCHECK_LE(nmin, cmax);
1661  DCHECK_LE(cmin, cmax);
1662  DCHECK_GE(cmin, omin_);
1663  DCHECK_LE(cmax, omax_);
1664  const int64 new_min =
1665  UnsafeLeastSignificantBitPosition64(bits_, nmin - omin_, cmax - omin_) +
1666  omin_;
1667  const uint64 removed_bits =
1668  BitCountRange64(bits_, cmin - omin_, new_min - omin_ - 1);
1669  size_.Add(solver_, -removed_bits);
1670  return new_min;
1671  }
1672 
1673  int64 ComputeNewMax(int64 nmax, int64 cmin, int64 cmax) override {
1674  DCHECK_GE(nmax, cmin);
1675  DCHECK_LE(nmax, cmax);
1676  DCHECK_LE(cmin, cmax);
1677  DCHECK_GE(cmin, omin_);
1678  DCHECK_LE(cmax, omax_);
1679  const int64 new_max =
1680  UnsafeMostSignificantBitPosition64(bits_, cmin - omin_, nmax - omin_) +
1681  omin_;
1682  const uint64 removed_bits =
1683  BitCountRange64(bits_, new_max - omin_ + 1, cmax - omin_);
1684  size_.Add(solver_, -removed_bits);
1685  return new_max;
1686  }
1687 
1688  bool SetValue(int64 val) override {
1689  DCHECK_GE(val, omin_);
1690  DCHECK_LE(val, omax_);
1691  if (bit(val)) {
1692  size_.SetValue(solver_, 1);
1693  return true;
1694  }
1695  return false;
1696  }
1697 
1698  bool Contains(int64 val) const override {
1699  DCHECK_GE(val, omin_);
1700  DCHECK_LE(val, omax_);
1701  return bit(val);
1702  }
1703 
1704  bool RemoveValue(int64 val) override {
1705  if (val < omin_ || val > omax_ || !bit(val)) {
1706  return false;
1707  }
1708  // Bitset.
1709  const int64 val_offset = val - omin_;
1710  const int offset = BitOffset64(val_offset);
1711  const uint64 current_stamp = solver_->stamp();
1712  if (stamps_[offset] < current_stamp) {
1713  stamps_[offset] = current_stamp;
1714  solver_->SaveValue(&bits_[offset]);
1715  }
1716  const int pos = BitPos64(val_offset);
1717  bits_[offset] &= ~OneBit64(pos);
1718  // Size.
1719  size_.Decr(solver_);
1720  // Holes.
1721  InitHoles();
1722  AddHole(val);
1723  return true;
1724  }
1725  uint64 Size() const override { return size_.Value(); }
1726 
1727  std::string DebugString() const override {
1728  std::string out;
1729  absl::StrAppendFormat(&out, "SimpleBitSet(%d..%d : ", omin_, omax_);
1730  for (int i = 0; i < bsize_; ++i) {
1731  absl::StrAppendFormat(&out, "%x", bits_[i]);
1732  }
1733  out += ")";
1734  return out;
1735  }
1736 
1737  void DelayRemoveValue(int64 val) override { removed_.push_back(val); }
1738 
1739  void ApplyRemovedValues(DomainIntVar* var) override {
1740  std::sort(removed_.begin(), removed_.end());
1741  for (std::vector<int64>::iterator it = removed_.begin();
1742  it != removed_.end(); ++it) {
1743  var->RemoveValue(*it);
1744  }
1745  }
1746 
1747  void ClearRemovedValues() override { removed_.clear(); }
1748 
1749  std::string pretty_DebugString(int64 min, int64 max) const override {
1750  std::string out;
1751  DCHECK(bit(min));
1752  DCHECK(bit(max));
1753  if (max != min) {
1754  int cumul = true;
1755  int64 start_cumul = min;
1756  for (int64 v = min + 1; v < max; ++v) {
1757  if (bit(v)) {
1758  if (!cumul) {
1759  cumul = true;
1760  start_cumul = v;
1761  }
1762  } else {
1763  if (cumul) {
1764  if (v == start_cumul + 1) {
1765  absl::StrAppendFormat(&out, "%d ", start_cumul);
1766  } else if (v == start_cumul + 2) {
1767  absl::StrAppendFormat(&out, "%d %d ", start_cumul, v - 1);
1768  } else {
1769  absl::StrAppendFormat(&out, "%d..%d ", start_cumul, v - 1);
1770  }
1771  cumul = false;
1772  }
1773  }
1774  }
1775  if (cumul) {
1776  if (max == start_cumul + 1) {
1777  absl::StrAppendFormat(&out, "%d %d", start_cumul, max);
1778  } else {
1779  absl::StrAppendFormat(&out, "%d..%d", start_cumul, max);
1780  }
1781  } else {
1782  absl::StrAppendFormat(&out, "%d", max);
1783  }
1784  } else {
1785  absl::StrAppendFormat(&out, "%d", min);
1786  }
1787  return out;
1788  }
1789 
1790  DomainIntVar::BitSetIterator* MakeIterator() override {
1791  return new DomainIntVar::BitSetIterator(bits_, omin_);
1792  }
1793 
1794  private:
1795  uint64* bits_;
1796  uint64* stamps_;
1797  const int64 omin_;
1798  const int64 omax_;
1799  NumericalRev<int64> size_;
1800  const int bsize_;
1801  std::vector<int64> removed_;
1802 };
1803 
1804 // This is a special case where the bitset fits into one 64 bit integer.
1805 // In that case, there are no offset to compute.
1806 // Overflows are caught by the robust ClosedIntervalNoLargerThan() method.
1807 class SmallBitSet : public DomainIntVar::BitSet {
1808  public:
1809  SmallBitSet(Solver* const s, int64 vmin, int64 vmax)
1810  : BitSet(s),
1811  bits_(uint64_t{0}),
1812  stamp_(s->stamp() - 1),
1813  omin_(vmin),
1814  omax_(vmax),
1815  size_(vmax - vmin + 1) {
1816  CHECK(ClosedIntervalNoLargerThan(vmin, vmax, 64)) << vmin << ", " << vmax;
1817  bits_ = OneRange64(0, size_.Value() - 1);
1818  }
1819 
1820  SmallBitSet(Solver* const s, const std::vector<int64>& sorted_values,
1821  int64 vmin, int64 vmax)
1822  : BitSet(s),
1823  bits_(uint64_t{0}),
1824  stamp_(s->stamp() - 1),
1825  omin_(vmin),
1826  omax_(vmax),
1827  size_(sorted_values.size()) {
1828  CHECK(ClosedIntervalNoLargerThan(vmin, vmax, 64)) << vmin << ", " << vmax;
1829  // We know the array is sorted and does not contains duplicate values.
1830  for (int i = 0; i < sorted_values.size(); ++i) {
1831  const int64 val = sorted_values[i];
1832  DCHECK_GE(val, vmin);
1833  DCHECK_LE(val, vmax);
1834  DCHECK(!IsBitSet64(&bits_, val - omin_));
1835  bits_ |= OneBit64(val - omin_);
1836  }
1837  }
1838 
1839  ~SmallBitSet() override {}
1840 
1841  bool bit(int64 val) const {
1842  DCHECK_GE(val, omin_);
1843  DCHECK_LE(val, omax_);
1844  return (bits_ & OneBit64(val - omin_)) != 0;
1845  }
1846 
1847  int64 ComputeNewMin(int64 nmin, int64 cmin, int64 cmax) override {
1848  DCHECK_GE(nmin, cmin);
1849  DCHECK_LE(nmin, cmax);
1850  DCHECK_LE(cmin, cmax);
1851  DCHECK_GE(cmin, omin_);
1852  DCHECK_LE(cmax, omax_);
1853  // We do not clean the bits between cmin and nmin.
1854  // But we use mask to look only at 'active' bits.
1855 
1856  // Create the mask and compute new bits
1857  const uint64 new_bits = bits_ & OneRange64(nmin - omin_, cmax - omin_);
1858  if (new_bits != uint64_t{0}) {
1859  // Compute new size and new min
1860  size_.SetValue(solver_, BitCount64(new_bits));
1861  if (bit(nmin)) { // Common case, the new min is inside the bitset
1862  return nmin;
1863  }
1864  return LeastSignificantBitPosition64(new_bits) + omin_;
1865  } else { // == 0 -> Fail()
1866  solver_->Fail();
1867  return kint64max;
1868  }
1869  }
1870 
1871  int64 ComputeNewMax(int64 nmax, int64 cmin, int64 cmax) override {
1872  DCHECK_GE(nmax, cmin);
1873  DCHECK_LE(nmax, cmax);
1874  DCHECK_LE(cmin, cmax);
1875  DCHECK_GE(cmin, omin_);
1876  DCHECK_LE(cmax, omax_);
1877  // We do not clean the bits between nmax and cmax.
1878  // But we use mask to look only at 'active' bits.
1879 
1880  // Create the mask and compute new_bits
1881  const uint64 new_bits = bits_ & OneRange64(cmin - omin_, nmax - omin_);
1882  if (new_bits != uint64_t{0}) {
1883  // Compute new size and new min
1884  size_.SetValue(solver_, BitCount64(new_bits));
1885  if (bit(nmax)) { // Common case, the new max is inside the bitset
1886  return nmax;
1887  }
1888  return MostSignificantBitPosition64(new_bits) + omin_;
1889  } else { // == 0 -> Fail()
1890  solver_->Fail();
1891  return kint64min;
1892  }
1893  }
1894 
1895  bool SetValue(int64 val) override {
1896  DCHECK_GE(val, omin_);
1897  DCHECK_LE(val, omax_);
1898  // We do not clean the bits. We will use masks to ignore the bits
1899  // that should have been cleaned.
1900  if (bit(val)) {
1901  size_.SetValue(solver_, 1);
1902  return true;
1903  }
1904  return false;
1905  }
1906 
1907  bool Contains(int64 val) const override {
1908  DCHECK_GE(val, omin_);
1909  DCHECK_LE(val, omax_);
1910  return bit(val);
1911  }
1912 
1913  bool RemoveValue(int64 val) override {
1914  DCHECK_GE(val, omin_);
1915  DCHECK_LE(val, omax_);
1916  if (bit(val)) {
1917  // Bitset.
1918  const uint64 current_stamp = solver_->stamp();
1919  if (stamp_ < current_stamp) {
1920  stamp_ = current_stamp;
1921  solver_->SaveValue(&bits_);
1922  }
1923  bits_ &= ~OneBit64(val - omin_);
1924  DCHECK(!bit(val));
1925  // Size.
1926  size_.Decr(solver_);
1927  // Holes.
1928  InitHoles();
1929  AddHole(val);
1930  return true;
1931  } else {
1932  return false;
1933  }
1934  }
1935 
1936  uint64 Size() const override { return size_.Value(); }
1937 
1938  std::string DebugString() const override {
1939  return absl::StrFormat("SmallBitSet(%d..%d : %llx)", omin_, omax_, bits_);
1940  }
1941 
1942  void DelayRemoveValue(int64 val) override {
1943  DCHECK_GE(val, omin_);
1944  DCHECK_LE(val, omax_);
1945  removed_.push_back(val);
1946  }
1947 
1948  void ApplyRemovedValues(DomainIntVar* var) override {
1949  std::sort(removed_.begin(), removed_.end());
1950  for (std::vector<int64>::iterator it = removed_.begin();
1951  it != removed_.end(); ++it) {
1952  var->RemoveValue(*it);
1953  }
1954  }
1955 
1956  void ClearRemovedValues() override { removed_.clear(); }
1957 
1958  std::string pretty_DebugString(int64 min, int64 max) const override {
1959  std::string out;
1960  DCHECK(bit(min));
1961  DCHECK(bit(max));
1962  if (max != min) {
1963  int cumul = true;
1964  int64 start_cumul = min;
1965  for (int64 v = min + 1; v < max; ++v) {
1966  if (bit(v)) {
1967  if (!cumul) {
1968  cumul = true;
1969  start_cumul = v;
1970  }
1971  } else {
1972  if (cumul) {
1973  if (v == start_cumul + 1) {
1974  absl::StrAppendFormat(&out, "%d ", start_cumul);
1975  } else if (v == start_cumul + 2) {
1976  absl::StrAppendFormat(&out, "%d %d ", start_cumul, v - 1);
1977  } else {
1978  absl::StrAppendFormat(&out, "%d..%d ", start_cumul, v - 1);
1979  }
1980  cumul = false;
1981  }
1982  }
1983  }
1984  if (cumul) {
1985  if (max == start_cumul + 1) {
1986  absl::StrAppendFormat(&out, "%d %d", start_cumul, max);
1987  } else {
1988  absl::StrAppendFormat(&out, "%d..%d", start_cumul, max);
1989  }
1990  } else {
1991  absl::StrAppendFormat(&out, "%d", max);
1992  }
1993  } else {
1994  absl::StrAppendFormat(&out, "%d", min);
1995  }
1996  return out;
1997  }
1998 
1999  DomainIntVar::BitSetIterator* MakeIterator() override {
2000  return new DomainIntVar::BitSetIterator(&bits_, omin_);
2001  }
2002 
2003  private:
2004  uint64 bits_;
2005  uint64 stamp_;
2006  const int64 omin_;
2007  const int64 omax_;
2008  NumericalRev<int64> size_;
2009  std::vector<int64> removed_;
2010 };
2011 
2012 class EmptyIterator : public IntVarIterator {
2013  public:
2014  ~EmptyIterator() override {}
2015  void Init() override {}
2016  bool Ok() const override { return false; }
2017  int64 Value() const override {
2018  LOG(FATAL) << "Should not be called";
2019  return 0LL;
2020  }
2021  void Next() override {}
2022 };
2023 
2024 class RangeIterator : public IntVarIterator {
2025  public:
2026  explicit RangeIterator(const IntVar* const var)
2027  : var_(var), min_(kint64max), max_(kint64min), current_(-1) {}
2028 
2029  ~RangeIterator() override {}
2030 
2031  void Init() override {
2032  min_ = var_->Min();
2033  max_ = var_->Max();
2034  current_ = min_;
2035  }
2036 
2037  bool Ok() const override { return current_ <= max_; }
2038 
2039  int64 Value() const override { return current_; }
2040 
2041  void Next() override { current_++; }
2042 
2043  private:
2044  const IntVar* const var_;
2045  int64 min_;
2046  int64 max_;
2047  int64 current_;
2048 };
2049 
2050 class DomainIntVarHoleIterator : public IntVarIterator {
2051  public:
2052  explicit DomainIntVarHoleIterator(const DomainIntVar* const v)
2053  : var_(v), bits_(nullptr), values_(nullptr), size_(0), index_(0) {}
2054 
2055  ~DomainIntVarHoleIterator() override {}
2056 
2057  void Init() override {
2058  bits_ = var_->bitset();
2059  if (bits_ != nullptr) {
2060  bits_->InitHoles();
2061  values_ = bits_->Holes().data();
2062  size_ = bits_->Holes().size();
2063  } else {
2064  values_ = nullptr;
2065  size_ = 0;
2066  }
2067  index_ = 0;
2068  }
2069 
2070  bool Ok() const override { return index_ < size_; }
2071 
2072  int64 Value() const override {
2073  DCHECK(bits_ != nullptr);
2074  DCHECK(index_ < size_);
2075  return values_[index_];
2076  }
2077 
2078  void Next() override { index_++; }
2079 
2080  private:
2081  const DomainIntVar* const var_;
2082  DomainIntVar::BitSet* bits_;
2083  const int64* values_;
2084  int size_;
2085  int index_;
2086 };
2087 
2088 class DomainIntVarDomainIterator : public IntVarIterator {
2089  public:
2090  explicit DomainIntVarDomainIterator(const DomainIntVar* const v,
2091  bool reversible)
2092  : var_(v),
2093  bitset_iterator_(nullptr),
2094  min_(kint64max),
2095  max_(kint64min),
2096  current_(-1),
2097  reversible_(reversible) {}
2098 
2099  ~DomainIntVarDomainIterator() override {
2100  if (!reversible_ && bitset_iterator_) {
2101  delete bitset_iterator_;
2102  }
2103  }
2104 
2105  void Init() override {
2106  if (var_->bitset() != nullptr && !var_->Bound()) {
2107  if (reversible_) {
2108  if (!bitset_iterator_) {
2109  Solver* const solver = var_->solver();
2110  solver->SaveValue(reinterpret_cast<void**>(&bitset_iterator_));
2111  bitset_iterator_ = solver->RevAlloc(var_->bitset()->MakeIterator());
2112  }
2113  } else {
2114  if (bitset_iterator_) {
2115  delete bitset_iterator_;
2116  }
2117  bitset_iterator_ = var_->bitset()->MakeIterator();
2118  }
2119  bitset_iterator_->Init(var_->Min(), var_->Max());
2120  } else {
2121  if (bitset_iterator_) {
2122  if (reversible_) {
2123  Solver* const solver = var_->solver();
2124  solver->SaveValue(reinterpret_cast<void**>(&bitset_iterator_));
2125  } else {
2126  delete bitset_iterator_;
2127  }
2128  bitset_iterator_ = nullptr;
2129  }
2130  min_ = var_->Min();
2131  max_ = var_->Max();
2132  current_ = min_;
2133  }
2134  }
2135 
2136  bool Ok() const override {
2137  return bitset_iterator_ ? bitset_iterator_->Ok() : (current_ <= max_);
2138  }
2139 
2140  int64 Value() const override {
2141  return bitset_iterator_ ? bitset_iterator_->Value() : current_;
2142  }
2143 
2144  void Next() override {
2145  if (bitset_iterator_) {
2146  bitset_iterator_->Next();
2147  } else {
2148  current_++;
2149  }
2150  }
2151 
2152  private:
2153  const DomainIntVar* const var_;
2154  DomainIntVar::BitSetIterator* bitset_iterator_;
2155  int64 min_;
2156  int64 max_;
2157  int64 current_;
2158  const bool reversible_;
2159 };
2160 
2161 class UnaryIterator : public IntVarIterator {
2162  public:
2163  UnaryIterator(const IntVar* const v, bool hole, bool reversible)
2164  : iterator_(hole ? v->MakeHoleIterator(reversible)
2165  : v->MakeDomainIterator(reversible)),
2166  reversible_(reversible) {}
2167 
2168  ~UnaryIterator() override {
2169  if (!reversible_) {
2170  delete iterator_;
2171  }
2172  }
2173 
2174  void Init() override { iterator_->Init(); }
2175 
2176  bool Ok() const override { return iterator_->Ok(); }
2177 
2178  void Next() override { iterator_->Next(); }
2179 
2180  protected:
2181  IntVarIterator* const iterator_;
2182  const bool reversible_;
2183 };
2184 
2185 DomainIntVar::DomainIntVar(Solver* const s, int64 vmin, int64 vmax,
2186  const std::string& name)
2187  : IntVar(s, name),
2188  min_(vmin),
2189  max_(vmax),
2190  old_min_(vmin),
2191  old_max_(vmax),
2192  new_min_(vmin),
2193  new_max_(vmax),
2194  handler_(this),
2195  in_process_(false),
2196  bits_(nullptr),
2197  value_watcher_(nullptr),
2198  bound_watcher_(nullptr) {}
2199 
2200 DomainIntVar::DomainIntVar(Solver* const s,
2201  const std::vector<int64>& sorted_values,
2202  const std::string& name)
2203  : IntVar(s, name),
2204  min_(kint64max),
2205  max_(kint64min),
2206  old_min_(kint64max),
2207  old_max_(kint64min),
2208  new_min_(kint64max),
2209  new_max_(kint64min),
2210  handler_(this),
2211  in_process_(false),
2212  bits_(nullptr),
2213  value_watcher_(nullptr),
2214  bound_watcher_(nullptr) {
2215  CHECK_GE(sorted_values.size(), 1);
2216  // We know that the vector is sorted and does not have duplicate values.
2217  const int64 vmin = sorted_values.front();
2218  const int64 vmax = sorted_values.back();
2219  const bool contiguous = vmax - vmin + 1 == sorted_values.size();
2220 
2221  min_.SetValue(solver(), vmin);
2222  old_min_ = vmin;
2223  new_min_ = vmin;
2224  max_.SetValue(solver(), vmax);
2225  old_max_ = vmax;
2226  new_max_ = vmax;
2227 
2228  if (!contiguous) {
2229  if (vmax - vmin + 1 < 65) {
2230  bits_ = solver()->RevAlloc(
2231  new SmallBitSet(solver(), sorted_values, vmin, vmax));
2232  } else {
2233  bits_ = solver()->RevAlloc(
2234  new SimpleBitSet(solver(), sorted_values, vmin, vmax));
2235  }
2236  }
2237 }
2238 
2239 DomainIntVar::~DomainIntVar() {}
2240 
2241 void DomainIntVar::SetMin(int64 m) {
2242  if (m <= min_.Value()) return;
2243  if (m > max_.Value()) solver()->Fail();
2244  if (in_process_) {
2245  if (m > new_min_) {
2246  new_min_ = m;
2247  if (new_min_ > new_max_) {
2248  solver()->Fail();
2249  }
2250  }
2251  } else {
2252  CheckOldMin();
2253  const int64 new_min =
2254  (bits_ == nullptr
2255  ? m
2256  : bits_->ComputeNewMin(m, min_.Value(), max_.Value()));
2257  min_.SetValue(solver(), new_min);
2258  if (min_.Value() > max_.Value()) {
2259  solver()->Fail();
2260  }
2261  Push();
2262  }
2263 }
2264 
2265 void DomainIntVar::SetMax(int64 m) {
2266  if (m >= max_.Value()) return;
2267  if (m < min_.Value()) solver()->Fail();
2268  if (in_process_) {
2269  if (m < new_max_) {
2270  new_max_ = m;
2271  if (new_max_ < new_min_) {
2272  solver()->Fail();
2273  }
2274  }
2275  } else {
2276  CheckOldMax();
2277  const int64 new_max =
2278  (bits_ == nullptr
2279  ? m
2280  : bits_->ComputeNewMax(m, min_.Value(), max_.Value()));
2281  max_.SetValue(solver(), new_max);
2282  if (min_.Value() > max_.Value()) {
2283  solver()->Fail();
2284  }
2285  Push();
2286  }
2287 }
2288 
2289 void DomainIntVar::SetRange(int64 mi, int64 ma) {
2290  if (mi == ma) {
2291  SetValue(mi);
2292  } else {
2293  if (mi > ma || mi > max_.Value() || ma < min_.Value()) solver()->Fail();
2294  if (mi <= min_.Value() && ma >= max_.Value()) return;
2295  if (in_process_) {
2296  if (ma < new_max_) {
2297  new_max_ = ma;
2298  }
2299  if (mi > new_min_) {
2300  new_min_ = mi;
2301  }
2302  if (new_min_ > new_max_) {
2303  solver()->Fail();
2304  }
2305  } else {
2306  if (mi > min_.Value()) {
2307  CheckOldMin();
2308  const int64 new_min =
2309  (bits_ == nullptr
2310  ? mi
2311  : bits_->ComputeNewMin(mi, min_.Value(), max_.Value()));
2312  min_.SetValue(solver(), new_min);
2313  }
2314  if (min_.Value() > ma) {
2315  solver()->Fail();
2316  }
2317  if (ma < max_.Value()) {
2318  CheckOldMax();
2319  const int64 new_max =
2320  (bits_ == nullptr
2321  ? ma
2322  : bits_->ComputeNewMax(ma, min_.Value(), max_.Value()));
2323  max_.SetValue(solver(), new_max);
2324  }
2325  if (min_.Value() > max_.Value()) {
2326  solver()->Fail();
2327  }
2328  Push();
2329  }
2330  }
2331 }
2332 
2333 void DomainIntVar::SetValue(int64 v) {
2334  if (v != min_.Value() || v != max_.Value()) {
2335  if (v < min_.Value() || v > max_.Value()) {
2336  solver()->Fail();
2337  }
2338  if (in_process_) {
2339  if (v > new_max_ || v < new_min_) {
2340  solver()->Fail();
2341  }
2342  new_min_ = v;
2343  new_max_ = v;
2344  } else {
2345  if (bits_ && !bits_->SetValue(v)) {
2346  solver()->Fail();
2347  }
2348  CheckOldMin();
2349  CheckOldMax();
2350  min_.SetValue(solver(), v);
2351  max_.SetValue(solver(), v);
2352  Push();
2353  }
2354  }
2355 }
2356 
2357 void DomainIntVar::RemoveValue(int64 v) {
2358  if (v < min_.Value() || v > max_.Value()) return;
2359  if (v == min_.Value()) {
2360  SetMin(v + 1);
2361  } else if (v == max_.Value()) {
2362  SetMax(v - 1);
2363  } else {
2364  if (bits_ == nullptr) {
2365  CreateBits();
2366  }
2367  if (in_process_) {
2368  if (v >= new_min_ && v <= new_max_ && bits_->Contains(v)) {
2369  bits_->DelayRemoveValue(v);
2370  }
2371  } else {
2372  if (bits_->RemoveValue(v)) {
2373  Push();
2374  }
2375  }
2376  }
2377 }
2378 
2379 void DomainIntVar::RemoveInterval(int64 l, int64 u) {
2380  if (l <= min_.Value()) {
2381  SetMin(u + 1);
2382  } else if (u >= max_.Value()) {
2383  SetMax(l - 1);
2384  } else {
2385  for (int64 v = l; v <= u; ++v) {
2386  RemoveValue(v);
2387  }
2388  }
2389 }
2390 
2391 void DomainIntVar::CreateBits() {
2392  solver()->SaveValue(reinterpret_cast<void**>(&bits_));
2393  if (max_.Value() - min_.Value() < 64) {
2394  bits_ = solver()->RevAlloc(
2395  new SmallBitSet(solver(), min_.Value(), max_.Value()));
2396  } else {
2397  bits_ = solver()->RevAlloc(
2398  new SimpleBitSet(solver(), min_.Value(), max_.Value()));
2399  }
2400 }
2401 
2402 void DomainIntVar::CleanInProcess() {
2403  in_process_ = false;
2404  if (bits_ != nullptr) {
2405  bits_->ClearHoles();
2406  }
2407 }
2408 
2409 void DomainIntVar::Push() {
2410  const bool in_process = in_process_;
2411  EnqueueVar(&handler_);
2412  CHECK_EQ(in_process, in_process_);
2413 }
2414 
2415 void DomainIntVar::Process() {
2416  CHECK(!in_process_);
2417  in_process_ = true;
2418  if (bits_ != nullptr) {
2419  bits_->ClearRemovedValues();
2420  }
2421  set_variable_to_clean_on_fail(this);
2422  new_min_ = min_.Value();
2423  new_max_ = max_.Value();
2424  const bool is_bound = min_.Value() == max_.Value();
2425  const bool range_changed =
2426  min_.Value() != OldMin() || max_.Value() != OldMax();
2427  // Process immediate demons.
2428  if (is_bound) {
2429  ExecuteAll(bound_demons_);
2430  }
2431  if (range_changed) {
2432  ExecuteAll(range_demons_);
2433  }
2434  ExecuteAll(domain_demons_);
2435 
2436  // Process delayed demons.
2437  if (is_bound) {
2438  EnqueueAll(delayed_bound_demons_);
2439  }
2440  if (range_changed) {
2441  EnqueueAll(delayed_range_demons_);
2442  }
2443  EnqueueAll(delayed_domain_demons_);
2444 
2445  // Everything went well if we arrive here. Let's clean the variable.
2446  set_variable_to_clean_on_fail(nullptr);
2447  CleanInProcess();
2448  old_min_ = min_.Value();
2449  old_max_ = max_.Value();
2450  if (min_.Value() < new_min_) {
2451  SetMin(new_min_);
2452  }
2453  if (max_.Value() > new_max_) {
2454  SetMax(new_max_);
2455  }
2456  if (bits_ != nullptr) {
2457  bits_->ApplyRemovedValues(this);
2458  }
2459 }
2460 
2461 #define COND_REV_ALLOC(rev, alloc) rev ? solver()->RevAlloc(alloc) : alloc;
2462 
2463 IntVarIterator* DomainIntVar::MakeHoleIterator(bool reversible) const {
2464  return COND_REV_ALLOC(reversible, new DomainIntVarHoleIterator(this));
2465 }
2466 
2467 IntVarIterator* DomainIntVar::MakeDomainIterator(bool reversible) const {
2468  return COND_REV_ALLOC(reversible,
2469  new DomainIntVarDomainIterator(this, reversible));
2470 }
2471 
2472 std::string DomainIntVar::DebugString() const {
2473  std::string out;
2474  const std::string& var_name = name();
2475  if (!var_name.empty()) {
2476  out = var_name + "(";
2477  } else {
2478  out = "DomainIntVar(";
2479  }
2480  if (min_.Value() == max_.Value()) {
2481  absl::StrAppendFormat(&out, "%d", min_.Value());
2482  } else if (bits_ != nullptr) {
2483  out.append(bits_->pretty_DebugString(min_.Value(), max_.Value()));
2484  } else {
2485  absl::StrAppendFormat(&out, "%d..%d", min_.Value(), max_.Value());
2486  }
2487  out += ")";
2488  return out;
2489 }
2490 
2491 // ----- Real Boolean Var -----
2492 
2493 class ConcreteBooleanVar : public BooleanVar {
2494  public:
2495  // Utility classes
2496  class Handler : public Demon {
2497  public:
2498  explicit Handler(ConcreteBooleanVar* const var) : Demon(), var_(var) {}
2499  ~Handler() override {}
2500  void Run(Solver* const s) override {
2501  s->GetPropagationMonitor()->StartProcessingIntegerVariable(var_);
2502  var_->Process();
2503  s->GetPropagationMonitor()->EndProcessingIntegerVariable(var_);
2504  }
2505  Solver::DemonPriority priority() const override {
2506  return Solver::VAR_PRIORITY;
2507  }
2508  std::string DebugString() const override {
2509  return absl::StrFormat("Handler(%s)", var_->DebugString());
2510  }
2511 
2512  private:
2513  ConcreteBooleanVar* const var_;
2514  };
2515 
2516  ConcreteBooleanVar(Solver* const s, const std::string& name)
2517  : BooleanVar(s, name), handler_(this) {}
2518 
2519  ~ConcreteBooleanVar() override {}
2520 
2521  void SetValue(int64 v) override {
2522  if (value_ == kUnboundBooleanVarValue) {
2523  if ((v & 0xfffffffffffffffe) == 0) {
2524  InternalSaveBooleanVarValue(solver(), this);
2525  value_ = static_cast<int>(v);
2526  EnqueueVar(&handler_);
2527  return;
2528  }
2529  } else if (v == value_) {
2530  return;
2531  }
2532  solver()->Fail();
2533  }
2534 
2535  void Process() {
2536  DCHECK_NE(value_, kUnboundBooleanVarValue);
2537  ExecuteAll(bound_demons_);
2538  for (SimpleRevFIFO<Demon*>::Iterator it(&delayed_bound_demons_); it.ok();
2539  ++it) {
2540  EnqueueDelayedDemon(*it);
2541  }
2542  }
2543 
2544  int64 OldMin() const override { return 0LL; }
2545  int64 OldMax() const override { return 1LL; }
2546  void RestoreValue() override { value_ = kUnboundBooleanVarValue; }
2547 
2548  private:
2549  Handler handler_;
2550 };
2551 
2552 // ----- IntConst -----
2553 
2554 class IntConst : public IntVar {
2555  public:
2556  IntConst(Solver* const s, int64 value, const std::string& name = "")
2557  : IntVar(s, name), value_(value) {}
2558  ~IntConst() override {}
2559 
2560  int64 Min() const override { return value_; }
2561  void SetMin(int64 m) override {
2562  if (m > value_) {
2563  solver()->Fail();
2564  }
2565  }
2566  int64 Max() const override { return value_; }
2567  void SetMax(int64 m) override {
2568  if (m < value_) {
2569  solver()->Fail();
2570  }
2571  }
2572  void SetRange(int64 l, int64 u) override {
2573  if (l > value_ || u < value_) {
2574  solver()->Fail();
2575  }
2576  }
2577  void SetValue(int64 v) override {
2578  if (v != value_) {
2579  solver()->Fail();
2580  }
2581  }
2582  bool Bound() const override { return true; }
2583  int64 Value() const override { return value_; }
2584  void RemoveValue(int64 v) override {
2585  if (v == value_) {
2586  solver()->Fail();
2587  }
2588  }
2589  void RemoveInterval(int64 l, int64 u) override {
2590  if (l <= value_ && value_ <= u) {
2591  solver()->Fail();
2592  }
2593  }
2594  void WhenBound(Demon* d) override {}
2595  void WhenRange(Demon* d) override {}
2596  void WhenDomain(Demon* d) override {}
2597  uint64 Size() const override { return 1; }
2598  bool Contains(int64 v) const override { return (v == value_); }
2599  IntVarIterator* MakeHoleIterator(bool reversible) const override {
2600  return COND_REV_ALLOC(reversible, new EmptyIterator());
2601  }
2602  IntVarIterator* MakeDomainIterator(bool reversible) const override {
2603  return COND_REV_ALLOC(reversible, new RangeIterator(this));
2604  }
2605  int64 OldMin() const override { return value_; }
2606  int64 OldMax() const override { return value_; }
2607  std::string DebugString() const override {
2608  std::string out;
2609  if (solver()->HasName(this)) {
2610  const std::string& var_name = name();
2611  absl::StrAppendFormat(&out, "%s(%d)", var_name, value_);
2612  } else {
2613  absl::StrAppendFormat(&out, "IntConst(%d)", value_);
2614  }
2615  return out;
2616  }
2617 
2618  int VarType() const override { return CONST_VAR; }
2619 
2620  IntVar* IsEqual(int64 constant) override {
2621  if (constant == value_) {
2622  return solver()->MakeIntConst(1);
2623  } else {
2624  return solver()->MakeIntConst(0);
2625  }
2626  }
2627 
2628  IntVar* IsDifferent(int64 constant) override {
2629  if (constant == value_) {
2630  return solver()->MakeIntConst(0);
2631  } else {
2632  return solver()->MakeIntConst(1);
2633  }
2634  }
2635 
2636  IntVar* IsGreaterOrEqual(int64 constant) override {
2637  return solver()->MakeIntConst(value_ >= constant);
2638  }
2639 
2640  IntVar* IsLessOrEqual(int64 constant) override {
2641  return solver()->MakeIntConst(value_ <= constant);
2642  }
2643 
2644  std::string name() const override {
2645  if (solver()->HasName(this)) {
2646  return PropagationBaseObject::name();
2647  } else {
2648  return absl::StrCat(value_);
2649  }
2650  }
2651 
2652  private:
2653  int64 value_;
2654 };
2655 
2656 // ----- x + c variable, optimized case -----
2657 
2658 class PlusCstVar : public IntVar {
2659  public:
2660  PlusCstVar(Solver* const s, IntVar* v, int64 c)
2661  : IntVar(s), var_(v), cst_(c) {}
2662 
2663  ~PlusCstVar() override {}
2664 
2665  void WhenRange(Demon* d) override { var_->WhenRange(d); }
2666 
2667  void WhenBound(Demon* d) override { var_->WhenBound(d); }
2668 
2669  void WhenDomain(Demon* d) override { var_->WhenDomain(d); }
2670 
2671  int64 OldMin() const override { return CapAdd(var_->OldMin(), cst_); }
2672 
2673  int64 OldMax() const override { return CapAdd(var_->OldMax(), cst_); }
2674 
2675  std::string DebugString() const override {
2676  if (HasName()) {
2677  return absl::StrFormat("%s(%s + %d)", name(), var_->DebugString(), cst_);
2678  } else {
2679  return absl::StrFormat("(%s + %d)", var_->DebugString(), cst_);
2680  }
2681  }
2682 
2683  int VarType() const override { return VAR_ADD_CST; }
2684 
2685  void Accept(ModelVisitor* const visitor) const override {
2686  visitor->VisitIntegerVariable(this, ModelVisitor::kSumOperation, cst_,
2687  var_);
2688  }
2689 
2690  IntVar* IsEqual(int64 constant) override {
2691  return var_->IsEqual(constant - cst_);
2692  }
2693 
2694  IntVar* IsDifferent(int64 constant) override {
2695  return var_->IsDifferent(constant - cst_);
2696  }
2697 
2698  IntVar* IsGreaterOrEqual(int64 constant) override {
2699  return var_->IsGreaterOrEqual(constant - cst_);
2700  }
2701 
2702  IntVar* IsLessOrEqual(int64 constant) override {
2703  return var_->IsLessOrEqual(constant - cst_);
2704  }
2705 
2706  IntVar* SubVar() const { return var_; }
2707 
2708  int64 Constant() const { return cst_; }
2709 
2710  protected:
2711  IntVar* const var_;
2712  const int64 cst_;
2713 };
2714 
2715 class PlusCstIntVar : public PlusCstVar {
2716  public:
2717  class PlusCstIntVarIterator : public UnaryIterator {
2718  public:
2719  PlusCstIntVarIterator(const IntVar* const v, int64 c, bool hole, bool rev)
2720  : UnaryIterator(v, hole, rev), cst_(c) {}
2721 
2722  ~PlusCstIntVarIterator() override {}
2723 
2724  int64 Value() const override { return iterator_->Value() + cst_; }
2725 
2726  private:
2727  const int64 cst_;
2728  };
2729 
2730  PlusCstIntVar(Solver* const s, IntVar* v, int64 c) : PlusCstVar(s, v, c) {}
2731 
2732  ~PlusCstIntVar() override {}
2733 
2734  int64 Min() const override { return var_->Min() + cst_; }
2735 
2736  void SetMin(int64 m) override { var_->SetMin(CapSub(m, cst_)); }
2737 
2738  int64 Max() const override { return var_->Max() + cst_; }
2739 
2740  void SetMax(int64 m) override { var_->SetMax(CapSub(m, cst_)); }
2741 
2742  void SetRange(int64 l, int64 u) override {
2743  var_->SetRange(CapSub(l, cst_), CapSub(u, cst_));
2744  }
2745 
2746  void SetValue(int64 v) override { var_->SetValue(v - cst_); }
2747 
2748  int64 Value() const override { return var_->Value() + cst_; }
2749 
2750  bool Bound() const override { return var_->Bound(); }
2751 
2752  void RemoveValue(int64 v) override { var_->RemoveValue(v - cst_); }
2753 
2754  void RemoveInterval(int64 l, int64 u) override {
2755  var_->RemoveInterval(l - cst_, u - cst_);
2756  }
2757 
2758  uint64 Size() const override { return var_->Size(); }
2759 
2760  bool Contains(int64 v) const override { return var_->Contains(v - cst_); }
2761 
2762  IntVarIterator* MakeHoleIterator(bool reversible) const override {
2763  return COND_REV_ALLOC(
2764  reversible, new PlusCstIntVarIterator(var_, cst_, true, reversible));
2765  }
2766  IntVarIterator* MakeDomainIterator(bool reversible) const override {
2767  return COND_REV_ALLOC(
2768  reversible, new PlusCstIntVarIterator(var_, cst_, false, reversible));
2769  }
2770 };
2771 
2772 class PlusCstDomainIntVar : public PlusCstVar {
2773  public:
2774  class PlusCstDomainIntVarIterator : public UnaryIterator {
2775  public:
2776  PlusCstDomainIntVarIterator(const IntVar* const v, int64 c, bool hole,
2777  bool reversible)
2778  : UnaryIterator(v, hole, reversible), cst_(c) {}
2779 
2780  ~PlusCstDomainIntVarIterator() override {}
2781 
2782  int64 Value() const override { return iterator_->Value() + cst_; }
2783 
2784  private:
2785  const int64 cst_;
2786  };
2787 
2788  PlusCstDomainIntVar(Solver* const s, DomainIntVar* v, int64 c)
2789  : PlusCstVar(s, v, c) {}
2790 
2791  ~PlusCstDomainIntVar() override {}
2792 
2793  int64 Min() const override;
2794  void SetMin(int64 m) override;
2795  int64 Max() const override;
2796  void SetMax(int64 m) override;
2797  void SetRange(int64 l, int64 u) override;
2798  void SetValue(int64 v) override;
2799  bool Bound() const override;
2800  int64 Value() const override;
2801  void RemoveValue(int64 v) override;
2802  void RemoveInterval(int64 l, int64 u) override;
2803  uint64 Size() const override;
2804  bool Contains(int64 v) const override;
2805 
2806  DomainIntVar* domain_int_var() const {
2807  return reinterpret_cast<DomainIntVar*>(var_);
2808  }
2809 
2810  IntVarIterator* MakeHoleIterator(bool reversible) const override {
2811  return COND_REV_ALLOC(reversible, new PlusCstDomainIntVarIterator(
2812  var_, cst_, true, reversible));
2813  }
2814  IntVarIterator* MakeDomainIterator(bool reversible) const override {
2815  return COND_REV_ALLOC(reversible, new PlusCstDomainIntVarIterator(
2816  var_, cst_, false, reversible));
2817  }
2818 };
2819 
2820 int64 PlusCstDomainIntVar::Min() const {
2821  return domain_int_var()->min_.Value() + cst_;
2822 }
2823 
2824 void PlusCstDomainIntVar::SetMin(int64 m) {
2825  domain_int_var()->DomainIntVar::SetMin(m - cst_);
2826 }
2827 
2828 int64 PlusCstDomainIntVar::Max() const {
2829  return domain_int_var()->max_.Value() + cst_;
2830 }
2831 
2832 void PlusCstDomainIntVar::SetMax(int64 m) {
2833  domain_int_var()->DomainIntVar::SetMax(m - cst_);
2834 }
2835 
2836 void PlusCstDomainIntVar::SetRange(int64 l, int64 u) {
2837  domain_int_var()->DomainIntVar::SetRange(l - cst_, u - cst_);
2838 }
2839 
2840 void PlusCstDomainIntVar::SetValue(int64 v) {
2841  domain_int_var()->DomainIntVar::SetValue(v - cst_);
2842 }
2843 
2844 bool PlusCstDomainIntVar::Bound() const {
2845  return domain_int_var()->min_.Value() == domain_int_var()->max_.Value();
2846 }
2847 
2849  CHECK_EQ(domain_int_var()->min_.Value(), domain_int_var()->max_.Value())
2850  << " variable is not bound";
2851  return domain_int_var()->min_.Value() + cst_;
2852 }
2853 
2854 void PlusCstDomainIntVar::RemoveValue(int64 v) {
2855  domain_int_var()->DomainIntVar::RemoveValue(v - cst_);
2856 }
2857 
2858 void PlusCstDomainIntVar::RemoveInterval(int64 l, int64 u) {
2859  domain_int_var()->DomainIntVar::RemoveInterval(l - cst_, u - cst_);
2860 }
2861 
2862 uint64 PlusCstDomainIntVar::Size() const {
2863  return domain_int_var()->DomainIntVar::Size();
2864 }
2865 
2866 bool PlusCstDomainIntVar::Contains(int64 v) const {
2867  return domain_int_var()->DomainIntVar::Contains(v - cst_);
2868 }
2869 
2870 // c - x variable, optimized case
2871 
2872 class SubCstIntVar : public IntVar {
2873  public:
2874  class SubCstIntVarIterator : public UnaryIterator {
2875  public:
2876  SubCstIntVarIterator(const IntVar* const v, int64 c, bool hole, bool rev)
2877  : UnaryIterator(v, hole, rev), cst_(c) {}
2878  ~SubCstIntVarIterator() override {}
2879 
2880  int64 Value() const override { return cst_ - iterator_->Value(); }
2881 
2882  private:
2883  const int64 cst_;
2884  };
2885 
2886  SubCstIntVar(Solver* const s, IntVar* v, int64 c);
2887  ~SubCstIntVar() override;
2888 
2889  int64 Min() const override;
2890  void SetMin(int64 m) override;
2891  int64 Max() const override;
2892  void SetMax(int64 m) override;
2893  void SetRange(int64 l, int64 u) override;
2894  void SetValue(int64 v) override;
2895  bool Bound() const override;
2896  int64 Value() const override;
2897  void RemoveValue(int64 v) override;
2898  void RemoveInterval(int64 l, int64 u) override;
2899  uint64 Size() const override;
2900  bool Contains(int64 v) const override;
2901  void WhenRange(Demon* d) override;
2902  void WhenBound(Demon* d) override;
2903  void WhenDomain(Demon* d) override;
2904  IntVarIterator* MakeHoleIterator(bool reversible) const override {
2905  return COND_REV_ALLOC(
2906  reversible, new SubCstIntVarIterator(var_, cst_, true, reversible));
2907  }
2908  IntVarIterator* MakeDomainIterator(bool reversible) const override {
2909  return COND_REV_ALLOC(
2910  reversible, new SubCstIntVarIterator(var_, cst_, false, reversible));
2911  }
2912  int64 OldMin() const override { return CapSub(cst_, var_->OldMax()); }
2913  int64 OldMax() const override { return CapSub(cst_, var_->OldMin()); }
2914  std::string DebugString() const override;
2915  std::string name() const override;
2916  int VarType() const override { return CST_SUB_VAR; }
2917 
2918  void Accept(ModelVisitor* const visitor) const override {
2919  visitor->VisitIntegerVariable(this, ModelVisitor::kDifferenceOperation,
2920  cst_, var_);
2921  }
2922 
2923  IntVar* IsEqual(int64 constant) override {
2924  return var_->IsEqual(cst_ - constant);
2925  }
2926 
2927  IntVar* IsDifferent(int64 constant) override {
2928  return var_->IsDifferent(cst_ - constant);
2929  }
2930 
2931  IntVar* IsGreaterOrEqual(int64 constant) override {
2932  return var_->IsLessOrEqual(cst_ - constant);
2933  }
2934 
2935  IntVar* IsLessOrEqual(int64 constant) override {
2936  return var_->IsGreaterOrEqual(cst_ - constant);
2937  }
2938 
2939  IntVar* SubVar() const { return var_; }
2940  int64 Constant() const { return cst_; }
2941 
2942  private:
2943  IntVar* const var_;
2944  const int64 cst_;
2945 };
2946 
2947 SubCstIntVar::SubCstIntVar(Solver* const s, IntVar* v, int64 c)
2948  : IntVar(s), var_(v), cst_(c) {}
2949 
2950 SubCstIntVar::~SubCstIntVar() {}
2951 
2952 int64 SubCstIntVar::Min() const { return cst_ - var_->Max(); }
2953 
2954 void SubCstIntVar::SetMin(int64 m) { var_->SetMax(CapSub(cst_, m)); }
2955 
2956 int64 SubCstIntVar::Max() const { return cst_ - var_->Min(); }
2957 
2958 void SubCstIntVar::SetMax(int64 m) { var_->SetMin(CapSub(cst_, m)); }
2959 
2960 void SubCstIntVar::SetRange(int64 l, int64 u) {
2961  var_->SetRange(CapSub(cst_, u), CapSub(cst_, l));
2962 }
2963 
2964 void SubCstIntVar::SetValue(int64 v) { var_->SetValue(cst_ - v); }
2965 
2966 bool SubCstIntVar::Bound() const { return var_->Bound(); }
2967 
2968 void SubCstIntVar::WhenRange(Demon* d) { var_->WhenRange(d); }
2969 
2970 int64 SubCstIntVar::Value() const { return cst_ - var_->Value(); }
2971 
2972 void SubCstIntVar::RemoveValue(int64 v) { var_->RemoveValue(cst_ - v); }
2973 
2974 void SubCstIntVar::RemoveInterval(int64 l, int64 u) {
2975  var_->RemoveInterval(cst_ - u, cst_ - l);
2976 }
2977 
2978 void SubCstIntVar::WhenBound(Demon* d) { var_->WhenBound(d); }
2979 
2980 void SubCstIntVar::WhenDomain(Demon* d) { var_->WhenDomain(d); }
2981 
2982 uint64 SubCstIntVar::Size() const { return var_->Size(); }
2983 
2984 bool SubCstIntVar::Contains(int64 v) const { return var_->Contains(cst_ - v); }
2985 
2986 std::string SubCstIntVar::DebugString() const {
2987  if (cst_ == 1 && var_->VarType() == BOOLEAN_VAR) {
2988  return absl::StrFormat("Not(%s)", var_->DebugString());
2989  } else {
2990  return absl::StrFormat("(%d - %s)", cst_, var_->DebugString());
2991  }
2992 }
2993 
2994 std::string SubCstIntVar::name() const {
2995  if (solver()->HasName(this)) {
2996  return PropagationBaseObject::name();
2997  } else if (cst_ == 1 && var_->VarType() == BOOLEAN_VAR) {
2998  return absl::StrFormat("Not(%s)", var_->name());
2999  } else {
3000  return absl::StrFormat("(%d - %s)", cst_, var_->name());
3001  }
3002 }
3003 
3004 // -x variable, optimized case
3005 
3006 class OppIntVar : public IntVar {
3007  public:
3008  class OppIntVarIterator : public UnaryIterator {
3009  public:
3010  OppIntVarIterator(const IntVar* const v, bool hole, bool reversible)
3011  : UnaryIterator(v, hole, reversible) {}
3012  ~OppIntVarIterator() override {}
3013 
3014  int64 Value() const override { return -iterator_->Value(); }
3015  };
3016 
3017  OppIntVar(Solver* const s, IntVar* v);
3018  ~OppIntVar() override;
3019 
3020  int64 Min() const override;
3021  void SetMin(int64 m) override;
3022  int64 Max() const override;
3023  void SetMax(int64 m) override;
3024  void SetRange(int64 l, int64 u) override;
3025  void SetValue(int64 v) override;
3026  bool Bound() const override;
3027  int64 Value() const override;
3028  void RemoveValue(int64 v) override;
3029  void RemoveInterval(int64 l, int64 u) override;
3030  uint64 Size() const override;
3031  bool Contains(int64 v) const override;
3032  void WhenRange(Demon* d) override;
3033  void WhenBound(Demon* d) override;
3034  void WhenDomain(Demon* d) override;
3035  IntVarIterator* MakeHoleIterator(bool reversible) const override {
3036  return COND_REV_ALLOC(reversible,
3037  new OppIntVarIterator(var_, true, reversible));
3038  }
3039  IntVarIterator* MakeDomainIterator(bool reversible) const override {
3040  return COND_REV_ALLOC(reversible,
3041  new OppIntVarIterator(var_, false, reversible));
3042  }
3043  int64 OldMin() const override { return CapOpp(var_->OldMax()); }
3044  int64 OldMax() const override { return CapOpp(var_->OldMin()); }
3045  std::string DebugString() const override;
3046  int VarType() const override { return OPP_VAR; }
3047 
3048  void Accept(ModelVisitor* const visitor) const override {
3049  visitor->VisitIntegerVariable(this, ModelVisitor::kDifferenceOperation, 0,
3050  var_);
3051  }
3052 
3053  IntVar* IsEqual(int64 constant) override { return var_->IsEqual(-constant); }
3054 
3055  IntVar* IsDifferent(int64 constant) override {
3056  return var_->IsDifferent(-constant);
3057  }
3058 
3059  IntVar* IsGreaterOrEqual(int64 constant) override {
3060  return var_->IsLessOrEqual(-constant);
3061  }
3062 
3063  IntVar* IsLessOrEqual(int64 constant) override {
3064  return var_->IsGreaterOrEqual(-constant);
3065  }
3066 
3067  IntVar* SubVar() const { return var_; }
3068 
3069  private:
3070  IntVar* const var_;
3071 };
3072 
3073 OppIntVar::OppIntVar(Solver* const s, IntVar* v) : IntVar(s), var_(v) {}
3074 
3075 OppIntVar::~OppIntVar() {}
3076 
3077 int64 OppIntVar::Min() const { return -var_->Max(); }
3078 
3079 void OppIntVar::SetMin(int64 m) { var_->SetMax(CapOpp(m)); }
3080 
3081 int64 OppIntVar::Max() const { return -var_->Min(); }
3082 
3083 void OppIntVar::SetMax(int64 m) { var_->SetMin(CapOpp(m)); }
3084 
3085 void OppIntVar::SetRange(int64 l, int64 u) {
3086  var_->SetRange(CapOpp(u), CapOpp(l));
3087 }
3088 
3089 void OppIntVar::SetValue(int64 v) { var_->SetValue(CapOpp(v)); }
3090 
3091 bool OppIntVar::Bound() const { return var_->Bound(); }
3092 
3093 void OppIntVar::WhenRange(Demon* d) { var_->WhenRange(d); }
3094 
3095 int64 OppIntVar::Value() const { return -var_->Value(); }
3096 
3097 void OppIntVar::RemoveValue(int64 v) { var_->RemoveValue(-v); }
3098 
3099 void OppIntVar::RemoveInterval(int64 l, int64 u) {
3100  var_->RemoveInterval(-u, -l);
3101 }
3102 
3103 void OppIntVar::WhenBound(Demon* d) { var_->WhenBound(d); }
3104 
3105 void OppIntVar::WhenDomain(Demon* d) { var_->WhenDomain(d); }
3106 
3107 uint64 OppIntVar::Size() const { return var_->Size(); }
3108 
3109 bool OppIntVar::Contains(int64 v) const { return var_->Contains(-v); }
3110 
3111 std::string OppIntVar::DebugString() const {
3112  return absl::StrFormat("-(%s)", var_->DebugString());
3113 }
3114 
3115 // ----- Utility functions -----
3116 
3117 // x * c variable, optimized case
3118 
3119 class TimesCstIntVar : public IntVar {
3120  public:
3121  TimesCstIntVar(Solver* const s, IntVar* v, int64 c)
3122  : IntVar(s), var_(v), cst_(c) {}
3123  ~TimesCstIntVar() override {}
3124 
3125  IntVar* SubVar() const { return var_; }
3126  int64 Constant() const { return cst_; }
3127 
3128  void Accept(ModelVisitor* const visitor) const override {
3129  visitor->VisitIntegerVariable(this, ModelVisitor::kProductOperation, cst_,
3130  var_);
3131  }
3132 
3133  IntVar* IsEqual(int64 constant) override {
3134  if (constant % cst_ == 0) {
3135  return var_->IsEqual(constant / cst_);
3136  } else {
3137  return solver()->MakeIntConst(0);
3138  }
3139  }
3140 
3141  IntVar* IsDifferent(int64 constant) override {
3142  if (constant % cst_ == 0) {
3143  return var_->IsDifferent(constant / cst_);
3144  } else {
3145  return solver()->MakeIntConst(1);
3146  }
3147  }
3148 
3149  IntVar* IsGreaterOrEqual(int64 constant) override {
3150  if (cst_ > 0) {
3151  return var_->IsGreaterOrEqual(PosIntDivUp(constant, cst_));
3152  } else {
3153  return var_->IsLessOrEqual(PosIntDivDown(-constant, -cst_));
3154  }
3155  }
3156 
3157  IntVar* IsLessOrEqual(int64 constant) override {
3158  if (cst_ > 0) {
3159  return var_->IsLessOrEqual(PosIntDivDown(constant, cst_));
3160  } else {
3161  return var_->IsGreaterOrEqual(PosIntDivUp(-constant, -cst_));
3162  }
3163  }
3164 
3165  std::string DebugString() const override {
3166  return absl::StrFormat("(%s * %d)", var_->DebugString(), cst_);
3167  }
3168 
3169  int VarType() const override { return VAR_TIMES_CST; }
3170 
3171  protected:
3172  IntVar* const var_;
3173  const int64 cst_;
3174 };
3175 
3176 class TimesPosCstIntVar : public TimesCstIntVar {
3177  public:
3178  class TimesPosCstIntVarIterator : public UnaryIterator {
3179  public:
3180  TimesPosCstIntVarIterator(const IntVar* const v, int64 c, bool hole,
3181  bool reversible)
3182  : UnaryIterator(v, hole, reversible), cst_(c) {}
3183  ~TimesPosCstIntVarIterator() override {}
3184 
3185  int64 Value() const override { return iterator_->Value() * cst_; }
3186 
3187  private:
3188  const int64 cst_;
3189  };
3190 
3191  TimesPosCstIntVar(Solver* const s, IntVar* v, int64 c);
3192  ~TimesPosCstIntVar() override;
3193 
3194  int64 Min() const override;
3195  void SetMin(int64 m) override;
3196  int64 Max() const override;
3197  void SetMax(int64 m) override;
3198  void SetRange(int64 l, int64 u) override;
3199  void SetValue(int64 v) override;
3200  bool Bound() const override;
3201  int64 Value() const override;
3202  void RemoveValue(int64 v) override;
3203  void RemoveInterval(int64 l, int64 u) override;
3204  uint64 Size() const override;
3205  bool Contains(int64 v) const override;
3206  void WhenRange(Demon* d) override;
3207  void WhenBound(Demon* d) override;
3208  void WhenDomain(Demon* d) override;
3209  IntVarIterator* MakeHoleIterator(bool reversible) const override {
3210  return COND_REV_ALLOC(reversible, new TimesPosCstIntVarIterator(
3211  var_, cst_, true, reversible));
3212  }
3213  IntVarIterator* MakeDomainIterator(bool reversible) const override {
3214  return COND_REV_ALLOC(reversible, new TimesPosCstIntVarIterator(
3215  var_, cst_, false, reversible));
3216  }
3217  int64 OldMin() const override { return CapProd(var_->OldMin(), cst_); }
3218  int64 OldMax() const override { return CapProd(var_->OldMax(), cst_); }
3219 };
3220 
3221 // ----- TimesPosCstIntVar -----
3222 
3223 TimesPosCstIntVar::TimesPosCstIntVar(Solver* const s, IntVar* v, int64 c)
3224  : TimesCstIntVar(s, v, c) {}
3225 
3226 TimesPosCstIntVar::~TimesPosCstIntVar() {}
3227 
3228 int64 TimesPosCstIntVar::Min() const { return CapProd(var_->Min(), cst_); }
3229 
3230 void TimesPosCstIntVar::SetMin(int64 m) {
3231  if (m != kint64min) {
3232  var_->SetMin(PosIntDivUp(m, cst_));
3233  }
3234 }
3235 
3236 int64 TimesPosCstIntVar::Max() const { return CapProd(var_->Max(), cst_); }
3237 
3238 void TimesPosCstIntVar::SetMax(int64 m) {
3239  if (m != kint64max) {
3240  var_->SetMax(PosIntDivDown(m, cst_));
3241  }
3242 }
3243 
3244 void TimesPosCstIntVar::SetRange(int64 l, int64 u) {
3245  var_->SetRange(PosIntDivUp(l, cst_), PosIntDivDown(u, cst_));
3246 }
3247 
3248 void TimesPosCstIntVar::SetValue(int64 v) {
3249  if (v % cst_ != 0) {
3250  solver()->Fail();
3251  }
3252  var_->SetValue(v / cst_);
3253 }
3254 
3255 bool TimesPosCstIntVar::Bound() const { return var_->Bound(); }
3256 
3257 void TimesPosCstIntVar::WhenRange(Demon* d) { var_->WhenRange(d); }
3258 
3259 int64 TimesPosCstIntVar::Value() const { return CapProd(var_->Value(), cst_); }
3260 
3261 void TimesPosCstIntVar::RemoveValue(int64 v) {
3262  if (v % cst_ == 0) {
3263  var_->RemoveValue(v / cst_);
3264  }
3265 }
3266 
3267 void TimesPosCstIntVar::RemoveInterval(int64 l, int64 u) {
3268  for (int64 v = l; v <= u; ++v) {
3269  RemoveValue(v);
3270  }
3271  // TODO(user) : Improve me
3272 }
3273 
3274 void TimesPosCstIntVar::WhenBound(Demon* d) { var_->WhenBound(d); }
3275 
3276 void TimesPosCstIntVar::WhenDomain(Demon* d) { var_->WhenDomain(d); }
3277 
3278 uint64 TimesPosCstIntVar::Size() const { return var_->Size(); }
3279 
3280 bool TimesPosCstIntVar::Contains(int64 v) const {
3281  return (v % cst_ == 0 && var_->Contains(v / cst_));
3282 }
3283 
3284 // b * c variable, optimized case
3285 
3286 class TimesPosCstBoolVar : public TimesCstIntVar {
3287  public:
3288  class TimesPosCstBoolVarIterator : public UnaryIterator {
3289  public:
3290  // TODO(user) : optimize this.
3291  TimesPosCstBoolVarIterator(const IntVar* const v, int64 c, bool hole,
3292  bool reversible)
3293  : UnaryIterator(v, hole, reversible), cst_(c) {}
3294  ~TimesPosCstBoolVarIterator() override {}
3295 
3296  int64 Value() const override { return iterator_->Value() * cst_; }
3297 
3298  private:
3299  const int64 cst_;
3300  };
3301 
3302  TimesPosCstBoolVar(Solver* const s, BooleanVar* v, int64 c);
3303  ~TimesPosCstBoolVar() override;
3304 
3305  int64 Min() const override;
3306  void SetMin(int64 m) override;
3307  int64 Max() const override;
3308  void SetMax(int64 m) override;
3309  void SetRange(int64 l, int64 u) override;
3310  void SetValue(int64 v) override;
3311  bool Bound() const override;
3312  int64 Value() const override;
3313  void RemoveValue(int64 v) override;
3314  void RemoveInterval(int64 l, int64 u) override;
3315  uint64 Size() const override;
3316  bool Contains(int64 v) const override;
3317  void WhenRange(Demon* d) override;
3318  void WhenBound(Demon* d) override;
3319  void WhenDomain(Demon* d) override;
3320  IntVarIterator* MakeHoleIterator(bool reversible) const override {
3321  return COND_REV_ALLOC(reversible, new EmptyIterator());
3322  }
3323  IntVarIterator* MakeDomainIterator(bool reversible) const override {
3324  return COND_REV_ALLOC(
3325  reversible,
3326  new TimesPosCstBoolVarIterator(boolean_var(), cst_, false, reversible));
3327  }
3328  int64 OldMin() const override { return 0; }
3329  int64 OldMax() const override { return cst_; }
3330 
3331  BooleanVar* boolean_var() const {
3332  return reinterpret_cast<BooleanVar*>(var_);
3333  }
3334 };
3335 
3336 // ----- TimesPosCstBoolVar -----
3337 
3338 TimesPosCstBoolVar::TimesPosCstBoolVar(Solver* const s, BooleanVar* v, int64 c)
3339  : TimesCstIntVar(s, v, c) {}
3340 
3341 TimesPosCstBoolVar::~TimesPosCstBoolVar() {}
3342 
3343 int64 TimesPosCstBoolVar::Min() const {
3344  return (boolean_var()->RawValue() == 1) * cst_;
3345 }
3346 
3347 void TimesPosCstBoolVar::SetMin(int64 m) {
3348  if (m > cst_) {
3349  solver()->Fail();
3350  } else if (m > 0) {
3351  boolean_var()->SetMin(1);
3352  }
3353 }
3354 
3355 int64 TimesPosCstBoolVar::Max() const {
3356  return (boolean_var()->RawValue() != 0) * cst_;
3357 }
3358 
3359 void TimesPosCstBoolVar::SetMax(int64 m) {
3360  if (m < 0) {
3361  solver()->Fail();
3362  } else if (m < cst_) {
3363  boolean_var()->SetMax(0);
3364  }
3365 }
3366 
3367 void TimesPosCstBoolVar::SetRange(int64 l, int64 u) {
3368  if (u < 0 || l > cst_ || l > u) {
3369  solver()->Fail();
3370  }
3371  if (l > 0) {
3372  boolean_var()->SetMin(1);
3373  } else if (u < cst_) {
3374  boolean_var()->SetMax(0);
3375  }
3376 }
3377 
3378 void TimesPosCstBoolVar::SetValue(int64 v) {
3379  if (v == 0) {
3380  boolean_var()->SetValue(0);
3381  } else if (v == cst_) {
3382  boolean_var()->SetValue(1);
3383  } else {
3384  solver()->Fail();
3385  }
3386 }
3387 
3388 bool TimesPosCstBoolVar::Bound() const {
3389  return boolean_var()->RawValue() != BooleanVar::kUnboundBooleanVarValue;
3390 }
3391 
3392 void TimesPosCstBoolVar::WhenRange(Demon* d) { boolean_var()->WhenRange(d); }
3393 
3395  CHECK_NE(boolean_var()->RawValue(), BooleanVar::kUnboundBooleanVarValue)
3396  << " variable is not bound";
3397  return boolean_var()->RawValue() * cst_;
3398 }
3399 
3400 void TimesPosCstBoolVar::RemoveValue(int64 v) {
3401  if (v == 0) {
3402  boolean_var()->RemoveValue(0);
3403  } else if (v == cst_) {
3404  boolean_var()->RemoveValue(1);
3405  }
3406 }
3407 
3408 void TimesPosCstBoolVar::RemoveInterval(int64 l, int64 u) {
3409  if (l <= 0 && u >= 0) {
3410  boolean_var()->RemoveValue(0);
3411  }
3412  if (l <= cst_ && u >= cst_) {
3413  boolean_var()->RemoveValue(1);
3414  }
3415 }
3416 
3417 void TimesPosCstBoolVar::WhenBound(Demon* d) { boolean_var()->WhenBound(d); }
3418 
3419 void TimesPosCstBoolVar::WhenDomain(Demon* d) { boolean_var()->WhenDomain(d); }
3420 
3421 uint64 TimesPosCstBoolVar::Size() const {
3422  return (1 +
3423  (boolean_var()->RawValue() == BooleanVar::kUnboundBooleanVarValue));
3424 }
3425 
3426 bool TimesPosCstBoolVar::Contains(int64 v) const {
3427  if (v == 0) {
3428  return boolean_var()->RawValue() != 1;
3429  } else if (v == cst_) {
3430  return boolean_var()->RawValue() != 0;
3431  }
3432  return false;
3433 }
3434 
3435 // TimesNegCstIntVar
3436 
3437 class TimesNegCstIntVar : public TimesCstIntVar {
3438  public:
3439  class TimesNegCstIntVarIterator : public UnaryIterator {
3440  public:
3441  TimesNegCstIntVarIterator(const IntVar* const v, int64 c, bool hole,
3442  bool reversible)
3443  : UnaryIterator(v, hole, reversible), cst_(c) {}
3444  ~TimesNegCstIntVarIterator() override {}
3445 
3446  int64 Value() const override { return iterator_->Value() * cst_; }
3447 
3448  private:
3449  const int64 cst_;
3450  };
3451 
3452  TimesNegCstIntVar(Solver* const s, IntVar* v, int64 c);
3453  ~TimesNegCstIntVar() override;
3454 
3455  int64 Min() const override;
3456  void SetMin(int64 m) override;
3457  int64 Max() const override;
3458  void SetMax(int64 m) override;
3459  void SetRange(int64 l, int64 u) override;
3460  void SetValue(int64 v) override;
3461  bool Bound() const override;
3462  int64 Value() const override;
3463  void RemoveValue(int64 v) override;
3464  void RemoveInterval(int64 l, int64 u) override;
3465  uint64 Size() const override;
3466  bool Contains(int64 v) const override;
3467  void WhenRange(Demon* d) override;
3468  void WhenBound(Demon* d) override;
3469  void WhenDomain(Demon* d) override;
3470  IntVarIterator* MakeHoleIterator(bool reversible) const override {
3471  return COND_REV_ALLOC(reversible, new TimesNegCstIntVarIterator(
3472  var_, cst_, true, reversible));
3473  }
3474  IntVarIterator* MakeDomainIterator(bool reversible) const override {
3475  return COND_REV_ALLOC(reversible, new TimesNegCstIntVarIterator(
3476  var_, cst_, false, reversible));
3477  }
3478  int64 OldMin() const override { return CapProd(var_->OldMax(), cst_); }
3479  int64 OldMax() const override { return CapProd(var_->OldMin(), cst_); }
3480 };
3481 
3482 // ----- TimesNegCstIntVar -----
3483 
3484 TimesNegCstIntVar::TimesNegCstIntVar(Solver* const s, IntVar* v, int64 c)
3485  : TimesCstIntVar(s, v, c) {}
3486 
3487 TimesNegCstIntVar::~TimesNegCstIntVar() {}
3488 
3489 int64 TimesNegCstIntVar::Min() const { return CapProd(var_->Max(), cst_); }
3490 
3491 void TimesNegCstIntVar::SetMin(int64 m) {
3492  if (m != kint64min) {
3493  var_->SetMax(PosIntDivDown(-m, -cst_));
3494  }
3495 }
3496 
3497 int64 TimesNegCstIntVar::Max() const { return CapProd(var_->Min(), cst_); }
3498 
3499 void TimesNegCstIntVar::SetMax(int64 m) {
3500  if (m != kint64max) {
3501  var_->SetMin(PosIntDivUp(-m, -cst_));
3502  }
3503 }
3504 
3505 void TimesNegCstIntVar::SetRange(int64 l, int64 u) {
3506  var_->SetRange(PosIntDivUp(-u, -cst_), PosIntDivDown(-l, -cst_));
3507 }
3508 
3509 void TimesNegCstIntVar::SetValue(int64 v) {
3510  if (v % cst_ != 0) {
3511  solver()->Fail();
3512  }
3513  var_->SetValue(v / cst_);
3514 }
3515 
3516 bool TimesNegCstIntVar::Bound() const { return var_->Bound(); }
3517 
3518 void TimesNegCstIntVar::WhenRange(Demon* d) { var_->WhenRange(d); }
3519 
3520 int64 TimesNegCstIntVar::Value() const { return CapProd(var_->Value(), cst_); }
3521 
3522 void TimesNegCstIntVar::RemoveValue(int64 v) {
3523  if (v % cst_ == 0) {
3524  var_->RemoveValue(v / cst_);
3525  }
3526 }
3527 
3528 void TimesNegCstIntVar::RemoveInterval(int64 l, int64 u) {
3529  for (int64 v = l; v <= u; ++v) {
3530  RemoveValue(v);
3531  }
3532  // TODO(user) : Improve me
3533 }
3534 
3535 void TimesNegCstIntVar::WhenBound(Demon* d) { var_->WhenBound(d); }
3536 
3537 void TimesNegCstIntVar::WhenDomain(Demon* d) { var_->WhenDomain(d); }
3538 
3539 uint64 TimesNegCstIntVar::Size() const { return var_->Size(); }
3540 
3541 bool TimesNegCstIntVar::Contains(int64 v) const {
3542  return (v % cst_ == 0 && var_->Contains(v / cst_));
3543 }
3544 
3545 // ---------- arithmetic expressions ----------
3546 
3547 // ----- PlusIntExpr -----
3548 
3549 class PlusIntExpr : public BaseIntExpr {
3550  public:
3551  PlusIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
3552  : BaseIntExpr(s), left_(l), right_(r) {}
3553 
3554  ~PlusIntExpr() override {}
3555 
3556  int64 Min() const override { return left_->Min() + right_->Min(); }
3557 
3558  void SetMin(int64 m) override {
3559  if (m > left_->Min() + right_->Min()) {
3560  left_->SetMin(m - right_->Max());
3561  right_->SetMin(m - left_->Max());
3562  }
3563  }
3564 
3565  void SetRange(int64 l, int64 u) override {
3566  const int64 left_min = left_->Min();
3567  const int64 right_min = right_->Min();
3568  const int64 left_max = left_->Max();
3569  const int64 right_max = right_->Max();
3570  if (l > left_min + right_min) {
3571  left_->SetMin(l - right_max);
3572  right_->SetMin(l - left_max);
3573  }
3574  if (u < left_max + right_max) {
3575  left_->SetMax(u - right_min);
3576  right_->SetMax(u - left_min);
3577  }
3578  }
3579 
3580  int64 Max() const override { return left_->Max() + right_->Max(); }
3581 
3582  void SetMax(int64 m) override {
3583  if (m < left_->Max() + right_->Max()) {
3584  left_->SetMax(m - right_->Min());
3585  right_->SetMax(m - left_->Min());
3586  }
3587  }
3588 
3589  bool Bound() const override { return (left_->Bound() && right_->Bound()); }
3590 
3591  void Range(int64* const mi, int64* const ma) override {
3592  *mi = left_->Min() + right_->Min();
3593  *ma = left_->Max() + right_->Max();
3594  }
3595 
3596  std::string name() const override {
3597  return absl::StrFormat("(%s + %s)", left_->name(), right_->name());
3598  }
3599 
3600  std::string DebugString() const override {
3601  return absl::StrFormat("(%s + %s)", left_->DebugString(),
3602  right_->DebugString());
3603  }
3604 
3605  void WhenRange(Demon* d) override {
3606  left_->WhenRange(d);
3607  right_->WhenRange(d);
3608  }
3609 
3610  void ExpandPlusIntExpr(IntExpr* const expr, std::vector<IntExpr*>* subs) {
3611  PlusIntExpr* const casted = dynamic_cast<PlusIntExpr*>(expr);
3612  if (casted != nullptr) {
3613  ExpandPlusIntExpr(casted->left_, subs);
3614  ExpandPlusIntExpr(casted->right_, subs);
3615  } else {
3616  subs->push_back(expr);
3617  }
3618  }
3619 
3620  IntVar* CastToVar() override {
3621  if (dynamic_cast<PlusIntExpr*>(left_) != nullptr ||
3622  dynamic_cast<PlusIntExpr*>(right_) != nullptr) {
3623  std::vector<IntExpr*> sub_exprs;
3624  ExpandPlusIntExpr(left_, &sub_exprs);
3625  ExpandPlusIntExpr(right_, &sub_exprs);
3626  if (sub_exprs.size() >= 3) {
3627  std::vector<IntVar*> sub_vars(sub_exprs.size());
3628  for (int i = 0; i < sub_exprs.size(); ++i) {
3629  sub_vars[i] = sub_exprs[i]->Var();
3630  }
3631  return solver()->MakeSum(sub_vars)->Var();
3632  }
3633  }
3634  return BaseIntExpr::CastToVar();
3635  }
3636 
3637  void Accept(ModelVisitor* const visitor) const override {
3638  visitor->BeginVisitIntegerExpression(ModelVisitor::kSum, this);
3639  visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
3640  visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
3641  right_);
3642  visitor->EndVisitIntegerExpression(ModelVisitor::kSum, this);
3643  }
3644 
3645  private:
3646  IntExpr* const left_;
3647  IntExpr* const right_;
3648 };
3649 
3650 class SafePlusIntExpr : public BaseIntExpr {
3651  public:
3652  SafePlusIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
3653  : BaseIntExpr(s), left_(l), right_(r) {}
3654 
3655  ~SafePlusIntExpr() override {}
3656 
3657  int64 Min() const override { return CapAdd(left_->Min(), right_->Min()); }
3658 
3659  void SetMin(int64 m) override {
3660  left_->SetMin(CapSub(m, right_->Max()));
3661  right_->SetMin(CapSub(m, left_->Max()));
3662  }
3663 
3664  void SetRange(int64 l, int64 u) override {
3665  const int64 left_min = left_->Min();
3666  const int64 right_min = right_->Min();
3667  const int64 left_max = left_->Max();
3668  const int64 right_max = right_->Max();
3669  if (l > CapAdd(left_min, right_min)) {
3670  left_->SetMin(CapSub(l, right_max));
3671  right_->SetMin(CapSub(l, left_max));
3672  }
3673  if (u < CapAdd(left_max, right_max)) {
3674  left_->SetMax(CapSub(u, right_min));
3675  right_->SetMax(CapSub(u, left_min));
3676  }
3677  }
3678 
3679  int64 Max() const override { return CapAdd(left_->Max(), right_->Max()); }
3680 
3681  void SetMax(int64 m) override {
3682  left_->SetMax(CapSub(m, right_->Min()));
3683  right_->SetMax(CapSub(m, left_->Min()));
3684  }
3685 
3686  bool Bound() const override { return (left_->Bound() && right_->Bound()); }
3687 
3688  std::string name() const override {
3689  return absl::StrFormat("(%s + %s)", left_->name(), right_->name());
3690  }
3691 
3692  std::string DebugString() const override {
3693  return absl::StrFormat("(%s + %s)", left_->DebugString(),
3694  right_->DebugString());
3695  }
3696 
3697  void WhenRange(Demon* d) override {
3698  left_->WhenRange(d);
3699  right_->WhenRange(d);
3700  }
3701 
3702  void Accept(ModelVisitor* const visitor) const override {
3703  visitor->BeginVisitIntegerExpression(ModelVisitor::kSum, this);
3704  visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
3705  visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
3706  right_);
3707  visitor->EndVisitIntegerExpression(ModelVisitor::kSum, this);
3708  }
3709 
3710  private:
3711  IntExpr* const left_;
3712  IntExpr* const right_;
3713 };
3714 
3715 // ----- PlusIntCstExpr -----
3716 
3717 class PlusIntCstExpr : public BaseIntExpr {
3718  public:
3719  PlusIntCstExpr(Solver* const s, IntExpr* const e, int64 v)
3720  : BaseIntExpr(s), expr_(e), value_(v) {}
3721  ~PlusIntCstExpr() override {}
3722  int64 Min() const override { return CapAdd(expr_->Min(), value_); }
3723  void SetMin(int64 m) override { expr_->SetMin(CapSub(m, value_)); }
3724  int64 Max() const override { return CapAdd(expr_->Max(), value_); }
3725  void SetMax(int64 m) override { expr_->SetMax(CapSub(m, value_)); }
3726  bool Bound() const override { return (expr_->Bound()); }
3727  std::string name() const override {
3728  return absl::StrFormat("(%s + %d)", expr_->name(), value_);
3729  }
3730  std::string DebugString() const override {
3731  return absl::StrFormat("(%s + %d)", expr_->DebugString(), value_);
3732  }
3733  void WhenRange(Demon* d) override { expr_->WhenRange(d); }
3734  IntVar* CastToVar() override;
3735  void Accept(ModelVisitor* const visitor) const override {
3736  visitor->BeginVisitIntegerExpression(ModelVisitor::kSum, this);
3737  visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
3738  expr_);
3739  visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
3740  visitor->EndVisitIntegerExpression(ModelVisitor::kSum, this);
3741  }
3742 
3743  private:
3744  IntExpr* const expr_;
3745  const int64 value_;
3746 };
3747 
3748 IntVar* PlusIntCstExpr::CastToVar() {
3749  Solver* const s = solver();
3750  IntVar* const var = expr_->Var();
3751  IntVar* cast = nullptr;
3752  if (AddOverflows(value_, expr_->Max()) ||
3753  AddOverflows(value_, expr_->Min())) {
3754  return BaseIntExpr::CastToVar();
3755  }
3756  switch (var->VarType()) {
3757  case DOMAIN_INT_VAR:
3758  cast = s->RegisterIntVar(s->RevAlloc(new PlusCstDomainIntVar(
3759  s, reinterpret_cast<DomainIntVar*>(var), value_)));
3760  // FIXME: Break was inserted during fallthrough cleanup. Please check.
3761  break;
3762  default:
3763  cast = s->RegisterIntVar(s->RevAlloc(new PlusCstIntVar(s, var, value_)));
3764  break;
3765  }
3766  return cast;
3767 }
3768 
3769 // ----- SubIntExpr -----
3770 
3771 class SubIntExpr : public BaseIntExpr {
3772  public:
3773  SubIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
3774  : BaseIntExpr(s), left_(l), right_(r) {}
3775 
3776  ~SubIntExpr() override {}
3777 
3778  int64 Min() const override { return left_->Min() - right_->Max(); }
3779 
3780  void SetMin(int64 m) override {
3781  left_->SetMin(CapAdd(m, right_->Min()));
3782  right_->SetMax(CapSub(left_->Max(), m));
3783  }
3784 
3785  int64 Max() const override { return left_->Max() - right_->Min(); }
3786 
3787  void SetMax(int64 m) override {
3788  left_->SetMax(CapAdd(m, right_->Max()));
3789  right_->SetMin(CapSub(left_->Min(), m));
3790  }
3791 
3792  void Range(int64* mi, int64* ma) override {
3793  *mi = left_->Min() - right_->Max();
3794  *ma = left_->Max() - right_->Min();
3795  }
3796 
3797  void SetRange(int64 l, int64 u) override {
3798  const int64 left_min = left_->Min();
3799  const int64 right_min = right_->Min();
3800  const int64 left_max = left_->Max();
3801  const int64 right_max = right_->Max();
3802  if (l > left_min - right_max) {
3803  left_->SetMin(CapAdd(l, right_min));
3804  right_->SetMax(CapSub(left_max, l));
3805  }
3806  if (u < left_max - right_min) {
3807  left_->SetMax(CapAdd(u, right_max));
3808  right_->SetMin(CapSub(left_min, u));
3809  }
3810  }
3811 
3812  bool Bound() const override { return (left_->Bound() && right_->Bound()); }
3813 
3814  std::string name() const override {
3815  return absl::StrFormat("(%s - %s)", left_->name(), right_->name());
3816  }
3817 
3818  std::string DebugString() const override {
3819  return absl::StrFormat("(%s - %s)", left_->DebugString(),
3820  right_->DebugString());
3821  }
3822 
3823  void WhenRange(Demon* d) override {
3824  left_->WhenRange(d);
3825  right_->WhenRange(d);
3826  }
3827 
3828  void Accept(ModelVisitor* const visitor) const override {
3829  visitor->BeginVisitIntegerExpression(ModelVisitor::kDifference, this);
3830  visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
3831  visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
3832  right_);
3833  visitor->EndVisitIntegerExpression(ModelVisitor::kDifference, this);
3834  }
3835 
3836  IntExpr* left() const { return left_; }
3837  IntExpr* right() const { return right_; }
3838 
3839  protected:
3840  IntExpr* const left_;
3841  IntExpr* const right_;
3842 };
3843 
3844 class SafeSubIntExpr : public SubIntExpr {
3845  public:
3846  SafeSubIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
3847  : SubIntExpr(s, l, r) {}
3848 
3849  ~SafeSubIntExpr() override {}
3850 
3851  int64 Min() const override { return CapSub(left_->Min(), right_->Max()); }
3852 
3853  void SetMin(int64 m) override {
3854  left_->SetMin(CapAdd(m, right_->Min()));
3855  right_->SetMax(CapSub(left_->Max(), m));
3856  }
3857 
3858  void SetRange(int64 l, int64 u) override {
3859  const int64 left_min = left_->Min();
3860  const int64 right_min = right_->Min();
3861  const int64 left_max = left_->Max();
3862  const int64 right_max = right_->Max();
3863  if (l > CapSub(left_min, right_max)) {
3864  left_->SetMin(CapAdd(l, right_min));
3865  right_->SetMax(CapSub(left_max, l));
3866  }
3867  if (u < CapSub(left_max, right_min)) {
3868  left_->SetMax(CapAdd(u, right_max));
3869  right_->SetMin(CapSub(left_min, u));
3870  }
3871  }
3872 
3873  void Range(int64* mi, int64* ma) override {
3874  *mi = CapSub(left_->Min(), right_->Max());
3875  *ma = CapSub(left_->Max(), right_->Min());
3876  }
3877 
3878  int64 Max() const override { return CapSub(left_->Max(), right_->Min()); }
3879 
3880  void SetMax(int64 m) override {
3881  left_->SetMax(CapAdd(m, right_->Max()));
3882  right_->SetMin(CapSub(left_->Min(), m));
3883  }
3884 };
3885 
3886 // l - r
3887 
3888 // ----- SubIntCstExpr -----
3889 
3890 class SubIntCstExpr : public BaseIntExpr {
3891  public:
3892  SubIntCstExpr(Solver* const s, IntExpr* const e, int64 v)
3893  : BaseIntExpr(s), expr_(e), value_(v) {}
3894  ~SubIntCstExpr() override {}
3895  int64 Min() const override { return CapSub(value_, expr_->Max()); }
3896  void SetMin(int64 m) override { expr_->SetMax(CapSub(value_, m)); }
3897  int64 Max() const override { return CapSub(value_, expr_->Min()); }
3898  void SetMax(int64 m) override { expr_->SetMin(CapSub(value_, m)); }
3899  bool Bound() const override { return (expr_->Bound()); }
3900  std::string name() const override {
3901  return absl::StrFormat("(%d - %s)", value_, expr_->name());
3902  }
3903  std::string DebugString() const override {
3904  return absl::StrFormat("(%d - %s)", value_, expr_->DebugString());
3905  }
3906  void WhenRange(Demon* d) override { expr_->WhenRange(d); }
3907  IntVar* CastToVar() override;
3908 
3909  void Accept(ModelVisitor* const visitor) const override {
3910  visitor->BeginVisitIntegerExpression(ModelVisitor::kDifference, this);
3911  visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
3912  visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
3913  expr_);
3914  visitor->EndVisitIntegerExpression(ModelVisitor::kDifference, this);
3915  }
3916 
3917  private:
3918  IntExpr* const expr_;
3919  const int64 value_;
3920 };
3921 
3922 IntVar* SubIntCstExpr::CastToVar() {
3923  if (SubOverflows(value_, expr_->Min()) ||
3924  SubOverflows(value_, expr_->Max())) {
3925  return BaseIntExpr::CastToVar();
3926  }
3927  Solver* const s = solver();
3928  IntVar* const var =
3929  s->RegisterIntVar(s->RevAlloc(new SubCstIntVar(s, expr_->Var(), value_)));
3930  return var;
3931 }
3932 
3933 // ----- OppIntExpr -----
3934 
3935 class OppIntExpr : public BaseIntExpr {
3936  public:
3937  OppIntExpr(Solver* const s, IntExpr* const e) : BaseIntExpr(s), expr_(e) {}
3938  ~OppIntExpr() override {}
3939  int64 Min() const override { return (-expr_->Max()); }
3940  void SetMin(int64 m) override { expr_->SetMax(-m); }
3941  int64 Max() const override { return (-expr_->Min()); }
3942  void SetMax(int64 m) override { expr_->SetMin(-m); }
3943  bool Bound() const override { return (expr_->Bound()); }
3944  std::string name() const override {
3945  return absl::StrFormat("(-%s)", expr_->name());
3946  }
3947  std::string DebugString() const override {
3948  return absl::StrFormat("(-%s)", expr_->DebugString());
3949  }
3950  void WhenRange(Demon* d) override { expr_->WhenRange(d); }
3951  IntVar* CastToVar() override;
3952 
3953  void Accept(ModelVisitor* const visitor) const override {
3954  visitor->BeginVisitIntegerExpression(ModelVisitor::kOpposite, this);
3955  visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
3956  expr_);
3957  visitor->EndVisitIntegerExpression(ModelVisitor::kOpposite, this);
3958  }
3959 
3960  private:
3961  IntExpr* const expr_;
3962 };
3963 
3964 IntVar* OppIntExpr::CastToVar() {
3965  Solver* const s = solver();
3966  IntVar* const var =
3967  s->RegisterIntVar(s->RevAlloc(new OppIntVar(s, expr_->Var())));
3968  return var;
3969 }
3970 
3971 // ----- TimesIntCstExpr -----
3972 
3973 class TimesIntCstExpr : public BaseIntExpr {
3974  public:
3975  TimesIntCstExpr(Solver* const s, IntExpr* const e, int64 v)
3976  : BaseIntExpr(s), expr_(e), value_(v) {}
3977 
3978  ~TimesIntCstExpr() override {}
3979 
3980  bool Bound() const override { return (expr_->Bound()); }
3981 
3982  std::string name() const override {
3983  return absl::StrFormat("(%s * %d)", expr_->name(), value_);
3984  }
3985 
3986  std::string DebugString() const override {
3987  return absl::StrFormat("(%s * %d)", expr_->DebugString(), value_);
3988  }
3989 
3990  void WhenRange(Demon* d) override { expr_->WhenRange(d); }
3991 
3992  IntExpr* Expr() const { return expr_; }
3993 
3994  int64 Constant() const { return value_; }
3995 
3996  void Accept(ModelVisitor* const visitor) const override {
3997  visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
3998  visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
3999  expr_);
4000  visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
4001  visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4002  }
4003 
4004  protected:
4005  IntExpr* const expr_;
4006  const int64 value_;
4007 };
4008 
4009 // ----- TimesPosIntCstExpr -----
4010 
4011 class TimesPosIntCstExpr : public TimesIntCstExpr {
4012  public:
4013  TimesPosIntCstExpr(Solver* const s, IntExpr* const e, int64 v)
4014  : TimesIntCstExpr(s, e, v) {
4015  CHECK_GT(v, 0);
4016  }
4017 
4018  ~TimesPosIntCstExpr() override {}
4019 
4020  int64 Min() const override { return expr_->Min() * value_; }
4021 
4022  void SetMin(int64 m) override { expr_->SetMin(PosIntDivUp(m, value_)); }
4023 
4024  int64 Max() const override { return expr_->Max() * value_; }
4025 
4026  void SetMax(int64 m) override { expr_->SetMax(PosIntDivDown(m, value_)); }
4027 
4028  IntVar* CastToVar() override {
4029  Solver* const s = solver();
4030  IntVar* var = nullptr;
4031  if (expr_->IsVar() &&
4032  reinterpret_cast<IntVar*>(expr_)->VarType() == BOOLEAN_VAR) {
4033  var = s->RegisterIntVar(s->RevAlloc(new TimesPosCstBoolVar(
4034  s, reinterpret_cast<BooleanVar*>(expr_), value_)));
4035  } else {
4036  var = s->RegisterIntVar(
4037  s->RevAlloc(new TimesPosCstIntVar(s, expr_->Var(), value_)));
4038  }
4039  return var;
4040  }
4041 };
4042 
4043 // This expressions adds safe arithmetic (w.r.t. overflows) compared
4044 // to the previous one.
4045 class SafeTimesPosIntCstExpr : public TimesIntCstExpr {
4046  public:
4047  SafeTimesPosIntCstExpr(Solver* const s, IntExpr* const e, int64 v)
4048  : TimesIntCstExpr(s, e, v) {
4049  CHECK_GT(v, 0);
4050  }
4051 
4052  ~SafeTimesPosIntCstExpr() override {}
4053 
4054  int64 Min() const override { return CapProd(expr_->Min(), value_); }
4055 
4056  void SetMin(int64 m) override {
4057  if (m != kint64min) {
4058  expr_->SetMin(PosIntDivUp(m, value_));
4059  }
4060  }
4061 
4062  int64 Max() const override { return CapProd(expr_->Max(), value_); }
4063 
4064  void SetMax(int64 m) override {
4065  if (m != kint64max) {
4066  expr_->SetMax(PosIntDivDown(m, value_));
4067  }
4068  }
4069 
4070  IntVar* CastToVar() override {
4071  Solver* const s = solver();
4072  IntVar* var = nullptr;
4073  if (expr_->IsVar() &&
4074  reinterpret_cast<IntVar*>(expr_)->VarType() == BOOLEAN_VAR) {
4075  var = s->RegisterIntVar(s->RevAlloc(new TimesPosCstBoolVar(
4076  s, reinterpret_cast<BooleanVar*>(expr_), value_)));
4077  } else {
4078  // TODO(user): Check overflows.
4079  var = s->RegisterIntVar(
4080  s->RevAlloc(new TimesPosCstIntVar(s, expr_->Var(), value_)));
4081  }
4082  return var;
4083  }
4084 };
4085 
4086 // ----- TimesIntNegCstExpr -----
4087 
4088 class TimesIntNegCstExpr : public TimesIntCstExpr {
4089  public:
4090  TimesIntNegCstExpr(Solver* const s, IntExpr* const e, int64 v)
4091  : TimesIntCstExpr(s, e, v) {
4092  CHECK_LT(v, 0);
4093  }
4094 
4095  ~TimesIntNegCstExpr() override {}
4096 
4097  int64 Min() const override { return CapProd(expr_->Max(), value_); }
4098 
4099  void SetMin(int64 m) override {
4100  if (m != kint64min) {
4101  expr_->SetMax(PosIntDivDown(-m, -value_));
4102  }
4103  }
4104 
4105  int64 Max() const override { return CapProd(expr_->Min(), value_); }
4106 
4107  void SetMax(int64 m) override {
4108  if (m != kint64max) {
4109  expr_->SetMin(PosIntDivUp(-m, -value_));
4110  }
4111  }
4112 
4113  IntVar* CastToVar() override {
4114  Solver* const s = solver();
4115  IntVar* var = nullptr;
4116  var = s->RegisterIntVar(
4117  s->RevAlloc(new TimesNegCstIntVar(s, expr_->Var(), value_)));
4118  return var;
4119  }
4120 };
4121 
4122 // ----- Utilities for product expression -----
4123 
4124 // Propagates set_min on left * right, left and right >= 0.
4125 void SetPosPosMinExpr(IntExpr* const left, IntExpr* const right, int64 m) {
4126  DCHECK_GE(left->Min(), 0);
4127  DCHECK_GE(right->Min(), 0);
4128  const int64 lmax = left->Max();
4129  const int64 rmax = right->Max();
4130  if (m > CapProd(lmax, rmax)) {
4131  left->solver()->Fail();
4132  }
4133  if (m > CapProd(left->Min(), right->Min())) {
4134  // Ok for m == 0 due to left and right being positive
4135  if (0 != rmax) {
4136  left->SetMin(PosIntDivUp(m, rmax));
4137  }
4138  if (0 != lmax) {
4139  right->SetMin(PosIntDivUp(m, lmax));
4140  }
4141  }
4142 }
4143 
4144 // Propagates set_max on left * right, left and right >= 0.
4145 void SetPosPosMaxExpr(IntExpr* const left, IntExpr* const right, int64 m) {
4146  DCHECK_GE(left->Min(), 0);
4147  DCHECK_GE(right->Min(), 0);
4148  const int64 lmin = left->Min();
4149  const int64 rmin = right->Min();
4150  if (m < CapProd(lmin, rmin)) {
4151  left->solver()->Fail();
4152  }
4153  if (m < CapProd(left->Max(), right->Max())) {
4154  if (0 != lmin) {
4155  right->SetMax(PosIntDivDown(m, lmin));
4156  }
4157  if (0 != rmin) {
4158  left->SetMax(PosIntDivDown(m, rmin));
4159  }
4160  // else do nothing: 0 is supporting any value from other expr.
4161  }
4162 }
4163 
4164 // Propagates set_min on left * right, left >= 0, right across 0.
4165 void SetPosGenMinExpr(IntExpr* const left, IntExpr* const right, int64 m) {
4166  DCHECK_GE(left->Min(), 0);
4167  DCHECK_GT(right->Max(), 0);
4168  DCHECK_LT(right->Min(), 0);
4169  const int64 lmax = left->Max();
4170  const int64 rmax = right->Max();
4171  if (m > CapProd(lmax, rmax)) {
4172  left->solver()->Fail();
4173  }
4174  if (left->Max() == 0) { // left is bound to 0, product is bound to 0.
4175  DCHECK_EQ(0, left->Min());
4176  DCHECK_LE(m, 0);
4177  } else {
4178  if (m > 0) { // We deduce right > 0.
4179  left->SetMin(PosIntDivUp(m, rmax));
4180  right->SetMin(PosIntDivUp(m, lmax));
4181  } else if (m == 0) {
4182  const int64 lmin = left->Min();
4183  if (lmin > 0) {
4184  right->SetMin(0);
4185  }
4186  } else { // m < 0
4187  const int64 lmin = left->Min();
4188  if (0 != lmin) { // We cannot deduce anything if 0 is in the domain.
4189  right->SetMin(-PosIntDivDown(-m, lmin));
4190  }
4191  }
4192  }
4193 }
4194 
4195 // Propagates set_min on left * right, left and right across 0.
4196 void SetGenGenMinExpr(IntExpr* const left, IntExpr* const right, int64 m) {
4197  DCHECK_LT(left->Min(), 0);
4198  DCHECK_GT(left->Max(), 0);
4199  DCHECK_GT(right->Max(), 0);
4200  DCHECK_LT(right->Min(), 0);
4201  const int64 lmin = left->Min();
4202  const int64 lmax = left->Max();
4203  const int64 rmin = right->Min();
4204  const int64 rmax = right->Max();
4205  if (m > std::max(CapProd(lmin, rmin), CapProd(lmax, rmax))) {
4206  left->solver()->Fail();
4207  }
4208  if (m > lmin * rmin) { // Must be positive section * positive section.
4209  left->SetMin(PosIntDivUp(m, rmax));
4210  right->SetMin(PosIntDivUp(m, lmax));
4211  } else if (m > CapProd(lmax, rmax)) { // Negative section * negative section.
4212  left->SetMax(-PosIntDivUp(m, -rmin));
4213  right->SetMax(-PosIntDivUp(m, -lmin));
4214  }
4215 }
4216 
4217 void TimesSetMin(IntExpr* const left, IntExpr* const right,
4218  IntExpr* const minus_left, IntExpr* const minus_right,
4219  int64 m) {
4220  if (left->Min() >= 0) {
4221  if (right->Min() >= 0) {
4222  SetPosPosMinExpr(left, right, m);
4223  } else if (right->Max() <= 0) {
4224  SetPosPosMaxExpr(left, minus_right, -m);
4225  } else { // right->Min() < 0 && right->Max() > 0
4226  SetPosGenMinExpr(left, right, m);
4227  }
4228  } else if (left->Max() <= 0) {
4229  if (right->Min() >= 0) {
4230  SetPosPosMaxExpr(right, minus_left, -m);
4231  } else if (right->Max() <= 0) {
4232  SetPosPosMinExpr(minus_left, minus_right, m);
4233  } else { // right->Min() < 0 && right->Max() > 0
4234  SetPosGenMinExpr(minus_left, minus_right, m);
4235  }
4236  } else if (right->Min() >= 0) { // left->Min() < 0 && left->Max() > 0
4237  SetPosGenMinExpr(right, left, m);
4238  } else if (right->Max() <= 0) { // left->Min() < 0 && left->Max() > 0
4239  SetPosGenMinExpr(minus_right, minus_left, m);
4240  } else { // left->Min() < 0 && left->Max() > 0 &&
4241  // right->Min() < 0 && right->Max() > 0
4242  SetGenGenMinExpr(left, right, m);
4243  }
4244 }
4245 
4246 class TimesIntExpr : public BaseIntExpr {
4247  public:
4248  TimesIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
4249  : BaseIntExpr(s),
4250  left_(l),
4251  right_(r),
4252  minus_left_(s->MakeOpposite(left_)),
4253  minus_right_(s->MakeOpposite(right_)) {}
4254  ~TimesIntExpr() override {}
4255  int64 Min() const override {
4256  const int64 lmin = left_->Min();
4257  const int64 lmax = left_->Max();
4258  const int64 rmin = right_->Min();
4259  const int64 rmax = right_->Max();
4260  return std::min(std::min(CapProd(lmin, rmin), CapProd(lmax, rmax)),
4261  std::min(CapProd(lmax, rmin), CapProd(lmin, rmax)));
4262  }
4263  void SetMin(int64 m) override;
4264  int64 Max() const override {
4265  const int64 lmin = left_->Min();
4266  const int64 lmax = left_->Max();
4267  const int64 rmin = right_->Min();
4268  const int64 rmax = right_->Max();
4269  return std::max(std::max(CapProd(lmin, rmin), CapProd(lmax, rmax)),
4270  std::max(CapProd(lmax, rmin), CapProd(lmin, rmax)));
4271  }
4272  void SetMax(int64 m) override;
4273  bool Bound() const override;
4274  std::string name() const override {
4275  return absl::StrFormat("(%s * %s)", left_->name(), right_->name());
4276  }
4277  std::string DebugString() const override {
4278  return absl::StrFormat("(%s * %s)", left_->DebugString(),
4279  right_->DebugString());
4280  }
4281  void WhenRange(Demon* d) override {
4282  left_->WhenRange(d);
4283  right_->WhenRange(d);
4284  }
4285 
4286  void Accept(ModelVisitor* const visitor) const override {
4287  visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4288  visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
4289  visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4290  right_);
4291  visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4292  }
4293 
4294  private:
4295  IntExpr* const left_;
4296  IntExpr* const right_;
4297  IntExpr* const minus_left_;
4298  IntExpr* const minus_right_;
4299 };
4300 
4301 void TimesIntExpr::SetMin(int64 m) {
4302  if (m != kint64min) {
4303  TimesSetMin(left_, right_, minus_left_, minus_right_, m);
4304  }
4305 }
4306 
4307 void TimesIntExpr::SetMax(int64 m) {
4308  if (m != kint64max) {
4309  TimesSetMin(left_, minus_right_, minus_left_, right_, -m);
4310  }
4311 }
4312 
4313 bool TimesIntExpr::Bound() const {
4314  const bool left_bound = left_->Bound();
4315  const bool right_bound = right_->Bound();
4316  return ((left_bound && left_->Max() == 0) ||
4317  (right_bound && right_->Max() == 0) || (left_bound && right_bound));
4318 }
4319 
4320 // ----- TimesPosIntExpr -----
4321 
4322 class TimesPosIntExpr : public BaseIntExpr {
4323  public:
4324  TimesPosIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
4325  : BaseIntExpr(s), left_(l), right_(r) {}
4326  ~TimesPosIntExpr() override {}
4327  int64 Min() const override { return (left_->Min() * right_->Min()); }
4328  void SetMin(int64 m) override;
4329  int64 Max() const override { return (left_->Max() * right_->Max()); }
4330  void SetMax(int64 m) override;
4331  bool Bound() const override;
4332  std::string name() const override {
4333  return absl::StrFormat("(%s * %s)", left_->name(), right_->name());
4334  }
4335  std::string DebugString() const override {
4336  return absl::StrFormat("(%s * %s)", left_->DebugString(),
4337  right_->DebugString());
4338  }
4339  void WhenRange(Demon* d) override {
4340  left_->WhenRange(d);
4341  right_->WhenRange(d);
4342  }
4343 
4344  void Accept(ModelVisitor* const visitor) const override {
4345  visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4346  visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
4347  visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4348  right_);
4349  visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4350  }
4351 
4352  private:
4353  IntExpr* const left_;
4354  IntExpr* const right_;
4355 };
4356 
4357 void TimesPosIntExpr::SetMin(int64 m) { SetPosPosMinExpr(left_, right_, m); }
4358 
4359 void TimesPosIntExpr::SetMax(int64 m) { SetPosPosMaxExpr(left_, right_, m); }
4360 
4361 bool TimesPosIntExpr::Bound() const {
4362  return (left_->Max() == 0 || right_->Max() == 0 ||
4363  (left_->Bound() && right_->Bound()));
4364 }
4365 
4366 // ----- SafeTimesPosIntExpr -----
4367 
4368 class SafeTimesPosIntExpr : public BaseIntExpr {
4369  public:
4370  SafeTimesPosIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
4371  : BaseIntExpr(s), left_(l), right_(r) {}
4372  ~SafeTimesPosIntExpr() override {}
4373  int64 Min() const override { return CapProd(left_->Min(), right_->Min()); }
4374  void SetMin(int64 m) override {
4375  if (m != kint64min) {
4376  SetPosPosMinExpr(left_, right_, m);
4377  }
4378  }
4379  int64 Max() const override { return CapProd(left_->Max(), right_->Max()); }
4380  void SetMax(int64 m) override {
4381  if (m != kint64max) {
4382  SetPosPosMaxExpr(left_, right_, m);
4383  }
4384  }
4385  bool Bound() const override {
4386  return (left_->Max() == 0 || right_->Max() == 0 ||
4387  (left_->Bound() && right_->Bound()));
4388  }
4389  std::string name() const override {
4390  return absl::StrFormat("(%s * %s)", left_->name(), right_->name());
4391  }
4392  std::string DebugString() const override {
4393  return absl::StrFormat("(%s * %s)", left_->DebugString(),
4394  right_->DebugString());
4395  }
4396  void WhenRange(Demon* d) override {
4397  left_->WhenRange(d);
4398  right_->WhenRange(d);
4399  }
4400 
4401  void Accept(ModelVisitor* const visitor) const override {
4402  visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4403  visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
4404  visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4405  right_);
4406  visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4407  }
4408 
4409  private:
4410  IntExpr* const left_;
4411  IntExpr* const right_;
4412 };
4413 
4414 // ----- TimesBooleanPosIntExpr -----
4415 
4416 class TimesBooleanPosIntExpr : public BaseIntExpr {
4417  public:
4418  TimesBooleanPosIntExpr(Solver* const s, BooleanVar* const b, IntExpr* const e)
4419  : BaseIntExpr(s), boolvar_(b), expr_(e) {}
4420  ~TimesBooleanPosIntExpr() override {}
4421  int64 Min() const override {
4422  return (boolvar_->RawValue() == 1 ? expr_->Min() : 0);
4423  }
4424  void SetMin(int64 m) override;
4425  int64 Max() const override {
4426  return (boolvar_->RawValue() == 0 ? 0 : expr_->Max());
4427  }
4428  void SetMax(int64 m) override;
4429  void Range(int64* mi, int64* ma) override;
4430  void SetRange(int64 mi, int64 ma) override;
4431  bool Bound() const override;
4432  std::string name() const override {
4433  return absl::StrFormat("(%s * %s)", boolvar_->name(), expr_->name());
4434  }
4435  std::string DebugString() const override {
4436  return absl::StrFormat("(%s * %s)", boolvar_->DebugString(),
4437  expr_->DebugString());
4438  }
4439  void WhenRange(Demon* d) override {
4440  boolvar_->WhenRange(d);
4441  expr_->WhenRange(d);
4442  }
4443 
4444  void Accept(ModelVisitor* const visitor) const override {
4445  visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4446  visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument,
4447  boolvar_);
4448  visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4449  expr_);
4450  visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4451  }
4452 
4453  private:
4454  BooleanVar* const boolvar_;
4455  IntExpr* const expr_;
4456 };
4457 
4458 void TimesBooleanPosIntExpr::SetMin(int64 m) {
4459  if (m > 0) {
4460  boolvar_->SetValue(1);
4461  expr_->SetMin(m);
4462  }
4463 }
4464 
4465 void TimesBooleanPosIntExpr::SetMax(int64 m) {
4466  if (m < 0) {
4467  solver()->Fail();
4468  }
4469  if (m < expr_->Min()) {
4470  boolvar_->SetValue(0);
4471  }
4472  if (boolvar_->RawValue() == 1) {
4473  expr_->SetMax(m);
4474  }
4475 }
4476 
4477 void TimesBooleanPosIntExpr::Range(int64* mi, int64* ma) {
4478  const int value = boolvar_->RawValue();
4479  if (value == 0) {
4480  *mi = 0;
4481  *ma = 0;
4482  } else if (value == 1) {
4483  expr_->Range(mi, ma);
4484  } else {
4485  *mi = 0;
4486  *ma = expr_->Max();
4487  }
4488 }
4489 
4490 void TimesBooleanPosIntExpr::SetRange(int64 mi, int64 ma) {
4491  if (ma < 0 || mi > ma) {
4492  solver()->Fail();
4493  }
4494  if (mi > 0) {
4495  boolvar_->SetValue(1);
4496  expr_->SetMin(mi);
4497  }
4498  if (ma < expr_->Min()) {
4499  boolvar_->SetValue(0);
4500  }
4501  if (boolvar_->RawValue() == 1) {
4502  expr_->SetMax(ma);
4503  }
4504 }
4505 
4506 bool TimesBooleanPosIntExpr::Bound() const {
4507  return (boolvar_->RawValue() == 0 || expr_->Max() == 0 ||
4508  (boolvar_->RawValue() != BooleanVar::kUnboundBooleanVarValue &&
4509  expr_->Bound()));
4510 }
4511 
4512 // ----- TimesBooleanIntExpr -----
4513 
4514 class TimesBooleanIntExpr : public BaseIntExpr {
4515  public:
4516  TimesBooleanIntExpr(Solver* const s, BooleanVar* const b, IntExpr* const e)
4517  : BaseIntExpr(s), boolvar_(b), expr_(e) {}
4518  ~TimesBooleanIntExpr() override {}
4519  int64 Min() const override {
4520  switch (boolvar_->RawValue()) {
4521  case 0: {
4522  return 0LL;
4523  }
4524  case 1: {
4525  return expr_->Min();
4526  }
4527  default: {
4528  DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4529  return std::min(int64{0}, expr_->Min());
4530  }
4531  }
4532  }
4533  void SetMin(int64 m) override;
4534  int64 Max() const override {
4535  switch (boolvar_->RawValue()) {
4536  case 0: {
4537  return 0LL;
4538  }
4539  case 1: {
4540  return expr_->Max();
4541  }
4542  default: {
4543  DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4544  return std::max(int64{0}, expr_->Max());
4545  }
4546  }
4547  }
4548  void SetMax(int64 m) override;
4549  void Range(int64* mi, int64* ma) override;
4550  void SetRange(int64 mi, int64 ma) override;
4551  bool Bound() const override;
4552  std::string name() const override {
4553  return absl::StrFormat("(%s * %s)", boolvar_->name(), expr_->name());
4554  }
4555  std::string DebugString() const override {
4556  return absl::StrFormat("(%s * %s)", boolvar_->DebugString(),
4557  expr_->DebugString());
4558  }
4559  void WhenRange(Demon* d) override {
4560  boolvar_->WhenRange(d);
4561  expr_->WhenRange(d);
4562  }
4563 
4564  void Accept(ModelVisitor* const visitor) const override {
4565  visitor->BeginVisitIntegerExpression(ModelVisitor::kProduct, this);
4566  visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument,
4567  boolvar_);
4568  visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4569  expr_);
4570  visitor->EndVisitIntegerExpression(ModelVisitor::kProduct, this);
4571  }
4572 
4573  private:
4574  BooleanVar* const boolvar_;
4575  IntExpr* const expr_;
4576 };
4577 
4578 void TimesBooleanIntExpr::SetMin(int64 m) {
4579  switch (boolvar_->RawValue()) {
4580  case 0: {
4581  if (m > 0) {
4582  solver()->Fail();
4583  }
4584  break;
4585  }
4586  case 1: {
4587  expr_->SetMin(m);
4588  break;
4589  }
4590  default: {
4591  DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4592  if (m > 0) { // 0 is no longer possible for boolvar because min > 0.
4593  boolvar_->SetValue(1);
4594  expr_->SetMin(m);
4595  } else if (m <= 0 && expr_->Max() < m) {
4596  boolvar_->SetValue(0);
4597  }
4598  }
4599  }
4600 }
4601 
4602 void TimesBooleanIntExpr::SetMax(int64 m) {
4603  switch (boolvar_->RawValue()) {
4604  case 0: {
4605  if (m < 0) {
4606  solver()->Fail();
4607  }
4608  break;
4609  }
4610  case 1: {
4611  expr_->SetMax(m);
4612  break;
4613  }
4614  default: {
4615  DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4616  if (m < 0) { // 0 is no longer possible for boolvar because max < 0.
4617  boolvar_->SetValue(1);
4618  expr_->SetMax(m);
4619  } else if (m >= 0 && expr_->Min() > m) {
4620  boolvar_->SetValue(0);
4621  }
4622  }
4623  }
4624 }
4625 
4626 void TimesBooleanIntExpr::Range(int64* mi, int64* ma) {
4627  switch (boolvar_->RawValue()) {
4628  case 0: {
4629  *mi = 0;
4630  *ma = 0;
4631  break;
4632  }
4633  case 1: {
4634  *mi = expr_->Min();
4635  *ma = expr_->Max();
4636  break;
4637  }
4638  default: {
4639  DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4640  *mi = std::min(int64{0}, expr_->Min());
4641  *ma = std::max(int64{0}, expr_->Max());
4642  break;
4643  }
4644  }
4645 }
4646 
4647 void TimesBooleanIntExpr::SetRange(int64 mi, int64 ma) {
4648  if (mi > ma) {
4649  solver()->Fail();
4650  }
4651  switch (boolvar_->RawValue()) {
4652  case 0: {
4653  if (mi > 0 || ma < 0) {
4654  solver()->Fail();
4655  }
4656  break;
4657  }
4658  case 1: {
4659  expr_->SetRange(mi, ma);
4660  break;
4661  }
4662  default: {
4663  DCHECK_EQ(BooleanVar::kUnboundBooleanVarValue, boolvar_->RawValue());
4664  if (mi > 0) {
4665  boolvar_->SetValue(1);
4666  expr_->SetMin(mi);
4667  } else if (mi == 0 && expr_->Max() < 0) {
4668  boolvar_->SetValue(0);
4669  }
4670  if (ma < 0) {
4671  boolvar_->SetValue(1);
4672  expr_->SetMax(ma);
4673  } else if (ma == 0 && expr_->Min() > 0) {
4674  boolvar_->SetValue(0);
4675  }
4676  break;
4677  }
4678  }
4679 }
4680 
4681 bool TimesBooleanIntExpr::Bound() const {
4682  return (boolvar_->RawValue() == 0 ||
4683  (expr_->Bound() &&
4684  (boolvar_->RawValue() != BooleanVar::kUnboundBooleanVarValue ||
4685  expr_->Max() == 0)));
4686 }
4687 
4688 // ----- DivPosIntCstExpr -----
4689 
4690 class DivPosIntCstExpr : public BaseIntExpr {
4691  public:
4692  DivPosIntCstExpr(Solver* const s, IntExpr* const e, int64 v)
4693  : BaseIntExpr(s), expr_(e), value_(v) {
4694  CHECK_GE(v, 0);
4695  }
4696  ~DivPosIntCstExpr() override {}
4697 
4698  int64 Min() const override { return expr_->Min() / value_; }
4699 
4700  void SetMin(int64 m) override {
4701  if (m > 0) {
4702  expr_->SetMin(m * value_);
4703  } else {
4704  expr_->SetMin((m - 1) * value_ + 1);
4705  }
4706  }
4707  int64 Max() const override { return expr_->Max() / value_; }
4708 
4709  void SetMax(int64 m) override {
4710  if (m >= 0) {
4711  expr_->SetMax((m + 1) * value_ - 1);
4712  } else {
4713  expr_->SetMax(m * value_);
4714  }
4715  }
4716 
4717  std::string name() const override {
4718  return absl::StrFormat("(%s div %d)", expr_->name(), value_);
4719  }
4720 
4721  std::string DebugString() const override {
4722  return absl::StrFormat("(%s div %d)", expr_->DebugString(), value_);
4723  }
4724 
4725  void WhenRange(Demon* d) override { expr_->WhenRange(d); }
4726 
4727  void Accept(ModelVisitor* const visitor) const override {
4728  visitor->BeginVisitIntegerExpression(ModelVisitor::kDivide, this);
4729  visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
4730  expr_);
4731  visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
4732  visitor->EndVisitIntegerExpression(ModelVisitor::kDivide, this);
4733  }
4734 
4735  private:
4736  IntExpr* const expr_;
4737  const int64 value_;
4738 };
4739 
4740 // DivPosIntExpr
4741 
4742 class DivPosIntExpr : public BaseIntExpr {
4743  public:
4744  DivPosIntExpr(Solver* const s, IntExpr* const num, IntExpr* const denom)
4745  : BaseIntExpr(s),
4746  num_(num),
4747  denom_(denom),
4748  opp_num_(s->MakeOpposite(num)) {}
4749 
4750  ~DivPosIntExpr() override {}
4751 
4752  int64 Min() const override {
4753  return num_->Min() >= 0
4754  ? num_->Min() / denom_->Max()
4755  : (denom_->Min() == 0 ? num_->Min()
4756  : num_->Min() / denom_->Min());
4757  }
4758 
4759  int64 Max() const override {
4760  return num_->Max() >= 0 ? (denom_->Min() == 0 ? num_->Max()
4761  : num_->Max() / denom_->Min())
4762  : num_->Max() / denom_->Max();
4763  }
4764 
4765  static void SetPosMin(IntExpr* const num, IntExpr* const denom, int64 m) {
4766  num->SetMin(m * denom->Min());
4767  denom->SetMax(num->Max() / m);
4768  }
4769 
4770  static void SetPosMax(IntExpr* const num, IntExpr* const denom, int64 m) {
4771  num->SetMax((m + 1) * denom->Max() - 1);
4772  denom->SetMin(num->Min() / (m + 1) + 1);
4773  }
4774 
4775  void SetMin(int64 m) override {
4776  if (m > 0) {
4777  SetPosMin(num_, denom_, m);
4778  } else {
4779  SetPosMax(opp_num_, denom_, -m);
4780  }
4781  }
4782 
4783  void SetMax(int64 m) override {
4784  if (m >= 0) {
4785  SetPosMax(num_, denom_, m);
4786  } else {
4787  SetPosMin(opp_num_, denom_, -m);
4788  }
4789  }
4790 
4791  std::string name() const override {
4792  return absl::StrFormat("(%s div %s)", num_->name(), denom_->name());
4793  }
4794  std::string DebugString() const override {
4795  return absl::StrFormat("(%s div %s)", num_->DebugString(),
4796  denom_->DebugString());
4797  }
4798  void WhenRange(Demon* d) override {
4799  num_->WhenRange(d);
4800  denom_->WhenRange(d);
4801  }
4802 
4803  void Accept(ModelVisitor* const visitor) const override {
4804  visitor->BeginVisitIntegerExpression(ModelVisitor::kDivide, this);
4805  visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, num_);
4806  visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4807  denom_);
4808  visitor->EndVisitIntegerExpression(ModelVisitor::kDivide, this);
4809  }
4810 
4811  private:
4812  IntExpr* const num_;
4813  IntExpr* const denom_;
4814  IntExpr* const opp_num_;
4815 };
4816 
4817 class DivPosPosIntExpr : public BaseIntExpr {
4818  public:
4819  DivPosPosIntExpr(Solver* const s, IntExpr* const num, IntExpr* const denom)
4820  : BaseIntExpr(s), num_(num), denom_(denom) {}
4821 
4822  ~DivPosPosIntExpr() override {}
4823 
4824  int64 Min() const override {
4825  if (denom_->Max() == 0) {
4826  solver()->Fail();
4827  }
4828  return num_->Min() / denom_->Max();
4829  }
4830 
4831  int64 Max() const override {
4832  if (denom_->Min() == 0) {
4833  return num_->Max();
4834  } else {
4835  return num_->Max() / denom_->Min();
4836  }
4837  }
4838 
4839  void SetMin(int64 m) override {
4840  if (m > 0) {
4841  num_->SetMin(m * denom_->Min());
4842  denom_->SetMax(num_->Max() / m);
4843  }
4844  }
4845 
4846  void SetMax(int64 m) override {
4847  if (m >= 0) {
4848  num_->SetMax((m + 1) * denom_->Max() - 1);
4849  denom_->SetMin(num_->Min() / (m + 1) + 1);
4850  } else {
4851  solver()->Fail();
4852  }
4853  }
4854 
4855  std::string name() const override {
4856  return absl::StrFormat("(%s div %s)", num_->name(), denom_->name());
4857  }
4858 
4859  std::string DebugString() const override {
4860  return absl::StrFormat("(%s div %s)", num_->DebugString(),
4861  denom_->DebugString());
4862  }
4863 
4864  void WhenRange(Demon* d) override {
4865  num_->WhenRange(d);
4866  denom_->WhenRange(d);
4867  }
4868 
4869  void Accept(ModelVisitor* const visitor) const override {
4870  visitor->BeginVisitIntegerExpression(ModelVisitor::kDivide, this);
4871  visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, num_);
4872  visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
4873  denom_);
4874  visitor->EndVisitIntegerExpression(ModelVisitor::kDivide, this);
4875  }
4876 
4877  private:
4878  IntExpr* const num_;
4879  IntExpr* const denom_;
4880 };
4881 
4882 // DivIntExpr
4883 
4884 class DivIntExpr : public BaseIntExpr {
4885  public:
4886  DivIntExpr(Solver* const s, IntExpr* const num, IntExpr* const denom)
4887  : BaseIntExpr(s),
4888  num_(num),
4889  denom_(denom),
4890  opp_num_(s->MakeOpposite(num)) {}
4891 
4892  ~DivIntExpr() override {}
4893 
4894  int64 Min() const override {
4895  const int64 num_min = num_->Min();
4896  const int64 num_max = num_->Max();
4897  const int64 denom_min = denom_->Min();
4898  const int64 denom_max = denom_->Max();
4899 
4900  if (denom_min == 0 && denom_max == 0) {
4901  return kint64max; // TODO(user): Check this convention.
4902  }
4903 
4904  if (denom_min >= 0) { // Denominator strictly positive.
4905  DCHECK_GT(denom_max, 0);
4906  const int64 adjusted_denom_min = denom_min == 0 ? 1 : denom_min;
4907  return num_min >= 0 ? num_min / denom_max : num_min / adjusted_denom_min;
4908  } else if (denom_max <= 0) { // Denominator strictly negative.
4909  DCHECK_LT(denom_min, 0);
4910  const int64 adjusted_denom_max = denom_max == 0 ? -1 : denom_max;
4911  return num_max >= 0 ? num_max / adjusted_denom_max : num_max / denom_min;
4912  } else { // Denominator across 0.
4913  return std::min(num_min, -num_max);
4914  }
4915  }
4916 
4917  int64 Max() const override {
4918  const int64 num_min = num_->Min();
4919  const int64 num_max = num_->Max();
4920  const int64 denom_min = denom_->Min();
4921  const int64 denom_max = denom_->Max();
4922 
4923  if (denom_min == 0 && denom_max == 0) {
4924  return kint64min; // TODO(user): Check this convention.
4925  }
4926 
4927  if (denom_min >= 0) { // Denominator strictly positive.
4928  DCHECK_GT(denom_max, 0);
4929  const int64 adjusted_denom_min = denom_min == 0 ? 1 : denom_min;
4930  return num_max >= 0 ? num_max / adjusted_denom_min : num_max / denom_max;
4931  } else if (denom_max <= 0) { // Denominator strictly negative.
4932  DCHECK_LT(denom_min, 0);
4933  const int64 adjusted_denom_max = denom_max == 0 ? -1 : denom_max;
4934  return num_min >= 0 ? num_min / denom_min
4935  : -num_min / -adjusted_denom_max;
4936  } else { // Denominator across 0.
4937  return std::max(num_max, -num_min);
4938  }
4939  }
4940 
4941  void AdjustDenominator() {
4942  if (denom_->Min() == 0) {
4943  denom_->SetMin(1);
4944  } else if (denom_->Max() == 0) {
4945  denom_->SetMax(-1);
4946  }
4947  }
4948 
4949  // m > 0.
4950  static void SetPosMin(IntExpr* const num, IntExpr* const denom, int64 m) {
4951  DCHECK_GT(m, 0);
4952  const int64 num_min = num->Min();
4953  const int64 num_max = num->Max();
4954  const int64 denom_min = denom->Min();
4955  const int64 denom_max = denom->Max();
4956  DCHECK_NE(denom_min, 0);
4957  DCHECK_NE(denom_max, 0);
4958  if (denom_min > 0) { // Denominator strictly positive.
4959  num->SetMin(m * denom_min);
4960  denom->SetMax(num_max / m);
4961  } else if (denom_max < 0) { // Denominator strictly negative.
4962  num->SetMax(m * denom_max);
4963  denom->SetMin(num_min / m);
4964  } else { // Denominator across 0.
4965  if (num_min >= 0) {
4966  num->SetMin(m);
4967  denom->SetRange(1, num_max / m);
4968  } else if (num_max <= 0) {
4969  num->SetMax(-m);
4970  denom->SetRange(num_min / m, -1);
4971  } else {
4972  if (m > -num_min) { // Denominator is forced positive.
4973  num->SetMin(m);
4974  denom->SetRange(1, num_max / m);
4975  } else if (m > num_max) { // Denominator is forced negative.
4976  num->SetMax(-m);
4977  denom->SetRange(num_min / m, -1);
4978  } else {
4979  denom->SetRange(num_min / m, num_max / m);
4980  }
4981  }
4982  }
4983  }
4984 
4985  // m >= 0.
4986  static void SetPosMax(IntExpr* const num, IntExpr* const denom, int64 m) {
4987  DCHECK_GE(m, 0);
4988  const int64 num_min = num->Min();
4989  const int64 num_max = num->Max();
4990  const int64 denom_min = denom->Min();
4991  const int64 denom_max = denom->Max();
4992  DCHECK_NE(denom_min, 0);
4993  DCHECK_NE(denom_max, 0);
4994  if (denom_min > 0) { // Denominator strictly positive.
4995  num->SetMax((m + 1) * denom_max - 1);
4996  denom->SetMin((num_min / (m + 1)) + 1);
4997  } else if (denom_max < 0) {
4998  num->SetMin((m + 1) * denom_min + 1);
4999  denom->SetMax(num_max / (m + 1) - 1);
5000  } else if (num_min > (m + 1) * denom_max - 1) {
5001  denom->SetMax(-1);
5002  } else if (num_max < (m + 1) * denom_min + 1) {
5003  denom->SetMin(1);
5004  }
5005  }
5006 
5007  void SetMin(int64 m) override {
5008  AdjustDenominator();
5009  if (m > 0) {
5010  SetPosMin(num_, denom_, m);
5011  } else {
5012  SetPosMax(opp_num_, denom_, -m);
5013  }
5014  }
5015 
5016  void SetMax(int64 m) override {
5017  AdjustDenominator();
5018  if (m >= 0) {
5019  SetPosMax(num_, denom_, m);
5020  } else {
5021  SetPosMin(opp_num_, denom_, -m);
5022  }
5023  }
5024 
5025  std::string name() const override {
5026  return absl::StrFormat("(%s div %s)", num_->name(), denom_->name());
5027  }
5028  std::string DebugString() const override {
5029  return absl::StrFormat("(%s div %s)", num_->DebugString(),
5030  denom_->DebugString());
5031  }
5032  void WhenRange(Demon* d) override {
5033  num_->WhenRange(d);
5034  denom_->WhenRange(d);
5035  }
5036 
5037  void Accept(ModelVisitor* const visitor) const override {
5038  visitor->BeginVisitIntegerExpression(ModelVisitor::kDivide, this);
5039  visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, num_);
5040  visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
5041  denom_);
5042  visitor->EndVisitIntegerExpression(ModelVisitor::kDivide, this);
5043  }
5044 
5045  private:
5046  IntExpr* const num_;
5047  IntExpr* const denom_;
5048  IntExpr* const opp_num_;
5049 };
5050 
5051 // ----- IntAbs And IntAbsConstraint ------
5052 
5053 class IntAbsConstraint : public CastConstraint {
5054  public:
5055  IntAbsConstraint(Solver* const s, IntVar* const sub, IntVar* const target)
5056  : CastConstraint(s, target), sub_(sub) {}
5057 
5058  ~IntAbsConstraint() override {}
5059 
5060  void Post() override {
5061  Demon* const sub_demon = MakeConstraintDemon0(
5062  solver(), this, &IntAbsConstraint::PropagateSub, "PropagateSub");
5063  sub_->WhenRange(sub_demon);
5064  Demon* const target_demon = MakeConstraintDemon0(
5065  solver(), this, &IntAbsConstraint::PropagateTarget, "PropagateTarget");
5066  target_var_->WhenRange(target_demon);
5067  }
5068 
5069  void InitialPropagate() override {
5070  PropagateSub();
5071  PropagateTarget();
5072  }
5073 
5074  void PropagateSub() {
5075  const int64 smin = sub_->Min();
5076  const int64 smax = sub_->Max();
5077  if (smax <= 0) {
5078  target_var_->SetRange(-smax, -smin);
5079  } else if (smin >= 0) {
5080  target_var_->SetRange(smin, smax);
5081  } else {
5082  target_var_->SetRange(0, std::max(-smin, smax));
5083  }
5084  }
5085 
5086  void PropagateTarget() {
5087  const int64 target_max = target_var_->Max();
5088  sub_->SetRange(-target_max, target_max);
5089  const int64 target_min = target_var_->Min();
5090  if (target_min > 0) {
5091  if (sub_->Min() > -target_min) {
5092  sub_->SetMin(target_min);
5093  } else if (sub_->Max() < target_min) {
5094  sub_->SetMax(-target_min);
5095  }
5096  }
5097  }
5098 
5099  std::string DebugString() const override {
5100  return absl::StrFormat("IntAbsConstraint(%s, %s)", sub_->DebugString(),
5101  target_var_->DebugString());
5102  }
5103 
5104  void Accept(ModelVisitor* const visitor) const override {
5105  visitor->BeginVisitConstraint(ModelVisitor::kAbsEqual, this);
5106  visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5107  sub_);
5108  visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
5109  target_var_);
5110  visitor->EndVisitConstraint(ModelVisitor::kAbsEqual, this);
5111  }
5112 
5113  private:
5114  IntVar* const sub_;
5115 };
5116 
5117 class IntAbs : public BaseIntExpr {
5118  public:
5119  IntAbs(Solver* const s, IntExpr* const e) : BaseIntExpr(s), expr_(e) {}
5120 
5121  ~IntAbs() override {}
5122 
5123  int64 Min() const override {
5124  int64 emin = 0;
5125  int64 emax = 0;
5126  expr_->Range(&emin, &emax);
5127  if (emin >= 0) {
5128  return emin;
5129  }
5130  if (emax <= 0) {
5131  return -emax;
5132  }
5133  return 0;
5134  }
5135 
5136  void SetMin(int64 m) override {
5137  if (m > 0) {
5138  int64 emin = 0;
5139  int64 emax = 0;
5140  expr_->Range(&emin, &emax);
5141  if (emin > -m) {
5142  expr_->SetMin(m);
5143  } else if (emax < m) {
5144  expr_->SetMax(-m);
5145  }
5146  }
5147  }
5148 
5149  int64 Max() const override {
5150  int64 emin = 0;
5151  int64 emax = 0;
5152  expr_->Range(&emin, &emax);
5153  return std::max(-emin, emax);
5154  }
5155 
5156  void SetMax(int64 m) override { expr_->SetRange(-m, m); }
5157 
5158  void SetRange(int64 mi, int64 ma) override {
5159  expr_->SetRange(-ma, ma);
5160  if (mi > 0) {
5161  int64 emin = 0;
5162  int64 emax = 0;
5163  expr_->Range(&emin, &emax);
5164  if (emin > -mi) {
5165  expr_->SetMin(mi);
5166  } else if (emax < mi) {
5167  expr_->SetMax(-mi);
5168  }
5169  }
5170  }
5171 
5172  void Range(int64* mi, int64* ma) override {
5173  int64 emin = 0;
5174  int64 emax = 0;
5175  expr_->Range(&emin, &emax);
5176  if (emin >= 0) {
5177  *mi = emin;
5178  *ma = emax;
5179  } else if (emax <= 0) {
5180  *mi = -emax;
5181  *ma = -emin;
5182  } else {
5183  *mi = 0;
5184  *ma = std::max(-emin, emax);
5185  }
5186  }
5187 
5188  bool Bound() const override { return expr_->Bound(); }
5189 
5190  void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5191 
5192  std::string name() const override {
5193  return absl::StrFormat("IntAbs(%s)", expr_->name());
5194  }
5195 
5196  std::string DebugString() const override {
5197  return absl::StrFormat("IntAbs(%s)", expr_->DebugString());
5198  }
5199 
5200  void Accept(ModelVisitor* const visitor) const override {
5201  visitor->BeginVisitIntegerExpression(ModelVisitor::kAbs, this);
5202  visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5203  expr_);
5204  visitor->EndVisitIntegerExpression(ModelVisitor::kAbs, this);
5205  }
5206 
5207  IntVar* CastToVar() override {
5208  int64 min_value = 0;
5209  int64 max_value = 0;
5210  Range(&min_value, &max_value);
5211  Solver* const s = solver();
5212  const std::string name = absl::StrFormat("AbsVar(%s)", expr_->name());
5213  IntVar* const target = s->MakeIntVar(min_value, max_value, name);
5214  CastConstraint* const ct =
5215  s->RevAlloc(new IntAbsConstraint(s, expr_->Var(), target));
5216  s->AddCastConstraint(ct, target, this);
5217  return target;
5218  }
5219 
5220  private:
5221  IntExpr* const expr_;
5222 };
5223 
5224 // ----- Square -----
5225 
5226 // TODO(user): shouldn't we compare to kint32max^2 instead of kint64max?
5227 class IntSquare : public BaseIntExpr {
5228  public:
5229  IntSquare(Solver* const s, IntExpr* const e) : BaseIntExpr(s), expr_(e) {}
5230  ~IntSquare() override {}
5231 
5232  int64 Min() const override {
5233  const int64 emin = expr_->Min();
5234  if (emin >= 0) {
5235  return emin >= kint32max ? kint64max : emin * emin;
5236  }
5237  const int64 emax = expr_->Max();
5238  if (emax < 0) {
5239  return emax <= -kint32max ? kint64max : emax * emax;
5240  }
5241  return 0LL;
5242  }
5243  void SetMin(int64 m) override {
5244  if (m <= 0) {
5245  return;
5246  }
5247  // TODO(user): What happens if m is kint64max?
5248  const int64 emin = expr_->Min();
5249  const int64 emax = expr_->Max();
5250  const int64 root = static_cast<int64>(ceil(sqrt(static_cast<double>(m))));
5251  if (emin >= 0) {
5252  expr_->SetMin(root);
5253  } else if (emax <= 0) {
5254  expr_->SetMax(-root);
5255  } else if (expr_->IsVar()) {
5256  reinterpret_cast<IntVar*>(expr_)->RemoveInterval(-root + 1, root - 1);
5257  }
5258  }
5259  int64 Max() const override {
5260  const int64 emax = expr_->Max();
5261  const int64 emin = expr_->Min();
5262  if (emax >= kint32max || emin <= -kint32max) {
5263  return kint64max;
5264  }
5265  return std::max(emin * emin, emax * emax);
5266  }
5267  void SetMax(int64 m) override {
5268  if (m < 0) {
5269  solver()->Fail();
5270  }
5271  if (m == kint64max) {
5272  return;
5273  }
5274  const int64 root = static_cast<int64>(floor(sqrt(static_cast<double>(m))));
5275  expr_->SetRange(-root, root);
5276  }
5277  bool Bound() const override { return expr_->Bound(); }
5278  void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5279  std::string name() const override {
5280  return absl::StrFormat("IntSquare(%s)", expr_->name());
5281  }
5282  std::string DebugString() const override {
5283  return absl::StrFormat("IntSquare(%s)", expr_->DebugString());
5284  }
5285 
5286  void Accept(ModelVisitor* const visitor) const override {
5287  visitor->BeginVisitIntegerExpression(ModelVisitor::kSquare, this);
5288  visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5289  expr_);
5290  visitor->EndVisitIntegerExpression(ModelVisitor::kSquare, this);
5291  }
5292 
5293  IntExpr* expr() const { return expr_; }
5294 
5295  protected:
5296  IntExpr* const expr_;
5297 };
5298 
5299 class PosIntSquare : public IntSquare {
5300  public:
5301  PosIntSquare(Solver* const s, IntExpr* const e) : IntSquare(s, e) {}
5302  ~PosIntSquare() override {}
5303 
5304  int64 Min() const override {
5305  const int64 emin = expr_->Min();
5306  return emin >= kint32max ? kint64max : emin * emin;
5307  }
5308  void SetMin(int64 m) override {
5309  if (m <= 0) {
5310  return;
5311  }
5312  const int64 root = static_cast<int64>(ceil(sqrt(static_cast<double>(m))));
5313  expr_->SetMin(root);
5314  }
5315  int64 Max() const override {
5316  const int64 emax = expr_->Max();
5317  return emax >= kint32max ? kint64max : emax * emax;
5318  }
5319  void SetMax(int64 m) override {
5320  if (m < 0) {
5321  solver()->Fail();
5322  }
5323  if (m == kint64max) {
5324  return;
5325  }
5326  const int64 root = static_cast<int64>(floor(sqrt(static_cast<double>(m))));
5327  expr_->SetMax(root);
5328  }
5329 };
5330 
5331 // ----- EvenPower -----
5332 
5333 int64 IntPower(int64 value, int64 power) {
5334  int64 result = value;
5335  // TODO(user): Speed that up.
5336  for (int i = 1; i < power; ++i) {
5337  result *= value;
5338  }
5339  return result;
5340 }
5341 
5342 int64 OverflowLimit(int64 power) {
5343  return static_cast<int64>(
5344  floor(exp(log(static_cast<double>(kint64max)) / power)));
5345 }
5346 
5347 class BasePower : public BaseIntExpr {
5348  public:
5349  BasePower(Solver* const s, IntExpr* const e, int64 n)
5350  : BaseIntExpr(s), expr_(e), pow_(n), limit_(OverflowLimit(n)) {
5351  CHECK_GT(n, 0);
5352  }
5353 
5354  ~BasePower() override {}
5355 
5356  bool Bound() const override { return expr_->Bound(); }
5357 
5358  IntExpr* expr() const { return expr_; }
5359 
5360  int64 exponant() const { return pow_; }
5361 
5362  void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5363 
5364  std::string name() const override {
5365  return absl::StrFormat("IntPower(%s, %d)", expr_->name(), pow_);
5366  }
5367 
5368  std::string DebugString() const override {
5369  return absl::StrFormat("IntPower(%s, %d)", expr_->DebugString(), pow_);
5370  }
5371 
5372  void Accept(ModelVisitor* const visitor) const override {
5373  visitor->BeginVisitIntegerExpression(ModelVisitor::kPower, this);
5374  visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5375  expr_);
5376  visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, pow_);
5377  visitor->EndVisitIntegerExpression(ModelVisitor::kPower, this);
5378  }
5379 
5380  protected:
5381  int64 Pown(int64 value) const {
5382  if (value >= limit_) {
5383  return kint64max;
5384  }
5385  if (value <= -limit_) {
5386  if (pow_ % 2 == 0) {
5387  return kint64max;
5388  } else {
5389  return kint64min;
5390  }
5391  }
5392  return IntPower(value, pow_);
5393  }
5394 
5395  int64 SqrnDown(int64 value) const {
5396  if (value == kint64min) {
5397  return kint64min;
5398  }
5399  if (value == kint64max) {
5400  return kint64max;
5401  }
5402  int64 res = 0;
5403  const double d_value = static_cast<double>(value);
5404  if (value >= 0) {
5405  const double sq = exp(log(d_value) / pow_);
5406  res = static_cast<int64>(floor(sq));
5407  } else {
5408  CHECK_EQ(1, pow_ % 2);
5409  const double sq = exp(log(-d_value) / pow_);
5410  res = -static_cast<int64>(ceil(sq));
5411  }
5412  const int64 pow_res = Pown(res + 1);
5413  if (pow_res <= value) {
5414  return res + 1;
5415  } else {
5416  return res;
5417  }
5418  }
5419 
5420  int64 SqrnUp(int64 value) const {
5421  if (value == kint64min) {
5422  return kint64min;
5423  }
5424  if (value == kint64max) {
5425  return kint64max;
5426  }
5427  int64 res = 0;
5428  const double d_value = static_cast<double>(value);
5429  if (value >= 0) {
5430  const double sq = exp(log(d_value) / pow_);
5431  res = static_cast<int64>(ceil(sq));
5432  } else {
5433  CHECK_EQ(1, pow_ % 2);
5434  const double sq = exp(log(-d_value) / pow_);
5435  res = -static_cast<int64>(floor(sq));
5436  }
5437  const int64 pow_res = Pown(res - 1);
5438  if (pow_res >= value) {
5439  return res - 1;
5440  } else {
5441  return res;
5442  }
5443  }
5444 
5445  IntExpr* const expr_;
5446  const int64 pow_;
5447  const int64 limit_;
5448 };
5449 
5450 class IntEvenPower : public BasePower {
5451  public:
5452  IntEvenPower(Solver* const s, IntExpr* const e, int64 n)
5453  : BasePower(s, e, n) {
5454  CHECK_EQ(0, n % 2);
5455  }
5456 
5457  ~IntEvenPower() override {}
5458 
5459  int64 Min() const override {
5460  int64 emin = 0;
5461  int64 emax = 0;
5462  expr_->Range(&emin, &emax);
5463  if (emin >= 0) {
5464  return Pown(emin);
5465  }
5466  if (emax < 0) {
5467  return Pown(emax);
5468  }
5469  return 0LL;
5470  }
5471  void SetMin(int64 m) override {
5472  if (m <= 0) {
5473  return;
5474  }
5475  int64 emin = 0;
5476  int64 emax = 0;
5477  expr_->Range(&emin, &emax);
5478  const int64 root = SqrnUp(m);
5479  if (emin > -root) {
5480  expr_->SetMin(root);
5481  } else if (emax < root) {
5482  expr_->SetMax(-root);
5483  } else if (expr_->IsVar()) {
5484  reinterpret_cast<IntVar*>(expr_)->RemoveInterval(-root + 1, root - 1);
5485  }
5486  }
5487 
5488  int64 Max() const override {
5489  return std::max(Pown(expr_->Min()), Pown(expr_->Max()));
5490  }
5491 
5492  void SetMax(int64 m) override {
5493  if (m < 0) {
5494  solver()->Fail();
5495  }
5496  if (m == kint64max) {
5497  return;
5498  }
5499  const int64 root = SqrnDown(m);
5500  expr_->SetRange(-root, root);
5501  }
5502 };
5503 
5504 class PosIntEvenPower : public BasePower {
5505  public:
5506  PosIntEvenPower(Solver* const s, IntExpr* const e, int64 pow)
5507  : BasePower(s, e, pow) {
5508  CHECK_EQ(0, pow % 2);
5509  }
5510 
5511  ~PosIntEvenPower() override {}
5512 
5513  int64 Min() const override { return Pown(expr_->Min()); }
5514 
5515  void SetMin(int64 m) override {
5516  if (m <= 0) {
5517  return;
5518  }
5519  expr_->SetMin(SqrnUp(m));
5520  }
5521  int64 Max() const override { return Pown(expr_->Max()); }
5522 
5523  void SetMax(int64 m) override {
5524  if (m < 0) {
5525  solver()->Fail();
5526  }
5527  if (m == kint64max) {
5528  return;
5529  }
5530  expr_->SetMax(SqrnDown(m));
5531  }
5532 };
5533 
5534 class IntOddPower : public BasePower {
5535  public:
5536  IntOddPower(Solver* const s, IntExpr* const e, int64 n) : BasePower(s, e, n) {
5537  CHECK_EQ(1, n % 2);
5538  }
5539 
5540  ~IntOddPower() override {}
5541 
5542  int64 Min() const override { return Pown(expr_->Min()); }
5543 
5544  void SetMin(int64 m) override { expr_->SetMin(SqrnUp(m)); }
5545 
5546  int64 Max() const override { return Pown(expr_->Max()); }
5547 
5548  void SetMax(int64 m) override { expr_->SetMax(SqrnDown(m)); }
5549 };
5550 
5551 // ----- Min(expr, expr) -----
5552 
5553 class MinIntExpr : public BaseIntExpr {
5554  public:
5555  MinIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
5556  : BaseIntExpr(s), left_(l), right_(r) {}
5557  ~MinIntExpr() override {}
5558  int64 Min() const override {
5559  const int64 lmin = left_->Min();
5560  const int64 rmin = right_->Min();
5561  return std::min(lmin, rmin);
5562  }
5563  void SetMin(int64 m) override {
5564  left_->SetMin(m);
5565  right_->SetMin(m);
5566  }
5567  int64 Max() const override {
5568  const int64 lmax = left_->Max();
5569  const int64 rmax = right_->Max();
5570  return std::min(lmax, rmax);
5571  }
5572  void SetMax(int64 m) override {
5573  if (left_->Min() > m) {
5574  right_->SetMax(m);
5575  }
5576  if (right_->Min() > m) {
5577  left_->SetMax(m);
5578  }
5579  }
5580  std::string name() const override {
5581  return absl::StrFormat("MinIntExpr(%s, %s)", left_->name(), right_->name());
5582  }
5583  std::string DebugString() const override {
5584  return absl::StrFormat("MinIntExpr(%s, %s)", left_->DebugString(),
5585  right_->DebugString());
5586  }
5587  void WhenRange(Demon* d) override {
5588  left_->WhenRange(d);
5589  right_->WhenRange(d);
5590  }
5591 
5592  void Accept(ModelVisitor* const visitor) const override {
5593  visitor->BeginVisitIntegerExpression(ModelVisitor::kMin, this);
5594  visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
5595  visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
5596  right_);
5597  visitor->EndVisitIntegerExpression(ModelVisitor::kMin, this);
5598  }
5599 
5600  private:
5601  IntExpr* const left_;
5602  IntExpr* const right_;
5603 };
5604 
5605 // ----- Min(expr, constant) -----
5606 
5607 class MinCstIntExpr : public BaseIntExpr {
5608  public:
5609  MinCstIntExpr(Solver* const s, IntExpr* const e, int64 v)
5610  : BaseIntExpr(s), expr_(e), value_(v) {}
5611 
5612  ~MinCstIntExpr() override {}
5613 
5614  int64 Min() const override { return std::min(expr_->Min(), value_); }
5615 
5616  void SetMin(int64 m) override {
5617  if (m > value_) {
5618  solver()->Fail();
5619  }
5620  expr_->SetMin(m);
5621  }
5622 
5623  int64 Max() const override { return std::min(expr_->Max(), value_); }
5624 
5625  void SetMax(int64 m) override {
5626  if (value_ > m) {
5627  expr_->SetMax(m);
5628  }
5629  }
5630 
5631  bool Bound() const override {
5632  return (expr_->Bound() || expr_->Min() >= value_);
5633  }
5634 
5635  std::string name() const override {
5636  return absl::StrFormat("MinCstIntExpr(%s, %d)", expr_->name(), value_);
5637  }
5638 
5639  std::string DebugString() const override {
5640  return absl::StrFormat("MinCstIntExpr(%s, %d)", expr_->DebugString(),
5641  value_);
5642  }
5643 
5644  void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5645 
5646  void Accept(ModelVisitor* const visitor) const override {
5647  visitor->BeginVisitIntegerExpression(ModelVisitor::kMin, this);
5648  visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5649  expr_);
5650  visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
5651  visitor->EndVisitIntegerExpression(ModelVisitor::kMin, this);
5652  }
5653 
5654  private:
5655  IntExpr* const expr_;
5656  const int64 value_;
5657 };
5658 
5659 // ----- Max(expr, expr) -----
5660 
5661 class MaxIntExpr : public BaseIntExpr {
5662  public:
5663  MaxIntExpr(Solver* const s, IntExpr* const l, IntExpr* const r)
5664  : BaseIntExpr(s), left_(l), right_(r) {}
5665 
5666  ~MaxIntExpr() override {}
5667 
5668  int64 Min() const override { return std::max(left_->Min(), right_->Min()); }
5669 
5670  void SetMin(int64 m) override {
5671  if (left_->Max() < m) {
5672  right_->SetMin(m);
5673  } else {
5674  if (right_->Max() < m) {
5675  left_->SetMin(m);
5676  }
5677  }
5678  }
5679 
5680  int64 Max() const override { return std::max(left_->Max(), right_->Max()); }
5681 
5682  void SetMax(int64 m) override {
5683  left_->SetMax(m);
5684  right_->SetMax(m);
5685  }
5686 
5687  std::string name() const override {
5688  return absl::StrFormat("MaxIntExpr(%s, %s)", left_->name(), right_->name());
5689  }
5690 
5691  std::string DebugString() const override {
5692  return absl::StrFormat("MaxIntExpr(%s, %s)", left_->DebugString(),
5693  right_->DebugString());
5694  }
5695 
5696  void WhenRange(Demon* d) override {
5697  left_->WhenRange(d);
5698  right_->WhenRange(d);
5699  }
5700 
5701  void Accept(ModelVisitor* const visitor) const override {
5702  visitor->BeginVisitIntegerExpression(ModelVisitor::kMax, this);
5703  visitor->VisitIntegerExpressionArgument(ModelVisitor::kLeftArgument, left_);
5704  visitor->VisitIntegerExpressionArgument(ModelVisitor::kRightArgument,
5705  right_);
5706  visitor->EndVisitIntegerExpression(ModelVisitor::kMax, this);
5707  }
5708 
5709  private:
5710  IntExpr* const left_;
5711  IntExpr* const right_;
5712 };
5713 
5714 // ----- Max(expr, constant) -----
5715 
5716 class MaxCstIntExpr : public BaseIntExpr {
5717  public:
5718  MaxCstIntExpr(Solver* const s, IntExpr* const e, int64 v)
5719  : BaseIntExpr(s), expr_(e), value_(v) {}
5720 
5721  ~MaxCstIntExpr() override {}
5722 
5723  int64 Min() const override { return std::max(expr_->Min(), value_); }
5724 
5725  void SetMin(int64 m) override {
5726  if (value_ < m) {
5727  expr_->SetMin(m);
5728  }
5729  }
5730 
5731  int64 Max() const override { return std::max(expr_->Max(), value_); }
5732 
5733  void SetMax(int64 m) override {
5734  if (m < value_) {
5735  solver()->Fail();
5736  }
5737  expr_->SetMax(m);
5738  }
5739 
5740  bool Bound() const override {
5741  return (expr_->Bound() || expr_->Max() <= value_);
5742  }
5743 
5744  std::string name() const override {
5745  return absl::StrFormat("MaxCstIntExpr(%s, %d)", expr_->name(), value_);
5746  }
5747 
5748  std::string DebugString() const override {
5749  return absl::StrFormat("MaxCstIntExpr(%s, %d)", expr_->DebugString(),
5750  value_);
5751  }
5752 
5753  void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5754 
5755  void Accept(ModelVisitor* const visitor) const override {
5756  visitor->BeginVisitIntegerExpression(ModelVisitor::kMax, this);
5757  visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5758  expr_);
5759  visitor->VisitIntegerArgument(ModelVisitor::kValueArgument, value_);
5760  visitor->EndVisitIntegerExpression(ModelVisitor::kMax, this);
5761  }
5762 
5763  private:
5764  IntExpr* const expr_;
5765  const int64 value_;
5766 };
5767 
5768 // ----- Convex Piecewise -----
5769 
5770 // This class is a very simple convex piecewise linear function. The
5771 // argument of the function is the expression. Between early_date and
5772 // late_date, the value of the function is 0. Before early date, it
5773 // is affine and the cost is early_cost * (early_date - x). After
5774 // late_date, the cost is late_cost * (x - late_date).
5775 
5776 class SimpleConvexPiecewiseExpr : public BaseIntExpr {
5777  public:
5778  SimpleConvexPiecewiseExpr(Solver* const s, IntExpr* const e, int64 ec,
5779  int64 ed, int64 ld, int64 lc)
5780  : BaseIntExpr(s),
5781  expr_(e),
5782  early_cost_(ec),
5783  early_date_(ec == 0 ? kint64min : ed),
5784  late_date_(lc == 0 ? kint64max : ld),
5785  late_cost_(lc) {
5786  DCHECK_GE(ec, int64{0});
5787  DCHECK_GE(lc, int64{0});
5788  DCHECK_GE(ld, ed);
5789 
5790  // If the penalty is 0, we can push the "confort zone or zone
5791  // of no cost towards infinity.
5792  }
5793 
5794  ~SimpleConvexPiecewiseExpr() override {}
5795 
5796  int64 Min() const override {
5797  const int64 vmin = expr_->Min();
5798  const int64 vmax = expr_->Max();
5799  if (vmin >= late_date_) {
5800  return (vmin - late_date_) * late_cost_;
5801  } else if (vmax <= early_date_) {
5802  return (early_date_ - vmax) * early_cost_;
5803  } else {
5804  return 0LL;
5805  }
5806  }
5807 
5808  void SetMin(int64 m) override {
5809  if (m <= 0) {
5810  return;
5811  }
5812  int64 vmin = 0;
5813  int64 vmax = 0;
5814  expr_->Range(&vmin, &vmax);
5815 
5816  const int64 rb =
5817  (late_cost_ == 0 ? vmax : late_date_ + PosIntDivUp(m, late_cost_) - 1);
5818  const int64 lb =
5819  (early_cost_ == 0 ? vmin
5820  : early_date_ - PosIntDivUp(m, early_cost_) + 1);
5821 
5822  if (expr_->IsVar()) {
5823  expr_->Var()->RemoveInterval(lb, rb);
5824  }
5825  }
5826 
5827  int64 Max() const override {
5828  const int64 vmin = expr_->Min();
5829  const int64 vmax = expr_->Max();
5830  const int64 mr = vmax > late_date_ ? (vmax - late_date_) * late_cost_ : 0;
5831  const int64 ml =
5832  vmin < early_date_ ? (early_date_ - vmin) * early_cost_ : 0;
5833  return std::max(mr, ml);
5834  }
5835 
5836  void SetMax(int64 m) override {
5837  if (m < 0) {
5838  solver()->Fail();
5839  }
5840  if (late_cost_ != 0LL) {
5841  const int64 rb = late_date_ + PosIntDivDown(m, late_cost_);
5842  if (early_cost_ != 0LL) {
5843  const int64 lb = early_date_ - PosIntDivDown(m, early_cost_);
5844  expr_->SetRange(lb, rb);
5845  } else {
5846  expr_->SetMax(rb);
5847  }
5848  } else {
5849  if (early_cost_ != 0LL) {
5850  const int64 lb = early_date_ - PosIntDivDown(m, early_cost_);
5851  expr_->SetMin(lb);
5852  }
5853  }
5854  }
5855 
5856  std::string name() const override {
5857  return absl::StrFormat(
5858  "ConvexPiecewiseExpr(%s, ec = %d, ed = %d, ld = %d, lc = %d)",
5859  expr_->name(), early_cost_, early_date_, late_date_, late_cost_);
5860  }
5861 
5862  std::string DebugString() const override {
5863  return absl::StrFormat(
5864  "ConvexPiecewiseExpr(%s, ec = %d, ed = %d, ld = %d, lc = %d)",
5865  expr_->DebugString(), early_cost_, early_date_, late_date_, late_cost_);
5866  }
5867 
5868  void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5869 
5870  void Accept(ModelVisitor* const visitor) const override {
5871  visitor->BeginVisitIntegerExpression(ModelVisitor::kConvexPiecewise, this);
5872  visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5873  expr_);
5874  visitor->VisitIntegerArgument(ModelVisitor::kEarlyCostArgument,
5875  early_cost_);
5876  visitor->VisitIntegerArgument(ModelVisitor::kEarlyDateArgument,
5877  early_date_);
5878  visitor->VisitIntegerArgument(ModelVisitor::kLateCostArgument, late_cost_);
5879  visitor->VisitIntegerArgument(ModelVisitor::kLateDateArgument, late_date_);
5880  visitor->EndVisitIntegerExpression(ModelVisitor::kConvexPiecewise, this);
5881  }
5882 
5883  private:
5884  IntExpr* const expr_;
5885  const int64 early_cost_;
5886  const int64 early_date_;
5887  const int64 late_date_;
5888  const int64 late_cost_;
5889 };
5890 
5891 // ----- Semi Continuous -----
5892 
5893 class SemiContinuousExpr : public BaseIntExpr {
5894  public:
5895  SemiContinuousExpr(Solver* const s, IntExpr* const e, int64 fixed_charge,
5896  int64 step)
5897  : BaseIntExpr(s), expr_(e), fixed_charge_(fixed_charge), step_(step) {
5898  DCHECK_GE(fixed_charge, int64{0});
5899  DCHECK_GT(step, int64{0});
5900  }
5901 
5902  ~SemiContinuousExpr() override {}
5903 
5904  int64 Value(int64 x) const {
5905  if (x <= 0) {
5906  return 0;
5907  } else {
5908  return CapAdd(fixed_charge_, CapProd(x, step_));
5909  }
5910  }
5911 
5912  int64 Min() const override { return Value(expr_->Min()); }
5913 
5914  void SetMin(int64 m) override {
5915  if (m >= CapAdd(fixed_charge_, step_)) {
5916  const int64 y = PosIntDivUp(CapSub(m, fixed_charge_), step_);
5917  expr_->SetMin(y);
5918  } else if (m > 0) {
5919  expr_->SetMin(1);
5920  }
5921  }
5922 
5923  int64 Max() const override { return Value(expr_->Max()); }
5924 
5925  void SetMax(int64 m) override {
5926  if (m < 0) {
5927  solver()->Fail();
5928  }
5929  if (m == kint64max) {
5930  return;
5931  }
5932  if (m < CapAdd(fixed_charge_, step_)) {
5933  expr_->SetMax(0);
5934  } else {
5935  const int64 y = PosIntDivDown(CapSub(m, fixed_charge_), step_);
5936  expr_->SetMax(y);
5937  }
5938  }
5939 
5940  std::string name() const override {
5941  return absl::StrFormat("SemiContinuous(%s, fixed_charge = %d, step = %d)",
5942  expr_->name(), fixed_charge_, step_);
5943  }
5944 
5945  std::string DebugString() const override {
5946  return absl::StrFormat("SemiContinuous(%s, fixed_charge = %d, step = %d)",
5947  expr_->DebugString(), fixed_charge_, step_);
5948  }
5949 
5950  void WhenRange(Demon* d) override { expr_->WhenRange(d); }
5951 
5952  void Accept(ModelVisitor* const visitor) const override {
5953  visitor->BeginVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
5954  visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
5955  expr_);
5956  visitor->VisitIntegerArgument(ModelVisitor::kFixedChargeArgument,
5957  fixed_charge_);
5958  visitor->VisitIntegerArgument(ModelVisitor::kStepArgument, step_);
5959  visitor->EndVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
5960  }
5961 
5962  private:
5963  IntExpr* const expr_;
5964  const int64 fixed_charge_;
5965  const int64 step_;
5966 };
5967 
5968 class SemiContinuousStepOneExpr : public BaseIntExpr {
5969  public:
5970  SemiContinuousStepOneExpr(Solver* const s, IntExpr* const e,
5971  int64 fixed_charge)
5972  : BaseIntExpr(s), expr_(e), fixed_charge_(fixed_charge) {
5973  DCHECK_GE(fixed_charge, int64{0});
5974  }
5975 
5976  ~SemiContinuousStepOneExpr() override {}
5977 
5978  int64 Value(int64 x) const {
5979  if (x <= 0) {
5980  return 0;
5981  } else {
5982  return fixed_charge_ + x;
5983  }
5984  }
5985 
5986  int64 Min() const override { return Value(expr_->Min()); }
5987 
5988  void SetMin(int64 m) override {
5989  if (m >= fixed_charge_ + 1) {
5990  expr_->SetMin(m - fixed_charge_);
5991  } else if (m > 0) {
5992  expr_->SetMin(1);
5993  }
5994  }
5995 
5996  int64 Max() const override { return Value(expr_->Max()); }
5997 
5998  void SetMax(int64 m) override {
5999  if (m < 0) {
6000  solver()->Fail();
6001  }
6002  if (m < fixed_charge_ + 1) {
6003  expr_->SetMax(0);
6004  } else {
6005  expr_->SetMax(m - fixed_charge_);
6006  }
6007  }
6008 
6009  std::string name() const override {
6010  return absl::StrFormat("SemiContinuousStepOne(%s, fixed_charge = %d)",
6011  expr_->name(), fixed_charge_);
6012  }
6013 
6014  std::string DebugString() const override {
6015  return absl::StrFormat("SemiContinuousStepOne(%s, fixed_charge = %d)",
6016  expr_->DebugString(), fixed_charge_);
6017  }
6018 
6019  void WhenRange(Demon* d) override { expr_->WhenRange(d); }
6020 
6021  void Accept(ModelVisitor* const visitor) const override {
6022  visitor->BeginVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6023  visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6024  expr_);
6025  visitor->VisitIntegerArgument(ModelVisitor::kFixedChargeArgument,
6026  fixed_charge_);
6027  visitor->VisitIntegerArgument(ModelVisitor::kStepArgument, 1);
6028  visitor->EndVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6029  }
6030 
6031  private:
6032  IntExpr* const expr_;
6033  const int64 fixed_charge_;
6034 };
6035 
6036 class SemiContinuousStepZeroExpr : public BaseIntExpr {
6037  public:
6038  SemiContinuousStepZeroExpr(Solver* const s, IntExpr* const e,
6039  int64 fixed_charge)
6040  : BaseIntExpr(s), expr_(e), fixed_charge_(fixed_charge) {
6041  DCHECK_GT(fixed_charge, int64{0});
6042  }
6043 
6044  ~SemiContinuousStepZeroExpr() override {}
6045 
6046  int64 Value(int64 x) const {
6047  if (x <= 0) {
6048  return 0;
6049  } else {
6050  return fixed_charge_;
6051  }
6052  }
6053 
6054  int64 Min() const override { return Value(expr_->Min()); }
6055 
6056  void SetMin(int64 m) override {
6057  if (m >= fixed_charge_) {
6058  solver()->Fail();
6059  } else if (m > 0) {
6060  expr_->SetMin(1);
6061  }
6062  }
6063 
6064  int64 Max() const override { return Value(expr_->Max()); }
6065 
6066  void SetMax(int64 m) override {
6067  if (m < 0) {
6068  solver()->Fail();
6069  }
6070  if (m < fixed_charge_) {
6071  expr_->SetMax(0);
6072  }
6073  }
6074 
6075  std::string name() const override {
6076  return absl::StrFormat("SemiContinuousStepZero(%s, fixed_charge = %d)",
6077  expr_->name(), fixed_charge_);
6078  }
6079 
6080  std::string DebugString() const override {
6081  return absl::StrFormat("SemiContinuousStepZero(%s, fixed_charge = %d)",
6082  expr_->DebugString(), fixed_charge_);
6083  }
6084 
6085  void WhenRange(Demon* d) override { expr_->WhenRange(d); }
6086 
6087  void Accept(ModelVisitor* const visitor) const override {
6088  visitor->BeginVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6089  visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6090  expr_);
6091  visitor->VisitIntegerArgument(ModelVisitor::kFixedChargeArgument,
6092  fixed_charge_);
6093  visitor->VisitIntegerArgument(ModelVisitor::kStepArgument, 0);
6094  visitor->EndVisitIntegerExpression(ModelVisitor::kSemiContinuous, this);
6095  }
6096 
6097  private:
6098  IntExpr* const expr_;
6099  const int64 fixed_charge_;
6100 };
6101 
6102 // This constraints links an expression and the variable it is casted into
6103 class LinkExprAndVar : public CastConstraint {
6104  public:
6105  LinkExprAndVar(Solver* const s, IntExpr* const expr, IntVar* const var)
6106  : CastConstraint(s, var), expr_(expr) {}
6107 
6108  ~LinkExprAndVar() override {}
6109 
6110  void Post() override {
6111  Solver* const s = solver();
6112  Demon* d = s->MakeConstraintInitialPropagateCallback(this);
6113  expr_->WhenRange(d);
6114  target_var_->WhenRange(d);
6115  }
6116 
6117  void InitialPropagate() override {
6118  expr_->SetRange(target_var_->Min(), target_var_->Max());
6119  int64 l, u;
6120  expr_->Range(&l, &u);
6121  target_var_->SetRange(l, u);
6122  }
6123 
6124  std::string DebugString() const override {
6125  return absl::StrFormat("cast(%s, %s)", expr_->DebugString(),
6126  target_var_->DebugString());
6127  }
6128 
6129  void Accept(ModelVisitor* const visitor) const override {
6130  visitor->BeginVisitConstraint(ModelVisitor::kLinkExprVar, this);
6131  visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6132  expr_);
6133  visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
6134  target_var_);
6135  visitor->EndVisitConstraint(ModelVisitor::kLinkExprVar, this);
6136  }
6137 
6138  private:
6139  IntExpr* const expr_;
6140 };
6141 
6142 // ----- Conditional Expression -----
6143 
6144 class ExprWithEscapeValue : public BaseIntExpr {
6145  public:
6146  ExprWithEscapeValue(Solver* const s, IntVar* const c, IntExpr* const e,
6147  int64 unperformed_value)
6148  : BaseIntExpr(s),
6149  condition_(c),
6150  expression_(e),
6151  unperformed_value_(unperformed_value) {}
6152 
6153  ~ExprWithEscapeValue() override {}
6154 
6155  int64 Min() const override {
6156  if (condition_->Min() == 1) {
6157  return expression_->Min();
6158  } else if (condition_->Max() == 1) {
6159  return std::min(unperformed_value_, expression_->Min());
6160  } else {
6161  return unperformed_value_;
6162  }
6163  }
6164 
6165  void SetMin(int64 m) override {
6166  if (m > unperformed_value_) {
6167  condition_->SetValue(1);
6168  expression_->SetMin(m);
6169  } else if (condition_->Min() == 1) {
6170  expression_->SetMin(m);
6171  } else if (m > expression_->Max()) {
6172  condition_->SetValue(0);
6173  }
6174  }
6175 
6176  int64 Max() const override {
6177  if (condition_->Min() == 1) {
6178  return expression_->Max();
6179  } else if (condition_->Max() == 1) {
6180  return std::max(unperformed_value_, expression_->Max());
6181  } else {
6182  return unperformed_value_;
6183  }
6184  }
6185 
6186  void SetMax(int64 m) override {
6187  if (m < unperformed_value_) {
6188  condition_->SetValue(1);
6189  expression_->SetMax(m);
6190  } else if (condition_->Min() == 1) {
6191  expression_->SetMax(m);
6192  } else if (m < expression_->Min()) {
6193  condition_->SetValue(0);
6194  }
6195  }
6196 
6197  void SetRange(int64 mi, int64 ma) override {
6198  if (ma < unperformed_value_ || mi > unperformed_value_) {
6199  condition_->SetValue(1);
6200  expression_->SetRange(mi, ma);
6201  } else if (condition_->Min() == 1) {
6202  expression_->SetRange(mi, ma);
6203  } else if (ma < expression_->Min() || mi > expression_->Max()) {
6204  condition_->SetValue(0);
6205  }
6206  }
6207 
6208  void SetValue(int64 v) override {
6209  if (v != unperformed_value_) {
6210  condition_->SetValue(1);
6211  expression_->SetValue(v);
6212  } else if (condition_->Min() == 1) {
6213  expression_->SetValue(v);
6214  } else if (v < expression_->Min() || v > expression_->Max()) {
6215  condition_->SetValue(0);
6216  }
6217  }
6218 
6219  bool Bound() const override {
6220  return condition_->Max() == 0 || expression_->Bound();
6221  }
6222 
6223  void WhenRange(Demon* d) override {
6224  expression_->WhenRange(d);
6225  condition_->WhenBound(d);
6226  }
6227 
6228  std::string DebugString() const override {
6229  return absl::StrFormat("ConditionExpr(%s, %s, %d)",
6230  condition_->DebugString(),
6231  expression_->DebugString(), unperformed_value_);
6232  }
6233 
6234  void Accept(ModelVisitor* const visitor) const override {
6235  visitor->BeginVisitIntegerExpression(ModelVisitor::kConditionalExpr, this);
6236  visitor->VisitIntegerExpressionArgument(ModelVisitor::kVariableArgument,
6237  condition_);
6238  visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6239  expression_);
6240  visitor->VisitIntegerArgument(ModelVisitor::kValueArgument,
6241  unperformed_value_);
6242  visitor->EndVisitIntegerExpression(ModelVisitor::kConditionalExpr, this);
6243  }
6244 
6245  private:
6246  IntVar* const condition_;
6247  IntExpr* const expression_;
6248  const int64 unperformed_value_;
6249  DISALLOW_COPY_AND_ASSIGN(ExprWithEscapeValue);
6250 };
6251 
6252 // ----- This is a specialized case when the variable exact type is known -----
6253 class LinkExprAndDomainIntVar : public CastConstraint {
6254  public:
6255  LinkExprAndDomainIntVar(Solver* const s, IntExpr* const expr,
6256  DomainIntVar* const var)
6257  : CastConstraint(s, var),
6258  expr_(expr),
6259  cached_min_(kint64min),
6260  cached_max_(kint64max),
6261  fail_stamp_(uint64_t{0}) {}
6262 
6263  ~LinkExprAndDomainIntVar() override {}
6264 
6265  DomainIntVar* var() const {
6266  return reinterpret_cast<DomainIntVar*>(target_var_);
6267  }
6268 
6269  void Post() override {
6270  Solver* const s = solver();
6271  Demon* const d = s->MakeConstraintInitialPropagateCallback(this);
6272  expr_->WhenRange(d);
6273  Demon* const target_var_demon = MakeConstraintDemon0(
6274  solver(), this, &LinkExprAndDomainIntVar::Propagate, "Propagate");
6275  target_var_->WhenRange(target_var_demon);
6276  }
6277 
6278  void InitialPropagate() override {
6279  expr_->SetRange(var()->min_.Value(), var()->max_.Value());
6280  expr_->Range(&cached_min_, &cached_max_);
6281  var()->DomainIntVar::SetRange(cached_min_, cached_max_);
6282  }
6283 
6284  void Propagate() {
6285  if (var()->min_.Value() > cached_min_ ||
6286  var()->max_.Value() < cached_max_ ||
6287  solver()->fail_stamp() != fail_stamp_) {
6288  InitialPropagate();
6289  fail_stamp_ = solver()->fail_stamp();
6290  }
6291  }
6292 
6293  std::string DebugString() const override {
6294  return absl::StrFormat("cast(%s, %s)", expr_->DebugString(),
6295  target_var_->DebugString());
6296  }
6297 
6298  void Accept(ModelVisitor* const visitor) const override {
6299  visitor->BeginVisitConstraint(ModelVisitor::kLinkExprVar, this);
6300  visitor->VisitIntegerExpressionArgument(ModelVisitor::kExpressionArgument,
6301  expr_);
6302  visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,
6303  target_var_);
6304  visitor->EndVisitConstraint(ModelVisitor::kLinkExprVar, this);
6305  }
6306 
6307  private:
6308  IntExpr* const expr_;
6309  int64 cached_min_;
6310  int64 cached_max_;
6311  uint64 fail_stamp_;
6312 };
6313 } // namespace
6314 
6315 // ----- Misc -----
6316 
6317 IntVarIterator* BooleanVar::MakeHoleIterator(bool reversible) const {
6318  return COND_REV_ALLOC(reversible, new EmptyIterator());
6319 }
6320 IntVarIterator* BooleanVar::MakeDomainIterator(bool reversible) const {
6321  return COND_REV_ALLOC(reversible, new RangeIterator(this));
6322 }
6323 
6324 // ----- API -----
6325 
6327  DCHECK_EQ(DOMAIN_INT_VAR, var->VarType());
6328  DomainIntVar* const dvar = reinterpret_cast<DomainIntVar*>(var);
6329  dvar->CleanInProcess();
6330 }
6331 
6332 Constraint* SetIsEqual(IntVar* const var, const std::vector<int64>& values,
6333  const std::vector<IntVar*>& vars) {
6334  DomainIntVar* const dvar = reinterpret_cast<DomainIntVar*>(var);
6335  CHECK(dvar != nullptr);
6336  return dvar->SetIsEqual(values, vars);
6337 }
6338 
6340  const std::vector<int64>& values,
6341  const std::vector<IntVar*>& vars) {
6342  DomainIntVar* const dvar = reinterpret_cast<DomainIntVar*>(var);
6343  CHECK(dvar != nullptr);
6344  return dvar->SetIsGreaterOrEqual(values, vars);
6345 }
6346 
6348  DCHECK_EQ(BOOLEAN_VAR, var->VarType());
6349  BooleanVar* const boolean_var = reinterpret_cast<BooleanVar*>(var);
6350  boolean_var->RestoreValue();
6351 }
6352 
6353 // ----- API -----
6354 
6355 IntVar* Solver::MakeIntVar(int64 min, int64 max, const std::string& name) {
6356  if (min == max) {
6357  return MakeIntConst(min, name);
6358  }
6359  if (min == 0 && max == 1) {
6360  return RegisterIntVar(RevAlloc(new ConcreteBooleanVar(this, name)));
6361  } else if (CapSub(max, min) == 1) {
6362  const std::string inner_name = "inner_" + name;
6363  return RegisterIntVar(
6364  MakeSum(RevAlloc(new ConcreteBooleanVar(this, inner_name)), min)
6365  ->VarWithName(name));
6366  } else {
6367  return RegisterIntVar(RevAlloc(new DomainIntVar(this, min, max, name)));
6368  }
6369 }
6370 
6371 IntVar* Solver::MakeIntVar(int64 min, int64 max) {
6372  return MakeIntVar(min, max, "");
6373 }
6374 
6375 IntVar* Solver::MakeBoolVar(const std::string& name) {
6376  return RegisterIntVar(RevAlloc(new ConcreteBooleanVar(this, name)));
6377 }
6378 
6379 IntVar* Solver::MakeBoolVar() {
6380  return RegisterIntVar(RevAlloc(new ConcreteBooleanVar(this, "")));
6381 }
6382 
6383 IntVar* Solver::MakeIntVar(const std::vector<int64>& values,
6384  const std::string& name) {
6385  DCHECK(!values.empty());
6386  // Fast-track the case where we have a single value.
6387  if (values.size() == 1) return MakeIntConst(values[0], name);
6388  // Sort and remove duplicates.
6389  std::vector<int64> unique_sorted_values = values;
6390  gtl::STLSortAndRemoveDuplicates(&unique_sorted_values);
6391  // Case when we have a single value, after clean-up.
6392  if (unique_sorted_values.size() == 1) return MakeIntConst(values[0], name);
6393  // Case when the values are a dense interval of integers.
6394  if (unique_sorted_values.size() ==
6395  unique_sorted_values.back() - unique_sorted_values.front() + 1) {
6396  return MakeIntVar(unique_sorted_values.front(), unique_sorted_values.back(),
6397  name);
6398  }
6399  // Compute the GCD: if it's not 1, we can express the variable's domain as
6400  // the product of the GCD and of a domain with smaller values.
6401  int64 gcd = 0;
6402  for (const int64 v : unique_sorted_values) {
6403  if (gcd == 0) {
6404  gcd = std::abs(v);
6405  } else {
6406  gcd = MathUtil::GCD64(gcd, std::abs(v)); // Supports v==0.
6407  }
6408  if (gcd == 1) {
6409  // If it's 1, though, we can't do anything special, so we
6410  // immediately return a new DomainIntVar.
6411  return RegisterIntVar(
6412  RevAlloc(new DomainIntVar(this, unique_sorted_values, name)));
6413  }
6414  }
6415  DCHECK_GT(gcd, 1);
6416  for (int64& v : unique_sorted_values) {
6417  DCHECK_EQ(0, v % gcd);
6418  v /= gcd;
6419  }
6420  const std::string new_name = name.empty() ? "" : "inner_" + name;
6421  // Catch the case where the divided values are a dense set of integers.
6422  IntVar* inner_intvar = nullptr;
6423  if (unique_sorted_values.size() ==
6424  unique_sorted_values.back() - unique_sorted_values.front() + 1) {
6425  inner_intvar = MakeIntVar(unique_sorted_values.front(),
6426  unique_sorted_values.back(), new_name);
6427  } else {
6428  inner_intvar = RegisterIntVar(
6429  RevAlloc(new DomainIntVar(this, unique_sorted_values, new_name)));
6430  }
6431  return MakeProd(inner_intvar, gcd)->Var();
6432 }
6433 
6434 IntVar* Solver::MakeIntVar(const std::vector<int64>& values) {
6435  return MakeIntVar(values, "");
6436 }
6437 
6438 IntVar* Solver::MakeIntVar(const std::vector<int>& values,
6439  const std::string& name) {
6440  return MakeIntVar(ToInt64Vector(values), name);
6441 }
6442 
6443 IntVar* Solver::MakeIntVar(const std::vector<int>& values) {
6444  return MakeIntVar(values, "");
6445 }
6446 
6447 IntVar* Solver::MakeIntConst(int64 val, const std::string& name) {
6448  // If IntConst is going to be named after its creation,
6449  // cp_share_int_consts should be set to false otherwise names can potentially
6450  // be overwritten.
6451  if (absl::GetFlag(FLAGS_cp_share_int_consts) && name.empty() &&
6452  val >= MIN_CACHED_INT_CONST && val <= MAX_CACHED_INT_CONST) {
6453  return cached_constants_[val - MIN_CACHED_INT_CONST];
6454  }
6455  return RevAlloc(new IntConst(this, val, name));
6456 }
6457 
6458 IntVar* Solver::MakeIntConst(int64 val) { return MakeIntConst(val, ""); }
6459 
6460 // ----- Int Var and associated methods -----
6461 
6462 namespace {
6463 std::string IndexedName(const std::string& prefix, int index, int max_index) {
6464 #if 0
6465 #if defined(_MSC_VER)
6466  const int digits = max_index > 0 ?
6467  static_cast<int>(log(1.0L * max_index) / log(10.0L)) + 1 :
6468  1;
6469 #else
6470  const int digits = max_index > 0 ? static_cast<int>(log10(max_index)) + 1: 1;
6471 #endif
6472  return absl::StrFormat("%s%0*d", prefix, digits, index);
6473 #else
6474  return absl::StrCat(prefix, index);
6475 #endif
6476 }
6477 } // namespace
6478 
6479 void Solver::MakeIntVarArray(int var_count, int64 vmin, int64 vmax,
6480  const std::string& name,
6481  std::vector<IntVar*>* vars) {
6482  for (int i = 0; i < var_count; ++i) {
6483  vars->push_back(MakeIntVar(vmin, vmax, IndexedName(name, i, var_count)));
6484  }
6485 }
6486 
6487 void Solver::MakeIntVarArray(int var_count, int64 vmin, int64 vmax,
6488  std::vector<IntVar*>* vars) {
6489  for (int i = 0; i < var_count; ++i) {
6490  vars->push_back(MakeIntVar(vmin, vmax));
6491  }
6492 }
6493 
6494 IntVar** Solver::MakeIntVarArray(int var_count, int64 vmin, int64 vmax,
6495  const std::string& name) {
6496  IntVar** vars = new IntVar*[var_count];
6497  for (int i = 0; i < var_count; ++i) {
6498  vars[i] = MakeIntVar(vmin, vmax, IndexedName(name, i, var_count));
6499  }
6500  return vars;
6501 }
6502 
6503 void Solver::MakeBoolVarArray(int var_count, const std::string& name,
6504  std::vector<IntVar*>* vars) {
6505  for (int i = 0; i < var_count; ++i) {
6506  vars->push_back(MakeBoolVar(IndexedName(name, i, var_count)));
6507  }
6508 }
6509 
6510 void Solver::MakeBoolVarArray(int var_count, std::vector<IntVar*>* vars) {
6511  for (int i = 0; i < var_count; ++i) {
6512  vars->push_back(MakeBoolVar());
6513  }
6514 }
6515 
6516 IntVar** Solver::MakeBoolVarArray(int var_count, const std::string& name) {
6517  IntVar** vars = new IntVar*[var_count];
6518  for (int i = 0; i < var_count; ++i) {
6519  vars[i] = MakeBoolVar(IndexedName(name, i, var_count));
6520  }
6521  return vars;
6522 }
6523 
6524 void Solver::InitCachedIntConstants() {
6525  for (int i = MIN_CACHED_INT_CONST; i <= MAX_CACHED_INT_CONST; ++i) {
6526  cached_constants_[i - MIN_CACHED_INT_CONST] =
6527  RevAlloc(new IntConst(this, i, "")); // note the empty name
6528  }
6529 }
6530 
6531 IntExpr* Solver::MakeSum(IntExpr* const left, IntExpr* const right) {
6532  CHECK_EQ(this, left->solver());
6533  CHECK_EQ(this, right->solver());
6534  if (right->Bound()) {
6535  return MakeSum(left, right->Min());
6536  }
6537  if (left->Bound()) {
6538  return MakeSum(right, left->Min());
6539  }
6540  if (left == right) {
6541  return MakeProd(left, 2);
6542  }
6543  IntExpr* cache = model_cache_->FindExprExprExpression(
6544  left, right, ModelCache::EXPR_EXPR_SUM);
6545  if (cache == nullptr) {
6546  cache = model_cache_->FindExprExprExpression(right, left,
6547  ModelCache::EXPR_EXPR_SUM);
6548  }
6549  if (cache != nullptr) {
6550  return cache;
6551  } else {
6552  IntExpr* const result =
6553  AddOverflows(left->Max(), right->Max()) ||
6554  AddOverflows(left->Min(), right->Min())
6555  ? RegisterIntExpr(RevAlloc(new SafePlusIntExpr(this, left, right)))
6556  : RegisterIntExpr(RevAlloc(new PlusIntExpr(this, left, right)));
6557  model_cache_->InsertExprExprExpression(result, left, right,
6558  ModelCache::EXPR_EXPR_SUM);
6559  return result;
6560  }
6561 }
6562 
6563 IntExpr* Solver::MakeSum(IntExpr* const expr, int64 value) {
6564  CHECK_EQ(this, expr->solver());
6565  if (expr->Bound()) {
6566  return MakeIntConst(expr->Min() + value);
6567  }
6568  if (value == 0) {
6569  return expr;
6570  }
6571  IntExpr* result = Cache()->FindExprConstantExpression(
6572  expr, value, ModelCache::EXPR_CONSTANT_SUM);
6573  if (result == nullptr) {
6574  if (expr->IsVar() && !AddOverflows(value, expr->Max()) &&
6575  !AddOverflows(value, expr->Min())) {
6576  IntVar* const var = expr->Var();
6577  switch (var->VarType()) {
6578  case DOMAIN_INT_VAR: {
6579  result = RegisterIntExpr(RevAlloc(new PlusCstDomainIntVar(
6580  this, reinterpret_cast<DomainIntVar*>(var), value)));
6581  break;
6582  }
6583  case CONST_VAR: {
6584  result = RegisterIntExpr(MakeIntConst(var->Min() + value));
6585  break;
6586  }
6587  case VAR_ADD_CST: {
6588  PlusCstVar* const add_var = reinterpret_cast<PlusCstVar*>(var);
6589  IntVar* const sub_var = add_var->SubVar();
6590  const int64 new_constant = value + add_var->Constant();
6591  if (new_constant == 0) {
6592  result = sub_var;
6593  } else {
6594  if (sub_var->VarType() == DOMAIN_INT_VAR) {
6595  DomainIntVar* const dvar =
6596  reinterpret_cast<DomainIntVar*>(sub_var);
6597  result = RegisterIntExpr(
6598  RevAlloc(new PlusCstDomainIntVar(this, dvar, new_constant)));
6599  } else {
6600  result = RegisterIntExpr(
6601  RevAlloc(new PlusCstIntVar(this, sub_var, new_constant)));
6602  }
6603  }
6604  break;
6605  }
6606  case CST_SUB_VAR: {
6607  SubCstIntVar* const add_var = reinterpret_cast<SubCstIntVar*>(var);
6608  IntVar* const sub_var = add_var->SubVar();
6609  const int64 new_constant = value + add_var->Constant();
6610  result = RegisterIntExpr(
6611  RevAlloc(new SubCstIntVar(this, sub_var, new_constant)));
6612  break;
6613  }
6614  case OPP_VAR: {
6615  OppIntVar* const add_var = reinterpret_cast<OppIntVar*>(var);
6616  IntVar* const sub_var = add_var->SubVar();
6617  result =
6618  RegisterIntExpr(RevAlloc(new SubCstIntVar(this, sub_var, value)));
6619  break;
6620  }
6621  default:
6622  result =
6623  RegisterIntExpr(RevAlloc(new PlusCstIntVar(this, var, value)));
6624  }
6625  } else {
6626  result = RegisterIntExpr(RevAlloc(new PlusIntCstExpr(this, expr, value)));
6627  }
6628  Cache()->InsertExprConstantExpression(result, expr, value,
6629  ModelCache::EXPR_CONSTANT_SUM);
6630  }
6631  return result;
6632 }
6633 
6634 IntExpr* Solver::MakeDifference(IntExpr* const left, IntExpr* const right) {
6635  CHECK_EQ(this, left->solver());
6636  CHECK_EQ(this, right->solver());
6637  if (left->Bound()) {
6638  return MakeDifference(left->Min(), right);
6639  }
6640  if (right->Bound()) {
6641  return MakeSum(left, -right->Min());
6642  }
6643  IntExpr* sub_left = nullptr;
6644  IntExpr* sub_right = nullptr;
6645  int64 left_coef = 1;
6646  int64 right_coef = 1;
6647  if (IsProduct(left, &sub_left, &left_coef) &&
6648  IsProduct(right, &sub_right, &right_coef)) {
6649  const int64 abs_gcd =
6650  MathUtil::GCD64(std::abs(left_coef), std::abs(right_coef));
6651  if (abs_gcd != 0 && abs_gcd != 1) {
6652  return MakeProd(MakeDifference(MakeProd(sub_left, left_coef / abs_gcd),
6653  MakeProd(sub_right, right_coef / abs_gcd)),
6654  abs_gcd);
6655  }
6656  }
6657 
6658  IntExpr* result = Cache()->FindExprExprExpression(
6659  left, right, ModelCache::EXPR_EXPR_DIFFERENCE);
6660  if (result == nullptr) {
6661  if (!SubOverflows(left->Min(), right->Max()) &&
6662  !SubOverflows(left->Max(), right->Min())) {
6663  result = RegisterIntExpr(RevAlloc(new SubIntExpr(this, left, right)));
6664  } else {
6665  result = RegisterIntExpr(RevAlloc(new SafeSubIntExpr(this, left, right)));
6666  }
6667  Cache()->InsertExprExprExpression(result, left, right,
6668  ModelCache::EXPR_EXPR_DIFFERENCE);
6669  }
6670  return result;
6671 }
6672 
6673 // warning: this is 'value - expr'.
6674 IntExpr* Solver::MakeDifference(int64 value, IntExpr* const expr) {
6675  CHECK_EQ(this, expr->solver());
6676  if (expr->Bound()) {
6677  return MakeIntConst(value - expr->Min());
6678  }
6679  if (value == 0) {
6680  return MakeOpposite(expr);
6681  }
6682  IntExpr* result = Cache()->FindExprConstantExpression(
6683  expr, value, ModelCache::EXPR_CONSTANT_DIFFERENCE);
6684  if (result == nullptr) {
6685  if (expr->IsVar() && expr->Min() != kint64min &&
6686  !SubOverflows(value, expr->Min()) &&
6687  !SubOverflows(value, expr->Max())) {
6688  IntVar* const var = expr->Var();
6689  switch (var->VarType()) {
6690  case VAR_ADD_CST: {
6691  PlusCstVar* const add_var = reinterpret_cast<PlusCstVar*>(var);
6692  IntVar* const sub_var = add_var->SubVar();
6693  const int64 new_constant = value - add_var->Constant();
6694  if (new_constant == 0) {
6695  result = sub_var;
6696  } else {
6697  result = RegisterIntExpr(
6698  RevAlloc(new SubCstIntVar(this, sub_var, new_constant)));
6699  }
6700  break;
6701  }
6702  case CST_SUB_VAR: {
6703  SubCstIntVar* const add_var = reinterpret_cast<SubCstIntVar*>(var);
6704  IntVar* const sub_var = add_var->SubVar();
6705  const int64 new_constant = value - add_var->Constant();
6706  result = MakeSum(sub_var, new_constant);
6707  break;
6708  }
6709  case OPP_VAR: {
6710  OppIntVar* const add_var = reinterpret_cast<OppIntVar*>(var);
6711  IntVar* const sub_var = add_var->SubVar();
6712  result = MakeSum(sub_var, value);
6713  break;
6714  }
6715  default:
6716  result =
6717  RegisterIntExpr(RevAlloc(new SubCstIntVar(this, var, value)));
6718  }
6719  } else {
6720  result = RegisterIntExpr(RevAlloc(new SubIntCstExpr(this, expr, value)));
6721  }
6722  Cache()->InsertExprConstantExpression(result, expr, value,
6723  ModelCache::EXPR_CONSTANT_DIFFERENCE);
6724  }
6725  return result;
6726 }
6727 
6728 IntExpr* Solver::MakeOpposite(IntExpr* const expr) {
6729  CHECK_EQ(this, expr->solver());
6730  if (expr->Bound()) {
6731  return MakeIntConst(-expr->Min());
6732  }
6733  IntExpr* result =
6734  Cache()->FindExprExpression(expr, ModelCache::EXPR_OPPOSITE);
6735  if (result == nullptr) {
6736  if (expr->IsVar()) {
6737  result = RegisterIntVar(RevAlloc(new OppIntExpr(this, expr))->Var());
6738  } else {
6739  result = RegisterIntExpr(RevAlloc(new OppIntExpr(this, expr)));
6740  }
6741  Cache()->InsertExprExpression(result, expr, ModelCache::EXPR_OPPOSITE);
6742  }
6743  return result;
6744 }
6745 
6746 IntExpr* Solver::MakeProd(IntExpr* const expr, int64 value) {
6747  CHECK_EQ(this, expr->solver());
6748  IntExpr* result = Cache()->FindExprConstantExpression(
6749  expr, value, ModelCache::EXPR_CONSTANT_PROD);
6750  if (result != nullptr) {
6751  return result;
6752  } else {
6753  IntExpr* m_expr = nullptr;
6754  int64 coefficient = 1;
6755  if (IsProduct(expr, &m_expr, &coefficient)) {
6756  coefficient *= value;
6757  } else {
6758  m_expr = expr;
6759  coefficient = value;
6760  }
6761  if (m_expr->Bound()) {
6762  return MakeIntConst(coefficient * m_expr->Min());
6763  } else if (coefficient == 1) {
6764  return m_expr;
6765  } else if (coefficient == -1) {
6766  return MakeOpposite(m_expr);
6767  } else if (coefficient > 0) {
6768  if (m_expr->Max() > kint64max / coefficient ||
6769  m_expr->Min() < kint64min / coefficient) {
6770  result = RegisterIntExpr(
6771  RevAlloc(new SafeTimesPosIntCstExpr(this, m_expr, coefficient)));
6772  } else {
6773  result = RegisterIntExpr(
6774  RevAlloc(new TimesPosIntCstExpr(this, m_expr, coefficient)));
6775  }
6776  } else if (coefficient == 0) {
6777  result = MakeIntConst(0);
6778  } else { // coefficient < 0.
6779  result = RegisterIntExpr(
6780  RevAlloc(new TimesIntNegCstExpr(this, m_expr, coefficient)));
6781  }
6782  if (m_expr->IsVar() &&
6783  !absl::GetFlag(FLAGS_cp_disable_expression_optimization)) {
6784  result = result->Var();
6785  }
6786  Cache()->InsertExprConstantExpression(result, expr, value,
6787  ModelCache::EXPR_CONSTANT_PROD);
6788  return result;
6789  }
6790 }
6791 
6792 namespace {
6793 void ExtractPower(IntExpr** const expr, int64* const exponant) {
6794  if (dynamic_cast<BasePower*>(*expr) != nullptr) {
6795  BasePower* const power = dynamic_cast<BasePower*>(*expr);
6796  *expr = power->expr();
6797  *exponant = power->exponant();
6798  }
6799  if (dynamic_cast<IntSquare*>(*expr) != nullptr) {
6800  IntSquare* const power = dynamic_cast<IntSquare*>(*expr);
6801  *expr = power->expr();
6802  *exponant = 2;
6803  }
6804  if ((*expr)->IsVar()) {
6805  IntVar* const var = (*expr)->Var();
6806  IntExpr* const sub = var->solver()->CastExpression(var);
6807  if (sub != nullptr && dynamic_cast<BasePower*>(sub) != nullptr) {
6808  BasePower* const power = dynamic_cast<BasePower*>(sub);
6809  *expr = power->expr();
6810  *exponant = power->exponant();
6811  }
6812  if (sub != nullptr && dynamic_cast<IntSquare*>(sub) != nullptr) {
6813  IntSquare* const power = dynamic_cast<IntSquare*>(sub);
6814  *expr = power->expr();
6815  *exponant = 2;
6816  }
6817  }
6818 }
6819 
6820 void ExtractProduct(IntExpr** const expr, int64* const coefficient,
6821  bool* modified) {
6822  if (dynamic_cast<TimesCstIntVar*>(*expr) != nullptr) {
6823  TimesCstIntVar* const left_prod = dynamic_cast<TimesCstIntVar*>(*expr);
6824  *coefficient *= left_prod->Constant();
6825  *expr = left_prod->SubVar();
6826  *modified = true;
6827  } else if (dynamic_cast<TimesIntCstExpr*>(*expr) != nullptr) {
6828  TimesIntCstExpr* const left_prod = dynamic_cast<TimesIntCstExpr*>(*expr);
6829  *coefficient *= left_prod->Constant();
6830  *expr = left_prod->Expr();
6831  *modified = true;
6832  }
6833 }
6834 } // namespace
6835 
6836 IntExpr* Solver::MakeProd(IntExpr* const left, IntExpr* const right) {
6837  if (left->Bound()) {
6838  return MakeProd(right, left->Min());
6839  }
6840 
6841  if (right->Bound()) {
6842  return MakeProd(left, right->Min());
6843  }
6844 
6845  // ----- Discover squares and powers -----
6846 
6847  IntExpr* m_left = left;
6848  IntExpr* m_right = right;
6849  int64 left_exponant = 1;
6850  int64 right_exponant = 1;
6851  ExtractPower(&m_left, &left_exponant);
6852  ExtractPower(&m_right, &right_exponant);
6853 
6854  if (m_left == m_right) {
6855  return MakePower(m_left, left_exponant + right_exponant);
6856  }
6857 
6858  // ----- Discover nested products -----
6859 
6860  m_left = left;
6861  m_right = right;
6862  int64 coefficient = 1;
6863  bool modified = false;
6864 
6865  ExtractProduct(&m_left, &coefficient, &modified);
6866  ExtractProduct(&m_right, &coefficient, &modified);
6867  if (modified) {
6868  return MakeProd(MakeProd(m_left, m_right), coefficient);
6869  }
6870 
6871  // ----- Standard build -----
6872 
6873  CHECK_EQ(this, left->solver());
6874  CHECK_EQ(this, right->solver());
6875  IntExpr* result = model_cache_->FindExprExprExpression(
6876  left, right, ModelCache::EXPR_EXPR_PROD);
6877  if (result == nullptr) {
6878  result = model_cache_->FindExprExprExpression(right, left,
6879  ModelCache::EXPR_EXPR_PROD);
6880  }
6881  if (result != nullptr) {
6882  return result;
6883  }
6884  if (left->IsVar() && left->Var()->VarType() == BOOLEAN_VAR) {
6885  if (right->Min() >= 0) {
6886  result = RegisterIntExpr(RevAlloc(new TimesBooleanPosIntExpr(
6887  this, reinterpret_cast<BooleanVar*>(left), right)));
6888  } else {
6889  result = RegisterIntExpr(RevAlloc(new TimesBooleanIntExpr(
6890  this, reinterpret_cast<BooleanVar*>(left), right)));
6891  }
6892  } else if (right->IsVar() &&
6893  reinterpret_cast<IntVar*>(right)->VarType() == BOOLEAN_VAR) {
6894  if (left->Min() >= 0) {
6895  result = RegisterIntExpr(RevAlloc(new TimesBooleanPosIntExpr(
6896  this, reinterpret_cast<BooleanVar*>(right), left)));
6897  } else {
6898  result = RegisterIntExpr(RevAlloc(new TimesBooleanIntExpr(
6899  this, reinterpret_cast<BooleanVar*>(right), left)));
6900  }
6901  } else if (left->Min() >= 0 && right->Min() >= 0) {
6902  if (CapProd(left->Max(), right->Max()) ==
6903  kint64max) { // Potential overflow.
6904  result =
6905  RegisterIntExpr(RevAlloc(new SafeTimesPosIntExpr(this, left, right)));
6906  } else {
6907  result =
6908  RegisterIntExpr(RevAlloc(new TimesPosIntExpr(this, left, right)));
6909  }
6910  } else {
6911  result = RegisterIntExpr(RevAlloc(new TimesIntExpr(this, left, right)));
6912  }
6913  model_cache_->InsertExprExprExpression(result, left, right,
6914  ModelCache::EXPR_EXPR_PROD);
6915  return result;
6916 }
6917 
6918 IntExpr* Solver::MakeDiv(IntExpr* const numerator, IntExpr* const denominator) {
6919  CHECK(numerator != nullptr);
6920  CHECK(denominator != nullptr);
6921  if (denominator->Bound()) {
6922  return MakeDiv(numerator, denominator->Min());
6923  }
6924  IntExpr* result = model_cache_->FindExprExprExpression(
6925  numerator, denominator, ModelCache::EXPR_EXPR_DIV);
6926  if (result != nullptr) {
6927  return result;
6928  }
6929 
6930  if (denominator->Min() <= 0 && denominator->Max() >= 0) {
6931  AddConstraint(MakeNonEquality(denominator, 0));
6932  }
6933 
6934  if (denominator->Min() >= 0) {
6935  if (numerator->Min() >= 0) {
6936  result = RevAlloc(new DivPosPosIntExpr(this, numerator, denominator));
6937  } else {
6938  result = RevAlloc(new DivPosIntExpr(this, numerator, denominator));
6939  }
6940  } else if (denominator->Max() <= 0) {
6941  if (numerator->Max() <= 0) {
6942  result = RevAlloc(new DivPosPosIntExpr(this, MakeOpposite(numerator),
6943  MakeOpposite(denominator)));
6944  } else {
6945  result = MakeOpposite(RevAlloc(
6946  new DivPosIntExpr(this, numerator, MakeOpposite(denominator))));
6947  }
6948  } else {
6949  result = RevAlloc(new DivIntExpr(this, numerator, denominator));
6950  }
6951  model_cache_->InsertExprExprExpression(result, numerator, denominator,
6952  ModelCache::EXPR_EXPR_DIV);
6953  return result;
6954 }
6955 
6956 IntExpr* Solver::MakeDiv(IntExpr* const expr, int64 value) {
6957  CHECK(expr != nullptr);
6958  CHECK_EQ(this, expr->solver());
6959  if (expr->Bound()) {
6960  return MakeIntConst(expr->Min() / value);
6961  } else if (value == 1) {
6962  return expr;
6963  } else if (value == -1) {
6964  return MakeOpposite(expr);
6965  } else if (value > 0) {
6966  return RegisterIntExpr(RevAlloc(new DivPosIntCstExpr(this, expr, value)));
6967  } else if (value == 0) {
6968  LOG(FATAL) << "Cannot divide by 0";
6969  return nullptr;
6970  } else {
6971  return RegisterIntExpr(
6972  MakeOpposite(RevAlloc(new DivPosIntCstExpr(this, expr, -value))));
6973  // TODO(user) : implement special case.
6974  }
6975 }
6976 
6977 Constraint* Solver::MakeAbsEquality(IntVar* const var, IntVar* const abs_var) {
6978  if (Cache()->FindExprExpression(var, ModelCache::EXPR_ABS) == nullptr) {
6979  Cache()->InsertExprExpression(abs_var, var, ModelCache::EXPR_ABS);
6980  }
6981  return RevAlloc(new IntAbsConstraint(this, var, abs_var));
6982 }
6983 
6984 IntExpr* Solver::MakeAbs(IntExpr* const e) {
6985  CHECK_EQ(this, e->solver());
6986  if (e->Min() >= 0) {
6987  return e;
6988  } else if (e->Max() <= 0) {
6989  return MakeOpposite(e);
6990  }
6991  IntExpr* result = Cache()->FindExprExpression(e, ModelCache::EXPR_ABS);
6992  if (result == nullptr) {
6993  int64 coefficient = 1;
6994  IntExpr* expr = nullptr;
6995  if (IsProduct(e, &expr, &coefficient)) {
6996  result = MakeProd(MakeAbs(expr), std::abs(coefficient));
6997  } else {
6998  result = RegisterIntExpr(RevAlloc(new IntAbs(this, e)));
6999  }
7000  Cache()->InsertExprExpression(result, e, ModelCache::EXPR_ABS);
7001  }
7002  return result;
7003 }
7004 
7005 IntExpr* Solver::MakeSquare(IntExpr* const expr) {
7006  CHECK_EQ(this, expr->solver());
7007  if (expr->Bound()) {
7008  const int64 v = expr->Min();
7009  return MakeIntConst(v * v);
7010  }
7011  IntExpr* result = Cache()->FindExprExpression(expr, ModelCache::EXPR_SQUARE);
7012  if (result == nullptr) {
7013  if (expr->Min() >= 0) {
7014  result = RegisterIntExpr(RevAlloc(new PosIntSquare(this, expr)));
7015  } else {
7016  result = RegisterIntExpr(RevAlloc(new IntSquare(this, expr)));
7017  }
7018  Cache()->InsertExprExpression(result, expr, ModelCache::EXPR_SQUARE);
7019  }
7020  return result;
7021 }
7022 
7023 IntExpr* Solver::MakePower(IntExpr* const expr, int64 n) {
7024  CHECK_EQ(this, expr->solver());
7025  CHECK_GE(n, 0);
7026  if (expr->Bound()) {
7027  const int64 v = expr->Min();
7028  if (v >= OverflowLimit(n)) { // Overflow.
7029  return MakeIntConst(kint64max);
7030  }
7031  return MakeIntConst(IntPower(v, n));
7032  }
7033  switch (n) {
7034  case 0:
7035  return MakeIntConst(1);
7036  case 1:
7037  return expr;
7038  case 2:
7039  return MakeSquare(expr);
7040  default: {
7041  IntExpr* result = nullptr;
7042  if (n % 2 == 0) { // even.
7043  if (expr->Min() >= 0) {
7044  result =
7045  RegisterIntExpr(RevAlloc(new PosIntEvenPower(this, expr, n)));
7046  } else {
7047  result = RegisterIntExpr(RevAlloc(new IntEvenPower(this, expr, n)));
7048  }
7049  } else {
7050  result = RegisterIntExpr(RevAlloc(new IntOddPower(this, expr, n)));
7051  }
7052  return result;
7053  }
7054  }
7055 }
7056 
7057 IntExpr* Solver::MakeMin(IntExpr* const left, IntExpr* const right) {
7058  CHECK_EQ(this, left->solver());
7059  CHECK_EQ(this, right->solver());
7060  if (left->Bound()) {
7061  return MakeMin(right, left->Min());
7062  }
7063  if (right->Bound()) {
7064  return MakeMin(left, right->Min());
7065  }
7066  if (left->Min() >= right->Max()) {
7067  return right;
7068  }
7069  if (right->Min() >= left->Max()) {
7070  return left;
7071  }
7072  return RegisterIntExpr(RevAlloc(new MinIntExpr(this, left, right)));
7073 }
7074 
7075 IntExpr* Solver::MakeMin(IntExpr* const expr, int64 value) {
7076  CHECK_EQ(this, expr->solver());
7077  if (value <= expr->Min()) {
7078  return MakeIntConst(value);
7079  }
7080  if (expr->Bound()) {
7081  return MakeIntConst(std::min(expr->Min(), value));
7082  }
7083  if (expr->Max() <= value) {
7084  return expr;
7085  }
7086  return RegisterIntExpr(RevAlloc(new MinCstIntExpr(this, expr, value)));
7087 }
7088 
7089 IntExpr* Solver::MakeMin(IntExpr* const expr, int value) {
7090  return MakeMin(expr, static_cast<int64>(value));
7091 }
7092 
7093 IntExpr* Solver::MakeMax(IntExpr* const left, IntExpr* const right) {
7094  CHECK_EQ(this, left->solver());
7095  CHECK_EQ(this, right->solver());
7096  if (left->Bound()) {
7097  return MakeMax(right, left->Min());
7098  }
7099  if (right->Bound()) {
7100  return MakeMax(left, right->Min());
7101  }
7102  if (left->Min() >= right->Max()) {
7103  return left;
7104  }
7105  if (right->Min() >= left->Max()) {
7106  return right;
7107  }
7108  return RegisterIntExpr(RevAlloc(new MaxIntExpr(this, left, right)));
7109 }
7110 
7111 IntExpr* Solver::MakeMax(IntExpr* const expr, int64 value) {
7112  CHECK_EQ(this, expr->solver());
7113  if (expr->Bound()) {
7114  return MakeIntConst(std::max(expr->Min(), value));
7115  }
7116  if (value <= expr->Min()) {
7117  return expr;
7118  }
7119  if (expr->Max() <= value) {
7120  return MakeIntConst(value);
7121  }
7122  return RegisterIntExpr(RevAlloc(new MaxCstIntExpr(this, expr, value)));
7123 }
7124 
7125 IntExpr* Solver::MakeMax(IntExpr* const expr, int value) {
7126  return MakeMax(expr, static_cast<int64>(value));
7127 }
7128 
7129 IntExpr* Solver::MakeConvexPiecewiseExpr(IntExpr* expr, int64 early_cost,
7130  int64 early_date, int64 late_date,
7131  int64 late_cost) {
7132  return RegisterIntExpr(RevAlloc(new SimpleConvexPiecewiseExpr(
7133  this, expr, early_cost, early_date, late_date, late_cost)));
7134 }
7135 
7136 IntExpr* Solver::MakeSemiContinuousExpr(IntExpr* const expr, int64 fixed_charge,
7137  int64 step) {
7138  if (step == 0) {
7139  if (fixed_charge == 0) {
7140  return MakeIntConst(int64{0});
7141  } else {
7142  return RegisterIntExpr(
7143  RevAlloc(new SemiContinuousStepZeroExpr(this, expr, fixed_charge)));
7144  }
7145  } else if (step == 1) {
7146  return RegisterIntExpr(
7147  RevAlloc(new SemiContinuousStepOneExpr(this, expr, fixed_charge)));
7148  } else {
7149  return RegisterIntExpr(
7150  RevAlloc(new SemiContinuousExpr(this, expr, fixed_charge, step)));
7151  }
7152  // TODO(user) : benchmark with virtualization of
7153  // PosIntDivDown and PosIntDivUp - or function pointers.
7154 }
7155 
7156 // ----- Piecewise Linear -----
7157 
7159  public:
7161  const PiecewiseLinearFunction& f)
7162  : BaseIntExpr(solver), expr_(expr), f_(f) {}
7163  ~PiecewiseLinearExpr() override {}
7164  int64 Min() const override {
7165  return f_.GetMinimum(expr_->Min(), expr_->Max());
7166  }
7167  void SetMin(int64 m) override {
7168  const auto& range =
7169  f_.GetSmallestRangeGreaterThanValue(expr_->Min(), expr_->Max(), m);
7170  expr_->SetRange(range.first, range.second);
7171  }
7172 
7173  int64 Max() const override {
7174  return f_.GetMaximum(expr_->Min(), expr_->Max());
7175  }
7176 
7177  void SetMax(int64 m) override {
7178  const auto& range =
7179  f_.GetSmallestRangeLessThanValue(expr_->Min(), expr_->Max(), m);
7180  expr_->SetRange(range.first, range.second);
7181  }
7182 
7183  void SetRange(int64 l, int64 u) override {
7184  const auto& range =
7185  f_.GetSmallestRangeInValueRange(expr_->Min(), expr_->Max(), l, u);
7186  expr_->SetRange(range.first, range.second);
7187  }
7188  std::string name() const override {
7189  return absl::StrFormat("PiecewiseLinear(%s, f = %s)", expr_->name(),
7190  f_.DebugString());
7191  }
7192 
7193  std::string DebugString() const override {
7194  return absl::StrFormat("PiecewiseLinear(%s, f = %s)", expr_->DebugString(),
7195  f_.DebugString());
7196  }
7197 
7198  void WhenRange(Demon* d) override { expr_->WhenRange(d); }
7199 
7200  void Accept(ModelVisitor* const visitor) const override {
7201  // TODO(user): Implement visitor.
7202  }
7203 
7204  private:
7205  IntExpr* const expr_;
7206  const PiecewiseLinearFunction f_;
7207 };
7208 
7209 IntExpr* Solver::MakePiecewiseLinearExpr(IntExpr* expr,
7210  const PiecewiseLinearFunction& f) {
7211  return RegisterIntExpr(RevAlloc(new PiecewiseLinearExpr(this, expr, f)));
7212 }
7213 
7214 // ----- Conditional Expression -----
7215 
7216 IntExpr* Solver::MakeConditionalExpression(IntVar* const condition,
7217  IntExpr* const expr,
7218  int64 unperformed_value) {
7219  if (condition->Min() == 1) {
7220  return expr;
7221  } else if (condition->Max() == 0) {
7222  return MakeIntConst(unperformed_value);
7223  } else {
7224  IntExpr* cache = Cache()->FindExprExprConstantExpression(
7225  condition, expr, unperformed_value,
7226  ModelCache::EXPR_EXPR_CONSTANT_CONDITIONAL);
7227  if (cache == nullptr) {
7228  cache = RevAlloc(
7229  new ExprWithEscapeValue(this, condition, expr, unperformed_value));
7230  Cache()->InsertExprExprConstantExpression(
7231  cache, condition, expr, unperformed_value,
7232  ModelCache::EXPR_EXPR_CONSTANT_CONDITIONAL);
7233  }
7234  return cache;
7235  }
7236 }
7237 
7238 // ----- Modulo -----
7239 
7240 IntExpr* Solver::MakeModulo(IntExpr* const x, int64 mod) {
7241  IntVar* const result =
7242  MakeDifference(x, MakeProd(MakeDiv(x, mod), mod))->Var();
7243  if (mod >= 0) {
7244  AddConstraint(MakeBetweenCt(result, 0, mod - 1));
7245  } else {
7246  AddConstraint(MakeBetweenCt(result, mod + 1, 0));
7247  }
7248  return result;
7249 }
7250 
7251 IntExpr* Solver::MakeModulo(IntExpr* const x, IntExpr* const mod) {
7252  if (mod->Bound()) {
7253  return MakeModulo(x, mod->Min());
7254  }
7255  IntVar* const result =
7256  MakeDifference(x, MakeProd(MakeDiv(x, mod), mod))->Var();
7257  AddConstraint(MakeLess(result, MakeAbs(mod)));
7258  AddConstraint(MakeGreater(result, MakeOpposite(MakeAbs(mod))));
7259  return result;
7260 }
7261 
7262 // --------- IntVar ---------
7263 
7264 int IntVar::VarType() const { return UNSPECIFIED; }
7265 
7266 void IntVar::RemoveValues(const std::vector<int64>& values) {
7267  // TODO(user): Check and maybe inline this code.
7268  const int size = values.size();
7269  DCHECK_GE(size, 0);
7270  switch (size) {
7271  case 0: {
7272  return;
7273  }
7274  case 1: {
7275  RemoveValue(values[0]);
7276  return;
7277  }
7278  case 2: {
7279  RemoveValue(values[0]);
7280  RemoveValue(values[1]);
7281  return;
7282  }
7283  case 3: {
7284  RemoveValue(values[0]);
7285  RemoveValue(values[1]);
7286  RemoveValue(values[2]);
7287  return;
7288  }
7289  default: {
7290  // 4 values, let's start doing some more clever things.
7291  // TODO(user) : Sort values!
7292  int start_index = 0;
7293  int64 new_min = Min();
7294  if (values[start_index] <= new_min) {
7295  while (start_index < size - 1 &&
7296  values[start_index + 1] == values[start_index] + 1) {
7297  new_min = values[start_index + 1] + 1;
7298  start_index++;
7299  }
7300  }
7301  int end_index = size - 1;
7302  int64 new_max = Max();
7303  if (values[end_index] >= new_max) {
7304  while (end_index > start_index + 1 &&
7305  values[end_index - 1] == values[end_index] - 1) {
7306  new_max = values[end_index - 1] - 1;
7307  end_index--;
7308  }
7309  }
7310  SetRange(new_min, new_max);
7311  for (int i = start_index; i <= end_index; ++i) {
7312  RemoveValue(values[i]);
7313  }
7314  }
7315  }
7316 }
7317 
7318 void IntVar::Accept(ModelVisitor* const visitor) const {
7319  IntExpr* const casted = solver()->CastExpression(this);
7320  visitor->VisitIntegerVariable(this, casted);
7321 }
7322 
7323 void IntVar::SetValues(const std::vector<int64>& values) {
7324  switch (values.size()) {
7325  case 0: {
7326  solver()->Fail();
7327  break;
7328  }
7329  case 1: {
7330  SetValue(values.back());
7331  break;
7332  }
7333  case 2: {
7334  if (Contains(values[0])) {
7335  if (Contains(values[1])) {
7336  const int64 l = std::min(values[0], values[1]);
7337  const int64 u = std::max(values[0], values[1]);
7338  SetRange(l, u);
7339  if (u > l + 1) {
7340  RemoveInterval(l + 1, u - 1);
7341  }
7342  } else {
7343  SetValue(values[0]);
7344  }
7345  } else {
7346  SetValue(values[1]);
7347  }
7348  break;
7349  }
7350  default: {
7351  // TODO(user): use a clean and safe SortedUniqueCopy() class
7352  // that uses a global, static shared (and locked) storage.
7353  // TODO(user): [optional] consider porting
7354  // STLSortAndRemoveDuplicates from ortools/base/stl_util.h to the
7355  // existing open_source/base/stl_util.h and using it here.
7356  // TODO(user): We could filter out values not in the var.
7357  std::vector<int64>& tmp = solver()->tmp_vector_;
7358  tmp.clear();
7359  tmp.insert(tmp.end(), values.begin(), values.end());
7360  std::sort(tmp.begin(), tmp.end());
7361  tmp.erase(std::unique(tmp.begin(), tmp.end()), tmp.end());
7362  const int size = tmp.size();
7363  const int64 vmin = Min();
7364  const int64 vmax = Max();
7365  int first = 0;
7366  int last = size - 1;
7367  if (tmp.front() > vmax || tmp.back() < vmin) {
7368  solver()->Fail();
7369  }
7370  // TODO(user) : We could find the first position >= vmin by dichotomy.
7371  while (tmp[first] < vmin || !Contains(tmp[first])) {
7372  ++first;
7373  if (first > last || tmp[first] > vmax) {
7374  solver()->Fail();
7375  }
7376  }
7377  while (last > first && (tmp[last] > vmax || !Contains(tmp[last]))) {
7378  // Note that last >= first implies tmp[last] >= vmin.
7379  --last;
7380  }
7381  DCHECK_GE(last, first);
7382  SetRange(tmp[first], tmp[last]);
7383  while (first < last) {
7384  const int64 start = tmp[first] + 1;
7385  const int64 end = tmp[first + 1] - 1;
7386  if (start <= end) {
7387  RemoveInterval(start, end);
7388  }
7389  first++;
7390  }
7391  }
7392  }
7393 }
7394 // ---------- BaseIntExpr ---------
7395 
7396 void LinkVarExpr(Solver* const s, IntExpr* const expr, IntVar* const var) {
7397  if (!var->Bound()) {
7398  if (var->VarType() == DOMAIN_INT_VAR) {
7399  DomainIntVar* dvar = reinterpret_cast<DomainIntVar*>(var);
7400  s->AddCastConstraint(
7401  s->RevAlloc(new LinkExprAndDomainIntVar(s, expr, dvar)), dvar, expr);
7402  } else {
7403  s->AddCastConstraint(s->RevAlloc(new LinkExprAndVar(s, expr, var)), var,
7404  expr);
7405  }
7406  }
7407 }
7408 
7409 IntVar* BaseIntExpr::Var() {
7410  if (var_ == nullptr) {
7411  solver()->SaveValue(reinterpret_cast<void**>(&var_));
7412  var_ = CastToVar();
7413  }
7414  return var_;
7415 }
7416 
7417 IntVar* BaseIntExpr::CastToVar() {
7418  int64 vmin, vmax;
7419  Range(&vmin, &vmax);
7420  IntVar* const var = solver()->MakeIntVar(vmin, vmax);
7421  LinkVarExpr(solver(), this, var);
7422  return var;
7423 }
7424 
7425 // Discovery methods
7426 bool Solver::IsADifference(IntExpr* expr, IntExpr** const left,
7427  IntExpr** const right) {
7428  if (expr->IsVar()) {
7429  IntVar* const expr_var = expr->Var();
7430  expr = CastExpression(expr_var);
7431  }
7432  // This is a dynamic cast to check the type of expr.
7433  // It returns nullptr is expr is not a subclass of SubIntExpr.
7434  SubIntExpr* const sub_expr = dynamic_cast<SubIntExpr*>(expr);
7435  if (sub_expr != nullptr) {
7436  *left = sub_expr->left();
7437  *right = sub_expr->right();
7438  return true;
7439  }
7440  return false;
7441 }
7442 
7443 bool Solver::IsBooleanVar(IntExpr* const expr, IntVar** inner_var,
7444  bool* is_negated) const {
7445  if (expr->IsVar() && expr->Var()->VarType() == BOOLEAN_VAR) {
7446  *inner_var = expr->Var();
7447  *is_negated = false;
7448  return true;
7449  } else if (expr->IsVar() && expr->Var()->VarType() == CST_SUB_VAR) {
7450  SubCstIntVar* const sub_var = reinterpret_cast<SubCstIntVar*>(expr);
7451  if (sub_var != nullptr && sub_var->Constant() == 1 &&
7452  sub_var->SubVar()->VarType() == BOOLEAN_VAR) {
7453  *is_negated = true;
7454  *inner_var = sub_var->SubVar();
7455  return true;
7456  }
7457  }
7458  return false;
7459 }
7460 
7461 bool Solver::IsProduct(IntExpr* const expr, IntExpr** inner_expr,
7462  int64* coefficient) {
7463  if (dynamic_cast<TimesCstIntVar*>(expr) != nullptr) {
7464  TimesCstIntVar* const var = dynamic_cast<TimesCstIntVar*>(expr);
7465  *coefficient = var->Constant();
7466  *inner_expr = var->SubVar();
7467  return true;
7468  } else if (dynamic_cast<TimesIntCstExpr*>(expr) != nullptr) {
7469  TimesIntCstExpr* const prod = dynamic_cast<TimesIntCstExpr*>(expr);
7470  *coefficient = prod->Constant();
7471  *inner_expr = prod->Expr();
7472  return true;
7473  }
7474  *inner_expr = expr;
7475  *coefficient = 1;
7476  return false;
7477 }
7478 
7479 #undef COND_REV_ALLOC
7480 
7481 } // namespace operations_research
int64 min
Definition: alldiff_cst.cc:138
int64 max
Definition: alldiff_cst.cc:139
#define CHECK(condition)
Definition: base/logging.h:495
#define DCHECK_LE(val1, val2)
Definition: base/logging.h:887
#define DCHECK_NE(val1, val2)
Definition: base/logging.h:886
#define CHECK_LT(val1, val2)
Definition: base/logging.h:700
#define CHECK_EQ(val1, val2)
Definition: base/logging.h:697
#define CHECK_GE(val1, val2)
Definition: base/logging.h:701
#define CHECK_GT(val1, val2)
Definition: base/logging.h:702
#define DCHECK_GE(val1, val2)
Definition: base/logging.h:889
#define CHECK_NE(val1, val2)
Definition: base/logging.h:698
#define DCHECK_GT(val1, val2)
Definition: base/logging.h:890
#define DCHECK_LT(val1, val2)
Definition: base/logging.h:888
#define LOG(severity)
Definition: base/logging.h:420
#define DCHECK(condition)
Definition: base/logging.h:884
#define DCHECK_EQ(val1, val2)
Definition: base/logging.h:885
This is the base class for all expressions that are not variables.
A BaseObject is the root of all reversibly allocated objects.
IntVar * IsDifferent(int64 constant) override
Definition: expressions.cc:143
void RemoveInterval(int64 l, int64 u) override
This method removes the interval 'l' .
Definition: expressions.cc:103
void SetMax(int64 m) override
Definition: expressions.cc:74
IntVar * IsLessOrEqual(int64 constant) override
Definition: expressions.cc:164
void WhenBound(Demon *d) override
This method attaches a demon that will be awakened when the variable is bound.
Definition: expressions.cc:114
bool Contains(int64 v) const override
This method returns whether the value 'v' is in the domain of the variable.
Definition: expressions.cc:128
SimpleRevFIFO< Demon * > delayed_bound_demons_
void SetRange(int64 mi, int64 ma) override
This method sets both the min and the max of the expression.
Definition: expressions.cc:80
void RemoveValue(int64 v) override
This method removes the value 'v' from the domain of the variable.
Definition: expressions.cc:91
IntVar * IsEqual(int64 constant) override
IsEqual.
Definition: expressions.cc:132
IntVar * IsGreaterOrEqual(int64 constant) override
Definition: expressions.cc:154
SimpleRevFIFO< Demon * > bound_demons_
void SetMin(int64 m) override
Definition: expressions.cc:68
std::string DebugString() const override
Definition: expressions.cc:174
uint64 Size() const override
This method returns the number of values in the domain of the variable.
Definition: expressions.cc:124
A constraint is the main modeling object.
A Demon is the base element of a propagation queue.
virtual Solver::DemonPriority priority() const
This method returns the priority of the demon.
The class IntExpr is the base of all integer expressions in constraint programming.
virtual IntVar * Var()=0
Creates a variable from the expression.
virtual void SetValue(int64 v)
This method sets the value of the expression.
virtual bool Bound() const
Returns true if the min and the max of the expression are equal.
virtual bool IsVar() const
Returns true if the expression is indeed a variable.
virtual int64 Max() const =0
IntVar * VarWithName(const std::string &name)
Creates a variable from the expression and set the name of the resulting var.
Definition: expressions.cc:49
virtual int64 Min() const =0
The class IntVar is a subset of IntExpr.
IntVar * Var() override
Creates a variable from the expression.
IntVar(Solver *const s)
Definition: expressions.cc:57
virtual int VarType() const
The class Iterator has two direct subclasses.
virtual void VisitIntegerVariable(const IntVar *const variable, IntExpr *const delegate)
void SetRange(int64 l, int64 u) override
This method sets both the min and the max of the expression.
PiecewiseLinearExpr(Solver *solver, IntExpr *expr, const PiecewiseLinearFunction &f)
void WhenRange(Demon *d) override
Attach a demon that will watch the min or the max of the expression.
void Accept(ModelVisitor *const visitor) const override
Accepts the given visitor.
std::string name() const override
Object naming.
std::string DebugString() const override
virtual std::string name() const
Object naming.
void SetValue(Solver *const s, const T &val)
DemonPriority
This enum represents the three possible priorities for a demon in the Solver queue.
@ VAR_PRIORITY
VAR_PRIORITY is between DELAYED_PRIORITY and NORMAL_PRIORITY.
@ DELAYED_PRIORITY
DELAYED_PRIORITY is the lowest priority: Demons will be processed after VAR_PRIORITY and NORMAL_PRIOR...
@ OUTSIDE_SEARCH
Before search, after search.
IntExpr * MakeDifference(IntExpr *const left, IntExpr *const right)
left - right
T * RevAlloc(T *object)
Registers the given object as being reversible.
void AddCastConstraint(CastConstraint *const constraint, IntVar *const target_var, IntExpr *const expr)
Adds 'constraint' to the solver and marks it as a cast constraint, that is, a constraint created call...
void Fail()
Abandon the current branch in the search tree. A backtrack will follow.
IntVar * MakeIntConst(int64 val, const std::string &name)
IntConst will create a constant expression.
std::vector< IntVarIterator * > holes_
const std::string name
const Constraint * ct
int64 value
IntVar *const expr_
Definition: element.cc:85
IntVar * var
Definition: expr_array.cc:1858
const int64 limit_
#define COND_REV_ALLOC(rev, alloc)
Solver *const solver_
Definition: expressions.cc:274
const int64 pow_
const int64 cst_
ABSL_FLAG(bool, cp_disable_expression_optimization, false, "Disable special optimization when creating expressions.")
IntVarIterator *const iterator_
static const int64 kint64max
int64_t int64
static const int32 kint32max
uint64_t uint64
static const int64 kint64min
const int64 offset_
Definition: interval.cc:2076
Handler handler_
Definition: interval.cc:420
bool in_process_
Definition: interval.cc:419
const int FATAL
Definition: log_severity.h:32
#define DISALLOW_COPY_AND_ASSIGN(TypeName)
Definition: macros.h:29
int RemoveAt(RepeatedType *array, const IndexContainer &indices)
Definition: protobuf_util.h:40
const Collection::value_type::second_type FindPtrOrNull(const Collection &collection, const typename Collection::value_type::first_type &key)
Definition: map_util.h:70
void STLSortAndRemoveDuplicates(T *v, const LessFunc &less_func)
Definition: stl_util.h:58
std::function< int64(const Model &)> Value(IntegerVariable v)
Definition: integer.h:1487
The vehicle routing library lets one model and solve generic vehicle routing problems ranging from th...
int LeastSignificantBitPosition64(uint64 n)
Definition: bitset.h:127
void InternalSaveBooleanVarValue(Solver *const solver, IntVar *const var)
uint32 BitPos64(uint64 pos)
Definition: bitset.h:330
int64 CapSub(int64 x, int64 y)
void CleanVariableOnFail(IntVar *const var)
int MostSignificantBitPosition64(uint64 n)
Definition: bitset.h:231
Demon * MakeConstraintDemon0(Solver *const s, T *const ct, void(T::*method)(), const std::string &name)
uint64 BitCount64(uint64 n)
Definition: bitset.h:42
int64 SubOverflows(int64 x, int64 y)
int64 PosIntDivUp(int64 e, int64 v)
int64 UnsafeLeastSignificantBitPosition64(const uint64 *const bitset, uint64 start, uint64 end)
uint64 BitOffset64(uint64 pos)
Definition: bitset.h:334
bool IsBitSet64(const uint64 *const bitset, uint64 pos)
Definition: bitset.h:346
Constraint * SetIsGreaterOrEqual(IntVar *const var, const std::vector< int64 > &values, const std::vector< IntVar * > &vars)
static const uint64 kAllBits64
Definition: bitset.h:33
int64 PosIntDivDown(int64 e, int64 v)
void RegisterDemon(Solver *const solver, Demon *const demon, DemonProfiler *const monitor)
int64 CapAdd(int64 x, int64 y)
void RestoreBoolValue(IntVar *const var)
uint64 OneRange64(uint64 s, uint64 e)
Definition: bitset.h:285
Constraint * SetIsEqual(IntVar *const var, const std::vector< int64 > &values, const std::vector< IntVar * > &vars)
int64 CapProd(int64 x, int64 y)
uint64 BitCountRange64(const uint64 *const bitset, uint64 start, uint64 end)
std::vector< int64 > ToInt64Vector(const std::vector< int > &input)
Definition: utilities.cc:822
void LinkVarExpr(Solver *const s, IntExpr *const expr, IntVar *const var)
int64 UnsafeMostSignificantBitPosition64(const uint64 *const bitset, uint64 start, uint64 end)
uint64 OneBit64(int pos)
Definition: bitset.h:38
uint64 BitLength64(uint64 size)
Definition: bitset.h:338
bool AddOverflows(int64 x, int64 y)
int index
Definition: pack.cc:508
int64 coefficient
IntervalVar *const target_var_
int64 step_
Definition: search.cc:2952
const int64 stamp_
Definition: search.cc:3039
int64 current_
Definition: search.cc:2953