@@ -193,15 +193,44 @@ def test_inferred(self):
193193 sequence_length = 1e6 ,
194194 random_seed = 1 ,
195195 )
196+ # use a high mutation rate so as to get >1 mutation per site
196197 ts = msprime .sim_mutations (ts , rate = 1e-8 , random_seed = 1 )
197198 sample_data = tsinfer .SampleData .from_tree_sequence (ts )
198- inferred_ts = tsinfer .infer (sample_data ). simplify ()
199+ inferred_ts = tsinfer .infer (sample_data )
199200 split_ts = tsdate .util .split_disjoint_nodes (inferred_ts )
200201 assert self .has_disjoint_nodes (inferred_ts )
201202 assert not self .has_disjoint_nodes (split_ts )
202203 assert split_ts .num_edges == inferred_ts .num_edges
203204 assert split_ts .num_nodes > inferred_ts .num_nodes
204205
206+ def test_worked_example (self ):
207+ """
208+ The mutation table is reordered such that the mutation above 4 (7 after
209+ splitting) is placed after the mutation above 5.
210+ """
211+ tables = tskit .TableCollection ()
212+ tables .nodes .set_columns (
213+ time = [0 , 0 , 0 , 0 , 1 , 1 , 2 ],
214+ flags = [tskit .NODE_IS_SAMPLE ] * 4 + [0 ] * 3 ,
215+ )
216+ tables .edges .set_columns (
217+ left = [0 , 0 , 0 , 0 , 0 , 0 , 1 , 1 , 2 , 2 , 2 ],
218+ right = [1 , 1 , 3 , 3 , 1 , 3 , 2 , 2 , 3 , 3 , 3 ],
219+ child = [0 , 1 , 2 , 3 , 4 , 5 , 0 , 1 , 0 , 1 , 4 ],
220+ parent = [4 , 4 , 5 , 5 , 6 , 6 , 6 , 6 , 4 , 4 , 6 ],
221+ )
222+ site_id = tables .sites .add_row (position = 2.5 , ancestral_state = "A" )
223+ tables .mutations .add_row (site = site_id , node = 4 , time = 1.5 , derived_state = "G" )
224+ tables .mutations .add_row (site = site_id , node = 5 , time = 1.5 , derived_state = "G" )
225+ tables .sequence_length = 3.0
226+ tables .sort ()
227+ tables .build_index ()
228+ tables .compute_mutation_parents ()
229+ ts = tables .tree_sequence ()
230+ np .testing .assert_equal (ts .mutations_node , [4 , 5 ])
231+ ts2 = tsdate .util .split_disjoint_nodes (ts )
232+ np .testing .assert_equal (ts2 .mutations_node , [5 , 7 ])
233+
205234
206235class TestPreprocessTs :
207236 def verify (self , ts , caplog , minimum_gap = None , erase_flanks = None , ** kwargs ):
@@ -396,6 +425,43 @@ def test_sim_example(self):
396425 # Next assumes no breakpoints before first site or after last
397426 assert ts .num_trees == num_trees + first_empty + last_empty
398427
428+ @pytest .mark .parametrize ("split_disjoint" , [True , False ])
429+ @pytest .mark .parametrize ("keep_unary" , [True , False ])
430+ def test_mutation_order (self , split_disjoint , keep_unary ):
431+ """
432+ Check that mutations are in sorted order after preprocessing
433+ """
434+ tables = tskit .TableCollection ()
435+ tables .nodes .set_columns (
436+ time = [0 , 0 , 0 , 0 , 1 , 1.5 , 1.25 , 2 ],
437+ flags = [tskit .NODE_IS_SAMPLE ] * 4 + [0 ] * 4 ,
438+ )
439+ tables .edges .set_columns (
440+ left = [0 , 0 , 0 , 0 , 0 , 0 , 0 ],
441+ right = [3 , 3 , 3 , 3 , 3 , 3 , 3 ],
442+ child = [0 , 1 , 2 , 3 , 4 , 5 , 6 ],
443+ parent = [4 , 4 , 6 , 6 , 5 , 7 , 7 ],
444+ )
445+ site_id = tables .sites .add_row (position = 2.5 , ancestral_state = "A" )
446+ tables .mutations .add_row (
447+ site = site_id , node = 5 , time = tskit .UNKNOWN_TIME , derived_state = "G"
448+ )
449+ tables .mutations .add_row (
450+ site = site_id , node = 6 , time = tskit .UNKNOWN_TIME , derived_state = "G"
451+ )
452+ tables .sequence_length = 3.0
453+ tables .sort ()
454+ tables .build_index ()
455+ tables .compute_mutation_parents ()
456+ ts0 = tables .tree_sequence ()
457+ tables .simplify (keep_unary = keep_unary )
458+ tables .sort ()
459+ ts1 = tables .tree_sequence ()
460+ ts2 = tsdate .preprocess_ts (
461+ ts0 , split_disjoint = split_disjoint , keep_unary = keep_unary
462+ )
463+ assert np .array_equal (ts1 .mutations_node , ts2 .mutations_node )
464+
399465
400466class TestUnaryNodeCheck :
401467 def test_inferred (self ):
@@ -430,3 +496,25 @@ def test_simulated(self):
430496 assert not tsdate .util .contains_unary_nodes (simplified_ts )
431497 with pytest .raises (ValueError , match = "contains unary nodes" ):
432498 tsdate .date (ts , mutation_rate = 1e-8 , method = "variational_gamma" )
499+
500+
501+ class TestInferencePipeline :
502+ """
503+ Test that tsinfer->preprocess_ts->tsdate runs through
504+ """
505+
506+ def test_inference_pipeline (self ):
507+ ts = msprime .sim_ancestry (
508+ 10 ,
509+ population_size = 1e4 ,
510+ recombination_rate = 1e-8 ,
511+ sequence_length = 1e6 ,
512+ random_seed = 1 ,
513+ )
514+ print (tskit .__version__ , tsinfer .__version__ )
515+ ts = msprime .sim_mutations (ts , rate = 1e-8 , random_seed = 1 )
516+ sample_data = tsinfer .SampleData .from_tree_sequence (ts )
517+ tsdate .date (
518+ tsdate .preprocess_ts (tsinfer .infer (sample_data )),
519+ mutation_rate = 1e-8 ,
520+ )
0 commit comments