@@ -641,7 +641,6 @@ def __eq__(self, other: object) -> bool:
641
641
642
642
@_core .enable_ki_protection
643
643
async def _protected_async_gen_fn () -> AsyncGenerator [None , None ]:
644
- return
645
644
yield
646
645
647
646
@@ -652,13 +651,11 @@ async def _protected_async_fn() -> None:
652
651
653
652
@_core .enable_ki_protection
654
653
def _protected_gen_fn () -> Generator [None , None , None ]:
655
- return
656
654
yield
657
655
658
656
659
657
@_core .disable_ki_protection
660
658
async def _unprotected_async_gen_fn () -> AsyncGenerator [None , None ]:
661
- return
662
659
yield
663
660
664
661
@@ -669,20 +666,27 @@ async def _unprotected_async_fn() -> None:
669
666
670
667
@_core .disable_ki_protection
671
668
def _unprotected_gen_fn () -> Generator [None , None , None ]:
672
- return
673
669
yield
674
670
675
671
672
+ async def _consume_async_generator (agen : AsyncGenerator [None , None ]) -> None :
673
+ try :
674
+ with pytest .raises (StopAsyncIteration ):
675
+ while True :
676
+ await agen .asend (None )
677
+ finally :
678
+ await agen .aclose ()
679
+
680
+
676
681
def _consume_function_for_coverage (fn : Callable [..., object ]) -> None :
677
682
result = fn ()
678
683
if inspect .isasyncgen (result ):
679
- with pytest .raises (StopAsyncIteration ):
680
- result .asend (None ).send (None )
681
- return
684
+ result = _consume_async_generator (result )
682
685
683
686
assert inspect .isgenerator (result ) or inspect .iscoroutine (result )
684
687
with pytest .raises (StopIteration ):
685
- result .send (None )
688
+ while True :
689
+ result .send (None )
686
690
687
691
688
692
def test_enable_disable_ki_protection_passes_on_inspect_flags () -> None :
0 commit comments