|
130 | 130 | from pymongo.server_selectors import Selection, writable_server_selector
|
131 | 131 | from pymongo.server_type import SERVER_TYPE
|
132 | 132 | from pymongo.topology_description import TopologyDescription
|
133 |
| -from pymongo.typings import _Address |
| 133 | +from pymongo.typings import ClusterTime, _Address |
134 | 134 | from pymongo.write_concern import WriteConcern
|
135 | 135 |
|
136 | 136 | JSON_OPTS = json_util.JSONOptions(tz_aware=False)
|
@@ -624,6 +624,18 @@ def get_lsid_for_session(self, session_name):
|
624 | 624 | # session has been closed.
|
625 | 625 | return self._session_lsids[session_name]
|
626 | 626 |
|
| 627 | + def entities(self): |
| 628 | + return self._entities |
| 629 | + |
| 630 | + def advance_cluster_times(self, cluster_time: Optional[ClusterTime] = None): |
| 631 | + if cluster_time is not None: |
| 632 | + self._cluster_time = cluster_time |
| 633 | + elif getattr(self, "_cluster_time", None) is None: |
| 634 | + self._cluster_time = self.test.client.admin.command("ping").get("$clusterTime") |
| 635 | + for entity in self.entities(): |
| 636 | + if isinstance(entity, ClientSession): |
| 637 | + entity.advance_cluster_time(self._cluster_time) |
| 638 | + |
627 | 639 |
|
628 | 640 | binary_types = (Binary, bytes)
|
629 | 641 | long_types = (Int64,)
|
@@ -1511,6 +1523,7 @@ def _testOperation_targetedFailPoint(self, spec):
|
1511 | 1523 |
|
1512 | 1524 | def _testOperation_createEntities(self, spec):
|
1513 | 1525 | self.entity_map.create_entities_from_spec(spec["entities"], uri=self._uri)
|
| 1526 | + self.entity_map.advance_cluster_times() |
1514 | 1527 |
|
1515 | 1528 | def _testOperation_assertSessionTransactionState(self, spec):
|
1516 | 1529 | session = self.entity_map[spec["session"]]
|
@@ -1874,13 +1887,7 @@ def _run_scenario(self, spec, uri=None):
|
1874 | 1887 | # process initialData
|
1875 | 1888 | if "initialData" in self.TEST_SPEC:
|
1876 | 1889 | self.insert_initial_data(self.TEST_SPEC.get("initialData", []))
|
1877 |
| - # advance cluster times of session entities, |
1878 |
| - # to ensure consistency in transactions against a sharded deployment, |
1879 |
| - cluster_time = self.client.admin.command("ping").get("$clusterTime") |
1880 |
| - if cluster_time: |
1881 |
| - for entity in self.entity_map._entities.values(): |
1882 |
| - if isinstance(entity, ClientSession): |
1883 |
| - entity.advance_cluster_time(cluster_time) |
| 1890 | + self.entity_map.advance_cluster_times() |
1884 | 1891 |
|
1885 | 1892 | if "expectLogMessages" in spec:
|
1886 | 1893 | expect_log_messages = spec["expectLogMessages"]
|
|
0 commit comments