Skip to content
34 changes: 31 additions & 3 deletions ehrql/query_engines/base_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,9 +940,16 @@ def get_table_sort_and_filter(self, node):
@get_table.register(PickOneRowPerPatientWithColumns)
def get_table_pick_one_row_per_patient(self, node):
selected_columns = [self.get_expr(c) for c in node.selected_columns]

sort_conditions = get_sort_conditions(node.source)
# Ensure a unique deterministic result in the case of any ties
# See: ehrql.query_model.transforms.apply_sort_rewrites()
tiebreakers = sorted(node.selected_columns, key=lambda c: c.name)
order_clauses = self.get_order_clauses(
get_sort_conditions(node.source), node.position
sort_conditions + tiebreakers, node.position
)
# Some tiebreakers may already be included in the sort conditions
order_clauses = remove_redundant_order_clauses(order_clauses)

query = self.get_select_query_for_node_domain(node.source)
query = query.add_columns(*selected_columns)
Expand Down Expand Up @@ -1278,12 +1285,33 @@ def get_table_and_filter_conditions(frame):

def get_sort_conditions(frame):
"""
Given a sorted frame, return a tuple of Series which gives the sort order
Given a sorted frame, return a list of Series which gives the sort order
"""
# Sort operations are given to us in order of application which is the reverse of
# order of priority (i.e. the most recently applied sort gives us the primary sort
# condition) so we reverse them here
return tuple(s.sort_by for s in reversed(get_sorts(frame)))
return [s.sort_by for s in reversed(get_sorts(frame))]


def remove_redundant_order_clauses(clauses):
"""
In a list of clauses like:

ORDER BY a, b, a, c

The second `a` is redundant and cannot affect the resulting order
"""
seen = set()
result = []
for clause in clauses:
# We can't use equality to compare SQLAlchemy elements because it's overloaded.
# There is a `compare()` method we can use to determine structural equivalence,
# but compiling to a string is simpler and sufficient for our purposes.
clause_str = str(clause)
if clause_str not in seen:
result.append(clause)
seen.add(clause_str)
return result


def get_cyclic_coalescence(columns):
Expand Down
19 changes: 14 additions & 5 deletions ehrql/query_engines/in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,14 +191,23 @@ def visit_Sort(self, node):
return source.sort(sort_column.sort_index())

def visit_PickOneRowPerPatient(self, node):
ix = {
qm.Position.FIRST: 0,
qm.Position.LAST: -1,
}[node.position]
ix = self.index_for_position(node.position)
return self.visit(node.source).pick_at_index(ix)

def visit_PickOneRowPerPatientWithColumns(self, node):
return self.visit_PickOneRowPerPatient(node)
source = self.visit(node.source)
# Ensure a unique deterministic result in the case of any ties
# See: ehrql.query_model.transforms.apply_sort_rewrites()
selected_columns = [c.name for c in node.selected_columns]
for column in sorted(selected_columns):
tiebreaker_sort = source[column].sort_index()
source = source.sort(tiebreaker_sort, tiebreak_only=True)

ix = self.index_for_position(node.position)
return source.pick_at_index(ix)

def index_for_position(self, position):
return {qm.Position.FIRST: 0, qm.Position.LAST: -1}[position]

def visit_Exists(self, node):
return self.visit(node.source).exists()
Expand Down
56 changes: 42 additions & 14 deletions ehrql/query_engines/in_memory_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,12 @@ def filter(self, predicate): # noqa A003
{name: col.filter(predicate) for name, col in self.name_to_col.items()}
)

def sort(self, sort_index):
def sort(self, sort_index, tiebreak_only=False):
return EventTable(
{name: col.sort(sort_index) for name, col in self.name_to_col.items()}
{
name: col.sort(sort_index, tiebreak_only=tiebreak_only)
for name, col in self.name_to_col.items()
}
)

def pick_at_index(self, ix):
Expand Down Expand Up @@ -356,9 +359,12 @@ def sort_index(self):
{p: rows.sort_index() for p, rows in self.patient_to_rows.items()}
)

def sort(self, sort_index):
def sort(self, sort_index, tiebreak_only=False):
return EventColumn(
{p: rows.sort(sort_index[p]) for p, rows in self.patient_to_rows.items()}
{
p: rows.sort(sort_index[p], tiebreak_only=tiebreak_only)
for p, rows in self.patient_to_rows.items()
}
)

def pick_at_index(self, ix):
Expand All @@ -382,6 +388,12 @@ class Rows(dict):
values belonging to a single patient in an EventColumn.
"""

# Keep track of any sorts which have been applied so we can identify where a given
# row's position is due to its explicit sort order or where it's just arbitrary.
# This allows us to supply tiebreaker sorts _after_ other sorts have already been
# applied.
_sort_index = None

def __repr__(self):
return f"Rows({super().__repr__()})"

Expand Down Expand Up @@ -412,7 +424,9 @@ def filter(self, predicate): # noqa A003
if not isinstance(predicate, Rows):
# This branch is hit when an EventSeries is filtered by a literal boolean.
predicate = Rows({k: predicate for k in self})
return Rows({k: v for k, v in self.items() if predicate[k]})
rows = Rows({k: v for k, v in self.items() if predicate[k]})
rows._sort_index = self._sort_index
return rows

def sort_index(self):
"""Map each value to its ordinal position in set of unique values.
Expand All @@ -421,26 +435,40 @@ def sort_index(self):
resulting sort_index will overspecify the order and we lose the stability of the
sort operation.
"""

sorted_values = sorted(set(self.values()), key=nulls_first_order)
return Rows({k: sorted_values.index(v) for k, v in self.items()})

def sort(self, sort_index):
def sort(self, sort_index, tiebreak_only=False):
"""Sort rows by position in sort_index.

If two values have the same position, their current position is used as a
tiebreaker. This ensures that sorting is stable.
If two values have the same position, their current sort index is used as a
tiebreaker. This ensures that sorting is stable.
"""
if self._sort_index is not None:
if not tiebreak_only:
# The standard case: create a combined sort index where we sort
# primarily on the supplied index and use the existing index to break
# any ties
combined_index = {
k: (sort_index[k], self._sort_index[k]) for k in self.keys()
}
else:
# The special tiebreak case: we sort primarily on the existing index to
# retain the existing order and only use the new index to break any ties
combined_index = {
k: (self._sort_index[k], sort_index[k]) for k in self.keys()
}
sorted_values = sorted(set(combined_index.values()))
sort_index = {k: sorted_values.index(v) for k, v in combined_index.items()}

return Rows(
rows = Rows(
{
k: v
for (_, _, k, v) in sorted(
(sort_index[k], tiebreaker, k, v)
for tiebreaker, (k, v) in enumerate(self.items())
)
for (_, k, v) in sorted((sort_index[k], k, v) for k, v in self.items())
}
)
rows._sort_index = sort_index
return rows

def pick_at_index(self, ix):
"""Return element at given position."""
Expand Down
Loading
Loading