diff --git a/src/Analyser/NodeScopeResolver.php b/src/Analyser/NodeScopeResolver.php index 58aa455c4a..cfd1bce4d1 100644 --- a/src/Analyser/NodeScopeResolver.php +++ b/src/Analyser/NodeScopeResolver.php @@ -5,6 +5,7 @@ use ArrayAccess; use Closure; use DivisionByZeroError; +use Iterator; use PhpParser\Comment\Doc; use PhpParser\Modifiers; use PhpParser\Node; @@ -1184,6 +1185,14 @@ private function processStmtNode( $stmt->expr, new Array_([]), ); + $exprType = $scope->getType($stmt->expr); + $iteratorValidExpr = null; + if ((new ObjectType(Iterator::class))->isSuperTypeOf($exprType)->yes()) { + $iteratorValidExpr = new BinaryOp\Identical( + new MethodCall($stmt->expr, 'valid'), + new ConstFetch(new Name\FullyQualified('true')), + ); + } if ($stmt->expr instanceof Variable && is_string($stmt->expr->name)) { $scope = $this->processVarAnnotation($scope, [$stmt->expr->name], $stmt); } @@ -1193,11 +1202,17 @@ private function processStmtNode( if ($context->isTopLevel()) { $originalScope = $this->polluteScopeWithAlwaysIterableForeach ? $scope->filterByTruthyValue($arrayComparisonExpr) : $scope; + if ($iteratorValidExpr !== null) { + $originalScope = $originalScope->filterByTruthyValue($iteratorValidExpr); + } $bodyScope = $this->enterForeach($originalScope, $originalScope, $stmt); $count = 0; do { $prevScope = $bodyScope; $bodyScope = $bodyScope->mergeWith($this->polluteScopeWithAlwaysIterableForeach ? $scope->filterByTruthyValue($arrayComparisonExpr) : $scope); + if ($iteratorValidExpr !== null) { + $bodyScope = $bodyScope->filterByTruthyValue($iteratorValidExpr); + } $bodyScope = $this->enterForeach($bodyScope, $originalScope, $stmt); $bodyScopeResult = $this->processStmtNodes($stmt, $stmt->stmts, $bodyScope, static function (): void { }, $context->enterDeep())->filterOutLoopExitPoints(); @@ -1217,6 +1232,9 @@ private function processStmtNode( } $bodyScope = $bodyScope->mergeWith($this->polluteScopeWithAlwaysIterableForeach ? $scope->filterByTruthyValue($arrayComparisonExpr) : $scope); + if ($iteratorValidExpr !== null) { + $bodyScope = $bodyScope->filterByTruthyValue($iteratorValidExpr); + } $bodyScope = $this->enterForeach($bodyScope, $originalScope, $stmt); $finalScopeResult = $this->processStmtNodes($stmt, $stmt->stmts, $bodyScope, $nodeCallback, $context)->filterOutLoopExitPoints(); $finalScope = $finalScopeResult->getScope(); @@ -1227,7 +1245,6 @@ private function processStmtNode( $finalScope = $breakExitPoint->getScope()->mergeWith($finalScope); } - $exprType = $scope->getType($stmt->expr); $isIterableAtLeastOnce = $exprType->isIterableAtLeastOnce(); if ($exprType->isIterable()->no() || $isIterableAtLeastOnce->maybe()) { $finalScope = $finalScope->mergeWith($scope->filterByTruthyValue(new BooleanOr( @@ -1250,10 +1267,14 @@ private function processStmtNode( $throwPoints = array_merge($throwPoints, $finalScopeResult->getThrowPoints()); $impurePoints = array_merge($impurePoints, $finalScopeResult->getImpurePoints()); } - if (!(new ObjectType(Traversable::class))->isSuperTypeOf($scope->getType($stmt->expr))->no()) { + if (!(new ObjectType(Traversable::class))->isSuperTypeOf($exprType)->no()) { $throwPoints[] = ThrowPoint::createImplicit($scope, $stmt->expr); } + if ($iteratorValidExpr !== null) { + $finalScope = $finalScope->filterByFalseyValue($iteratorValidExpr); + } + return new StatementResult( $finalScope, $finalScopeResult->hasYield() || $condResult->hasYield(), diff --git a/stubs/iterable.stub b/stubs/iterable.stub index 6f78862316..5d438f5b34 100644 --- a/stubs/iterable.stub +++ b/stubs/iterable.stub @@ -34,12 +34,21 @@ interface Iterator extends Traversable { /** - * @return TValue + * @phpstan-assert-if-true =TValue $this->current() + * @phpstan-assert-if-false =null $this->current() + * + * @phpstan-assert-if-true =TKey $this->key() + * @phpstan-assert-if-false =null $this->key() + */ + public function valid(): bool; + + /** + * @return TValue|null */ public function current(); /** - * @return TKey + * @return TKey|null */ public function key(); diff --git a/tests/PHPStan/Analyser/nsrt/bug-3674.php b/tests/PHPStan/Analyser/nsrt/bug-3674.php new file mode 100644 index 0000000000..10b0edd831 --- /dev/null +++ b/tests/PHPStan/Analyser/nsrt/bug-3674.php @@ -0,0 +1,45 @@ + $it + */ +function foo(Iterator $it): void { + assertType('int|null', $it->current()); + + if ($it->valid()) { + assertType('int', $it->current()); + + $it->rewind(); + + assertType('int|null', $it->current()); + + if ($it->valid()) { + assertType('int', $it->current()); + } else { + assertType('null', $it->current()); + } + } else { + assertType('null', $it->current()); + } +} + +/** + * @param Iterator $it + */ +function bar(Iterator $it): void { + assertType('bool', $it->valid()); + assertType('int|null', $it->current()); + + foreach ($it as $v) { + assertType('true', $it->valid()); + assertType('int', $it->current()); + } + + assertType('false', $it->valid()); + assertType('null', $it->current()); +} diff --git a/tests/PHPStan/Analyser/nsrt/bug-7519.php b/tests/PHPStan/Analyser/nsrt/bug-7519.php index 1fd556f0e3..b8201139bb 100644 --- a/tests/PHPStan/Analyser/nsrt/bug-7519.php +++ b/tests/PHPStan/Analyser/nsrt/bug-7519.php @@ -41,8 +41,8 @@ function doFoo() { $iterator = new FooFilterIterator($generator()); - assertType('array{}|bool|stdClass', $iterator->key()); - assertType('array{}|bool|stdClass', $iterator->current()); + assertType('array{}|bool|stdClass|null', $iterator->key()); + assertType('array{}|bool|stdClass|null', $iterator->current()); $generator = static function (): Generator { yield true => true; @@ -51,6 +51,6 @@ function doFoo() { $iterator = new FooFilterIterator($generator()); - assertType('bool', $iterator->key()); - assertType('bool', $iterator->current()); + assertType('bool|null', $iterator->key()); + assertType('bool|null', $iterator->current()); }