Skip to content

Commit 19bcce4

Browse files
committed
bring back retryable var cuz default of false isn't great and move retryable var to bulk class to match client_bulk
1 parent 12e2f8e commit 19bcce4

11 files changed

+145
-60
lines changed

pymongo/asynchronous/bulk.py

+18-10
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ def gen_ordered(
258258
yield run
259259
run = _Run(op_type)
260260
run.add(idx, operation)
261-
self.is_retryable = self.is_retryable and retryable
261+
run.is_retryable = run.is_retryable and retryable
262262
if run is None:
263263
raise InvalidOperation("No operations to execute")
264264
yield run
@@ -276,7 +276,7 @@ def gen_unordered(
276276
retryable = process(request)
277277
(op_type, operation) = self.ops[idx]
278278
operations[op_type].add(idx, operation)
279-
self.is_retryable = self.is_retryable and retryable
279+
operations[op_type].is_retryable = operations[op_type].is_retryable and retryable
280280
if (
281281
len(operations[_INSERT].ops) == 0
282282
and len(operations[_UPDATE].ops) == 0
@@ -517,6 +517,7 @@ async def _execute_command(
517517
session: Optional[AsyncClientSession],
518518
conn: AsyncConnection,
519519
op_id: int,
520+
retryable: bool,
520521
full_result: MutableMapping[str, Any],
521522
validate: bool,
522523
final_write_concern: Optional[WriteConcern] = None,
@@ -536,6 +537,9 @@ async def _execute_command(
536537
last_run = False
537538

538539
while run:
540+
self.is_retryable = run.is_retryable
541+
self.retrying = run.retrying
542+
self.started_retryable_write = run.started_retryable_write
539543
if not self.retrying:
540544
self.next_run = next(generator, None)
541545
if self.next_run is None:
@@ -570,10 +574,13 @@ async def _execute_command(
570574
if session:
571575
# Start a new retryable write unless one was already
572576
# started for this command.
573-
if self.is_retryable and not self.started_retryable_write:
577+
if retryable and self.is_retryable and not self.started_retryable_write:
578+
# print("starting retrayable write")
574579
session._start_retryable_write()
575580
self.started_retryable_write = True
576-
session._apply_to(cmd, self.is_retryable, ReadPreference.PRIMARY, conn)
581+
session._apply_to(
582+
cmd, retryable and self.is_retryable, ReadPreference.PRIMARY, conn
583+
)
577584
conn.send_cluster_time(cmd, session, client)
578585
conn.add_server_api(cmd)
579586
# CSOT: apply timeout before encoding the command.
@@ -593,12 +600,10 @@ async def _execute_command(
593600
full = copy.deepcopy(full_result)
594601
_merge_command(run, full, run.idx_offset, result)
595602
_raise_bulk_write_error(full)
596-
597603
_merge_command(run, full_result, run.idx_offset, result)
598-
599604
# We're no longer in a retry once a command succeeds.
600-
run.retrying = False
601-
run.started_retryable_write = False
605+
self.retrying = False
606+
self.started_retryable_write = False
602607

603608
if self.ordered and "writeErrors" in result:
604609
break
@@ -636,27 +641,29 @@ async def execute_command(
636641
op_id = _randint()
637642

638643
async def retryable_bulk(
639-
session: Optional[AsyncClientSession],
640-
conn: AsyncConnection,
644+
session: Optional[AsyncClientSession], conn: AsyncConnection, retryable: bool
641645
) -> None:
642646
await self._execute_command(
643647
generator,
644648
write_concern,
645649
session,
646650
conn,
647651
op_id,
652+
retryable,
648653
full_result,
649654
validate=False,
650655
)
651656

652657
client = self.collection.database.client
653658
_ = await client._retryable_write(
659+
self.is_retryable,
654660
retryable_bulk,
655661
session,
656662
operation,
657663
bulk=self, # type: ignore[arg-type]
658664
operation_id=op_id,
659665
)
666+
660667
if full_result["writeErrors"] or full_result["writeConcernErrors"]:
661668
_raise_bulk_write_error(full_result)
662669
return full_result
@@ -730,6 +737,7 @@ async def execute_command_no_results(
730737
None,
731738
conn,
732739
op_id,
740+
False,
733741
full_result,
734742
True,
735743
write_concern,

pymongo/asynchronous/client_bulk.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,7 @@ async def _execute_command(
489489
session: Optional[AsyncClientSession],
490490
conn: AsyncConnection,
491491
op_id: int,
492+
retryable: bool,
492493
full_result: MutableMapping[str, Any],
493494
final_write_concern: Optional[WriteConcern] = None,
494495
) -> None:
@@ -534,10 +535,12 @@ async def _execute_command(
534535
if session:
535536
# Start a new retryable write unless one was already
536537
# started for this command.
537-
if self.is_retryable and not self.started_retryable_write:
538+
if retryable and self.is_retryable and not self.started_retryable_write:
538539
session._start_retryable_write()
539540
self.started_retryable_write = True
540-
session._apply_to(cmd, self.is_retryable, ReadPreference.PRIMARY, conn)
541+
session._apply_to(
542+
cmd, retryable and self.is_retryable, ReadPreference.PRIMARY, conn
543+
)
541544
conn.send_cluster_time(cmd, session, self.client)
542545
conn.add_server_api(cmd)
543546
# CSOT: apply timeout before encoding the command.
@@ -564,7 +567,11 @@ async def _execute_command(
564567

565568
# Synthesize the full bulk result without modifying the
566569
# current one because this write operation may be retried.
567-
if self.is_retryable and (retryable_top_level_error or retryable_network_error):
570+
if (
571+
retryable
572+
and self.is_retryable
573+
and (retryable_top_level_error or retryable_network_error)
574+
):
568575
full = copy.deepcopy(full_result)
569576
_merge_command(self.ops, self.idx_offset, full, result)
570577
_throw_client_bulk_write_exception(full, self.verbose_results)
@@ -583,7 +590,7 @@ async def _execute_command(
583590
_merge_command(self.ops, self.idx_offset, full_result, result)
584591
break
585592

586-
if self.is_retryable:
593+
if retryable and self.is_retryable:
587594
# Retryable writeConcernErrors halt the execution of this batch.
588595
wce = result.get("writeConcernError", {})
589596
if wce.get("code", 0) in _RETRYABLE_ERROR_CODES:
@@ -638,6 +645,7 @@ async def execute_command(
638645
async def retryable_bulk(
639646
session: Optional[AsyncClientSession],
640647
conn: AsyncConnection,
648+
retryable: bool,
641649
) -> None:
642650
if conn.max_wire_version < 25:
643651
raise InvalidOperation(
@@ -648,10 +656,12 @@ async def retryable_bulk(
648656
session,
649657
conn,
650658
op_id,
659+
retryable,
651660
full_result,
652661
)
653662

654663
await self.client._retryable_write(
664+
self.is_retryable,
655665
retryable_bulk,
656666
session,
657667
operation,

pymongo/asynchronous/client_session.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -854,12 +854,13 @@ async def _finish_transaction_with_retry(self, command_name: str) -> dict[str, A
854854
"""
855855

856856
async def func(
857-
_session: Optional[AsyncClientSession],
858-
conn: AsyncConnection,
857+
_session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable: bool
859858
) -> dict[str, Any]:
860859
return await self._finish_transaction(conn, command_name)
861860

862-
return await self._client._retry_internal(func, self, None, operation=_Op.ABORT)
861+
return await self._client._retry_internal(
862+
func, self, None, retryable=True, operation=_Op.ABORT
863+
)
863864

864865
async def _finish_transaction(self, conn: AsyncConnection, command_name: str) -> dict[str, Any]:
865866
self._transaction.attempt += 1

pymongo/asynchronous/collection.py

+22-7
Original file line numberDiff line numberDiff line change
@@ -785,7 +785,7 @@ async def bulk_write(
785785

786786
write_concern = self._write_concern_for(session)
787787

788-
def process_for_bulk(request: _WriteOp) -> bool:
788+
def process_for_bulk(request: Union[_DocumentType, RawBSONDocument, _WriteOp]) -> bool:
789789
try:
790790
return request._add_to_bulk(blk)
791791
except AttributeError:
@@ -810,27 +810,32 @@ async def _insert_one(
810810
) -> Any:
811811
"""Internal helper for inserting a single document."""
812812
write_concern = write_concern or self.write_concern
813+
acknowledged = write_concern.acknowledged
813814
command = {"insert": self.name, "ordered": ordered, "documents": [doc]}
814815
if comment is not None:
815816
command["comment"] = comment
816817

817818
async def _insert_command(
818-
session: Optional[AsyncClientSession], conn: AsyncConnection
819+
session: Optional[AsyncClientSession], conn: AsyncConnection, retryable_write: bool
819820
) -> None:
820821
if bypass_doc_val is not None:
821822
command["bypassDocumentValidation"] = bypass_doc_val
823+
822824
result = await conn.command(
823825
self._database.name,
824826
command,
825827
write_concern=write_concern,
826828
codec_options=self._write_response_codec_options,
827829
session=session,
828830
client=self._database.client,
831+
retryable_write=retryable_write,
829832
)
830833

831834
_check_write_command_response(result)
832835

833-
await self._database.client._retryable_write(_insert_command, session, operation=_Op.INSERT)
836+
await self._database.client._retryable_write(
837+
acknowledged, _insert_command, session, operation=_Op.INSERT
838+
)
834839

835840
if not isinstance(doc, RawBSONDocument):
836841
return doc.get("_id")
@@ -959,7 +964,7 @@ async def insert_many(
959964
raise TypeError("documents must be a non-empty list")
960965
inserted_ids: list[ObjectId] = []
961966

962-
def process_for_bulk(document: Union[_DocumentType, RawBSONDocument]) -> bool:
967+
def process_for_bulk(document: Union[_DocumentType, RawBSONDocument, _WriteOp]) -> bool:
963968
"""A generator that validates documents and handles _ids."""
964969
common.validate_is_document_type("document", document)
965970
if not isinstance(document, RawBSONDocument):
@@ -989,6 +994,7 @@ async def _update(
989994
array_filters: Optional[Sequence[Mapping[str, Any]]] = None,
990995
hint: Optional[_IndexKeyHint] = None,
991996
session: Optional[AsyncClientSession] = None,
997+
retryable_write: bool = False,
992998
let: Optional[Mapping[str, Any]] = None,
993999
sort: Optional[Mapping[str, Any]] = None,
9941000
comment: Optional[Any] = None,
@@ -1051,6 +1057,7 @@ async def _update(
10511057
codec_options=self._write_response_codec_options,
10521058
session=session,
10531059
client=self._database.client,
1060+
retryable_write=retryable_write,
10541061
)
10551062
).copy()
10561063
_check_write_command_response(result)
@@ -1090,7 +1097,7 @@ async def _update_retryable(
10901097
"""Internal update / replace helper."""
10911098

10921099
async def _update(
1093-
session: Optional[AsyncClientSession], conn: AsyncConnection
1100+
session: Optional[AsyncClientSession], conn: AsyncConnection, retryable_write: bool
10941101
) -> Optional[Mapping[str, Any]]:
10951102
return await self._update(
10961103
conn,
@@ -1106,12 +1113,14 @@ async def _update(
11061113
array_filters=array_filters,
11071114
hint=hint,
11081115
session=session,
1116+
retryable_write=retryable_write,
11091117
let=let,
11101118
sort=sort,
11111119
comment=comment,
11121120
)
11131121

11141122
return await self._database.client._retryable_write(
1123+
(write_concern or self.write_concern).acknowledged and not multi,
11151124
_update,
11161125
session,
11171126
operation,
@@ -1501,6 +1510,7 @@ async def _delete(
15011510
collation: Optional[_CollationIn] = None,
15021511
hint: Optional[_IndexKeyHint] = None,
15031512
session: Optional[AsyncClientSession] = None,
1513+
retryable_write: bool = False,
15041514
let: Optional[Mapping[str, Any]] = None,
15051515
comment: Optional[Any] = None,
15061516
) -> Mapping[str, Any]:
@@ -1540,6 +1550,7 @@ async def _delete(
15401550
codec_options=self._write_response_codec_options,
15411551
session=session,
15421552
client=self._database.client,
1553+
retryable_write=retryable_write,
15431554
)
15441555
_check_write_command_response(result)
15451556
return result
@@ -1560,7 +1571,7 @@ async def _delete_retryable(
15601571
"""Internal delete helper."""
15611572

15621573
async def _delete(
1563-
session: Optional[AsyncClientSession], conn: AsyncConnection
1574+
session: Optional[AsyncClientSession], conn: AsyncConnection, retryable_write: bool
15641575
) -> Mapping[str, Any]:
15651576
return await self._delete(
15661577
conn,
@@ -1572,11 +1583,13 @@ async def _delete(
15721583
collation=collation,
15731584
hint=hint,
15741585
session=session,
1586+
retryable_write=retryable_write,
15751587
let=let,
15761588
comment=comment,
15771589
)
15781590

15791591
return await self._database.client._retryable_write(
1592+
(write_concern or self.write_concern).acknowledged and not multi,
15801593
_delete,
15811594
session,
15821595
operation=_Op.DELETE,
@@ -3221,7 +3234,7 @@ async def _find_and_modify(
32213234
write_concern = self._write_concern_for_cmd(cmd, session)
32223235

32233236
async def _find_and_modify_helper(
3224-
session: Optional[AsyncClientSession], conn: AsyncConnection
3237+
session: Optional[AsyncClientSession], conn: AsyncConnection, retryable_write: bool
32253238
) -> Any:
32263239
acknowledged = write_concern.acknowledged
32273240
if array_filters is not None:
@@ -3247,13 +3260,15 @@ async def _find_and_modify_helper(
32473260
write_concern=write_concern,
32483261
collation=collation,
32493262
session=session,
3263+
retryable_write=retryable_write,
32503264
user_fields=_FIND_AND_MODIFY_DOC_FIELDS,
32513265
)
32523266
_check_write_command_response(out)
32533267

32543268
return out.get("value")
32553269

32563270
return await self._database.client._retryable_write(
3271+
write_concern.acknowledged,
32573272
_find_and_modify_helper,
32583273
session,
32593274
operation=_Op.FIND_AND_MODIFY,

0 commit comments

Comments
 (0)