Skip to content

Commit 77b8574

Browse files
authored
Add support for using __bool__ method literal value in union narrowing in if statements (#9297)
This adds support for union narrowing in if statements when condition value has defined literal annotations in `__bool__` method. Value is narrowed based on the `__bool__` method return annotation and this works even if multiple instances defines the same literal value for `__bool__` method return type. This PR also works well with #9288 and makes below example to work as expected: ```python class A: def __bool__(self) -> Literal[True]: ... class B: def __bool__(self) -> Literal[False]: ... def get_thing() -> Union[A, B]: ... if x := get_thing(): reveal_type(x) # Revealed type is '__main__.A' else: reveal_type(x) # Revealed type is '__main__.B' ``` Partially fixes #9220
1 parent ea913ac commit 77b8574

File tree

2 files changed

+94
-2
lines changed

2 files changed

+94
-2
lines changed

mypy/typeops.py

+31-2
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,19 @@ def make_simplified_union(items: Sequence[Type],
378378
return UnionType.make_union(simplified_set, line, column)
379379

380380

381+
def get_type_special_method_bool_ret_type(t: Type) -> Optional[Type]:
382+
t = get_proper_type(t)
383+
384+
if isinstance(t, Instance):
385+
bool_method = t.type.names.get("__bool__", None)
386+
if bool_method:
387+
callee = get_proper_type(bool_method.type)
388+
if isinstance(callee, CallableType):
389+
return callee.ret_type
390+
391+
return None
392+
393+
381394
def true_only(t: Type) -> ProperType:
382395
"""
383396
Restricted version of t with only True-ish values
@@ -393,8 +406,16 @@ def true_only(t: Type) -> ProperType:
393406
elif isinstance(t, UnionType):
394407
# The true version of a union type is the union of the true versions of its components
395408
new_items = [true_only(item) for item in t.items]
396-
return make_simplified_union(new_items, line=t.line, column=t.column)
409+
can_be_true_items = [item for item in new_items if item.can_be_true]
410+
return make_simplified_union(can_be_true_items, line=t.line, column=t.column)
397411
else:
412+
ret_type = get_type_special_method_bool_ret_type(t)
413+
414+
if ret_type and ret_type.can_be_false and not ret_type.can_be_true:
415+
new_t = copy_type(t)
416+
new_t.can_be_true = False
417+
return new_t
418+
398419
new_t = copy_type(t)
399420
new_t.can_be_false = False
400421
return new_t
@@ -420,8 +441,16 @@ def false_only(t: Type) -> ProperType:
420441
elif isinstance(t, UnionType):
421442
# The false version of a union type is the union of the false versions of its components
422443
new_items = [false_only(item) for item in t.items]
423-
return make_simplified_union(new_items, line=t.line, column=t.column)
444+
can_be_false_items = [item for item in new_items if item.can_be_false]
445+
return make_simplified_union(can_be_false_items, line=t.line, column=t.column)
424446
else:
447+
ret_type = get_type_special_method_bool_ret_type(t)
448+
449+
if ret_type and ret_type.can_be_true and not ret_type.can_be_false:
450+
new_t = copy_type(t)
451+
new_t.can_be_false = False
452+
return new_t
453+
425454
new_t = copy_type(t)
426455
new_t.can_be_true = False
427456
return new_t

test-data/unit/check-literal.test

+63
Original file line numberDiff line numberDiff line change
@@ -3243,3 +3243,66 @@ assert c.a is True
32433243
c.update()
32443244
assert c.a is False
32453245
[builtins fixtures/bool.pyi]
3246+
3247+
[case testConditionalBoolLiteralUnionNarrowing]
3248+
# flags: --warn-unreachable
3249+
3250+
from typing import Union
3251+
from typing_extensions import Literal
3252+
3253+
class Truth:
3254+
def __bool__(self) -> Literal[True]: ...
3255+
3256+
class AlsoTruth:
3257+
def __bool__(self) -> Literal[True]: ...
3258+
3259+
class Lie:
3260+
def __bool__(self) -> Literal[False]: ...
3261+
3262+
class AnyAnswer:
3263+
def __bool__(self) -> bool: ...
3264+
3265+
class NoAnswerSpecified:
3266+
pass
3267+
3268+
x: Union[Truth, Lie]
3269+
3270+
if x:
3271+
reveal_type(x) # N: Revealed type is '__main__.Truth'
3272+
else:
3273+
reveal_type(x) # N: Revealed type is '__main__.Lie'
3274+
3275+
if not x:
3276+
reveal_type(x) # N: Revealed type is '__main__.Lie'
3277+
else:
3278+
reveal_type(x) # N: Revealed type is '__main__.Truth'
3279+
3280+
y: Union[Truth, AlsoTruth, Lie]
3281+
3282+
if y:
3283+
reveal_type(y) # N: Revealed type is 'Union[__main__.Truth, __main__.AlsoTruth]'
3284+
else:
3285+
reveal_type(y) # N: Revealed type is '__main__.Lie'
3286+
3287+
z: Union[Truth, AnyAnswer]
3288+
3289+
if z:
3290+
reveal_type(z) # N: Revealed type is 'Union[__main__.Truth, __main__.AnyAnswer]'
3291+
else:
3292+
reveal_type(z) # N: Revealed type is '__main__.AnyAnswer'
3293+
3294+
q: Union[Truth, NoAnswerSpecified]
3295+
3296+
if q:
3297+
reveal_type(q) # N: Revealed type is 'Union[__main__.Truth, __main__.NoAnswerSpecified]'
3298+
else:
3299+
reveal_type(q) # N: Revealed type is '__main__.NoAnswerSpecified'
3300+
3301+
w: Union[Truth, AlsoTruth]
3302+
3303+
if w:
3304+
reveal_type(w) # N: Revealed type is 'Union[__main__.Truth, __main__.AlsoTruth]'
3305+
else:
3306+
reveal_type(w) # E: Statement is unreachable
3307+
3308+
[builtins fixtures/bool.pyi]

0 commit comments

Comments
 (0)