Skip to content

Commit 76289b0

Browse files
Move add_vestigial_root to new arg_ops module
Create tsinfer/arg_ops.py for ARG topology operations and move add_vestigial_root from matching.py. Create tests/test_arg_ops.py with the corresponding tests. Update all call sites in matching.py, test_python_c.py, and test_lshmm.py.
1 parent ef579ec commit 76289b0

File tree

6 files changed

+90
-67
lines changed

6 files changed

+90
-67
lines changed

tests/test_arg_ops.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import pytest
2+
import tskit
3+
4+
from tsinfer import arg_ops
5+
6+
7+
class TestAddVestigialRoot:
8+
def test_non_discrete_genome(self):
9+
tables = tskit.TableCollection(sequence_length=1.5)
10+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0)
11+
ts = tables.tree_sequence()
12+
with pytest.raises(ValueError, match="discrete genome"):
13+
arg_ops.add_vestigial_root(ts)
14+
15+
def test_empty_tree_sequence(self):
16+
tables = tskit.TableCollection(sequence_length=1)
17+
ts = tables.tree_sequence()
18+
with pytest.raises(ValueError, match="Emtpy trees"):
19+
arg_ops.add_vestigial_root(ts)

tests/test_lshmm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import tskit
1515

1616
import _tsinfer
17-
from tsinfer import matching
17+
from tsinfer import arg_ops, matching
1818

1919

2020
@dataclasses.dataclass
@@ -78,7 +78,7 @@ class MatcherIndexes:
7878
def __init__(self, in_tables, *, vestigial_root=True, num_alleles=None):
7979
ts = in_tables.tree_sequence()
8080
if vestigial_root:
81-
ts = matching.add_vestigial_root(ts)
81+
ts = arg_ops.add_vestigial_root(ts)
8282
tables = ts.dump_tables()
8383

8484
self.sequence_length = tables.sequence_length

tests/test_matching.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from __future__ import annotations
2424

2525
import numpy as np
26-
import pytest
2726
import tskit
2827

2928
from tsinfer import grouping, matching, vcz
@@ -761,26 +760,6 @@ def test_metadata_survives_multiple_cycles(self):
761760
assert ts2.metadata["sequence_intervals"] == [[10, 51]]
762761

763762

764-
# ---------------------------------------------------------------------------
765-
# TestAddVestigialRoot
766-
# ---------------------------------------------------------------------------
767-
768-
769-
class TestAddVestigialRoot:
770-
def test_non_discrete_genome(self):
771-
tables = tskit.TableCollection(sequence_length=1.5)
772-
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0)
773-
ts = tables.tree_sequence()
774-
with pytest.raises(ValueError, match="discrete genome"):
775-
matching.add_vestigial_root(ts)
776-
777-
def test_empty_tree_sequence(self):
778-
tables = tskit.TableCollection(sequence_length=1)
779-
ts = tables.tree_sequence()
780-
with pytest.raises(ValueError, match="Emtpy trees"):
781-
matching.add_vestigial_root(ts)
782-
783-
784763
# ---------------------------------------------------------------------------
785764
# TestAncestorMatcherWrapper
786765
# ---------------------------------------------------------------------------

tests/test_python_c.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import tskit
2929

3030
import _tsinfer
31-
from tsinfer import matching
31+
from tsinfer import arg_ops
3232

3333
IS_WINDOWS = sys.platform == "win32"
3434

@@ -155,7 +155,7 @@ def make_matcher_indexes_and_matcher(num_samples=4):
155155
tables.sites.add_row(position=1, ancestral_state="A")
156156
tables.mutations.add_row(site=0, node=1, derived_state="T")
157157
ts = tables.tree_sequence()
158-
ts = matching.add_vestigial_root(ts)
158+
ts = arg_ops.add_vestigial_root(ts)
159159
ll_tables = _tsinfer.LightweightTableCollection(ts.sequence_length)
160160
ll_tables.fromdict(ts.dump_tables().asdict())
161161
mi = _tsinfer.MatcherIndexes(ll_tables)
@@ -267,7 +267,7 @@ def test_find_path_match_impossible(self):
267267
tables.sites.add_row(position=1, ancestral_state="A")
268268
# No mutations: all nodes carry allele 0
269269
ts = tables.tree_sequence()
270-
ts = matching.add_vestigial_root(ts)
270+
ts = arg_ops.add_vestigial_root(ts)
271271
ll_tables = _tsinfer.LightweightTableCollection(ts.sequence_length)
272272
ll_tables.fromdict(ts.dump_tables().asdict())
273273
mi = _tsinfer.MatcherIndexes(ll_tables)
@@ -299,7 +299,7 @@ def test_get_traceback_bad_site(self):
299299
class TestMatcherIndexes:
300300
def test_single_tree(self):
301301
ts = tskit.Tree.generate_balanced(4).tree_sequence
302-
ts = matching.add_vestigial_root(ts)
302+
ts = arg_ops.add_vestigial_root(ts)
303303
tables = ts.dump_tables()
304304
ll_tables = _tsinfer.LightweightTableCollection(tables.sequence_length)
305305
ll_tables.fromdict(tables.asdict())
@@ -323,7 +323,7 @@ def test_num_alleles(self):
323323

