From 3530c90684339f7c01381369276dca5daf35c776 Mon Sep 17 00:00:00 2001
From: Richard van Velzen <richard@letsgetdigital.com>
Date: Tue, 13 May 2025 21:14:54 +0200
Subject: [PATCH] Make Iterator::current() and ::key() nullable

---
 src/Analyser/NodeScopeResolver.php       | 25 +++++++++++--
 stubs/iterable.stub                      | 13 +++++--
 tests/PHPStan/Analyser/nsrt/bug-3674.php | 45 ++++++++++++++++++++++++
 tests/PHPStan/Analyser/nsrt/bug-7519.php |  8 ++---
 4 files changed, 83 insertions(+), 8 deletions(-)
 create mode 100644 tests/PHPStan/Analyser/nsrt/bug-3674.php

diff --git a/src/Analyser/NodeScopeResolver.php b/src/Analyser/NodeScopeResolver.php
index 58aa455c4a..cfd1bce4d1 100644
--- a/src/Analyser/NodeScopeResolver.php
+++ b/src/Analyser/NodeScopeResolver.php
@@ -5,6 +5,7 @@
 use ArrayAccess;
 use Closure;
 use DivisionByZeroError;
+use Iterator;
 use PhpParser\Comment\Doc;
 use PhpParser\Modifiers;
 use PhpParser\Node;
@@ -1184,6 +1185,14 @@ private function processStmtNode(
 				$stmt->expr,
 				new Array_([]),
 			);
+			$exprType = $scope->getType($stmt->expr);
+			$iteratorValidExpr = null;
+			if ((new ObjectType(Iterator::class))->isSuperTypeOf($exprType)->yes()) {
+				$iteratorValidExpr = new BinaryOp\Identical(
+					new MethodCall($stmt->expr, 'valid'),
+					new ConstFetch(new Name\FullyQualified('true')),
+				);
+			}
 			if ($stmt->expr instanceof Variable && is_string($stmt->expr->name)) {
 				$scope = $this->processVarAnnotation($scope, [$stmt->expr->name], $stmt);
 			}
