Skip to content

Commit 734a2d5

Browse files
committed
Fix issue #497 by removing assertion from split_disjoint_nodes; and sorting after simplify in preprocess_ts
1 parent f412ec9 commit 734a2d5

File tree

3 files changed

+103
-7
lines changed

3 files changed

+103
-7
lines changed

CHANGELOG.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,14 @@
44

55
In development
66

7+
- `preprocess_ts` now always sorts the tree sequence, which may change the order
8+
of mutations
9+
10+
**Bugfixes**
11+
12+
- Removed an assertion in `split_disjoint_nodes` that is not guarenteed
13+
to hold since table sorting changes the order of mutations in tskit 1.0.0 onwards
14+
715
## [0.2.6] - 2026-03-06
816

917
Maintenance release.

tests/test_util.py

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,15 +193,44 @@ def test_inferred(self):
193193
sequence_length=1e6,
194194
random_seed=1,
195195
)
196+
# use a high mutation rate so as to get >1 mutation per site
196197
ts = msprime.sim_mutations(ts, rate=1e-8, random_seed=1)
197198
sample_data = tsinfer.SampleData.from_tree_sequence(ts)
198-
inferred_ts = tsinfer.infer(sample_data).simplify()
199+
inferred_ts = tsinfer.infer(sample_data)
199200
split_ts = tsdate.util.split_disjoint_nodes(inferred_ts)
200201
assert self.has_disjoint_nodes(inferred_ts)
201202
assert not self.has_disjoint_nodes(split_ts)
202203
assert split_ts.num_edges == inferred_ts.num_edges
203204
assert split_ts.num_nodes > inferred_ts.num_nodes
204205

206+
def test_worked_example(self):
207+
"""
208+
The mutation table is reordered such that the mutation above 4 (7 after
209+
splitting) is placed after the mutation above 5.
210+
"""
211+
tables = tskit.TableCollection()
212+
tables.nodes.set_columns(
213+
time=[0, 0, 0, 0, 1, 1, 2],
214+
flags=[tskit.NODE_IS_SAMPLE] * 4 + [0] * 3,
215+
)
216+
tables.edges.set_columns(
217+
left=[0, 0, 0, 0, 0, 0, 1, 1, 2, 2, 2],
218+
right=[1, 1, 3, 3, 1, 3, 2, 2, 3, 3, 3],
219+
child=[0, 1, 2, 3, 4, 5, 0, 1, 0, 1, 4],
220+
parent=[4, 4, 5, 5, 6, 6, 6, 6, 4, 4, 6],
221+
)
222+
site_id = tables.sites.add_row(position=2.5, ancestral_state="A")
223+
tables.mutations.add_row(site=site_id, node=4, time=1.5, derived_state="G")
224+
tables.mutations.add_row(site=site_id, node=5, time=1.5, derived_state="G")
225+
tables.sequence_length = 3.0
226+
tables.sort()
227+
tables.build_index()
228+
tables.compute_mutation_parents()
229+
ts = tables.tree_sequence()
230+
np.testing.assert_equal(ts.mutations_node, [4, 5])
231+
ts2 = tsdate.util.split_disjoint_nodes(ts)
232+
np.testing.assert_equal(ts2.mutations_node, [5, 7])
233+
205234

206235
class TestPreprocessTs:
207236
def verify(self, ts, caplog, minimum_gap=None, erase_flanks=None, **kwargs):
@@ -396,6 +425,43 @@ def test_sim_example(self):
396425
# Next assumes no breakpoints before first site or after last
397426
assert ts.num_trees == num_trees + first_empty + last_empty
398427

