Skip to content

Commit 407eef1

Browse files
committed
Fix CI, add extra test
1 parent 6b5bd0d commit 407eef1

File tree

6 files changed

+63
-31
lines changed

6 files changed

+63
-31
lines changed

.github/workflows/tests.yml

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -126,14 +126,6 @@ jobs:
126126
python -m pytest -x --cov=tskit --cov-report=xml --cov-branch -n2 --durations=20 tests
127127
fi
128128
129-
- name: Run numba tests
130-
working-directory: python
131-
run: |
132-
source ~/.profile
133-
conda activate anaconda-client-env
134-
pip install numba
135-
python -m pytest -x --only-numba-tests --cov=tskit.numba --cov-report=xml --cov-branch -n2 --durations=20 tests
136-
137129
- name: Upload coverage to Codecov
138130
uses: codecov/[email protected]
139131
with:

python/requirements/CI-complete/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ lshmm==0.0.8
66
msgpack==1.1.0
77
msprime==1.3.3
88
networkx==3.2.1
9+
numba==0.61.2
910
portion==2.6.0
1011
pytest==8.3.5
1112
pytest-cov==6.0.0

python/requirements/CI-tests-pip/requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@ networkx==3.2.1
1111
msgpack==1.1.0
1212
newick==1.10.0
1313
kastore==0.3.3
14-
jsonschema==4.23.0
14+
jsonschema==4.23.0
15+
numba>=0.60.0

python/tests/conftest.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -64,20 +64,13 @@ def pytest_addoption(parser):
6464
default=False,
6565
help="To help debugging, draw lines around the plotboxes in SVG output files",
6666
)
67-
parser.addoption(
68-
"--only-numba-tests",
69-
action="store_true",
70-
default=False,
71-
help="Only run tests marked with @pytest.mark.numba",
72-
)
7367

7468

7569
def pytest_configure(config):
7670
"""
7771
Add docs on the "slow" marker
7872
"""
7973
config.addinivalue_line("markers", "slow: mark test as slow to run")
80-
config.addinivalue_line("markers", "numba: mark test as a Numba test")
8174

8275

8376
def pytest_collection_modifyitems(config, items):
@@ -86,18 +79,6 @@ def pytest_collection_modifyitems(config, items):
8679
for item in items:
8780
if "slow" in item.keywords:
8881
item.add_marker(skip_slow)
89-
if config.getoption("--only-numba-tests"):
90-
only_numba = pytest.mark.skip(reason="--only-numba-tests specified")
91-
for item in items:
92-
if "numba" not in item.keywords:
93-
item.add_marker(only_numba)
94-
else:
95-
numba_tests_skipped = pytest.mark.skip(
96-
reason="--only-numba-tests not specified, skipping numba tests"
97-
)
98-
for item in items:
99-
if "numba" in item.keywords:
100-
item.add_marker(numba_tests_skipped)
10182

10283

10384
@fixture

python/tests/test_jit.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
import sys
33
from unittest.mock import patch
44

5+
import msprime
6+
import numba
7+
import numpy as np
58
import pytest
69

710
import tests.tsutil as tsutil
@@ -14,7 +17,6 @@ def test_numba_import_error():
1417
import tskit.jit.numba # noqa: F401
1518

1619

17-
@pytest.mark.numba
1820
@pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences())
1921
def test_correct_trees_forward(ts):
2022
import tskit.jit.numba as jit_numba
@@ -34,3 +36,49 @@ def test_correct_trees_forward(ts):
3436
range(*numba_edge_diff.edges_out_index_range), edge_diff.edges_out
3537
):
3638
assert edge.id == out_index[edge_out_index]
39+
40+
41+
def test_using_from_jit_function():
42+
"""
43+
Test that we can use the numba jit function from the tskit.jit module.
44+
"""
45+
import tskit.jit.numba as jit_numba
46+
47+
ts = msprime.sim_ancestry(
48+
samples=10, sequence_length=100, recombination_rate=1, random_seed=42
49+
)
50+
51+
@numba.njit
52+
def _coalescent_nodes_numba(numba_ts, num_nodes, edges_parent):
53+
is_coalescent = np.zeros(num_nodes, dtype=np.int8)
54+
num_children = np.zeros(num_nodes, dtype=np.int64)
55+
for tree_pos in numba_ts.edge_diffs():
56+
for j in range(*tree_pos.edges_out_index_range):
57+
e = numba_ts.indexes_edge_removal_order[j]
58+
num_children[edges_parent[e]] -= 1
59+
for j in range(*tree_pos.edges_in_index_range):
60+
e = numba_ts.indexes_edge_insertion_order[j]
61+
p = edges_parent[e]
62+
num_children[p] += 1
63+
if num_children[p] == 2:
64+
is_coalescent[p] = True
65+
return is_coalescent
66+
67+
def coalescent_nodes_python(ts):
68+
is_coalescent = np.zeros(ts.num_nodes, dtype=bool)
69+
num_children = np.zeros(ts.num_nodes, dtype=int)
70+
for _, edges_out, edges_in in ts.edge_diffs():
71+
for e in edges_out:
72+
num_children[e.parent] -= 1
73+
for e in edges_in:
74+
num_children[e.parent] += 1
75+
if num_children[e.parent] == 2:
76+
# Num_children will always be exactly two once, even arity is greater
77+
is_coalescent[e.parent] = True
78+
return is_coalescent
79+
80+
numba_ts = jit_numba.numba_tree_sequence(ts)
81+
C1 = coalescent_nodes_python(ts)
82+
C2 = _coalescent_nodes_numba(numba_ts, ts.num_nodes, ts.edges_parent)
83+
84+
np.testing.assert_array_equal(C1, C2)

python/tskit/jit/numba.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,20 @@
1313
# Decorator that makes a jited dataclass by removing certain methods
1414
# that are not compatible with Numba's JIT compilation.
1515
def jitdataclass(cls):
16-
dc_cls = dataclass(cls, eq=False, match_args=False)
16+
dc_cls = dataclass(cls, eq=False)
1717
del dc_cls.__dataclass_params__
1818
del dc_cls.__dataclass_fields__
1919
del dc_cls.__repr__
20-
del dc_cls.__replace__
20+
try:
21+
del dc_cls.__replace__
22+
except AttributeError:
23+
# __replace__ is not available in Python < 3.10
24+
pass
25+
try:
26+
del dc_cls.__match_args__
27+
except AttributeError:
28+
# __match_args__ is not available in Python < 3.10
29+
pass
2130
return numba.experimental.jitclass(dc_cls)
2231

2332

0 commit comments

Comments
 (0)