Skip to content

Commit 6670461

Browse files
authored
PYTHON-3289 Apply client timeoutMS to every operation (#1011)
1 parent 5c38676 commit 6670461

File tree

5 files changed

+44
-20
lines changed

5 files changed

+44
-20
lines changed

pymongo/_csot.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414

1515
"""Internal helpers for CSOT."""
1616

17+
import functools
1718
import time
1819
from contextvars import ContextVar, Token
19-
from typing import Optional, Tuple
20+
from typing import Any, Callable, Optional, Tuple, TypeVar, cast
2021

2122
TIMEOUT: ContextVar[Optional[float]] = ContextVar("TIMEOUT", default=None)
2223
RTT: ContextVar[float] = ContextVar("RTT", default=0.0)
@@ -83,3 +84,22 @@ def __exit__(self, exc_type, exc_val, exc_tb):
8384
TIMEOUT.reset(timeout_token)
8485
DEADLINE.reset(deadline_token)
8586
RTT.reset(rtt_token)
87+
88+
89+
# See https://mypy.readthedocs.io/en/stable/generics.html?#decorator-factories
90+
F = TypeVar("F", bound=Callable[..., Any])
91+
92+
93+
def apply(func: F) -> F:
94+
"""Apply the client's timeoutMS to this operation."""
95+
96+
@functools.wraps(func)
97+
def csot_wrapper(self, *args, **kwargs):
98+
if get_timeout() is None:
99+
timeout = self._timeout
100+
if timeout is not None:
101+
with _TimeoutContext(timeout):
102+
return func(self, *args, **kwargs)
103+
return func(self, *args, **kwargs)
104+
105+
return cast(F, csot_wrapper)

pymongo/collection.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from bson.raw_bson import RawBSONDocument
3636
from bson.son import SON
3737
from bson.timestamp import Timestamp
38-
from pymongo import ASCENDING, common, helpers, message
38+
from pymongo import ASCENDING, _csot, common, helpers, message
3939
from pymongo.aggregation import (
4040
_CollectionAggregationCommand,
4141
_CollectionRawAggregationCommand,
@@ -217,6 +217,10 @@ def __init__(
217217
self.__database: Database[_DocumentType] = database
218218
self.__name = name
219219
self.__full_name = "%s.%s" % (self.__database.name, self.__name)
220+
self.__write_response_codec_options = self.codec_options._replace(
221+
unicode_decode_error_handler="replace", document_class=dict
222+
)
223+
self._timeout = database.client.options.timeout
220224
encrypted_fields = kwargs.pop("encryptedFields", None)
221225
if create or kwargs or collation:
222226
if encrypted_fields:
@@ -230,10 +234,6 @@ def __init__(
230234
else:
231235
self.__create(name, kwargs, collation, session)
232236

233-
self.__write_response_codec_options = self.codec_options._replace(
234-
unicode_decode_error_handler="replace", document_class=dict
235-
)
236-
237237
def _socket_for_reads(self, session):
238238
return self.__database.client._socket_for_reads(self._read_preference_for(session), session)
239239

@@ -433,6 +433,7 @@ def with_options(
433433
read_concern or self.read_concern,
434434
)
435435

436+
@_csot.apply
436437
def bulk_write(
437438
self,
438439
requests: Sequence[_WriteOp],
@@ -631,6 +632,7 @@ def insert_one(
631632
write_concern.acknowledged,
632633
)
633634

635+
@_csot.apply
634636
def insert_many(
635637
self,
636638
documents: Iterable[_DocumentIn],
@@ -1892,6 +1894,7 @@ def create_indexes(
18921894
kwargs["comment"] = comment
18931895
return self.__create_indexes(indexes, session, **kwargs)
18941896

1897+
@_csot.apply
18951898
def __create_indexes(self, indexes, session, **kwargs):
18961899
"""Internal createIndexes helper.
18971900
@@ -2088,6 +2091,7 @@ def drop_indexes(
20882091
kwargs["comment"] = comment
20892092
self.drop_index("*", session=session, **kwargs)
20902093

2094+
@_csot.apply
20912095
def drop_index(
20922096
self,
20932097
index_or_name: _IndexKeyHint,
@@ -2311,6 +2315,7 @@ def options(
23112315

23122316
return options
23132317

2318+
@_csot.apply
23142319
def _aggregate(
23152320
self,
23162321
aggregation_command,
@@ -2618,6 +2623,7 @@ def watch(
26182623
full_document_before_change,
26192624
)
26202625

2626+
@_csot.apply
26212627
def rename(
26222628
self,
26232629
new_name: str,

pymongo/database.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from bson.dbref import DBRef
3434
from bson.son import SON
3535
from bson.timestamp import Timestamp
36-
from pymongo import common
36+
from pymongo import _csot, common
3737
from pymongo.aggregation import _DatabaseAggregationCommand
3838
from pymongo.change_stream import DatabaseChangeStream
3939
from pymongo.collection import Collection
@@ -138,6 +138,7 @@ def __init__(
138138

139139
self.__name = name
140140
self.__client: MongoClient[_DocumentType] = client
141+
self._timeout = client.options.timeout
141142

142143
@property
143144
def client(self) -> "MongoClient[_DocumentType]":
@@ -290,6 +291,7 @@ def get_collection(
290291
read_concern,
291292
)
292293

294+
@_csot.apply
293295
def create_collection(
294296
self,
295297
name: str,
@@ -690,6 +692,7 @@ def _command(
690692
client=self.__client,
691693
)
692694

695+
@_csot.apply
693696
def command(
694697
self,
695698
command: Union[str, MutableMapping[str, Any]],
@@ -964,6 +967,7 @@ def _drop_helper(self, name, session=None, comment=None):
964967
session=session,
965968
)
966969

970+
@_csot.apply
967971
def drop_collection(
968972
self,
969973
name_or_collection: Union[str, Collection],

pymongo/mongo_client.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -838,6 +838,7 @@ def target():
838838
from pymongo.encryption import _Encrypter
839839

840840
self._encrypter = _Encrypter(self, self.__options.auto_encryption_opts)
841+
self._timeout = options.timeout
841842

842843
def _duplicate(self, **kwargs):
843844
args = self.__init_kwargs.copy()
@@ -1270,6 +1271,7 @@ def _socket_for_reads(self, read_preference, session):
12701271
def _should_pin_cursor(self, session):
12711272
return self.__options.load_balanced and not (session and session.in_transaction)
12721273

1274+
@_csot.apply
12731275
def _run_operation(self, operation, unpack_res, address=None):
12741276
"""Run a _Query/_GetMore operation and return a Response.
12751277
@@ -1318,6 +1320,7 @@ def _retry_with_session(self, retryable, func, session, bulk):
13181320
)
13191321
return self._retry_internal(retryable, func, session, bulk)
13201322

1323+
@_csot.apply
13211324
def _retry_internal(self, retryable, func, session, bulk):
13221325
"""Internal retryable write helper."""
13231326
max_wire_version = 0
@@ -1384,6 +1387,7 @@ def is_retrying():
13841387
retrying = True
13851388
last_error = exc
13861389

1390+
@_csot.apply
13871391
def _retryable_read(self, func, read_pref, session, address=None, retryable=True):
13881392
"""Execute an operation with at most one consecutive retries
13891393
@@ -1834,6 +1838,7 @@ def list_database_names(
18341838
"""
18351839
return [doc["name"] for doc in self.list_databases(session, nameOnly=True, comment=comment)]
18361840

1841+
@_csot.apply
18371842
def drop_database(
18381843
self,
18391844
name_or_database: Union[str, database.Database],

test/unified_format.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,27 +1140,20 @@ def run_entity_operation(self, spec):
11401140

11411141
if isinstance(target, MongoClient):
11421142
method_name = "_clientOperation_%s" % (opname,)
1143-
client = target
11441143
elif isinstance(target, Database):
11451144
method_name = "_databaseOperation_%s" % (opname,)
1146-
client = target.client
11471145
elif isinstance(target, Collection):
11481146
method_name = "_collectionOperation_%s" % (opname,)
1149-
client = target.database.client
11501147
elif isinstance(target, ChangeStream):
11511148
method_name = "_changeStreamOperation_%s" % (opname,)
1152-
client = target._client
11531149
elif isinstance(target, NonLazyCursor):
11541150
method_name = "_cursor_%s" % (opname,)
1155-
client = target.client
11561151
elif isinstance(target, ClientSession):
11571152
method_name = "_sessionOperation_%s" % (opname,)
1158-
client = target._client
11591153
elif isinstance(target, GridFSBucket):
11601154
raise NotImplementedError
11611155
elif isinstance(target, ClientEncryption):
11621156
method_name = "_clientEncryptionOperation_%s" % (opname,)
1163-
client = target._key_vault_client
11641157
else:
11651158
method_name = "doesNotExist"
11661159

@@ -1175,13 +1168,9 @@ def run_entity_operation(self, spec):
11751168
cmd = functools.partial(method, target)
11761169

11771170
try:
1178-
# TODO: PYTHON-3289 apply inherited timeout by default.
1179-
inherit_timeout = client.options.timeout
11801171
# CSOT: Translate the spec test "timeout" arg into pymongo's context timeout API.
1181-
if "timeout" in arguments or inherit_timeout is not None:
1182-
timeout = arguments.pop("timeout", None)
1183-
if timeout is None:
1184-
timeout = inherit_timeout
1172+
if "timeout" in arguments:
1173+
timeout = arguments.pop("timeout")
11851174
with pymongo.timeout(timeout):
11861175
result = cmd(**dict(arguments))
11871176
else:

0 commit comments

Comments
 (0)