428+
@pytest.mark.parametrize("split_disjoint", [True, False])
429+
@pytest.mark.parametrize("keep_unary", [True, False])
430+
def test_mutation_order(self, split_disjoint, keep_unary):
431+
"""
432+
Check that mutations are in sorted order after preprocessing
433+
"""
434+
tables = tskit.TableCollection()
435+
tables.nodes.set_columns(
436+
time=[0, 0, 0, 0, 1, 1.5, 1.25, 2],
437+
flags=[tskit.NODE_IS_SAMPLE] * 4 + [0] * 4,
438+
)
439+
tables.edges.set_columns(
440+
left=[0, 0, 0, 0, 0, 0, 0],
441+
right=[3, 3, 3, 3, 3, 3, 3],
442+
child=[0, 1, 2, 3, 4, 5, 6],
443+
parent=[4, 4, 6, 6, 5, 7, 7],
444+
)
445+
site_id = tables.sites.add_row(position=2.5, ancestral_state="A")
446+
tables.mutations.add_row(
447+
site=site_id, node=5, time=tskit.UNKNOWN_TIME, derived_state="G"
448+
)
449+
tables.mutations.add_row(
450+
site=site_id, node=6, time=tskit.UNKNOWN_TIME, derived_state="G"
451+
)
452+
tables.sequence_length = 3.0
453+
tables.sort()
454+
tables.build_index()
455+
tables.compute_mutation_parents()
456+
ts0 = tables.tree_sequence()
457+
tables.simplify(keep_unary=keep_unary)
458+
tables.sort()
459+
ts1 = tables.tree_sequence()
460+
ts2 = tsdate.preprocess_ts(
461+
ts0, split_disjoint=split_disjoint, keep_unary=keep_unary
462+
)
463+
assert np.array_equal(ts1.mutations_node, ts2.mutations_node)
464+
399465

400466
class TestUnaryNodeCheck:
401467
def test_inferred(self):
@@ -430,3 +496,25 @@ def test_simulated(self):
430496
assert not tsdate.util.contains_unary_nodes(simplified_ts)
431497
with pytest.raises(ValueError, match="contains unary nodes"):
432498
tsdate.date(ts, mutation_rate=1e-8, method="variational_gamma")
499+
500+
501+
class TestInferencePipeline:
502+
"""
503+
Test that tsinfer->preprocess_ts->tsdate runs through
504+
"""
505+
506+
def test_inference_pipeline(self):
507+
ts = msprime.sim_ancestry(
508+
10,
509+
population_size=1e4,
510+
recombination_rate=1e-8,
511+
sequence_length=1e6,
512+
random_seed=1,
513+
)
514+
print(tskit.__version__, tsinfer.__version__)
515+
ts = msprime.sim_mutations(ts, rate=1e-8, random_seed=1)
516+
sample_data = tsinfer.SampleData.from_tree_sequence(ts)
517+
tsdate.date(
518+
tsdate.preprocess_ts(tsinfer.infer(sample_data)),
519+
mutation_rate=1e-8,
520+
)

tsdate/util.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def preprocess_ts(
114114
:param \\**kwargs: All further keyword arguments are passed to the
115115
{meth}`tskit.TreeSequence.simplify` command.
116116
117-
:return: A tree sequence with gaps removed.
117+
:return: A tree sequence with gaps removed and disjoint node segments split.
118118
:rtype: tskit.TreeSequence
119119
"""
120120

@@ -194,8 +194,13 @@ def preprocess_ts(
194194
record_provenance=False,
195195
**kwargs,
196196
)
197+
tables.sort()
197198
if split_disjoint:
198199
ts = split_disjoint_nodes(tables.tree_sequence(), record_provenance=False)
200+
logger.info(
201+
f"Split disjoint node segments from {tables.nodes.num_rows} "
202+
f"nodes into {ts.num_nodes} nodes"
203+
)
199204
tables = ts.dump_tables()
200205
if record_provenance:
201206
provenance.record_provenance(
@@ -582,11 +587,6 @@ def split_disjoint_nodes(ts, *, record_provenance=None):
582587
tables.sort()
583588
tables.build_index()
584589
tables.compute_mutation_parents()
585-
586-
assert np.array_equal(
587-
tables.nodes.time[tables.mutations.node], ts.nodes_time[ts.mutations_node]
588-
)
589-
590590
if record_provenance:
591591
provenance.record_provenance(
592592
tables,

0 commit comments

Comments
 (0)