324324
def test_print_state(self, tmpdir):
325325
ts = tskit.Tree.generate_balanced(4).tree_sequence
326-
ts = matching.add_vestigial_root(ts)
326+
ts = arg_ops.add_vestigial_root(ts)
327327
tables = ts.dump_tables()
328328
ll_tables = _tsinfer.LightweightTableCollection(tables.sequence_length)
329329
ll_tables.fromdict(tables.asdict())
@@ -341,7 +341,7 @@ def test_print_state(self, tmpdir):
341341

342342
def test_print_state_bad_file(self):
343343
ts = tskit.Tree.generate_balanced(4).tree_sequence
344-
ts = matching.add_vestigial_root(ts)
344+
ts = arg_ops.add_vestigial_root(ts)
345345
tables = ts.dump_tables()
346346
ll_tables = _tsinfer.LightweightTableCollection(tables.sequence_length)
347347
ll_tables.fromdict(tables.asdict())

tsinfer/arg_ops.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#
2+
# Copyright (C) 2018-2026 University of Oxford
3+
#
4+
# This file is part of tsinfer.
5+
#
6+
# tsinfer is free software: you can redistribute it and/or modify
7+
# it under the terms of the GNU General Public License as published by
8+
# the Free Software Foundation, either version 3 of the License, or
9+
# (at your option) any later version.
10+
#
11+
# tsinfer is distributed in the hope that it will be useful,
12+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
13+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14+
# GNU General Public License for more details.
15+
#
16+
# You should have received a copy of the GNU General Public License
17+
# along with tsinfer. If not, see <http://www.gnu.org/licenses/>.
18+
#
19+
"""
20+
Operations on ARG (Ancestral Recombination Graph) topology.
21+
"""
22+
23+
import logging
24+
25+
logger = logging.getLogger(__name__)
26+
27+
28+
def add_vestigial_root(ts):
29+
"""
30+
Adds the nodes and edges required by tsinfer to the specified tree sequence
31+
and returns it.
32+
"""
33+
if not ts.discrete_genome:
34+
raise ValueError("Only discrete genome coords supported")
35+
if ts.num_nodes == 0:
36+
raise ValueError("Emtpy trees not supported")
37+
38+
base_tables = ts.dump_tables()
39+
tables = base_tables.copy()
40+
tables.nodes.clear()
41+
t = max(ts.nodes_time)
42+
tables.nodes.add_row(time=t + 1)
43+
num_additonal_nodes = 1
44+
tables.mutations.node += num_additonal_nodes
45+
tables.edges.child += num_additonal_nodes
46+
tables.edges.parent += num_additonal_nodes
47+
for node in base_tables.nodes:
48+
tables.nodes.append(node)
49+
if ts.num_edges > 0:
50+
for tree in ts.trees():
51+
# if tree.num_roots > 1:
52+
# print(ts.draw_text())
53+
root = tree.root + num_additonal_nodes
54+
tables.edges.add_row(
55+
tree.interval.left, tree.interval.right, parent=0, child=root
56+
)
57+
tables.edges.squash()
58+
# FIXME probably don't need to sort here most of the time, or at least
59+
# we can just sort almost the end of the table.
60+
tables.sort()
61+
return tables.tree_sequence()

tsinfer/matching.py

Lines changed: 2 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
import _tsinfer
3535

36-
from . import grouping, vcz
36+
from . import arg_ops, grouping, vcz
3737

3838
logger = logging.getLogger(__name__)
3939

@@ -380,48 +380,12 @@ def extend_ts(
380380
return result_ts
381381

382382

383-
def add_vestigial_root(ts):
384-
"""
385-
Adds the nodes and edges required by tsinfer to the specified tree sequence
386-
and returns it.
387-
"""
388-
if not ts.discrete_genome:
389-
raise ValueError("Only discrete genome coords supported")
390-
if ts.num_nodes == 0:
391-
raise ValueError("Emtpy trees not supported")
392-
393-
base_tables = ts.dump_tables()
394-
tables = base_tables.copy()
395-
tables.nodes.clear()
396-
t = max(ts.nodes_time)
397-
tables.nodes.add_row(time=t + 1)
398-
num_additonal_nodes = 1
399-
tables.mutations.node += num_additonal_nodes
400-
tables.edges.child += num_additonal_nodes
401-
tables.edges.parent += num_additonal_nodes
402-
for node in base_tables.nodes:
403-
tables.nodes.append(node)
404-
if ts.num_edges > 0:
405-
for tree in ts.trees():
406-
# if tree.num_roots > 1:
407-
# print(ts.draw_text())
408-
root = tree.root + num_additonal_nodes
409-
tables.edges.add_row(
410-
tree.interval.left, tree.interval.right, parent=0, child=root
411-
)
412-
tables.edges.squash()
413-
# FIXME probably don't need to sort here most of the time, or at least we
414-
# can just sort almost the end of the table.
415-
tables.sort()
416-
return tables.tree_sequence()
417-
418-
419383
class MatcherIndexes(_tsinfer.MatcherIndexes):
420384
"""Wrapper around the C MatcherIndexes, built from a tree sequence."""
421385

422386
def __init__(self, ts, *, vestigial_root=True, num_alleles=None):
423387
if vestigial_root:
424-
ts = add_vestigial_root(ts)
388+
ts = arg_ops.add_vestigial_root(ts)
425389
tables = ts.dump_tables()
426390
ll_tables = _tsinfer.LightweightTableCollection(tables.sequence_length)
427391
ll_tables.fromdict(tables.asdict())

0 commit comments

Comments
 (0)