Skip to content

Commit 6b5bd0d

Browse files
committed
Initial numba edge diffs
1 parent 1cd6ea9 commit 6b5bd0d

File tree

5 files changed

+146
-0
lines changed

5 files changed

+146
-0
lines changed

.github/workflows/tests.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,14 @@ jobs:
126126
python -m pytest -x --cov=tskit --cov-report=xml --cov-branch -n2 --durations=20 tests
127127
fi
128128
129+
- name: Run numba tests
130+
working-directory: python
131+
run: |
132+
source ~/.profile
133+
conda activate anaconda-client-env
134+
pip install numba
135+
python -m pytest -x --only-numba-tests --cov=tskit.numba --cov-report=xml --cov-branch -n2 --durations=20 tests
136+
129137
- name: Upload coverage to Codecov
130138
uses: codecov/[email protected]
131139
with:

python/tests/conftest.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,20 @@ def pytest_addoption(parser):
6464
default=False,
6565
help="To help debugging, draw lines around the plotboxes in SVG output files",
6666
)
67+
parser.addoption(
68+
"--only-numba-tests",
69+
action="store_true",
70+
default=False,
71+
help="Only run tests marked with @pytest.mark.numba",
72+
)
6773

6874

6975
def pytest_configure(config):
7076
"""
7177
Add docs on the "slow" marker
7278
"""
7379
config.addinivalue_line("markers", "slow: mark test as slow to run")
80+
config.addinivalue_line("markers", "numba: mark test as a Numba test")
7481

7582

7683
def pytest_collection_modifyitems(config, items):
@@ -79,6 +86,18 @@ def pytest_collection_modifyitems(config, items):
7986
for item in items:
8087
if "slow" in item.keywords:
8188
item.add_marker(skip_slow)
89+
if config.getoption("--only-numba-tests"):
90+
only_numba = pytest.mark.skip(reason="--only-numba-tests specified")
91+
for item in items:
92+
if "numba" not in item.keywords:
93+
item.add_marker(only_numba)
94+
else:
95+
numba_tests_skipped = pytest.mark.skip(
96+
reason="--only-numba-tests not specified, skipping numba tests"
97+
)
98+
for item in items:
99+
if "numba" in item.keywords:
100+
item.add_marker(numba_tests_skipped)
82101

83102

84103
@fixture

python/tests/test_jit.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import itertools
2+
import sys
3+
from unittest.mock import patch
4+
5+
import pytest
6+
7+
import tests.tsutil as tsutil
8+
9+
10+
def test_numba_import_error():
11+
# Mock numba as not available
12+
with patch.dict(sys.modules, {"numba": None}):
13+
with pytest.raises(ImportError, match="pip install numba"):
14+
import tskit.jit.numba # noqa: F401
15+
16+
17+
@pytest.mark.numba
18+
@pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences())
19+
def test_correct_trees_forward(ts):
20+
import tskit.jit.numba as jit_numba
21+
22+
numba_ts = jit_numba.numba_tree_sequence(ts)
23+
in_index = ts.indexes_edge_insertion_order
24+
out_index = ts.indexes_edge_removal_order
25+
for numba_edge_diff, edge_diff in itertools.zip_longest(
26+
numba_ts.edge_diffs(), ts.edge_diffs()
27+
):
28+
assert edge_diff.interval == numba_edge_diff.interval
29+
for edge_in_index, edge in itertools.zip_longest(
30+
range(*numba_edge_diff.edges_in_index_range), edge_diff.edges_in
31+
):
32+
assert edge.id == in_index[edge_in_index]
33+
for edge_out_index, edge in itertools.zip_longest(
34+
range(*numba_edge_diff.edges_out_index_range), edge_diff.edges_out
35+
):
36+
assert edge.id == out_index[edge_out_index]

python/tskit/jit/__init__.py

Whitespace-only changes.

python/tskit/jit/numba.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from dataclasses import dataclass
2+
3+
4+
try:
5+
import numba
6+
except ImportError:
7+
raise ImportError(
8+
"Numba is not installed. Please install it with `pip install numba` "
9+
"or `conda install numba` to use the tskit.jit.numba module."
10+
)
11+
12+
13+
# Decorator that makes a jited dataclass by removing certain methods
14+
# that are not compatible with Numba's JIT compilation.
15+
def jitdataclass(cls):
16+
dc_cls = dataclass(cls, eq=False, match_args=False)
17+
del dc_cls.__dataclass_params__
18+
del dc_cls.__dataclass_fields__
19+
del dc_cls.__repr__
20+
del dc_cls.__replace__
21+
return numba.experimental.jitclass(dc_cls)
22+
23+
24+
@jitdataclass
25+
class NumbaEdgeDiff:
26+
interval: numba.types.UniTuple(numba.float64, 2)
27+
edges_in_index_range: numba.types.UniTuple(numba.int32, 2)
28+
edges_out_index_range: numba.types.UniTuple(numba.int32, 2)
29+
30+
31+
@jitdataclass
32+
class NumbaTreeSequence:
33+
num_edges: numba.int64
34+
sequence_length: numba.float64
35+
edges_left: numba.float64[:]
36+
edges_right: numba.float64[:]
37+
indexes_edge_insertion_order: numba.int32[:]
38+
indexes_edge_removal_order: numba.int32[:]
39+
40+
def edge_diffs(self, include_terminal=False):
41+
left = 0.0
42+
j = 0
43+
k = 0
44+
edges_left = self.edges_left
45+
edges_right = self.edges_right
46+
in_order = self.indexes_edge_insertion_order
47+
out_order = self.indexes_edge_removal_order
48+
49+
while j < self.num_edges or left < self.sequence_length:
50+
in_start = j
51+
out_start = k
52+
53+
while k < self.num_edges and edges_right[out_order[k]] == left:
54+
k += 1
55+
while j < self.num_edges and edges_left[in_order[j]] == left:
56+
j += 1
57+
in_end = j
58+
out_end = k
59+
60+
right = self.sequence_length
61+
if j < self.num_edges:
62+
right = min(right, edges_left[in_order[j]])
63+
if k < self.num_edges:
64+
right = min(right, edges_right[out_order[k]])
65+
66+
yield NumbaEdgeDiff((left, right), (in_start, in_end), (out_start, out_end))
67+
68+
left = right
69+
70+
# Handle remaining edges that haven't been processed
71+
if include_terminal:
72+
yield NumbaEdgeDiff((left, right), (j, j), (k, self.num_edges))
73+
74+
75+
def numba_tree_sequence(ts):
76+
return NumbaTreeSequence(
77+
num_edges=ts.num_edges,
78+
sequence_length=ts.sequence_length,
79+
edges_left=ts.edges_left,
80+
edges_right=ts.edges_right,
81+
indexes_edge_insertion_order=ts.indexes_edge_insertion_order,
82+
indexes_edge_removal_order=ts.indexes_edge_removal_order,
83+
)

0 commit comments

Comments
 (0)