Skip to content
This repository was archived by the owner on Jul 17, 2024. It is now read-only.

Commit e1dac95

Browse files
feat: add support for fairness constraints (#94)
Co-authored-by: Christopher Chianelli <christopher@timefold.ai>
1 parent 4396e2d commit e1dac95

File tree

6 files changed

+1410
-193
lines changed

6 files changed

+1410
-193
lines changed

tests/test_collectors.py

+22-7
Original file line numberDiff line numberDiff line change
@@ -563,24 +563,39 @@ def define_constraints(constraint_factory: ConstraintFactory):
563563
assert score_manager.explain(problem).score == SimpleScore.of(4)
564564

565565

566-
def test_flatten_last():
566+
def test_load_balance():
567567
@constraint_provider
568568
def define_constraints(constraint_factory: ConstraintFactory):
569569
return [
570570
constraint_factory.for_each(Entity)
571-
.map(lambda entity: (1, 2, 3))
572-
.flatten_last(lambda the_tuple: the_tuple)
573-
.reward(SimpleScore.ONE)
574-
.as_constraint('Count')
571+
.group_by(ConstraintCollectors.load_balance(
572+
lambda entity: entity.value
573+
))
574+
.reward(SimpleScore.ONE,
575+
lambda balance: balance.unfairness().movePointRight(3).intValue())
576+
.as_constraint('Balanced value')
575577
]
576578

577579
score_manager = create_score_manager(define_constraints)
578580

579581
entity_a: Entity = Entity('A')
582+
entity_b: Entity = Entity('B')
583+
entity_c: Entity = Entity('C')
580584

581585
value_1 = Value(1)
586+
value_2 = Value(2)
582587

583-
problem = Solution([entity_a], [value_1])
588+
problem = Solution([entity_a, entity_b], [value_1])
584589
entity_a.value = value_1
590+
entity_b.value = value_1
591+
entity_c.value = value_1
585592

586-
assert score_manager.explain(problem).score == SimpleScore.of(3)
593+
assert score_manager.explain(problem).score == SimpleScore.of(0)
594+
595+
problem = Solution([entity_a, entity_b, entity_c], [value_1, value_2])
596+
597+
assert score_manager.explain(problem).score == SimpleScore.of(0)
598+
599+
entity_c.value = value_2
600+
601+
assert score_manager.explain(problem).score == SimpleScore.of(707)

tests/test_constraint_streams.py

+146-5
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,29 @@ def define_constraints(constraint_factory: ConstraintFactory):
223223
assert score_manager.explain(problem).score.score == 1
224224

225225

226+
def test_flatten_last():
227+
@constraint_provider
228+
def define_constraints(constraint_factory: ConstraintFactory):
229+
return [
230+
constraint_factory.for_each(Entity)
231+
.map(lambda entity: (1, 2, 3))
232+
.flatten_last(lambda the_tuple: the_tuple)
233+
.reward(SimpleScore.ONE)
234+
.as_constraint('Count')
235+
]
236+
237+
score_manager = create_score_manager(define_constraints)
238+
239+
entity_a: Entity = Entity('A')
240+
241+
value_1 = Value(1)
242+
243+
problem = Solution([entity_a], [value_1])
244+
entity_a.value = value_1
245+
246+
assert score_manager.explain(problem).score == SimpleScore.of(3)
247+
248+
226249
def test_join_uni():
227250
@constraint_provider
228251
def define_constraints(constraint_factory: ConstraintFactory):
@@ -265,6 +288,87 @@ def define_constraints(constraint_factory: ConstraintFactory):
265288
assert score_manager.explain(problem).score.score == 8
266289

267290

291+
def test_if_exists_uni():
292+
@constraint_provider
293+
def define_constraints(constraint_factory: ConstraintFactory):
294+
return [
295+
constraint_factory.for_each(Entity)
296+
.if_exists(Entity, Joiners.equal(lambda entity: entity.code))
297+
.reward(SimpleScore.ONE, lambda e1: e1.value.number)
298+
.as_constraint('Count')
299+
]
300+
301+
score_manager = create_score_manager(define_constraints)
302+
entity_a1: Entity = Entity('A')
303+
entity_a2: Entity = Entity('A')
304+
entity_b1: Entity = Entity('B')
305+
entity_b2: Entity = Entity('B')
306+
307+
value_1 = Value(1)
308+
value_2 = Value(2)
309+
310+
problem = Solution([entity_a1, entity_a2, entity_b1, entity_b2], [value_1, value_2])
311+
312+
entity_a1.value = value_1
313+
314+
# With itself
315+
assert score_manager.explain(problem).score.score == 1
316+
317+
entity_a1.value = value_1
318+
entity_a2.value = value_1
319+
320+
entity_b1.value = value_2
321+
entity_b2.value = value_2
322+
323+
# 1 + 2 + 1 + 2
324+
assert score_manager.explain(problem).score.score == 6
325+
326+
entity_a1.value = value_2
327+
entity_b1.value = value_1
328+
329+
# 1 + 2 + 1 + 2
330+
assert score_manager.explain(problem).score.score == 6
331+
332+
333+
def test_if_not_exists_uni():
334+
@constraint_provider
335+
def define_constraints(constraint_factory: ConstraintFactory):
336+
return [
337+
constraint_factory.for_each(Entity)
338+
.if_not_exists(Entity, Joiners.equal(lambda entity: entity.code))
339+
.reward(SimpleScore.ONE, lambda e1: e1.value.number)
340+
.as_constraint('Count')
341+
]
342+
343+
score_manager = create_score_manager(define_constraints)
344+
entity_a1: Entity = Entity('A')
345+
entity_a2: Entity = Entity('A')
346+
entity_b1: Entity = Entity('B')
347+
entity_b2: Entity = Entity('B')
348+
349+
value_1 = Value(1)
350+
value_2 = Value(2)
351+
352+
problem = Solution([entity_a1, entity_a2, entity_b1, entity_b2], [value_1, value_2])
353+
354+
entity_a1.value = value_1
355+
356+
assert score_manager.explain(problem).score.score == 0
357+
358+
entity_a1.value = value_1
359+
entity_a2.value = value_1
360+
361+
entity_b1.value = value_2
362+
entity_b2.value = value_2
363+
364+
assert score_manager.explain(problem).score.score == 0
365+
366+
entity_a1.value = value_2
367+
entity_b1.value = value_1
368+
369+
assert score_manager.explain(problem).score.score == 0
370+
371+
268372
def test_map():
269373
@constraint_provider
270374
def define_constraints(constraint_factory: ConstraintFactory):
@@ -436,6 +540,41 @@ def define_constraints(constraint_factory: ConstraintFactory):
436540

437541
assert score_manager.explain(problem).score.score == 1
438542

543+
def test_complement():
544+
@constraint_provider
545+
def define_constraints(constraint_factory: ConstraintFactory):
546+
return [
547+
constraint_factory.for_each(Entity)
548+
.filter(lambda e: e.value.number == 1)
549+
.complement(Entity)
550+
.reward(SimpleScore.ONE)
551+
.as_constraint('Count')
552+
]
553+
554+
score_manager = create_score_manager(define_constraints)
555+
entity_a: Entity = Entity('A')
556+
entity_b: Entity = Entity('B')
557+
558+
value_1 = Value(1)
559+
value_2 = Value(2)
560+
value_3 = Value(3)
561+
562+
problem = Solution([entity_a, entity_b], [value_1, value_2, value_3])
563+
564+
assert score_manager.explain(problem).score.score == 0
565+
566+
entity_a.value = value_1
567+
568+
assert score_manager.explain(problem).score.score == 1
569+
570+
entity_b.value = value_2
571+
572+
assert score_manager.explain(problem).score.score == 2
573+
574+
entity_b.value = value_3
575+
576+
assert score_manager.explain(problem).score.score == 2
577+
439578

440579
def test_custom_indictments():
441580
@dataclass(unsafe_hash=True)
@@ -630,6 +769,7 @@ def define_constraints(constraint_factory: ConstraintFactory):
630769

631770

632771
def test_has_all_methods():
772+
missing = []
633773
for python_type, java_type in ((UniConstraintStream, JavaUniConstraintStream),
634774
(BiConstraintStream, JavaBiConstraintStream),
635775
(TriConstraintStream, JavaTriConstraintStream),
@@ -641,7 +781,6 @@ def test_has_all_methods():
641781
(Joiners, JavaJoiners),
642782
(ConstraintCollectors, JavaConstraintCollectors),
643783
(ConstraintFactory, JavaConstraintFactory)):
644-
missing = []
645784
for function_name, function_impl in inspect.getmembers(java_type, inspect.isfunction):
646785
if function_name in ignored_java_functions:
647786
continue
@@ -654,8 +793,10 @@ def test_has_all_methods():
654793
# change h_t_t_p -> http
655794
snake_case_name = re.sub('([a-z0-9])([A-Z])', r'\1_\2', snake_case_name).lower()
656795
if not hasattr(python_type, snake_case_name):
657-
missing.append(snake_case_name)
796+
missing.append((java_type, python_type, snake_case_name))
658797

659-
if missing:
660-
raise AssertionError(f'{python_type} is missing methods ({missing}) '
661-
f'from java_type ({java_type}).)')
798+
if missing:
799+
assertion_msg = ''
800+
for java_type, python_type, snake_case_name in missing:
801+
assertion_msg += f'{python_type} is missing a method ({snake_case_name}) from java_type ({java_type}).)\n'
802+
raise AssertionError(assertion_msg)

0 commit comments

Comments
 (0)