@@ -1193,11 +1202,17 @@ private function processStmtNode(
 
 			if ($context->isTopLevel()) {
 				$originalScope = $this->polluteScopeWithAlwaysIterableForeach ? $scope->filterByTruthyValue($arrayComparisonExpr) : $scope;
+				if ($iteratorValidExpr !== null) {
+					$originalScope = $originalScope->filterByTruthyValue($iteratorValidExpr);
+				}
 				$bodyScope = $this->enterForeach($originalScope, $originalScope, $stmt);
 				$count = 0;
 				do {
 					$prevScope = $bodyScope;
 					$bodyScope = $bodyScope->mergeWith($this->polluteScopeWithAlwaysIterableForeach ? $scope->filterByTruthyValue($arrayComparisonExpr) : $scope);
+					if ($iteratorValidExpr !== null) {
+						$bodyScope = $bodyScope->filterByTruthyValue($iteratorValidExpr);
+					}
 					$bodyScope = $this->enterForeach($bodyScope, $originalScope, $stmt);
 					$bodyScopeResult = $this->processStmtNodes($stmt, $stmt->stmts, $bodyScope, static function (): void {
 					}, $context->enterDeep())->filterOutLoopExitPoints();
@@ -1217,6 +1232,9 @@ private function processStmtNode(
 			}
 
 			$bodyScope = $bodyScope->mergeWith($this->polluteScopeWithAlwaysIterableForeach ? $scope->filterByTruthyValue($arrayComparisonExpr) : $scope);
+			if ($iteratorValidExpr !== null) {
+				$bodyScope = $bodyScope->filterByTruthyValue($iteratorValidExpr);
+			}
 			$bodyScope = $this->enterForeach($bodyScope, $originalScope, $stmt);
 			$finalScopeResult = $this->processStmtNodes($stmt, $stmt->stmts, $bodyScope, $nodeCallback, $context)->filterOutLoopExitPoints();
 			$finalScope = $finalScopeResult->getScope();
@@ -1227,7 +1245,6 @@ private function processStmtNode(
 				$finalScope = $breakExitPoint->getScope()->mergeWith($finalScope);
 			}
 
-			$exprType = $scope->getType($stmt->expr);
 			$isIterableAtLeastOnce = $exprType->isIterableAtLeastOnce();
 			if ($exprType->isIterable()->no() || $isIterableAtLeastOnce->maybe()) {
 				$finalScope = $finalScope->mergeWith($scope->filterByTruthyValue(new BooleanOr(
@@ -1250,10 +1267,14 @@ private function processStmtNode(
 				$throwPoints = array_merge($throwPoints, $finalScopeResult->getThrowPoints());
 				$impurePoints = array_merge($impurePoints, $finalScopeResult->getImpurePoints());
 			}
-			if (!(new ObjectType(Traversable::class))->isSuperTypeOf($scope->getType($stmt->expr))->no()) {
+			if (!(new ObjectType(Traversable::class))->isSuperTypeOf($exprType)->no()) {
 				$throwPoints[] = ThrowPoint::createImplicit($scope, $stmt->expr);
 			}
 
+			if ($iteratorValidExpr !== null) {
+				$finalScope = $finalScope->filterByFalseyValue($iteratorValidExpr);
+			}
+
 			return new StatementResult(
 				$finalScope,
 				$finalScopeResult->hasYield() || $condResult->hasYield(),
diff --git a/stubs/iterable.stub b/stubs/iterable.stub
index 6f78862316..5d438f5b34 100644
--- a/stubs/iterable.stub
+++ b/stubs/iterable.stub
@@ -34,12 +34,21 @@ interface Iterator extends Traversable
 {
 
 	/**
-	 * @return TValue
+	 * @phpstan-assert-if-true =TValue $this->current()
+	 * @phpstan-assert-if-false =null $this->current()
+	 *
+	 * @phpstan-assert-if-true =TKey $this->key()
+	 * @phpstan-assert-if-false =null $this->key()
+	 */
+	public function valid(): bool;
+
+	/**
+	 * @return TValue|null
 	 */
 	public function current();
 
 	/**
-	 * @return TKey
+	 * @return TKey|null
 	 */
 	public function key();
 
diff --git a/tests/PHPStan/Analyser/nsrt/bug-3674.php b/tests/PHPStan/Analyser/nsrt/bug-3674.php
new file mode 100644
index 0000000000..10b0edd831
--- /dev/null
+++ b/tests/PHPStan/Analyser/nsrt/bug-3674.php
@@ -0,0 +1,45 @@
+<?php declare(strict_types = 1);
+
+namespace Bug3674;
+
+use Iterator;
+use function PHPStan\Testing\assertType;
+
+/**
+ * @param Iterator<int> $it
+ */
+function foo(Iterator $it): void {
+	assertType('int|null', $it->current());
+
+	if ($it->valid()) {
+		assertType('int', $it->current());
+
+		$it->rewind();
+
+		assertType('int|null', $it->current());
+
+		if ($it->valid()) {
+			assertType('int', $it->current());
+		} else {
+			assertType('null', $it->current());
+		}
+	} else {
+		assertType('null', $it->current());
+	}
+}
+
+/**
+ * @param Iterator<int> $it
+ */
+function bar(Iterator $it): void {
+	assertType('bool', $it->valid());
+	assertType('int|null', $it->current());
+
+	foreach ($it as $v) {
+		assertType('true', $it->valid());
+		assertType('int', $it->current());
+	}
+
+	assertType('false', $it->valid());
+	assertType('null', $it->current());
+}
diff --git a/tests/PHPStan/Analyser/nsrt/bug-7519.php b/tests/PHPStan/Analyser/nsrt/bug-7519.php
index 1fd556f0e3..b8201139bb 100644
--- a/tests/PHPStan/Analyser/nsrt/bug-7519.php
+++ b/tests/PHPStan/Analyser/nsrt/bug-7519.php
@@ -41,8 +41,8 @@ function doFoo() {
 
 	$iterator = new FooFilterIterator($generator());
 
-	assertType('array{}|bool|stdClass', $iterator->key());
-	assertType('array{}|bool|stdClass', $iterator->current());
+	assertType('array{}|bool|stdClass|null', $iterator->key());
+	assertType('array{}|bool|stdClass|null', $iterator->current());
 
 	$generator = static function (): Generator {
 		yield true => true;
@@ -51,6 +51,6 @@ function doFoo() {
 
 	$iterator = new FooFilterIterator($generator());
 
-	assertType('bool', $iterator->key());
-	assertType('bool', $iterator->current());
+	assertType('bool|null', $iterator->key());
+	assertType('bool|null', $iterator->current());
 }