diff --git a/ehrql/query_engines/base_sql.py b/ehrql/query_engines/base_sql.py index d4552fa0c..bcf57af97 100644 --- a/ehrql/query_engines/base_sql.py +++ b/ehrql/query_engines/base_sql.py @@ -940,9 +940,16 @@ def get_table_sort_and_filter(self, node): @get_table.register(PickOneRowPerPatientWithColumns) def get_table_pick_one_row_per_patient(self, node): selected_columns = [self.get_expr(c) for c in node.selected_columns] + + sort_conditions = get_sort_conditions(node.source) + # Ensure a unique deterministic result in the case of any ties + # See: ehrql.query_model.transforms.apply_sort_rewrites() + tiebreakers = sorted(node.selected_columns, key=lambda c: c.name) order_clauses = self.get_order_clauses( - get_sort_conditions(node.source), node.position + sort_conditions + tiebreakers, node.position ) + # Some tiebreakers may already be included in the sort conditions + order_clauses = remove_redundant_order_clauses(order_clauses) query = self.get_select_query_for_node_domain(node.source) query = query.add_columns(*selected_columns) @@ -1278,12 +1285,33 @@ def get_table_and_filter_conditions(frame): def get_sort_conditions(frame): """ - Given a sorted frame, return a tuple of Series which gives the sort order + Given a sorted frame, return a list of Series which gives the sort order """ # Sort operations are given to us in order of application which is the reverse of # order of priority (i.e. the most recently applied sort gives us the primary sort # condition) so we reverse them here - return tuple(s.sort_by for s in reversed(get_sorts(frame))) + return [s.sort_by for s in reversed(get_sorts(frame))] + + +def remove_redundant_order_clauses(clauses): + """ + In a list of clauses like: + + ORDER BY a, b, a, c + + The second `a` is redundant and cannot affect the resulting order + """ + seen = set() + result = [] + for clause in clauses: + # We can't use equality to compare SQLAlchemy elements because it's overloaded. + # There is a `compare()` method we can use to determine structural equivalence, + # but compiling to a string is simpler and sufficient for our purposes. + clause_str = str(clause) + if clause_str not in seen: + result.append(clause) + seen.add(clause_str) + return result def get_cyclic_coalescence(columns): diff --git a/ehrql/query_engines/in_memory.py b/ehrql/query_engines/in_memory.py index e10b2eb77..19d7e28e1 100644 --- a/ehrql/query_engines/in_memory.py +++ b/ehrql/query_engines/in_memory.py @@ -191,14 +191,23 @@ def visit_Sort(self, node): return source.sort(sort_column.sort_index()) def visit_PickOneRowPerPatient(self, node): - ix = { - qm.Position.FIRST: 0, - qm.Position.LAST: -1, - }[node.position] + ix = self.index_for_position(node.position) return self.visit(node.source).pick_at_index(ix) def visit_PickOneRowPerPatientWithColumns(self, node): - return self.visit_PickOneRowPerPatient(node) + source = self.visit(node.source) + # Ensure a unique deterministic result in the case of any ties + # See: ehrql.query_model.transforms.apply_sort_rewrites() + selected_columns = [c.name for c in node.selected_columns] + for column in sorted(selected_columns): + tiebreaker_sort = source[column].sort_index() + source = source.sort(tiebreaker_sort, tiebreak_only=True) + + ix = self.index_for_position(node.position) + return source.pick_at_index(ix) + + def index_for_position(self, position): + return {qm.Position.FIRST: 0, qm.Position.LAST: -1}[position] def visit_Exists(self, node): return self.visit(node.source).exists() diff --git a/ehrql/query_engines/in_memory_database.py b/ehrql/query_engines/in_memory_database.py index 9efe0dabd..5d2f307c8 100644 --- a/ehrql/query_engines/in_memory_database.py +++ b/ehrql/query_engines/in_memory_database.py @@ -206,9 +206,12 @@ def filter(self, predicate): # noqa A003 {name: col.filter(predicate) for name, col in self.name_to_col.items()} ) - def sort(self, sort_index): + def sort(self, sort_index, tiebreak_only=False): return EventTable( - {name: col.sort(sort_index) for name, col in self.name_to_col.items()} + { + name: col.sort(sort_index, tiebreak_only=tiebreak_only) + for name, col in self.name_to_col.items() + } ) def pick_at_index(self, ix): @@ -356,9 +359,12 @@ def sort_index(self): {p: rows.sort_index() for p, rows in self.patient_to_rows.items()} ) - def sort(self, sort_index): + def sort(self, sort_index, tiebreak_only=False): return EventColumn( - {p: rows.sort(sort_index[p]) for p, rows in self.patient_to_rows.items()} + { + p: rows.sort(sort_index[p], tiebreak_only=tiebreak_only) + for p, rows in self.patient_to_rows.items() + } ) def pick_at_index(self, ix): @@ -382,6 +388,12 @@ class Rows(dict): values belonging to a single patient in an EventColumn. """ + # Keep track of any sorts which have been applied so we can identify where a given + # row's position is due to its explicit sort order or where it's just arbitrary. + # This allows us to supply tiebreaker sorts _after_ other sorts have already been + # applied. + _sort_index = None + def __repr__(self): return f"Rows({super().__repr__()})" @@ -412,7 +424,9 @@ def filter(self, predicate): # noqa A003 if not isinstance(predicate, Rows): # This branch is hit when an EventSeries is filtered by a literal boolean. predicate = Rows({k: predicate for k in self}) - return Rows({k: v for k, v in self.items() if predicate[k]}) + rows = Rows({k: v for k, v in self.items() if predicate[k]}) + rows._sort_index = self._sort_index + return rows def sort_index(self): """Map each value to its ordinal position in set of unique values. @@ -421,26 +435,40 @@ def sort_index(self): resulting sort_index will overspecify the order and we lose the stability of the sort operation. """ - sorted_values = sorted(set(self.values()), key=nulls_first_order) return Rows({k: sorted_values.index(v) for k, v in self.items()}) - def sort(self, sort_index): + def sort(self, sort_index, tiebreak_only=False): """Sort rows by position in sort_index. - If two values have the same position, their current position is used as a - tiebreaker. This ensures that sorting is stable. + If two values have the same position, their current sort index is used as a + tiebreaker. This ensures that sorting is stable. """ + if self._sort_index is not None: + if not tiebreak_only: + # The standard case: create a combined sort index where we sort + # primarily on the supplied index and use the existing index to break + # any ties + combined_index = { + k: (sort_index[k], self._sort_index[k]) for k in self.keys() + } + else: + # The special tiebreak case: we sort primarily on the existing index to + # retain the existing order and only use the new index to break any ties + combined_index = { + k: (self._sort_index[k], sort_index[k]) for k in self.keys() + } + sorted_values = sorted(set(combined_index.values())) + sort_index = {k: sorted_values.index(v) for k, v in combined_index.items()} - return Rows( + rows = Rows( { k: v - for (_, _, k, v) in sorted( - (sort_index[k], tiebreaker, k, v) - for tiebreaker, (k, v) in enumerate(self.items()) - ) + for (_, k, v) in sorted((sort_index[k], k, v) for k, v in self.items()) } ) + rows._sort_index = sort_index + return rows def pick_at_index(self, ix): """Return element at given position.""" diff --git a/ehrql/query_model/transforms.py b/ehrql/query_model/transforms.py index 72a33cf2a..a660fda63 100644 --- a/ehrql/query_model/transforms.py +++ b/ehrql/query_model/transforms.py @@ -26,12 +26,8 @@ PickOneRowPerPatient, SelectColumn, Series, - Sort, Value, get_input_nodes, - get_series_type, - get_sorts, - has_one_row_per_patient, ) from ehrql.query_model.query_graph_rewriter import QueryGraphRewriter @@ -72,37 +68,95 @@ class Coalesce(Series[T]): def apply_transforms(root_node, skip_optimizations=False): - # Note that we're currently sharing `rewriter`, `nodes` and `reverse_index` across - # transforms. While we only have one this is obviously fine! It _might_ be OK as we - # add more depending on whether they're commutative but we should be careful here - # and might decide we want to restructure things to keep the transforms independent. + root_node = apply_sort_rewrites(root_node) + if not skip_optimizations: + root_node = apply_optimizations(root_node) + return root_node + + +def apply_sort_rewrites(root_node): + """ + Sorting rows and then picking the first or last row for each patient is a common + operation in ehrQL but it's responsible for a slightly weird corner of ehrQL's + semantics. This is because it's easy to have "under-specified" results e.g. you sort + some events by date and pick the first but a patient has multiple events recorded on + that day. + + In that case, there's no one correct answer as to what row should be returned and + this creates two related problems: + + * Some databases (e.g. MSSQL) pick randomly (or effectively randomly) meaning you + can run the same query against the same data and get different results each time. + This is confusing and generally bad for research (and not in a theoretical + sense: we've seen this actually happen) so we want to avoid it. + + * Even those databases which return consistent results each time don't necessarily + return the same results as each other. This prevents our automated generative + testing, which relies on comparing results between databases, from working. + + What we want is to ensure that even in the case of under-specified sorts there is a + single correct result defined by ehrQL. But we want to do this without imposing a + significant performance cost on all our sort queries. + + We do this by defining "tiebreaker sorts" for selecting a winning row in the case of + multiple equal candidates. That is, if the rows are equally positioned when sorting + by all the things the user has specified then sort by these other conditions as + well. + + One simple solution to this would be to sort by every column in the table in lexical + order. That would guarantee a single stable result. Of course there might still be + completely duplicate rows, but in that case it doesn't matter which you pick because + the results are necessarily identical. + + The problem with this solution is that we can have very wide tables with many + columns and now every time we sort we need to specify all of these as tiebreaker + conditions. So it fails our "don't impose significant performance cost" condition. + + Another solution is to use just the columns we're actually going to select from the + results as the tiebreaker conditions. Of course, this doesn't guarantee uniqueness + of rows: if we select columns A and B from a table we might have multiple rows with + the same values for A and B but a different value for C. But it does guarantee + uniqueness of _results_: because we're not selecting column C it doesn't matter + what's in it. + + This does exactly what we need, however it introduces an oddity into ehrQL's + semantics which is that you can no longer evaluate it in a bottom-up fashion. The + value of a pick-first-row operation depends on what columns are _going to be_ + selected from that row. + + This means we need to do some pre-processing of the query graph and annotate each + such operation with the set of columns that are selected from it elsewhere in the + query. It would be nicer not to have to do this, but given the above constraints I + think it's the best practical solution. + """ nodes = all_unique_nodes(root_node) reverse_index = build_reverse_index(nodes) + rewriter = QueryGraphRewriter() + for node in nodes: + if isinstance(node, PickOneRowPerPatient): + rewrite_sorts(rewriter, node, reverse_index) + return rewriter.rewrite(root_node) - # This transform is required for ehrQL's sorting semantics to be respected + +def apply_optimizations(root_node): + # These transforms should not affect behaviour but are just performance + # improvements transforms = [ - (PickOneRowPerPatient, rewrite_sorts), + rewrite_case_to_fixed_value_map, + rewrite_case_to_coalesce, ] - # These transforms should not affect behaviour but are just performance - # improvements. For testing purposes we want to be able to disable them. - if not skip_optimizations: - transforms.extend( - [ - (Case, specialize_case_operations), - ] - ) rewriter = QueryGraphRewriter() - for type_, transform in transforms: - apply_transform(rewriter, type_, transform, nodes, reverse_index) - - return rewriter.rewrite(root_node) + for node in all_unique_nodes(root_node): + original = node + for transform in transforms: + if result := transform(node): + node = result + if node is not original: + rewriter.replace(original, node) -def apply_transform(rewriter, type_, transform, nodes, reverse_index): - for node in nodes: - if isinstance(node, type_): - transform(rewriter, node, reverse_index) + return rewriter.rewrite(root_node) def replace_nodes(root_node, replacements): @@ -118,61 +172,10 @@ def replace_nodes(root_node, replacements): def rewrite_sorts(rewriter, node, reverse_index): - """ - Frames are sorted in order to then pick the first or last row for a patient. Multiple sorts - may be applied to give the desired results. Once a single row has been picked, one or more - columns are then selected. - - This results in a subgraph of QM objects like this: - - SelectColumn(A) -+ - | - SelectColumn(B) -+-> PickOneRowPerPatient -> Sort(A) -> Sort(B) -> SelectTable - | - SelectColumn(C) -+ - - There are two transformations that we need to carry out on this stack. - - 1. We annotate PickOneRowPerPatient with the columns that are going to be selected from it, in - order to allow us to generate the appropriate query more easily. - 2. Add sorts so that we have one for each column that will be selected, in order to ensure that - the sort order (and hence the values of the selected columns) is deterministic. - - For the example above the resulting subgraph would be: - - SelectColumn(A) -+ - | - SelectColumn(B) -+-> PickOneRowPerPatientWithColumns -> Sort(A) -> Sort(B) -> Sort(C) -> SelectTable - | - SelectColumn(C) -+ - - Some notes on the additional sorts are in order. - - * A potential lack of determinism in sort order creeps in when a patient has multiple rows with - the same value of the column(s) being sorted on. In this case some databases may give different - orders on different runs of the same query against the same data. - * When this lack of determinism exists, and we select a column that has not been sorted on, the - value returned may change between runs. - * We add sorts only for columns that don't already have them. Duplicate sorts wouldn't cause a - problem, but are conceptually messy and might have a small performance impact. - * We add the sorts below the existing ones so that they have lower priority and are only used to - break any ties in the user-specified sorts. - * When considering the existing sorts we only attend to those that sort directly on selected - columns, not on expressions derived from a column. Such expressions may not be injective and so - may not be sufficient to fully determine the order. As above, duplicates are not a problem. - * We introduce an arbitrary order for the additional sorts (lexically by column name) to ensure - that their order itself is deterministic. - """ # What columns are select from this patient frame? selected_column_names = { c.name for c in reverse_index[node] if isinstance(c, SelectColumn) } - - add_columns_to_pick(rewriter, node, selected_column_names) - add_extra_sorts(rewriter, node, selected_column_names) - - -def add_columns_to_pick(rewriter, node, selected_column_names): selected_columns = frozenset( SelectColumn(node.source, c) for c in selected_column_names ) @@ -186,63 +189,14 @@ def add_columns_to_pick(rewriter, node, selected_column_names): ) -def add_extra_sorts(rewriter, node, selected_column_names): - all_sorts = get_sorts(node.source) - # Add at the bottom of the stack - lowest_sort = all_sorts[0] - - for column in calculate_sorts_to_add(all_sorts, selected_column_names): - new_sort = Sort( - source=lowest_sort.source, - sort_by=make_sortable(SelectColumn(lowest_sort.source, column)), - ) - rewriter.replace( - lowest_sort, - Sort( - source=new_sort, - sort_by=lowest_sort.sort_by, - ), - ) - lowest_sort = new_sort - - -def calculate_sorts_to_add(all_sorts, selected_column_names): - # Don't duplicate existing direct sorts - direct_sorts = [ - sort - for sort in all_sorts - if isinstance(sort.sort_by, SelectColumn) - # SelectColumn operations only count as direct sorts if they're selected from - # the frame we're sorting, not from some other patient frame - and not has_one_row_per_patient(sort.sort_by.source) - ] - existing_sorted_column_names = {sort.sort_by.name for sort in direct_sorts} - sorts_to_add = selected_column_names - existing_sorted_column_names - - # Arbitrary canonical ordering - return sorted(sorts_to_add) - - -def make_sortable(col): - if get_series_type(col) is bool: - # Some databases can't sort booleans (including SQL Server), so we cast them to - # integers - return Function.CastToInt(col) - return col - - -def specialize_case_operations(rewriter, node, reverse_index): - if replacement := rewrite_case_to_fixed_value_map(node): - rewriter.replace(node, replacement) - elif replacement := rewrite_case_to_coalesce(node): - rewriter.replace(node, replacement) - - def rewrite_case_to_fixed_value_map(node): """ If the supplied Case operation can be represented as a FixedValueMap then return that representation, otherwise return None """ + if not isinstance(node, Case): + return + source = MISSING = object() mapping = {} @@ -299,6 +253,9 @@ def rewrite_case_to_coalesce(node): If the supplied Case operation can be represented as a Coalesce then return that representation, otherwise return None """ + if not isinstance(node, Case): + return + sources = [] # We're looking for Case operations where every case is of the form: diff --git a/tests/integration/test_query_engines.py b/tests/integration/test_query_engines.py index 733e52f79..45c79f03a 100644 --- a/tests/integration/test_query_engines.py +++ b/tests/integration/test_query_engines.py @@ -547,10 +547,84 @@ def test_sql_logging(engine, caplog): assert counts[r] > 0, f"No logs matching {r!r}" -# The fix for this turns out to be not straightforward and it's sufficiently edge-case-y -# that it doesn't affect us in practice. So for now we keep the test in place but -# xfailed. -@pytest.mark.xfail +def test_sort_tiebreaker_semantics(engine): + @table + class events(EventFrame): + a = Series(int) + b = Series(int) + c = Series(int) + d = Series(int) + + engine.populate( + { + events: [ + # Check we get the first row ordered by `a` + {"patient_id": 1, "a": 1, "b": 0, "c": 3, "d": 4}, + {"patient_id": 1, "a": 0, "b": 0, "c": 5, "d": 6}, + # When multiple rows are tied for first place, check that we sort by `c` + # and then `d`, in that specific order + {"patient_id": 2, "a": 0, "b": 0, "c": 3, "d": 2}, + {"patient_id": 2, "a": 0, "b": 0, "c": 2, "d": 5}, + {"patient_id": 2, "a": 0, "b": 0, "c": 2, "d": 4}, + # Check that we don't sort by `b`: even though it's the lexically + # smallest column it doesn't appear in our query and so it shouldn't be + # used + {"patient_id": 3, "a": 0, "b": 0, "c": 2, "d": 2}, + {"patient_id": 3, "a": 0, "b": 1, "c": 1, "d": 1}, + ] + } + ) + dataset = create_dataset() + dataset.define_population(events.exists_for_patient()) + # Sort by `a` and then use columns `c` and `d` but ignore `b` + first_by_a = events.sort_by(events.a).first_for_patient() + dataset.c = first_by_a.c + dataset.d = first_by_a.d + + assert engine.extract(dataset) == [ + {"patient_id": 1, "c": 5, "d": 6}, + {"patient_id": 2, "c": 2, "d": 4}, + {"patient_id": 3, "c": 1, "d": 1}, + ] + + +def test_implicit_sort_on_boolean(engine): + # We don't allow explicit sorting on booleans in ehrQL, but sometimes our + # tiebreaking semantics requires this. We want to ensure all engines return the same + # ordering. + + @table + class events(EventFrame): + a = Series(int) + b = Series(bool) + + engine.populate( + { + events: [ + {"patient_id": 1, "a": 0, "b": True}, + {"patient_id": 1, "a": 0, "b": False}, + {"patient_id": 1, "a": 0, "b": None}, + {"patient_id": 2, "a": 0, "b": True}, + {"patient_id": 2, "a": 0, "b": False}, + {"patient_id": 2, "a": 1, "b": None}, + {"patient_id": 3, "a": 0, "b": True}, + {"patient_id": 3, "a": 1, "b": False}, + {"patient_id": 3, "a": 1, "b": None}, + ] + } + ) + dataset = create_dataset() + dataset.define_population(events.exists_for_patient()) + first_by_a = events.sort_by(events.a).first_for_patient() + dataset.b = first_by_a.b + + assert engine.extract(dataset) == [ + {"patient_id": 1, "b": None}, + {"patient_id": 2, "b": False}, + {"patient_id": 3, "b": True}, + ] + + def test_sort_edge_case(engine): # Regression test for a weird edge case in our sort transformation code identified, # as you'd expect, by Hypothesis. See: @@ -599,6 +673,36 @@ class events(EventFrame): ] +def test_remove_redundant_order_clauses(engine): + if engine.name == "in_memory": + pytest.skip("test does not apply to in-memory engine") + + @table + class events(EventFrame): + col_a = Series(int) + col_b = Series(int) + + dataset = create_dataset() + dataset.define_population(events.exists_for_patient()) + # We sort by columns A and B and also use them in our results. This is the situation + # which can lead to redundant order clauses. + first_row = events.sort_by(events.col_a, events.col_b).first_for_patient() + dataset.col_a = first_row.col_a + dataset.col_b = first_row.col_b + + queries = engine.dump_dataset_sql(dataset) + + partition_clauses = [ + match[0] for q in queries if (match := re.search(r"\(PARTITION BY .+\)", q)) + ] + assert len(partition_clauses) == 1 + partition_clause = partition_clauses[0] + + # Check that we only reference each column once + assert partition_clause.count("col_a") == 1 + assert partition_clause.count("col_b") == 1 + + def build_dataset(*, population, variables=None, events=None): return Dataset( population=population, diff --git a/tests/unit/query_model/test_transforms.py b/tests/unit/query_model/test_transforms.py index dff3900f5..77ae6ba78 100644 --- a/tests/unit/query_model/test_transforms.py +++ b/tests/unit/query_model/test_transforms.py @@ -6,7 +6,6 @@ Case, Column, Dataset, - Filter, Function, Parameter, PickOneRowPerPatient, @@ -81,284 +80,6 @@ def test_pick_one_row_per_patient_transform(): assert transformed.variables == expected -def test_adds_one_selected_column_to_sorts(): - events = SelectTable( - "events", - TableSchema(i1=Column(int), i2=Column(int)), - ) - by_i1 = Sort(events, SelectColumn(events, "i1")) - variable = SelectColumn( - PickOneRowPerPatient(source=by_i1, position=Position.FIRST), - "i2", - ) - - by_i2 = Sort(events, SelectColumn(events, "i2")) - by_i2_then_i1 = Sort(by_i2, SelectColumn(events, "i1")) - expected = SelectColumn( - PickOneRowPerPatientWithColumns( - by_i2_then_i1, - Position.FIRST, - selected_columns=frozenset( - { - SelectColumn( - source=by_i2_then_i1, - name="i2", - ), - } - ), - ), - "i2", - ) - - assert apply_transforms(variable) == expected - - -def test_adds_sorts_at_lowest_priority(): - events = SelectTable( - "events", - TableSchema(i1=Column(int), i2=Column(int), i3=Column(int)), - ) - by_i2 = Sort(events, SelectColumn(events, "i2")) - by_i2_then_i1 = Sort(by_i2, SelectColumn(by_i2, "i1")) - variable = SelectColumn( - PickOneRowPerPatient(source=by_i2_then_i1, position=Position.FIRST), - "i3", - ) - - by_i3 = Sort(events, SelectColumn(events, "i3")) - by_i3_then_i2 = Sort(by_i3, SelectColumn(events, "i2")) - by_i3_then_i2_then_i1 = Sort(by_i3_then_i2, SelectColumn(by_i3_then_i2, "i1")) - expected = SelectColumn( - PickOneRowPerPatientWithColumns( - by_i3_then_i2_then_i1, - Position.FIRST, - selected_columns=frozenset( - { - SelectColumn( - source=by_i3_then_i2_then_i1, - name="i3", - ), - } - ), - ), - "i3", - ) - - assert apply_transforms(variable) == expected - - -def test_copes_with_interleaved_sorts_and_filters(): - events = SelectTable( - "events", - TableSchema(i1=Column(int), i2=Column(int), i3=Column(int)), - ) - by_i2 = Sort(events, SelectColumn(events, "i2")) - by_i2_filtered = Filter(by_i2, Value(True)) - by_i2_then_i1 = Sort(by_i2_filtered, SelectColumn(by_i2_filtered, "i1")) - variable = SelectColumn( - PickOneRowPerPatient(source=by_i2_then_i1, position=Position.FIRST), - "i3", - ) - - by_i3 = Sort(events, SelectColumn(events, "i3")) - by_i3_then_i2 = Sort(by_i3, SelectColumn(events, "i2")) - by_i3_then_i2_filtered = Filter(by_i3_then_i2, Value(True)) - by_i3_then_i2_then_i1 = Sort( - by_i3_then_i2_filtered, SelectColumn(by_i3_then_i2_filtered, "i1") - ) - expected = SelectColumn( - PickOneRowPerPatientWithColumns( - by_i3_then_i2_then_i1, - Position.FIRST, - selected_columns=frozenset( - { - SelectColumn( - source=by_i3_then_i2_then_i1, - name="i3", - ), - } - ), - ), - "i3", - ) - - assert apply_transforms(variable) == expected - - -def test_doesnt_duplicate_existing_sorts(): - events = SelectTable( - "events", - TableSchema(i1=Column(int)), - ) - by_i1 = Sort(events, SelectColumn(events, "i1")) - variable = SelectColumn( - PickOneRowPerPatient(source=by_i1, position=Position.FIRST), - "i1", - ) - - expected = SelectColumn( - PickOneRowPerPatientWithColumns( - by_i1, - Position.FIRST, - selected_columns=frozenset( - { - SelectColumn( - source=by_i1, - name="i1", - ), - } - ), - ), - "i1", - ) - - assert apply_transforms(variable) == expected - - -def test_adds_sorts_in_lexical_order_of_column_names(): - events = SelectTable( - "events", - TableSchema(i1=Column(int), iz=Column(int), ia=Column(int)), - ) - by_i1 = Sort(events, SelectColumn(events, "i1")) - first_initial = PickOneRowPerPatient(source=by_i1, position=Position.FIRST) - dataset = dataset_factory( - z=SelectColumn(first_initial, "iz"), - a=SelectColumn(first_initial, "ia"), - ) - - transformed = apply_transforms(dataset) - - by_iz = Sort(events, SelectColumn(events, "iz")) - by_iz_then_ia = Sort(by_iz, SelectColumn(events, "ia")) - by_iz_then_ia_then_i1 = Sort(by_iz_then_ia, SelectColumn(events, "i1")) - first_with_extra_sorts = PickOneRowPerPatientWithColumns( - by_iz_then_ia_then_i1, - Position.FIRST, - selected_columns=frozenset( - { - SelectColumn( - source=by_iz_then_ia_then_i1, - name="iz", - ), - SelectColumn( - source=by_iz_then_ia_then_i1, - name="ia", - ), - } - ), - ) - - expected = dict( - z=SelectColumn(first_with_extra_sorts, "iz"), - a=SelectColumn(first_with_extra_sorts, "ia"), - ) - - assert transformed.variables == expected - - -def test_maps_booleans_to_a_sortable_type(): - events = SelectTable( - "events", - TableSchema(i=Column(int), b=Column(bool)), - ) - by_i = Sort(events, SelectColumn(events, "i")) - variable = SelectColumn( - PickOneRowPerPatient(source=by_i, position=Position.FIRST), - "b", - ) - - b = SelectColumn(events, "b") - by_b = Sort(events, Function.CastToInt(b)) - by_b_then_i = Sort(by_b, SelectColumn(events, "i")) - expected = SelectColumn( - PickOneRowPerPatientWithColumns( - by_b_then_i, - Position.FIRST, - selected_columns=frozenset( - { - SelectColumn( - source=by_b_then_i, - name="b", - ), - } - ), - ), - "b", - ) - - assert apply_transforms(variable) == expected - - -def test_sorts_by_derived_value_handled_correctly(): - events = SelectTable("events", TableSchema(i=Column(int))) - - by_negative_i = Sort(events, Function.Negate(SelectColumn(events, "i"))) - variable = SelectColumn(PickOneRowPerPatient(by_negative_i, Position.FIRST), "i") - - by_i = Sort(events, SelectColumn(events, "i")) - by_i_then_by_negative_i = Sort(by_i, Function.Negate(SelectColumn(events, "i"))) - expected = SelectColumn( - PickOneRowPerPatientWithColumns( - by_i_then_by_negative_i, - Position.FIRST, - frozenset({SelectColumn(by_i_then_by_negative_i, "i")}), - ), - "i", - ) - - assert apply_transforms(variable) == expected - - -def test_identical_operations_are_not_transformed_differently(): - # Query model nodes are intended to be value objects: that is equality is determined - # by value, not identity and equal objects should be intersubstitutable. Approaches - # to query transformation which involve mutation can violate this principle and - # treat equal but non-identical nodes differently. This tests for a specific - # instance of this problem. - events = SelectTable( - "events", - TableSchema(i1=Column(int), i2=Column(int)), - ) - # Construct two equal but non-identical sort-and-picks - first_by_i1_v1 = PickOneRowPerPatient( - Sort(events, SelectColumn(events, "i1")), position=Position.FIRST - ) - first_by_i1_v2 = PickOneRowPerPatient( - Sort(events, SelectColumn(events, "i1")), position=Position.FIRST - ) - - # Select different columns from each one - dataset = dataset_factory( - i1=SelectColumn(first_by_i1_v1, "i1"), - i2=SelectColumn(first_by_i1_v2, "i2"), - ) - - # We expect i2 to be added at the bottom of the stack of sorts - by_i2_then_i1 = Sort( - source=Sort(source=events, sort_by=SelectColumn(source=events, name="i2")), - sort_by=SelectColumn(source=events, name="i1"), - ) - # We expect the selected columns to include both i1 and i2 - pick_with_columns = PickOneRowPerPatientWithColumns( - source=by_i2_then_i1, - position=Position.FIRST, - selected_columns=frozenset( - { - SelectColumn(by_i2_then_i1, "i1"), - SelectColumn(by_i2_then_i1, "i2"), - } - ), - ) - - expected = dict( - i1=SelectColumn(source=pick_with_columns, name="i1"), - i2=SelectColumn(source=pick_with_columns, name="i2"), - ) - - assert apply_transforms(dataset).variables == expected - - def test_substitute_parameters(): node = Function.Negate(Function.Add(Value(10), Parameter("i", int))) transformed = substitute_parameters(node, i=20)