Skip to content

Commit c11d0f4

Browse files
authored
PYTHON-5306: [v4.12] - Fix use of public MongoClient attributes before connection (#2285) (#2311)
1 parent f5836b3 commit c11d0f4

File tree

4 files changed

+70
-22
lines changed

4 files changed

+70
-22
lines changed

pymongo/asynchronous/mongo_client.py

+30-9
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@
109109
)
110110
from pymongo.read_preferences import ReadPreference, _ServerMode
111111
from pymongo.results import ClientBulkWriteResult
112+
from pymongo.server_description import ServerDescription
112113
from pymongo.server_selectors import writable_server_selector
113114
from pymongo.server_type import SERVER_TYPE
114115
from pymongo.topology_description import TOPOLOGY_TYPE, TopologyDescription
@@ -779,7 +780,7 @@ def __init__(
779780
keyword_opts["document_class"] = doc_class
780781
self._resolve_srv_info: dict[str, Any] = {"keyword_opts": keyword_opts}
781782

782-
seeds = set()
783+
self._seeds = set()
783784
is_srv = False
784785
username = None
785786
password = None
@@ -804,18 +805,18 @@ def __init__(
804805
srv_max_hosts=srv_max_hosts,
805806
)
806807
is_srv = entity.startswith(SRV_SCHEME)
807-
seeds.update(res["nodelist"])
808+
self._seeds.update(res["nodelist"])
808809
username = res["username"] or username
809810
password = res["password"] or password
810811
dbase = res["database"] or dbase
811812
opts = res["options"]
812813
fqdn = res["fqdn"]
813814
else:
814-
seeds.update(split_hosts(entity, self._port))
815-
if not seeds:
815+
self._seeds.update(split_hosts(entity, self._port))
816+
if not self._seeds:
816817
raise ConfigurationError("need to specify at least one host")
817818

818-
for hostname in [node[0] for node in seeds]:
819+
for hostname in [node[0] for node in self._seeds]:
819820
if _detect_external_db(hostname):
820821
break
821822

@@ -838,7 +839,7 @@ def __init__(
838839
srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME)
839840

840841
srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts")
841-
opts = self._normalize_and_validate_options(opts, seeds)
842+
opts = self._normalize_and_validate_options(opts, self._seeds)
842843

843844
# Username and password passed as kwargs override user info in URI.
844845
username = opts.get("username", username)
@@ -857,7 +858,7 @@ def __init__(
857858
"username": username,
858859
"password": password,
859860
"dbase": dbase,
860-
"seeds": seeds,
861+
"seeds": self._seeds,
861862
"fqdn": fqdn,
862863
"srv_service_name": srv_service_name,
863864
"pool_class": pool_class,
@@ -873,8 +874,7 @@ def __init__(
873874
self._options.read_concern,
874875
)
875876

876-
if not is_srv:
877-
self._init_based_on_options(seeds, srv_max_hosts, srv_service_name)
877+
self._init_based_on_options(self._seeds, srv_max_hosts, srv_service_name)
878878

879879
self._opened = False
880880
self._closed = False
@@ -975,6 +975,7 @@ def _init_based_on_options(
975975
srv_service_name=srv_service_name,
976976
srv_max_hosts=srv_max_hosts,
977977
server_monitoring_mode=self._options.server_monitoring_mode,
978+
topology_id=self._topology_settings._topology_id if self._topology_settings else None,
978979
)
979980
if self._options.auto_encryption_opts:
980981
from pymongo.asynchronous.encryption import _Encrypter
@@ -1205,6 +1206,16 @@ def topology_description(self) -> TopologyDescription:
12051206
12061207
.. versionadded:: 4.0
12071208
"""
1209+
if self._topology is None:
1210+
servers = {(host, port): ServerDescription((host, port)) for host, port in self._seeds}
1211+
return TopologyDescription(
1212+
TOPOLOGY_TYPE.Unknown,
1213+
servers,
1214+
None,
1215+
None,
1216+
None,
1217+
self._topology_settings,
1218+
)
12081219
return self._topology.description
12091220

12101221
@property
@@ -1218,6 +1229,8 @@ def nodes(self) -> FrozenSet[_Address]:
12181229
to any servers, or a network partition causes it to lose connection
12191230
to all servers.
12201231
"""
1232+
if self._topology is None:
1233+
return frozenset()
12211234
description = self._topology.description
12221235
return frozenset(s.address for s in description.known_servers)
12231236

@@ -1576,6 +1589,8 @@ async def address(self) -> Optional[tuple[str, int]]:
15761589
15771590
.. versionadded:: 3.0
15781591
"""
1592+
if self._topology is None:
1593+
await self._get_topology()
15791594
topology_type = self._topology._description.topology_type
15801595
if (
15811596
topology_type == TOPOLOGY_TYPE.Sharded
@@ -1598,6 +1613,8 @@ async def primary(self) -> Optional[tuple[str, int]]:
15981613
.. versionadded:: 3.0
15991614
AsyncMongoClient gained this property in version 3.0.
16001615
"""
1616+
if self._topology is None:
1617+
await self._get_topology()
16011618
return await self._topology.get_primary() # type: ignore[return-value]
16021619

16031620
@property
@@ -1611,6 +1628,8 @@ async def secondaries(self) -> set[_Address]:
16111628
.. versionadded:: 3.0
16121629
AsyncMongoClient gained this property in version 3.0.
16131630
"""
1631+
if self._topology is None:
1632+
await self._get_topology()
16141633
return await self._topology.get_secondaries()
16151634

16161635
@property
@@ -1621,6 +1640,8 @@ async def arbiters(self) -> set[_Address]:
16211640
connected to a replica set, there are no arbiters, or this client was
16221641
created without the `replicaSet` option.
16231642
"""
1643+
if self._topology is None:
1644+
await self._get_topology()
16241645
return await self._topology.get_arbiters()
16251646

16261647
@property

pymongo/asynchronous/settings.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def __init__(
5151
srv_service_name: str = common.SRV_SERVICE_NAME,
5252
srv_max_hosts: int = 0,
5353
server_monitoring_mode: str = common.SERVER_MONITORING_MODE,
54+
topology_id: Optional[ObjectId] = None,
5455
):
5556
"""Represent MongoClient's configuration.
5657
@@ -78,8 +79,10 @@ def __init__(
7879
self._srv_service_name = srv_service_name
7980
self._srv_max_hosts = srv_max_hosts or 0
8081
self._server_monitoring_mode = server_monitoring_mode
81-
82-
self._topology_id = ObjectId()
82+
if topology_id is not None:
83+
self._topology_id = topology_id
84+
else:
85+
self._topology_id = ObjectId()
8386
# Store the allocation traceback to catch unclosed clients in the
8487
# test suite.
8588
self._stack = "".join(traceback.format_stack()[:-2])

pymongo/synchronous/mongo_client.py

+30-9
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@
101101
)
102102
from pymongo.read_preferences import ReadPreference, _ServerMode
103103
from pymongo.results import ClientBulkWriteResult
104+
from pymongo.server_description import ServerDescription
104105
from pymongo.server_selectors import writable_server_selector
105106
from pymongo.server_type import SERVER_TYPE
106107
from pymongo.synchronous import client_session, database, uri_parser
@@ -777,7 +778,7 @@ def __init__(
777778
keyword_opts["document_class"] = doc_class
778779
self._resolve_srv_info: dict[str, Any] = {"keyword_opts": keyword_opts}
779780

780-
seeds = set()
781+
self._seeds = set()
781782
is_srv = False
782783
username = None
783784
password = None
@@ -802,18 +803,18 @@ def __init__(
802803
srv_max_hosts=srv_max_hosts,
803804
)
804805
is_srv = entity.startswith(SRV_SCHEME)
805-
seeds.update(res["nodelist"])
806+
self._seeds.update(res["nodelist"])
806807
username = res["username"] or username
807808
password = res["password"] or password
808809
dbase = res["database"] or dbase
809810
opts = res["options"]
810811
fqdn = res["fqdn"]
811812
else:
812-
seeds.update(split_hosts(entity, self._port))
813-
if not seeds:
813+
self._seeds.update(split_hosts(entity, self._port))
814+
if not self._seeds:
814815
raise ConfigurationError("need to specify at least one host")
815816

816-
for hostname in [node[0] for node in seeds]:
817+
for hostname in [node[0] for node in self._seeds]:
817818
if _detect_external_db(hostname):
818819
break
819820

@@ -836,7 +837,7 @@ def __init__(
836837
srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME)
837838

838839
srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts")
839-
opts = self._normalize_and_validate_options(opts, seeds)
840+
opts = self._normalize_and_validate_options(opts, self._seeds)
840841

841842
# Username and password passed as kwargs override user info in URI.
842843
username = opts.get("username", username)
@@ -855,7 +856,7 @@ def __init__(
855856
"username": username,
856857
"password": password,
857858
"dbase": dbase,
858-
"seeds": seeds,
859+
"seeds": self._seeds,
859860
"fqdn": fqdn,
860861
"srv_service_name": srv_service_name,
861862
"pool_class": pool_class,
@@ -871,8 +872,7 @@ def __init__(
871872
self._options.read_concern,
872873
)
873874

874-
if not is_srv:
875-
self._init_based_on_options(seeds, srv_max_hosts, srv_service_name)
875+
self._init_based_on_options(self._seeds, srv_max_hosts, srv_service_name)
876876

877877
self._opened = False
878878
self._closed = False
@@ -973,6 +973,7 @@ def _init_based_on_options(
973973
srv_service_name=srv_service_name,
974974
srv_max_hosts=srv_max_hosts,
975975
server_monitoring_mode=self._options.server_monitoring_mode,
976+
topology_id=self._topology_settings._topology_id if self._topology_settings else None,
976977
)
977978
if self._options.auto_encryption_opts:
978979
from pymongo.synchronous.encryption import _Encrypter
@@ -1203,6 +1204,16 @@ def topology_description(self) -> TopologyDescription:
12031204
12041205
.. versionadded:: 4.0
12051206
"""
1207+
if self._topology is None:
1208+
servers = {(host, port): ServerDescription((host, port)) for host, port in self._seeds}
1209+
return TopologyDescription(
1210+
TOPOLOGY_TYPE.Unknown,
1211+
servers,
1212+
None,
1213+
None,
1214+
None,
1215+
self._topology_settings,
1216+
)
12061217
return self._topology.description
12071218

12081219
@property
@@ -1216,6 +1227,8 @@ def nodes(self) -> FrozenSet[_Address]:
12161227
to any servers, or a network partition causes it to lose connection
12171228
to all servers.
12181229
"""
1230+
if self._topology is None:
1231+
return frozenset()
12191232
description = self._topology.description
12201233
return frozenset(s.address for s in description.known_servers)
12211234

@@ -1570,6 +1583,8 @@ def address(self) -> Optional[tuple[str, int]]:
15701583
15711584
.. versionadded:: 3.0
15721585
"""
1586+
if self._topology is None:
1587+
self._get_topology()
15731588
topology_type = self._topology._description.topology_type
15741589
if (
15751590
topology_type == TOPOLOGY_TYPE.Sharded
@@ -1592,6 +1607,8 @@ def primary(self) -> Optional[tuple[str, int]]:
15921607
.. versionadded:: 3.0
15931608
MongoClient gained this property in version 3.0.
15941609
"""
1610+
if self._topology is None:
1611+
self._get_topology()
15951612
return self._topology.get_primary() # type: ignore[return-value]
15961613

15971614
@property
@@ -1605,6 +1622,8 @@ def secondaries(self) -> set[_Address]:
16051622
.. versionadded:: 3.0
16061623
MongoClient gained this property in version 3.0.
16071624
"""
1625+
if self._topology is None:
1626+
self._get_topology()
16081627
return self._topology.get_secondaries()
16091628

16101629
@property
@@ -1615,6 +1634,8 @@ def arbiters(self) -> set[_Address]:
16151634
connected to a replica set, there are no arbiters, or this client was
16161635
created without the `replicaSet` option.
16171636
"""
1637+
if self._topology is None:
1638+
self._get_topology()
16181639
return self._topology.get_arbiters()
16191640

16201641
@property

pymongo/synchronous/settings.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def __init__(
5151
srv_service_name: str = common.SRV_SERVICE_NAME,
5252
srv_max_hosts: int = 0,
5353
server_monitoring_mode: str = common.SERVER_MONITORING_MODE,
54+
topology_id: Optional[ObjectId] = None,
5455
):
5556
"""Represent MongoClient's configuration.
5657
@@ -78,8 +79,10 @@ def __init__(
7879
self._srv_service_name = srv_service_name
7980
self._srv_max_hosts = srv_max_hosts or 0
8081
self._server_monitoring_mode = server_monitoring_mode
81-
82-
self._topology_id = ObjectId()
82+
if topology_id is not None:
83+
self._topology_id = topology_id
84+
else:
85+
self._topology_id = ObjectId()
8386
# Store the allocation traceback to catch unclosed clients in the
8487
# test suite.
8588
self._stack = "".join(traceback.format_stack()[:-2])

0 commit comments

Comments
 (0)