From 5dc0edcbd131d4eaaf330dcd182cb38ebcb4eef4 Mon Sep 17 00:00:00 2001 From: Dave Evans Date: Tue, 28 Apr 2026 11:37:58 +0100 Subject: [PATCH 01/10] Add an integration test for our sort tiebreaking logic We exercise this a bit in the spec tests, but the constraints of the spec tests (in particular the single result column) makes it hard to do so fully. --- tests/integration/test_query_engines.py | 41 +++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/tests/integration/test_query_engines.py b/tests/integration/test_query_engines.py index 733e52f79..d6bc9069a 100644 --- a/tests/integration/test_query_engines.py +++ b/tests/integration/test_query_engines.py @@ -547,6 +547,47 @@ def test_sql_logging(engine, caplog): assert counts[r] > 0, f"No logs matching {r!r}" +def test_sort_tiebreaker_semantics(engine): + @table + class events(EventFrame): + a = Series(int) + b = Series(int) + c = Series(int) + d = Series(int) + + engine.populate( + { + events: [ + # Check we get the first row ordered by `a` + {"patient_id": 1, "a": 1, "b": 0, "c": 3, "d": 4}, + {"patient_id": 1, "a": 0, "b": 0, "c": 5, "d": 6}, + # When multiple rows are tied for first place, check that we sort by `c` + # and then `d`, in that specific order + {"patient_id": 2, "a": 0, "b": 0, "c": 3, "d": 2}, + {"patient_id": 2, "a": 0, "b": 0, "c": 2, "d": 5}, + {"patient_id": 2, "a": 0, "b": 0, "c": 2, "d": 4}, + # Check that we don't sort by `b`: even though it's the lexically + # smallest column it doesn't appear in our query and so it shouldn't be + # used + {"patient_id": 3, "a": 0, "b": 0, "c": 2, "d": 2}, + {"patient_id": 3, "a": 0, "b": 1, "c": 1, "d": 1}, + ] + } + ) + dataset = create_dataset() + dataset.define_population(events.exists_for_patient()) + # Sort by `a` and then use columns `c` and `d` but ignore `b` + first_by_a = events.sort_by(events.a).first_for_patient() + dataset.c = first_by_a.c + dataset.d = first_by_a.d + + assert engine.extract(dataset) == [ + {"patient_id": 1, "c": 5, "d": 6}, + {"patient_id": 2, "c": 2, "d": 4}, + {"patient_id": 3, "c": 1, "d": 1}, + ] + + # The fix for this turns out to be not straightforward and it's sufficiently edge-case-y # that it doesn't affect us in practice. So for now we keep the test in place but # xfailed. From 64e09503d3be37c87ee81c37d5734196cb1c9e1d Mon Sep 17 00:00:00 2001 From: Dave Evans Date: Tue, 28 Apr 2026 11:51:47 +0100 Subject: [PATCH 02/10] Add an integration test for boolean sorts At present we're forced to do the right thing here by the query model validation rules because we handle the tiebreaking logic by rewriting the query itself. But we may not always do that and so we want an integration test to enforce that this works. --- tests/integration/test_query_engines.py | 37 +++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/tests/integration/test_query_engines.py b/tests/integration/test_query_engines.py index d6bc9069a..9c713aa71 100644 --- a/tests/integration/test_query_engines.py +++ b/tests/integration/test_query_engines.py @@ -588,6 +588,43 @@ class events(EventFrame): ] +def test_implicit_sort_on_boolean(engine): + # We don't allow explicit sorting on booleans in ehrQL, but sometimes our + # tiebreaking semantics requires this. We want to ensure all engines return the same + # ordering. + + @table + class events(EventFrame): + a = Series(int) + b = Series(bool) + + engine.populate( + { + events: [ + {"patient_id": 1, "a": 0, "b": True}, + {"patient_id": 1, "a": 0, "b": False}, + {"patient_id": 1, "a": 0, "b": None}, + {"patient_id": 2, "a": 0, "b": True}, + {"patient_id": 2, "a": 0, "b": False}, + {"patient_id": 2, "a": 1, "b": None}, + {"patient_id": 3, "a": 0, "b": True}, + {"patient_id": 3, "a": 1, "b": False}, + {"patient_id": 3, "a": 1, "b": None}, + ] + } + ) + dataset = create_dataset() + dataset.define_population(events.exists_for_patient()) + first_by_a = events.sort_by(events.a).first_for_patient() + dataset.b = first_by_a.b + + assert engine.extract(dataset) == [ + {"patient_id": 1, "b": None}, + {"patient_id": 2, "b": False}, + {"patient_id": 3, "b": True}, + ] + + # The fix for this turns out to be not straightforward and it's sufficiently edge-case-y # that it doesn't affect us in practice. So for now we keep the test in place but # xfailed. From ada7249512d3194e4e65d0dc298f4c1244c4f511 Mon Sep 17 00:00:00 2001 From: Dave Evans Date: Tue, 28 Apr 2026 13:33:55 +0100 Subject: [PATCH 03/10] Split out sort rewrite and optimisation query transforms --- ehrql/query_model/transforms.py | 62 ++++++++++++++++++--------------- 1 file changed, 33 insertions(+), 29 deletions(-) diff --git a/ehrql/query_model/transforms.py b/ehrql/query_model/transforms.py index 72a33cf2a..a34c096d4 100644 --- a/ehrql/query_model/transforms.py +++ b/ehrql/query_model/transforms.py @@ -72,37 +72,42 @@ class Coalesce(Series[T]): def apply_transforms(root_node, skip_optimizations=False): - # Note that we're currently sharing `rewriter`, `nodes` and `reverse_index` across - # transforms. While we only have one this is obviously fine! It _might_ be OK as we - # add more depending on whether they're commutative but we should be careful here - # and might decide we want to restructure things to keep the transforms independent. + root_node = apply_sort_rewrites(root_node) + if not skip_optimizations: + root_node = apply_optimizations(root_node) + return root_node + + +def apply_sort_rewrites(root_node): + # This transform is required for ehrQL's sorting semantics to be respected nodes = all_unique_nodes(root_node) reverse_index = build_reverse_index(nodes) + rewriter = QueryGraphRewriter() + for node in nodes: + if isinstance(node, PickOneRowPerPatient): + rewrite_sorts(rewriter, node, reverse_index) + return rewriter.rewrite(root_node) - # This transform is required for ehrQL's sorting semantics to be respected + +def apply_optimizations(root_node): + # These transforms should not affect behaviour but are just performance + # improvements transforms = [ - (PickOneRowPerPatient, rewrite_sorts), + rewrite_case_to_fixed_value_map, + rewrite_case_to_coalesce, ] - # These transforms should not affect behaviour but are just performance - # improvements. For testing purposes we want to be able to disable them. - if not skip_optimizations: - transforms.extend( - [ - (Case, specialize_case_operations), - ] - ) rewriter = QueryGraphRewriter() - for type_, transform in transforms: - apply_transform(rewriter, type_, transform, nodes, reverse_index) - - return rewriter.rewrite(root_node) + for node in all_unique_nodes(root_node): + original = node + for transform in transforms: + if result := transform(node): + node = result + if node is not original: + rewriter.replace(original, node) -def apply_transform(rewriter, type_, transform, nodes, reverse_index): - for node in nodes: - if isinstance(node, type_): - transform(rewriter, node, reverse_index) + return rewriter.rewrite(root_node) def replace_nodes(root_node, replacements): @@ -231,18 +236,14 @@ def make_sortable(col): return col -def specialize_case_operations(rewriter, node, reverse_index): - if replacement := rewrite_case_to_fixed_value_map(node): - rewriter.replace(node, replacement) - elif replacement := rewrite_case_to_coalesce(node): - rewriter.replace(node, replacement) - - def rewrite_case_to_fixed_value_map(node): """ If the supplied Case operation can be represented as a FixedValueMap then return that representation, otherwise return None """ + if not isinstance(node, Case): + return + source = MISSING = object() mapping = {} @@ -299,6 +300,9 @@ def rewrite_case_to_coalesce(node): If the supplied Case operation can be represented as a Coalesce then return that representation, otherwise return None """ + if not isinstance(node, Case): + return + sources = [] # We're looking for Case operations where every case is of the form: From 134f809580f387a58c5306e6432ce88d16797c34 Mon Sep 17 00:00:00 2001 From: Dave Evans Date: Tue, 28 Apr 2026 14:25:11 +0100 Subject: [PATCH 04/10] Add extended comment explaining sort behaviour This is captured in discussion on issues, and explained a bit in comments elsewhere, but I don't think this was explained properly in one place in the source code before. --- ehrql/query_model/transforms.py | 55 ++++++++++++++++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/ehrql/query_model/transforms.py b/ehrql/query_model/transforms.py index a34c096d4..a322754e6 100644 --- a/ehrql/query_model/transforms.py +++ b/ehrql/query_model/transforms.py @@ -79,7 +79,60 @@ def apply_transforms(root_node, skip_optimizations=False): def apply_sort_rewrites(root_node): - # This transform is required for ehrQL's sorting semantics to be respected + """ + Sorting rows and then picking the first or last row for each patient is a common + operation in ehrQL but it's responsible for a slightly weird corner of ehrQL's + semantics. This is because it's easy to have "under-specified" results e.g. you sort + some events by date and pick the first but a patient has multiple events recorded on + that day. + + In that case, there's no one correct answer as to what row should be returned and + this creates two related problems: + + * Some databases (e.g. MSSQL) pick randomly (or effectively randomly) meaning you + can run the same query against the same data and get different results each time. + This is confusing and generally bad for research (and not in a theoretical + sense: we've seen this actually happen) so we want to avoid it. + + * Even those databases which return consistent results each time don't necessarily + return the same results as each other. This prevents our automated generative + testing, which relies on comparing results between databases, from working. + + What we want is to ensure that even in the case of under-specified sorts there is a + single correct result defined by ehrQL. But we want to do this without imposing a + significant performance cost on all our sort queries. + + We do this by defining "tiebreaker sorts" for selecting a winning row in the case of + multiple equal candidates. That is, if the rows are equally positioned when sorting + by all the things the user has specified then sort by these other conditions as + well. + + One simple solution to this would be to sort by every column in the table in lexical + order. That would guarantee a single stable result. Of course there might still be + completely duplicate rows, but in that case it doesn't matter which you pick because + the results are necessarily identical. + + The problem with this solution is that we can have very wide tables with many + columns and now every time we sort we need to specify all of these as tiebreaker + conditions. So it fails our "don't impose significant performance cost" condition. + + Another solution is to use just the columns we're actually going to select from the + results as the tiebreaker conditions. Of course, this doesn't guarantee uniqueness + of rows: if we select columns A and B from a table we might have multiple rows with + the same values for A and B but a different value for C. But it does guarantee + uniqueness of _results_: because we're not selecting column C it doesn't matter + what's in it. + + This does exactly what we need, however it introduces an oddity into ehrQL's + semantics which is that you can no longer evaluate it in a bottom-up fashion. The + value of a pick-first-row operation depends on what columns are _going to be_ + selected from that row. + + This means we need to do some pre-processing of the query graph and annotate each + such operation with the set of columns that are selected from it elsewhere in the + query. It would be nicer not to have to do this, but given the above constraints I + think it's the best practical solution. + """ nodes = all_unique_nodes(root_node) reverse_index = build_reverse_index(nodes) rewriter = QueryGraphRewriter() From d5da8a1d5abdab4a1c9fd366c4c203698627ac02 Mon Sep 17 00:00:00 2001 From: Dave Evans Date: Tue, 28 Apr 2026 14:37:34 +0100 Subject: [PATCH 05/10] Keep track of sort index in the in-memory database This will allow us to resolve ties after other sorts have already been applied which means we can do it without having to rewrite the query. --- ehrql/query_engines/in_memory_database.py | 32 ++++++++++++++++------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/ehrql/query_engines/in_memory_database.py b/ehrql/query_engines/in_memory_database.py index 9efe0dabd..f6a00de9b 100644 --- a/ehrql/query_engines/in_memory_database.py +++ b/ehrql/query_engines/in_memory_database.py @@ -382,6 +382,12 @@ class Rows(dict): values belonging to a single patient in an EventColumn. """ + # Keep track of any sorts which have been applied so we can identify where a given + # row's position is due to its explicit sort order or where it's just arbitrary. + # This allows us to supply tiebreaker sorts _after_ other sorts have already been + # applied. + _sort_index = None + def __repr__(self): return f"Rows({super().__repr__()})" @@ -412,7 +418,9 @@ def filter(self, predicate): # noqa A003 if not isinstance(predicate, Rows): # This branch is hit when an EventSeries is filtered by a literal boolean. predicate = Rows({k: predicate for k in self}) - return Rows({k: v for k, v in self.items() if predicate[k]}) + rows = Rows({k: v for k, v in self.items() if predicate[k]}) + rows._sort_index = self._sort_index + return rows def sort_index(self): """Map each value to its ordinal position in set of unique values. @@ -421,26 +429,32 @@ def sort_index(self): resulting sort_index will overspecify the order and we lose the stability of the sort operation. """ - sorted_values = sorted(set(self.values()), key=nulls_first_order) return Rows({k: sorted_values.index(v) for k, v in self.items()}) def sort(self, sort_index): """Sort rows by position in sort_index. - If two values have the same position, their current position is used as a - tiebreaker. This ensures that sorting is stable. + If two values have the same position, their current sort index is used as a + tiebreaker. This ensures that sorting is stable. """ + if self._sort_index is not None: + # Create a combined sort index where we sort primarily on the supplied index + # and use the existing index to break any ties + combined_index = { + k: (sort_index[k], self._sort_index[k]) for k in self.keys() + } + sorted_values = sorted(set(combined_index.values())) + sort_index = {k: sorted_values.index(v) for k, v in combined_index.items()} - return Rows( + rows = Rows( { k: v - for (_, _, k, v) in sorted( - (sort_index[k], tiebreaker, k, v) - for tiebreaker, (k, v) in enumerate(self.items()) - ) + for (_, k, v) in sorted((sort_index[k], k, v) for k, v in self.items()) } ) + rows._sort_index = sort_index + return rows def pick_at_index(self, ix): """Return element at given position.""" From 1f6600c71965d4dd9cb92699548deb821832b205 Mon Sep 17 00:00:00 2001 From: Dave Evans Date: Tue, 28 Apr 2026 14:39:05 +0100 Subject: [PATCH 06/10] Apply tiebreaker sorts in the in-memory query engine --- ehrql/query_engines/in_memory.py | 19 +++++++++---- ehrql/query_engines/in_memory_database.py | 34 ++++++++++++++++------- 2 files changed, 38 insertions(+), 15 deletions(-) diff --git a/ehrql/query_engines/in_memory.py b/ehrql/query_engines/in_memory.py index e10b2eb77..19d7e28e1 100644 --- a/ehrql/query_engines/in_memory.py +++ b/ehrql/query_engines/in_memory.py @@ -191,14 +191,23 @@ def visit_Sort(self, node): return source.sort(sort_column.sort_index()) def visit_PickOneRowPerPatient(self, node): - ix = { - qm.Position.FIRST: 0, - qm.Position.LAST: -1, - }[node.position] + ix = self.index_for_position(node.position) return self.visit(node.source).pick_at_index(ix) def visit_PickOneRowPerPatientWithColumns(self, node): - return self.visit_PickOneRowPerPatient(node) + source = self.visit(node.source) + # Ensure a unique deterministic result in the case of any ties + # See: ehrql.query_model.transforms.apply_sort_rewrites() + selected_columns = [c.name for c in node.selected_columns] + for column in sorted(selected_columns): + tiebreaker_sort = source[column].sort_index() + source = source.sort(tiebreaker_sort, tiebreak_only=True) + + ix = self.index_for_position(node.position) + return source.pick_at_index(ix) + + def index_for_position(self, position): + return {qm.Position.FIRST: 0, qm.Position.LAST: -1}[position] def visit_Exists(self, node): return self.visit(node.source).exists() diff --git a/ehrql/query_engines/in_memory_database.py b/ehrql/query_engines/in_memory_database.py index f6a00de9b..5d2f307c8 100644 --- a/ehrql/query_engines/in_memory_database.py +++ b/ehrql/query_engines/in_memory_database.py @@ -206,9 +206,12 @@ def filter(self, predicate): # noqa A003 {name: col.filter(predicate) for name, col in self.name_to_col.items()} ) - def sort(self, sort_index): + def sort(self, sort_index, tiebreak_only=False): return EventTable( - {name: col.sort(sort_index) for name, col in self.name_to_col.items()} + { + name: col.sort(sort_index, tiebreak_only=tiebreak_only) + for name, col in self.name_to_col.items() + } ) def pick_at_index(self, ix): @@ -356,9 +359,12 @@ def sort_index(self): {p: rows.sort_index() for p, rows in self.patient_to_rows.items()} ) - def sort(self, sort_index): + def sort(self, sort_index, tiebreak_only=False): return EventColumn( - {p: rows.sort(sort_index[p]) for p, rows in self.patient_to_rows.items()} + { + p: rows.sort(sort_index[p], tiebreak_only=tiebreak_only) + for p, rows in self.patient_to_rows.items() + } ) def pick_at_index(self, ix): @@ -432,18 +438,26 @@ def sort_index(self): sorted_values = sorted(set(self.values()), key=nulls_first_order) return Rows({k: sorted_values.index(v) for k, v in self.items()}) - def sort(self, sort_index): + def sort(self, sort_index, tiebreak_only=False): """Sort rows by position in sort_index. If two values have the same position, their current sort index is used as a tiebreaker. This ensures that sorting is stable. """ if self._sort_index is not None: - # Create a combined sort index where we sort primarily on the supplied index - # and use the existing index to break any ties - combined_index = { - k: (sort_index[k], self._sort_index[k]) for k in self.keys() - } + if not tiebreak_only: + # The standard case: create a combined sort index where we sort + # primarily on the supplied index and use the existing index to break + # any ties + combined_index = { + k: (sort_index[k], self._sort_index[k]) for k in self.keys() + } + else: + # The special tiebreak case: we sort primarily on the existing index to + # retain the existing order and only use the new index to break any ties + combined_index = { + k: (self._sort_index[k], sort_index[k]) for k in self.keys() + } sorted_values = sorted(set(combined_index.values())) sort_index = {k: sorted_values.index(v) for k, v in combined_index.items()} From 0e495ddc7f96ce8b2ff0a628d251bd84007de894 Mon Sep 17 00:00:00 2001 From: Dave Evans Date: Wed, 29 Apr 2026 08:42:28 +0100 Subject: [PATCH 07/10] Apply tiebreaker sorts in the SQL query engines --- ehrql/query_engines/base_sql.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/ehrql/query_engines/base_sql.py b/ehrql/query_engines/base_sql.py index d4552fa0c..ec8557ee2 100644 --- a/ehrql/query_engines/base_sql.py +++ b/ehrql/query_engines/base_sql.py @@ -940,8 +940,13 @@ def get_table_sort_and_filter(self, node): @get_table.register(PickOneRowPerPatientWithColumns) def get_table_pick_one_row_per_patient(self, node): selected_columns = [self.get_expr(c) for c in node.selected_columns] + + sort_conditions = get_sort_conditions(node.source) + # Ensure a unique deterministic result in the case of any ties + # See: ehrql.query_model.transforms.apply_sort_rewrites() + tiebreakers = sorted(node.selected_columns, key=lambda c: c.name) order_clauses = self.get_order_clauses( - get_sort_conditions(node.source), node.position + sort_conditions + tiebreakers, node.position ) query = self.get_select_query_for_node_domain(node.source) @@ -1278,12 +1283,12 @@ def get_table_and_filter_conditions(frame): def get_sort_conditions(frame): """ - Given a sorted frame, return a tuple of Series which gives the sort order + Given a sorted frame, return a list of Series which gives the sort order """ # Sort operations are given to us in order of application which is the reverse of # order of priority (i.e. the most recently applied sort gives us the primary sort # condition) so we reverse them here - return tuple(s.sort_by for s in reversed(get_sorts(frame))) + return [s.sort_by for s in reversed(get_sorts(frame))] def get_cyclic_coalescence(columns): From f31c981e6baaa812967f7815f0309a4c5c1ed9cb Mon Sep 17 00:00:00 2001 From: Dave Evans Date: Thu, 30 Apr 2026 13:41:13 +0100 Subject: [PATCH 08/10] Remove redundant order clauses from SQL --- ehrql/query_engines/base_sql.py | 23 +++++++++++++++++++ tests/integration/test_query_engines.py | 30 +++++++++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/ehrql/query_engines/base_sql.py b/ehrql/query_engines/base_sql.py index ec8557ee2..bcf57af97 100644 --- a/ehrql/query_engines/base_sql.py +++ b/ehrql/query_engines/base_sql.py @@ -948,6 +948,8 @@ def get_table_pick_one_row_per_patient(self, node): order_clauses = self.get_order_clauses( sort_conditions + tiebreakers, node.position ) + # Some tiebreakers may already be included in the sort conditions + order_clauses = remove_redundant_order_clauses(order_clauses) query = self.get_select_query_for_node_domain(node.source) query = query.add_columns(*selected_columns) @@ -1291,6 +1293,27 @@ def get_sort_conditions(frame): return [s.sort_by for s in reversed(get_sorts(frame))] +def remove_redundant_order_clauses(clauses): + """ + In a list of clauses like: + + ORDER BY a, b, a, c + + The second `a` is redundant and cannot affect the resulting order + """ + seen = set() + result = [] + for clause in clauses: + # We can't use equality to compare SQLAlchemy elements because it's overloaded. + # There is a `compare()` method we can use to determine structural equivalence, + # but compiling to a string is simpler and sufficient for our purposes. + clause_str = str(clause) + if clause_str not in seen: + result.append(clause) + seen.add(clause_str) + return result + + def get_cyclic_coalescence(columns): """ Given a list of columns, this produces a list of coalescences of all columns with the diff --git a/tests/integration/test_query_engines.py b/tests/integration/test_query_engines.py index 9c713aa71..3fde56554 100644 --- a/tests/integration/test_query_engines.py +++ b/tests/integration/test_query_engines.py @@ -677,6 +677,36 @@ class events(EventFrame): ] +def test_remove_redundant_order_clauses(engine): + if engine.name == "in_memory": + pytest.skip("test does not apply to in-memory engine") + + @table + class events(EventFrame): + col_a = Series(int) + col_b = Series(int) + + dataset = create_dataset() + dataset.define_population(events.exists_for_patient()) + # We sort by columns A and B and also use them in our results. This is the situation + # which can lead to redundant order clauses. + first_row = events.sort_by(events.col_a, events.col_b).first_for_patient() + dataset.col_a = first_row.col_a + dataset.col_b = first_row.col_b + + queries = engine.dump_dataset_sql(dataset) + + partition_clauses = [ + match[0] for q in queries if (match := re.search(r"\(PARTITION BY .+\)", q)) + ] + assert len(partition_clauses) == 1 + partition_clause = partition_clauses[0] + + # Check that we only reference each column once + assert partition_clause.count("col_a") == 1 + assert partition_clause.count("col_b") == 1 + + def build_dataset(*, population, variables=None, events=None): return Dataset( population=population, From 20fa13fe81be6c21758a1a6ca296e05e847fe168 Mon Sep 17 00:00:00 2001 From: Dave Evans Date: Thu, 30 Apr 2026 18:33:39 +0100 Subject: [PATCH 09/10] Delete unused code --- ehrql/query_model/transforms.py | 100 -------- tests/unit/query_model/test_transforms.py | 279 ---------------------- 2 files changed, 379 deletions(-) diff --git a/ehrql/query_model/transforms.py b/ehrql/query_model/transforms.py index a322754e6..a660fda63 100644 --- a/ehrql/query_model/transforms.py +++ b/ehrql/query_model/transforms.py @@ -26,12 +26,8 @@ PickOneRowPerPatient, SelectColumn, Series, - Sort, Value, get_input_nodes, - get_series_type, - get_sorts, - has_one_row_per_patient, ) from ehrql.query_model.query_graph_rewriter import QueryGraphRewriter @@ -176,61 +172,10 @@ def replace_nodes(root_node, replacements): def rewrite_sorts(rewriter, node, reverse_index): - """ - Frames are sorted in order to then pick the first or last row for a patient. Multiple sorts - may be applied to give the desired results. Once a single row has been picked, one or more - columns are then selected. - - This results in a subgraph of QM objects like this: - - SelectColumn(A) -+ - | - SelectColumn(B) -+-> PickOneRowPerPatient -> Sort(A) -> Sort(B) -> SelectTable - | - SelectColumn(C) -+ - - There are two transformations that we need to carry out on this stack. - - 1. We annotate PickOneRowPerPatient with the columns that are going to be selected from it, in - order to allow us to generate the appropriate query more easily. - 2. Add sorts so that we have one for each column that will be selected, in order to ensure that - the sort order (and hence the values of the selected columns) is deterministic. - - For the example above the resulting subgraph would be: - - SelectColumn(A) -+ - | - SelectColumn(B) -+-> PickOneRowPerPatientWithColumns -> Sort(A) -> Sort(B) -> Sort(C) -> SelectTable - | - SelectColumn(C) -+ - - Some notes on the additional sorts are in order. - - * A potential lack of determinism in sort order creeps in when a patient has multiple rows with - the same value of the column(s) being sorted on. In this case some databases may give different - orders on different runs of the same query against the same data. - * When this lack of determinism exists, and we select a column that has not been sorted on, the - value returned may change between runs. - * We add sorts only for columns that don't already have them. Duplicate sorts wouldn't cause a - problem, but are conceptually messy and might have a small performance impact. - * We add the sorts below the existing ones so that they have lower priority and are only used to - break any ties in the user-specified sorts. - * When considering the existing sorts we only attend to those that sort directly on selected - columns, not on expressions derived from a column. Such expressions may not be injective and so - may not be sufficient to fully determine the order. As above, duplicates are not a problem. - * We introduce an arbitrary order for the additional sorts (lexically by column name) to ensure - that their order itself is deterministic. - """ # What columns are select from this patient frame? selected_column_names = { c.name for c in reverse_index[node] if isinstance(c, SelectColumn) } - - add_columns_to_pick(rewriter, node, selected_column_names) - add_extra_sorts(rewriter, node, selected_column_names) - - -def add_columns_to_pick(rewriter, node, selected_column_names): selected_columns = frozenset( SelectColumn(node.source, c) for c in selected_column_names ) @@ -244,51 +189,6 @@ def add_columns_to_pick(rewriter, node, selected_column_names): ) -def add_extra_sorts(rewriter, node, selected_column_names): - all_sorts = get_sorts(node.source) - # Add at the bottom of the stack - lowest_sort = all_sorts[0] - - for column in calculate_sorts_to_add(all_sorts, selected_column_names): - new_sort = Sort( - source=lowest_sort.source, - sort_by=make_sortable(SelectColumn(lowest_sort.source, column)), - ) - rewriter.replace( - lowest_sort, - Sort( - source=new_sort, - sort_by=lowest_sort.sort_by, - ), - ) - lowest_sort = new_sort - - -def calculate_sorts_to_add(all_sorts, selected_column_names): - # Don't duplicate existing direct sorts - direct_sorts = [ - sort - for sort in all_sorts - if isinstance(sort.sort_by, SelectColumn) - # SelectColumn operations only count as direct sorts if they're selected from - # the frame we're sorting, not from some other patient frame - and not has_one_row_per_patient(sort.sort_by.source) - ] - existing_sorted_column_names = {sort.sort_by.name for sort in direct_sorts} - sorts_to_add = selected_column_names - existing_sorted_column_names - - # Arbitrary canonical ordering - return sorted(sorts_to_add) - - -def make_sortable(col): - if get_series_type(col) is bool: - # Some databases can't sort booleans (including SQL Server), so we cast them to - # integers - return Function.CastToInt(col) - return col - - def rewrite_case_to_fixed_value_map(node): """ If the supplied Case operation can be represented as a FixedValueMap then return diff --git a/tests/unit/query_model/test_transforms.py b/tests/unit/query_model/test_transforms.py index dff3900f5..77ae6ba78 100644 --- a/tests/unit/query_model/test_transforms.py +++ b/tests/unit/query_model/test_transforms.py @@ -6,7 +6,6 @@ Case, Column, Dataset, - Filter, Function, Parameter, PickOneRowPerPatient, @@ -81,284 +80,6 @@ def test_pick_one_row_per_patient_transform(): assert transformed.variables == expected -def test_adds_one_selected_column_to_sorts(): - events = SelectTable( - "events", - TableSchema(i1=Column(int), i2=Column(int)), - ) - by_i1 = Sort(events, SelectColumn(events, "i1")) - variable = SelectColumn( - PickOneRowPerPatient(source=by_i1, position=Position.FIRST), - "i2", - ) - - by_i2 = Sort(events, SelectColumn(events, "i2")) - by_i2_then_i1 = Sort(by_i2, SelectColumn(events, "i1")) - expected = SelectColumn( - PickOneRowPerPatientWithColumns( - by_i2_then_i1, - Position.FIRST, - selected_columns=frozenset( - { - SelectColumn( - source=by_i2_then_i1, - name="i2", - ), - } - ), - ), - "i2", - ) - - assert apply_transforms(variable) == expected - - -def test_adds_sorts_at_lowest_priority(): - events = SelectTable( - "events", - TableSchema(i1=Column(int), i2=Column(int), i3=Column(int)), - ) - by_i2 = Sort(events, SelectColumn(events, "i2")) - by_i2_then_i1 = Sort(by_i2, SelectColumn(by_i2, "i1")) - variable = SelectColumn( - PickOneRowPerPatient(source=by_i2_then_i1, position=Position.FIRST), - "i3", - ) - - by_i3 = Sort(events, SelectColumn(events, "i3")) - by_i3_then_i2 = Sort(by_i3, SelectColumn(events, "i2")) - by_i3_then_i2_then_i1 = Sort(by_i3_then_i2, SelectColumn(by_i3_then_i2, "i1")) - expected = SelectColumn( - PickOneRowPerPatientWithColumns( - by_i3_then_i2_then_i1, - Position.FIRST, - selected_columns=frozenset( - { - SelectColumn( - source=by_i3_then_i2_then_i1, - name="i3", - ), - } - ), - ), - "i3", - ) - - assert apply_transforms(variable) == expected - - -def test_copes_with_interleaved_sorts_and_filters(): - events = SelectTable( - "events", - TableSchema(i1=Column(int), i2=Column(int), i3=Column(int)), - ) - by_i2 = Sort(events, SelectColumn(events, "i2")) - by_i2_filtered = Filter(by_i2, Value(True)) - by_i2_then_i1 = Sort(by_i2_filtered, SelectColumn(by_i2_filtered, "i1")) - variable = SelectColumn( - PickOneRowPerPatient(source=by_i2_then_i1, position=Position.FIRST), - "i3", - ) - - by_i3 = Sort(events, SelectColumn(events, "i3")) - by_i3_then_i2 = Sort(by_i3, SelectColumn(events, "i2")) - by_i3_then_i2_filtered = Filter(by_i3_then_i2, Value(True)) - by_i3_then_i2_then_i1 = Sort( - by_i3_then_i2_filtered, SelectColumn(by_i3_then_i2_filtered, "i1") - ) - expected = SelectColumn( - PickOneRowPerPatientWithColumns( - by_i3_then_i2_then_i1, - Position.FIRST, - selected_columns=frozenset( - { - SelectColumn( - source=by_i3_then_i2_then_i1, - name="i3", - ), - } - ), - ), - "i3", - ) - - assert apply_transforms(variable) == expected - - -def test_doesnt_duplicate_existing_sorts(): - events = SelectTable( - "events", - TableSchema(i1=Column(int)), - ) - by_i1 = Sort(events, SelectColumn(events, "i1")) - variable = SelectColumn( - PickOneRowPerPatient(source=by_i1, position=Position.FIRST), - "i1", - ) - - expected = SelectColumn( - PickOneRowPerPatientWithColumns( - by_i1, - Position.FIRST, - selected_columns=frozenset( - { - SelectColumn( - source=by_i1, - name="i1", - ), - } - ), - ), - "i1", - ) - - assert apply_transforms(variable) == expected - - -def test_adds_sorts_in_lexical_order_of_column_names(): - events = SelectTable( - "events", - TableSchema(i1=Column(int), iz=Column(int), ia=Column(int)), - ) - by_i1 = Sort(events, SelectColumn(events, "i1")) - first_initial = PickOneRowPerPatient(source=by_i1, position=Position.FIRST) - dataset = dataset_factory( - z=SelectColumn(first_initial, "iz"), - a=SelectColumn(first_initial, "ia"), - ) - - transformed = apply_transforms(dataset) - - by_iz = Sort(events, SelectColumn(events, "iz")) - by_iz_then_ia = Sort(by_iz, SelectColumn(events, "ia")) - by_iz_then_ia_then_i1 = Sort(by_iz_then_ia, SelectColumn(events, "i1")) - first_with_extra_sorts = PickOneRowPerPatientWithColumns( - by_iz_then_ia_then_i1, - Position.FIRST, - selected_columns=frozenset( - { - SelectColumn( - source=by_iz_then_ia_then_i1, - name="iz", - ), - SelectColumn( - source=by_iz_then_ia_then_i1, - name="ia", - ), - } - ), - ) - - expected = dict( - z=SelectColumn(first_with_extra_sorts, "iz"), - a=SelectColumn(first_with_extra_sorts, "ia"), - ) - - assert transformed.variables == expected - - -def test_maps_booleans_to_a_sortable_type(): - events = SelectTable( - "events", - TableSchema(i=Column(int), b=Column(bool)), - ) - by_i = Sort(events, SelectColumn(events, "i")) - variable = SelectColumn( - PickOneRowPerPatient(source=by_i, position=Position.FIRST), - "b", - ) - - b = SelectColumn(events, "b") - by_b = Sort(events, Function.CastToInt(b)) - by_b_then_i = Sort(by_b, SelectColumn(events, "i")) - expected = SelectColumn( - PickOneRowPerPatientWithColumns( - by_b_then_i, - Position.FIRST, - selected_columns=frozenset( - { - SelectColumn( - source=by_b_then_i, - name="b", - ), - } - ), - ), - "b", - ) - - assert apply_transforms(variable) == expected - - -def test_sorts_by_derived_value_handled_correctly(): - events = SelectTable("events", TableSchema(i=Column(int))) - - by_negative_i = Sort(events, Function.Negate(SelectColumn(events, "i"))) - variable = SelectColumn(PickOneRowPerPatient(by_negative_i, Position.FIRST), "i") - - by_i = Sort(events, SelectColumn(events, "i")) - by_i_then_by_negative_i = Sort(by_i, Function.Negate(SelectColumn(events, "i"))) - expected = SelectColumn( - PickOneRowPerPatientWithColumns( - by_i_then_by_negative_i, - Position.FIRST, - frozenset({SelectColumn(by_i_then_by_negative_i, "i")}), - ), - "i", - ) - - assert apply_transforms(variable) == expected - - -def test_identical_operations_are_not_transformed_differently(): - # Query model nodes are intended to be value objects: that is equality is determined - # by value, not identity and equal objects should be intersubstitutable. Approaches - # to query transformation which involve mutation can violate this principle and - # treat equal but non-identical nodes differently. This tests for a specific - # instance of this problem. - events = SelectTable( - "events", - TableSchema(i1=Column(int), i2=Column(int)), - ) - # Construct two equal but non-identical sort-and-picks - first_by_i1_v1 = PickOneRowPerPatient( - Sort(events, SelectColumn(events, "i1")), position=Position.FIRST - ) - first_by_i1_v2 = PickOneRowPerPatient( - Sort(events, SelectColumn(events, "i1")), position=Position.FIRST - ) - - # Select different columns from each one - dataset = dataset_factory( - i1=SelectColumn(first_by_i1_v1, "i1"), - i2=SelectColumn(first_by_i1_v2, "i2"), - ) - - # We expect i2 to be added at the bottom of the stack of sorts - by_i2_then_i1 = Sort( - source=Sort(source=events, sort_by=SelectColumn(source=events, name="i2")), - sort_by=SelectColumn(source=events, name="i1"), - ) - # We expect the selected columns to include both i1 and i2 - pick_with_columns = PickOneRowPerPatientWithColumns( - source=by_i2_then_i1, - position=Position.FIRST, - selected_columns=frozenset( - { - SelectColumn(by_i2_then_i1, "i1"), - SelectColumn(by_i2_then_i1, "i2"), - } - ), - ) - - expected = dict( - i1=SelectColumn(source=pick_with_columns, name="i1"), - i2=SelectColumn(source=pick_with_columns, name="i2"), - ) - - assert apply_transforms(dataset).variables == expected - - def test_substitute_parameters(): node = Function.Negate(Function.Add(Value(10), Parameter("i", int))) transformed = substitute_parameters(node, i=20) From e916917ca606fe1952e995cd5b6aa8a90d32d061 Mon Sep 17 00:00:00 2001 From: Dave Evans Date: Fri, 1 May 2026 16:07:42 +0100 Subject: [PATCH 10/10] fix: Remove "xfail" from test which now passes --- tests/integration/test_query_engines.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/integration/test_query_engines.py b/tests/integration/test_query_engines.py index 3fde56554..45c79f03a 100644 --- a/tests/integration/test_query_engines.py +++ b/tests/integration/test_query_engines.py @@ -625,10 +625,6 @@ class events(EventFrame): ] -# The fix for this turns out to be not straightforward and it's sufficiently edge-case-y -# that it doesn't affect us in practice. So for now we keep the test in place but -# xfailed. -@pytest.mark.xfail def test_sort_edge_case(engine): # Regression test for a weird edge case in our sort transformation code identified, # as you'd expect, by Hypothesis. See: