From 7340d8e1a8c56af5b6dbaaa15d5842d49736083d Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Wed, 18 Jun 2025 12:22:46 +0100 Subject: [PATCH 1/9] Initial numba edge diffs --- .github/workflows/tests.yml | 8 ++++ python/tests/conftest.py | 19 +++++++++ python/tests/test_jit.py | 36 ++++++++++++++++ python/tskit/jit/__init__.py | 0 python/tskit/jit/numba.py | 83 ++++++++++++++++++++++++++++++++++++ 5 files changed, 146 insertions(+) create mode 100644 python/tests/test_jit.py create mode 100644 python/tskit/jit/__init__.py create mode 100644 python/tskit/jit/numba.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 30b1c1b5f3..e11a92bf65 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -126,6 +126,14 @@ jobs: python -m pytest -x --cov=tskit --cov-report=xml --cov-branch -n2 --durations=20 tests fi + - name: Run numba tests + working-directory: python + run: | + source ~/.profile + conda activate anaconda-client-env + pip install numba + python -m pytest -x --only-numba-tests --cov=tskit.numba --cov-report=xml --cov-branch -n2 --durations=20 tests + - name: Upload coverage to Codecov uses: codecov/codecov-action@v5.4.0 with: diff --git a/python/tests/conftest.py b/python/tests/conftest.py index d23c019003..1a10f0f5bf 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -64,6 +64,12 @@ def pytest_addoption(parser): default=False, help="To help debugging, draw lines around the plotboxes in SVG output files", ) + parser.addoption( + "--only-numba-tests", + action="store_true", + default=False, + help="Only run tests marked with @pytest.mark.numba", + ) def pytest_configure(config): @@ -71,6 +77,7 @@ def pytest_configure(config): Add docs on the "slow" marker """ config.addinivalue_line("markers", "slow: mark test as slow to run") + config.addinivalue_line("markers", "numba: mark test as a Numba test") def pytest_collection_modifyitems(config, items): @@ -79,6 +86,18 @@ def pytest_collection_modifyitems(config, items): for item in items: if "slow" in item.keywords: item.add_marker(skip_slow) + if config.getoption("--only-numba-tests"): + only_numba = pytest.mark.skip(reason="--only-numba-tests specified") + for item in items: + if "numba" not in item.keywords: + item.add_marker(only_numba) + else: + numba_tests_skipped = pytest.mark.skip( + reason="--only-numba-tests not specified, skipping numba tests" + ) + for item in items: + if "numba" in item.keywords: + item.add_marker(numba_tests_skipped) @fixture diff --git a/python/tests/test_jit.py b/python/tests/test_jit.py new file mode 100644 index 0000000000..3111d49e6f --- /dev/null +++ b/python/tests/test_jit.py @@ -0,0 +1,36 @@ +import itertools +import sys +from unittest.mock import patch + +import pytest + +import tests.tsutil as tsutil + + +def test_numba_import_error(): + # Mock numba as not available + with patch.dict(sys.modules, {"numba": None}): + with pytest.raises(ImportError, match="pip install numba"): + import tskit.jit.numba # noqa: F401 + + +@pytest.mark.numba +@pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) +def test_correct_trees_forward(ts): + import tskit.jit.numba as jit_numba + + numba_ts = jit_numba.numba_tree_sequence(ts) + in_index = ts.indexes_edge_insertion_order + out_index = ts.indexes_edge_removal_order + for numba_edge_diff, edge_diff in itertools.zip_longest( + numba_ts.edge_diffs(), ts.edge_diffs() + ): + assert edge_diff.interval == numba_edge_diff.interval + for edge_in_index, edge in itertools.zip_longest( + range(*numba_edge_diff.edges_in_index_range), edge_diff.edges_in + ): + assert edge.id == in_index[edge_in_index] + for edge_out_index, edge in itertools.zip_longest( + range(*numba_edge_diff.edges_out_index_range), edge_diff.edges_out + ): + assert edge.id == out_index[edge_out_index] diff --git a/python/tskit/jit/__init__.py b/python/tskit/jit/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/tskit/jit/numba.py b/python/tskit/jit/numba.py new file mode 100644 index 0000000000..98290bc591 --- /dev/null +++ b/python/tskit/jit/numba.py @@ -0,0 +1,83 @@ +from dataclasses import dataclass + + +try: + import numba +except ImportError: + raise ImportError( + "Numba is not installed. Please install it with `pip install numba` " + "or `conda install numba` to use the tskit.jit.numba module." + ) + + +# Decorator that makes a jited dataclass by removing certain methods +# that are not compatible with Numba's JIT compilation. +def jitdataclass(cls): + dc_cls = dataclass(cls, eq=False, match_args=False) + del dc_cls.__dataclass_params__ + del dc_cls.__dataclass_fields__ + del dc_cls.__repr__ + del dc_cls.__replace__ + return numba.experimental.jitclass(dc_cls) + + +@jitdataclass +class NumbaEdgeDiff: + interval: numba.types.UniTuple(numba.float64, 2) + edges_in_index_range: numba.types.UniTuple(numba.int32, 2) + edges_out_index_range: numba.types.UniTuple(numba.int32, 2) + + +@jitdataclass +class NumbaTreeSequence: + num_edges: numba.int64 + sequence_length: numba.float64 + edges_left: numba.float64[:] + edges_right: numba.float64[:] + indexes_edge_insertion_order: numba.int32[:] + indexes_edge_removal_order: numba.int32[:] + + def edge_diffs(self, include_terminal=False): + left = 0.0 + j = 0 + k = 0 + edges_left = self.edges_left + edges_right = self.edges_right + in_order = self.indexes_edge_insertion_order + out_order = self.indexes_edge_removal_order + + while j < self.num_edges or left < self.sequence_length: + in_start = j + out_start = k + + while k < self.num_edges and edges_right[out_order[k]] == left: + k += 1 + while j < self.num_edges and edges_left[in_order[j]] == left: + j += 1 + in_end = j + out_end = k + + right = self.sequence_length + if j < self.num_edges: + right = min(right, edges_left[in_order[j]]) + if k < self.num_edges: + right = min(right, edges_right[out_order[k]]) + + yield NumbaEdgeDiff((left, right), (in_start, in_end), (out_start, out_end)) + + left = right + + # Handle remaining edges that haven't been processed + if include_terminal: + yield NumbaEdgeDiff((left, right), (j, j), (k, self.num_edges)) + + +def numba_tree_sequence(ts): + return NumbaTreeSequence( + num_edges=ts.num_edges, + sequence_length=ts.sequence_length, + edges_left=ts.edges_left, + edges_right=ts.edges_right, + indexes_edge_insertion_order=ts.indexes_edge_insertion_order, + indexes_edge_removal_order=ts.indexes_edge_removal_order, + ) From 7fb9ef834c46bb1a9d49d19967b3a26af6b2c8f1 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Wed, 18 Jun 2025 12:35:01 +0100 Subject: [PATCH 2/9] Fix CI, add extra test --- .github/workflows/tests.yml | 8 --- .../requirements/CI-complete/requirements.txt | 1 + .../CI-tests-pip/requirements.txt | 3 +- python/tests/conftest.py | 19 ------- python/tests/test_jit.py | 50 ++++++++++++++++++- python/tskit/jit/numba.py | 13 ++++- 6 files changed, 63 insertions(+), 31 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index e11a92bf65..30b1c1b5f3 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -126,14 +126,6 @@ jobs: python -m pytest -x --cov=tskit --cov-report=xml --cov-branch -n2 --durations=20 tests fi - - name: Run numba tests - working-directory: python - run: | - source ~/.profile - conda activate anaconda-client-env - pip install numba - python -m pytest -x --only-numba-tests --cov=tskit.numba --cov-report=xml --cov-branch -n2 --durations=20 tests - - name: Upload coverage to Codecov uses: codecov/codecov-action@v5.4.0 with: diff --git a/python/requirements/CI-complete/requirements.txt b/python/requirements/CI-complete/requirements.txt index 8920fa33ca..99465a9c5c 100644 --- a/python/requirements/CI-complete/requirements.txt +++ b/python/requirements/CI-complete/requirements.txt @@ -6,6 +6,7 @@ lshmm==0.0.8 msgpack==1.1.0 msprime==1.3.3 networkx==3.2.1 +numba==0.61.2 portion==2.6.0 pytest==8.3.5 pytest-cov==6.0.0 diff --git a/python/requirements/CI-tests-pip/requirements.txt b/python/requirements/CI-tests-pip/requirements.txt index 2e162f68d7..a72ec4227c 100644 --- a/python/requirements/CI-tests-pip/requirements.txt +++ b/python/requirements/CI-tests-pip/requirements.txt @@ -11,4 +11,5 @@ networkx==3.2.1 msgpack==1.1.0 newick==1.10.0 kastore==0.3.3 -jsonschema==4.23.0 \ No newline at end of file +jsonschema==4.23.0 +numba>=0.60.0 \ No newline at end of file diff --git a/python/tests/conftest.py b/python/tests/conftest.py index 1a10f0f5bf..d23c019003 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -64,12 +64,6 @@ def pytest_addoption(parser): default=False, help="To help debugging, draw lines around the plotboxes in SVG output files", ) - parser.addoption( - "--only-numba-tests", - action="store_true", - default=False, - help="Only run tests marked with @pytest.mark.numba", - ) def pytest_configure(config): @@ -77,7 +71,6 @@ def pytest_configure(config): Add docs on the "slow" marker """ config.addinivalue_line("markers", "slow: mark test as slow to run") - config.addinivalue_line("markers", "numba: mark test as a Numba test") def pytest_collection_modifyitems(config, items): @@ -86,18 +79,6 @@ def pytest_collection_modifyitems(config, items): for item in items: if "slow" in item.keywords: item.add_marker(skip_slow) - if config.getoption("--only-numba-tests"): - only_numba = pytest.mark.skip(reason="--only-numba-tests specified") - for item in items: - if "numba" not in item.keywords: - item.add_marker(only_numba) - else: - numba_tests_skipped = pytest.mark.skip( - reason="--only-numba-tests not specified, skipping numba tests" - ) - for item in items: - if "numba" in item.keywords: - item.add_marker(numba_tests_skipped) @fixture diff --git a/python/tests/test_jit.py b/python/tests/test_jit.py index 3111d49e6f..5fe369eb57 100644 --- a/python/tests/test_jit.py +++ b/python/tests/test_jit.py @@ -2,6 +2,9 @@ import sys from unittest.mock import patch +import msprime +import numba +import numpy as np import pytest import tests.tsutil as tsutil @@ -14,7 +17,6 @@ def test_numba_import_error(): import tskit.jit.numba # noqa: F401 -@pytest.mark.numba @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_correct_trees_forward(ts): import tskit.jit.numba as jit_numba @@ -34,3 +36,49 @@ def test_correct_trees_forward(ts): range(*numba_edge_diff.edges_out_index_range), edge_diff.edges_out ): assert edge.id == out_index[edge_out_index] + + +def test_using_from_jit_function(): + """ + Test that we can use the numba jit function from the tskit.jit module. + """ + import tskit.jit.numba as jit_numba + + ts = msprime.sim_ancestry( + samples=10, sequence_length=100, recombination_rate=1, random_seed=42 + ) + + @numba.njit + def _coalescent_nodes_numba(numba_ts, num_nodes, edges_parent): + is_coalescent = np.zeros(num_nodes, dtype=np.int8) + num_children = np.zeros(num_nodes, dtype=np.int64) + for tree_pos in numba_ts.edge_diffs(): + for j in range(*tree_pos.edges_out_index_range): + e = numba_ts.indexes_edge_removal_order[j] + num_children[edges_parent[e]] -= 1 + for j in range(*tree_pos.edges_in_index_range): + e = numba_ts.indexes_edge_insertion_order[j] + p = edges_parent[e] + num_children[p] += 1 + if num_children[p] == 2: + is_coalescent[p] = True + return is_coalescent + + def coalescent_nodes_python(ts): + is_coalescent = np.zeros(ts.num_nodes, dtype=bool) + num_children = np.zeros(ts.num_nodes, dtype=int) + for _, edges_out, edges_in in ts.edge_diffs(): + for e in edges_out: + num_children[e.parent] -= 1 + for e in edges_in: + num_children[e.parent] += 1 + if num_children[e.parent] == 2: + # Num_children will always be exactly two once, even arity is greater + is_coalescent[e.parent] = True + return is_coalescent + + numba_ts = jit_numba.numba_tree_sequence(ts) + C1 = coalescent_nodes_python(ts) + C2 = _coalescent_nodes_numba(numba_ts, ts.num_nodes, ts.edges_parent) + + np.testing.assert_array_equal(C1, C2) diff --git a/python/tskit/jit/numba.py b/python/tskit/jit/numba.py index 98290bc591..bd9680337f 100644 --- a/python/tskit/jit/numba.py +++ b/python/tskit/jit/numba.py @@ -13,11 +13,20 @@ # Decorator that makes a jited dataclass by removing certain methods # that are not compatible with Numba's JIT compilation. def jitdataclass(cls): - dc_cls = dataclass(cls, eq=False, match_args=False) + dc_cls = dataclass(cls, eq=False) del dc_cls.__dataclass_params__ del dc_cls.__dataclass_fields__ del dc_cls.__repr__ - del dc_cls.__replace__ + try: + del dc_cls.__replace__ + except AttributeError: + # __replace__ is not available in Python < 3.10 + pass + try: + del dc_cls.__match_args__ + except AttributeError: + # __match_args__ is not available in Python < 3.10 + pass return numba.experimental.jitclass(dc_cls) From ad3d3d42f37c38f40892d674e7882f54db0dfa5f Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Mon, 23 Jun 2025 07:17:49 +0100 Subject: [PATCH 3/9] Refactor to remove iteration --- python/tests/test_jit.py | 16 ++++---- python/tskit/jit/numba.py | 78 +++++++++++++++++++-------------------- 2 files changed, 47 insertions(+), 47 deletions(-) diff --git a/python/tests/test_jit.py b/python/tests/test_jit.py index 5fe369eb57..7759c8949c 100644 --- a/python/tests/test_jit.py +++ b/python/tests/test_jit.py @@ -24,16 +24,17 @@ def test_correct_trees_forward(ts): numba_ts = jit_numba.numba_tree_sequence(ts) in_index = ts.indexes_edge_insertion_order out_index = ts.indexes_edge_removal_order - for numba_edge_diff, edge_diff in itertools.zip_longest( - numba_ts.edge_diffs(), ts.edge_diffs() - ): - assert edge_diff.interval == numba_edge_diff.interval + tree_pos = numba_ts.tree_position() + ts_edge_diffs = ts.edge_diffs() + while tree_pos.next(): + edge_diff = next(ts_edge_diffs) + assert edge_diff.interval == tree_pos.interval for edge_in_index, edge in itertools.zip_longest( - range(*numba_edge_diff.edges_in_index_range), edge_diff.edges_in + range(*tree_pos.edges_in_index_range), edge_diff.edges_in ): assert edge.id == in_index[edge_in_index] for edge_out_index, edge in itertools.zip_longest( - range(*numba_edge_diff.edges_out_index_range), edge_diff.edges_out + range(*tree_pos.edges_out_index_range), edge_diff.edges_out ): assert edge.id == out_index[edge_out_index] @@ -52,7 +53,8 @@ def test_using_from_jit_function(): def _coalescent_nodes_numba(numba_ts, num_nodes, edges_parent): is_coalescent = np.zeros(num_nodes, dtype=np.int8) num_children = np.zeros(num_nodes, dtype=np.int64) - for tree_pos in numba_ts.edge_diffs(): + tree_pos = numba_ts.tree_position() + while tree_pos.next(): for j in range(*tree_pos.edges_out_index_range): e = numba_ts.indexes_edge_removal_order[j] num_children[edges_parent[e]] -= 1 diff --git a/python/tskit/jit/numba.py b/python/tskit/jit/numba.py index bd9680337f..5ffd23f790 100644 --- a/python/tskit/jit/numba.py +++ b/python/tskit/jit/numba.py @@ -30,13 +30,6 @@ def jitdataclass(cls): return numba.experimental.jitclass(dc_cls) -@jitdataclass -class NumbaEdgeDiff: - interval: numba.types.UniTuple(numba.float64, 2) - edges_in_index_range: numba.types.UniTuple(numba.int32, 2) - edges_out_index_range: numba.types.UniTuple(numba.int32, 2) - - @jitdataclass class NumbaTreeSequence: num_edges: numba.int64 @@ -46,39 +39,44 @@ class NumbaTreeSequence: indexes_edge_insertion_order: numba.int32[:] indexes_edge_removal_order: numba.int32[:] - def edge_diffs(self, include_terminal=False): - left = 0.0 - j = 0 - k = 0 - edges_left = self.edges_left - edges_right = self.edges_right - in_order = self.indexes_edge_insertion_order - out_order = self.indexes_edge_removal_order - - while j < self.num_edges or left < self.sequence_length: - in_start = j - out_start = k - - while k < self.num_edges and edges_right[out_order[k]] == left: - k += 1 - while j < self.num_edges and edges_left[in_order[j]] == left: - j += 1 - in_end = j - out_end = k - - right = self.sequence_length - if j < self.num_edges: - right = min(right, edges_left[in_order[j]]) - if k < self.num_edges: - right = min(right, edges_right[out_order[k]]) - - yield NumbaEdgeDiff((left, right), (in_start, in_end), (out_start, out_end)) - - left = right - - # Handle remaining edges that haven't been processed - if include_terminal: - yield NumbaEdgeDiff((left, right), (j, j), (k, self.num_edges)) + def tree_position(self): + return NumbaTreePosition(self, (0, 0), (0, 0), (0, 0)) + + +@jitdataclass +class NumbaTreePosition: + ts: NumbaTreeSequence + interval: numba.types.UniTuple(numba.float64, 2) + edges_in_index_range: numba.types.UniTuple(numba.int32, 2) + edges_out_index_range: numba.types.UniTuple(numba.int32, 2) + + def next(self): # noqa: A003 + M = self.ts.num_edges + edges_left = self.ts.edges_left + edges_right = self.ts.edges_right + in_order = self.ts.indexes_edge_insertion_order + out_order = self.ts.indexes_edge_removal_order + + left = self.interval[1] + j = self.edges_in_index_range[1] + k = self.edges_out_index_range[1] + + while k < M and edges_right[out_order[k]] == left: + k += 1 + while j < M and edges_left[in_order[j]] == left: + j += 1 + + self.edges_in_index_range = (self.edges_in_index_range[1], j) + self.edges_out_index_range = (self.edges_out_index_range[1], k) + + right = self.ts.sequence_length + if j < M: + right = min(right, edges_left[in_order[j]]) + if k < M: + right = min(right, edges_right[out_order[k]]) + + self.interval = (left, right) + return j < M or left < self.ts.sequence_length def numba_tree_sequence(ts): From b19a2a8b150ab2e3878896fb97cbdd1bae34fa26 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Thu, 26 Jun 2025 13:44:32 +0100 Subject: [PATCH 4/9] Remove dataclass, add ts properties --- python/tests/test_jit.py | 51 +++++++++++++++ python/tskit/jit/numba.py | 127 +++++++++++++++++++++++++++----------- 2 files changed, 143 insertions(+), 35 deletions(-) diff --git a/python/tests/test_jit.py b/python/tests/test_jit.py index 7759c8949c..ed1a0289ae 100644 --- a/python/tests/test_jit.py +++ b/python/tests/test_jit.py @@ -84,3 +84,54 @@ def coalescent_nodes_python(ts): C2 = _coalescent_nodes_numba(numba_ts, ts.num_nodes, ts.edges_parent) np.testing.assert_array_equal(C1, C2) + + +def test_numba_tree_sequence_properties(ts_fixture): + """ + Test that NumbaTreeSequence properties have correct contents and dtypes. + """ + ts = ts_fixture + import tskit.jit.numba as jit_numba + + numba_ts = jit_numba.numba_tree_sequence(ts) + + assert numba_ts.num_edges == ts.num_edges + assert numba_ts.sequence_length == ts.sequence_length + np.testing.assert_array_equal(numba_ts.edges_left, ts.edges_left) + np.testing.assert_array_equal(numba_ts.edges_right, ts.edges_right) + np.testing.assert_array_equal(numba_ts.edges_parent, ts.edges_parent) + np.testing.assert_array_equal(numba_ts.edges_child, ts.edges_child) + assert numba_ts.edges_left.dtype == np.float64 + assert numba_ts.edges_right.dtype == np.float64 + assert numba_ts.edges_parent.dtype == np.int32 + assert numba_ts.edges_child.dtype == np.int32 + np.testing.assert_array_equal(numba_ts.nodes_time, ts.nodes_time) + np.testing.assert_array_equal(numba_ts.nodes_flags, ts.nodes_flags) + np.testing.assert_array_equal(numba_ts.nodes_population, ts.nodes_population) + np.testing.assert_array_equal(numba_ts.nodes_individual, ts.nodes_individual) + assert numba_ts.nodes_time.dtype == np.float64 + assert numba_ts.nodes_flags.dtype == np.uint32 + assert numba_ts.nodes_population.dtype == np.int32 + assert numba_ts.nodes_individual.dtype == np.int32 + np.testing.assert_array_equal(numba_ts.individuals_flags, ts.individuals_flags) + assert numba_ts.individuals_flags.dtype == np.uint32 + np.testing.assert_array_equal(numba_ts.sites_position, ts.sites_position) + assert numba_ts.sites_position.dtype == np.float64 + np.testing.assert_array_equal(numba_ts.mutations_site, ts.mutations_site) + np.testing.assert_array_equal(numba_ts.mutations_node, ts.mutations_node) + np.testing.assert_array_equal(numba_ts.mutations_parent, ts.mutations_parent) + np.testing.assert_array_equal(numba_ts.mutations_time, ts.mutations_time) + assert numba_ts.mutations_site.dtype == np.int32 + assert numba_ts.mutations_node.dtype == np.int32 + assert numba_ts.mutations_parent.dtype == np.int32 + assert numba_ts.mutations_time.dtype == np.float64 + np.testing.assert_array_equal( + numba_ts.indexes_edge_insertion_order, ts.indexes_edge_insertion_order + ) + np.testing.assert_array_equal( + numba_ts.indexes_edge_removal_order, ts.indexes_edge_removal_order + ) + assert numba_ts.indexes_edge_insertion_order.dtype == np.int32 + assert numba_ts.indexes_edge_removal_order.dtype == np.int32 + assert numba_ts.breakpoints.dtype == np.float64 + np.testing.assert_array_equal(numba_ts.breakpoints, ts.breakpoints(as_array=True)) diff --git a/python/tskit/jit/numba.py b/python/tskit/jit/numba.py index 5ffd23f790..f0a19a622d 100644 --- a/python/tskit/jit/numba.py +++ b/python/tskit/jit/numba.py @@ -1,6 +1,3 @@ -from dataclasses import dataclass - - try: import numba except ImportError: @@ -10,45 +7,92 @@ ) -# Decorator that makes a jited dataclass by removing certain methods -# that are not compatible with Numba's JIT compilation. -def jitdataclass(cls): - dc_cls = dataclass(cls, eq=False) - del dc_cls.__dataclass_params__ - del dc_cls.__dataclass_fields__ - del dc_cls.__repr__ - try: - del dc_cls.__replace__ - except AttributeError: - # __replace__ is not available in Python < 3.10 - pass - try: - del dc_cls.__match_args__ - except AttributeError: - # __match_args__ is not available in Python < 3.10 - pass - return numba.experimental.jitclass(dc_cls) - - -@jitdataclass +tree_sequence_spec = [ + ("num_edges", numba.int64), + ("sequence_length", numba.float64), + ("edges_left", numba.float64[:]), + ("edges_right", numba.float64[:]), + ("indexes_edge_insertion_order", numba.int32[:]), + ("indexes_edge_removal_order", numba.int32[:]), + ("individuals_flags", numba.uint32[:]), + ("nodes_time", numba.float64[:]), + ("nodes_flags", numba.uint32[:]), + ("nodes_population", numba.int32[:]), + ("nodes_individual", numba.int32[:]), + ("edges_parent", numba.int32[:]), + ("edges_child", numba.int32[:]), + ("sites_position", numba.float64[:]), + ("mutations_site", numba.int32[:]), + ("mutations_node", numba.int32[:]), + ("mutations_parent", numba.int32[:]), + ("mutations_time", numba.float64[:]), + ("breakpoints", numba.float64[:]), +] + + +@numba.experimental.jitclass(tree_sequence_spec) class NumbaTreeSequence: - num_edges: numba.int64 - sequence_length: numba.float64 - edges_left: numba.float64[:] - edges_right: numba.float64[:] - indexes_edge_insertion_order: numba.int32[:] - indexes_edge_removal_order: numba.int32[:] + def __init__( + self, + num_edges, + sequence_length, + edges_left, + edges_right, + indexes_edge_insertion_order, + indexes_edge_removal_order, + individuals_flags, + nodes_time, + nodes_flags, + nodes_population, + nodes_individual, + edges_parent, + edges_child, + sites_position, + mutations_site, + mutations_node, + mutations_parent, + mutations_time, + breakpoints, + ): + self.num_edges = num_edges + self.sequence_length = sequence_length + self.edges_left = edges_left + self.edges_right = edges_right + self.indexes_edge_insertion_order = indexes_edge_insertion_order + self.indexes_edge_removal_order = indexes_edge_removal_order + self.individuals_flags = individuals_flags + self.nodes_time = nodes_time + self.nodes_flags = nodes_flags + self.nodes_population = nodes_population + self.nodes_individual = nodes_individual + self.edges_parent = edges_parent + self.edges_child = edges_child + self.sites_position = sites_position + self.mutations_site = mutations_site + self.mutations_node = mutations_node + self.mutations_parent = mutations_parent + self.mutations_time = mutations_time + self.breakpoints = breakpoints def tree_position(self): return NumbaTreePosition(self, (0, 0), (0, 0), (0, 0)) -@jitdataclass +tree_position_spec = [ + ("ts", NumbaTreeSequence.class_type.instance_type), + ("interval", numba.types.UniTuple(numba.float64, 2)), + ("edges_in_index_range", numba.types.UniTuple(numba.int32, 2)), + ("edges_out_index_range", numba.types.UniTuple(numba.int32, 2)), +] + + +@numba.experimental.jitclass(tree_position_spec) class NumbaTreePosition: - ts: NumbaTreeSequence - interval: numba.types.UniTuple(numba.float64, 2) - edges_in_index_range: numba.types.UniTuple(numba.int32, 2) - edges_out_index_range: numba.types.UniTuple(numba.int32, 2) + def __init__(self, ts, interval, edges_in_index_range, edges_out_index_range): + self.ts = ts + self.interval = interval + self.edges_in_index_range = edges_in_index_range + self.edges_out_index_range = edges_out_index_range def next(self): # noqa: A003 M = self.ts.num_edges @@ -87,4 +131,17 @@ def numba_tree_sequence(ts): edges_right=ts.edges_right, indexes_edge_insertion_order=ts.indexes_edge_insertion_order, indexes_edge_removal_order=ts.indexes_edge_removal_order, + individuals_flags=ts.individuals_flags, + nodes_time=ts.nodes_time, + nodes_flags=ts.nodes_flags, + nodes_population=ts.nodes_population, + nodes_individual=ts.nodes_individual, + edges_parent=ts.edges_parent, + edges_child=ts.edges_child, + sites_position=ts.sites_position, + mutations_site=ts.mutations_site, + mutations_node=ts.mutations_node, + mutations_parent=ts.mutations_parent, + mutations_time=ts.mutations_time, + breakpoints=ts.breakpoints(as_array=True), ) From e50307ddd9ed28ab3469b4261e9907593d8fc35b Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Thu, 26 Jun 2025 14:51:52 +0100 Subject: [PATCH 5/9] Use tsutil style next and prev --- python/tests/test_jit.py | 44 ++++++++--- python/tskit/jit/numba.py | 152 +++++++++++++++++++++++++++++++------- 2 files changed, 158 insertions(+), 38 deletions(-) diff --git a/python/tests/test_jit.py b/python/tests/test_jit.py index ed1a0289ae..b422377a69 100644 --- a/python/tests/test_jit.py +++ b/python/tests/test_jit.py @@ -8,6 +8,7 @@ import pytest import tests.tsutil as tsutil +import tskit def test_numba_import_error(): @@ -22,21 +23,43 @@ def test_correct_trees_forward(ts): import tskit.jit.numba as jit_numba numba_ts = jit_numba.numba_tree_sequence(ts) - in_index = ts.indexes_edge_insertion_order - out_index = ts.indexes_edge_removal_order tree_pos = numba_ts.tree_position() ts_edge_diffs = ts.edge_diffs() while tree_pos.next(): edge_diff = next(ts_edge_diffs) assert edge_diff.interval == tree_pos.interval for edge_in_index, edge in itertools.zip_longest( - range(*tree_pos.edges_in_index_range), edge_diff.edges_in + range(tree_pos.in_range.start, tree_pos.in_range.stop), edge_diff.edges_in ): - assert edge.id == in_index[edge_in_index] + assert edge.id == tree_pos.in_range.order[edge_in_index] for edge_out_index, edge in itertools.zip_longest( - range(*tree_pos.edges_out_index_range), edge_diff.edges_out + range(tree_pos.out_range.start, tree_pos.out_range.stop), + edge_diff.edges_out, ): - assert edge.id == out_index[edge_out_index] + assert edge.id == tree_pos.out_range.order[edge_out_index] + + +@pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) +def test_correct_trees_backwards(ts): + import tskit.jit.numba as jit_numba + + numba_ts = jit_numba.numba_tree_sequence(ts) + tree_pos = numba_ts.tree_position() + ts_edge_diffs = ts.edge_diffs(direction=tskit.REVERSE) + while tree_pos.prev(): + edge_diff = next(ts_edge_diffs) + assert edge_diff.interval == tree_pos.interval + for edge_in_index, edge in itertools.zip_longest( + range(tree_pos.in_range.start, tree_pos.in_range.stop, -1), + edge_diff.edges_in, + ): + + assert edge.id == tree_pos.in_range.order[edge_in_index] + for edge_out_index, edge in itertools.zip_longest( + range(tree_pos.out_range.start, tree_pos.out_range.stop, -1), + edge_diff.edges_out, + ): + assert edge.id == tree_pos.out_range.order[edge_out_index] def test_using_from_jit_function(): @@ -55,11 +78,11 @@ def _coalescent_nodes_numba(numba_ts, num_nodes, edges_parent): num_children = np.zeros(num_nodes, dtype=np.int64) tree_pos = numba_ts.tree_position() while tree_pos.next(): - for j in range(*tree_pos.edges_out_index_range): - e = numba_ts.indexes_edge_removal_order[j] + for j in range(tree_pos.out_range.start, tree_pos.out_range.stop): + e = tree_pos.out_range.order[j] num_children[edges_parent[e]] -= 1 - for j in range(*tree_pos.edges_in_index_range): - e = numba_ts.indexes_edge_insertion_order[j] + for j in range(tree_pos.in_range.start, tree_pos.in_range.stop): + e = tree_pos.in_range.order[j] p = edges_parent[e] num_children[p] += 1 if num_children[p] == 2: @@ -95,6 +118,7 @@ def test_numba_tree_sequence_properties(ts_fixture): numba_ts = jit_numba.numba_tree_sequence(ts) + assert numba_ts.num_trees == ts.num_trees assert numba_ts.num_edges == ts.num_edges assert numba_ts.sequence_length == ts.sequence_length np.testing.assert_array_equal(numba_ts.edges_left, ts.edges_left) diff --git a/python/tskit/jit/numba.py b/python/tskit/jit/numba.py index f0a19a622d..bc6b238126 100644 --- a/python/tskit/jit/numba.py +++ b/python/tskit/jit/numba.py @@ -1,3 +1,5 @@ +import numpy as np + try: import numba except ImportError: @@ -7,8 +9,13 @@ ) +FORWARD = 1 +REVERSE = -1 + + tree_sequence_spec = [ - ("num_edges", numba.int64), + ("num_trees", numba.int32), + ("num_edges", numba.int32), ("sequence_length", numba.float64), ("edges_left", numba.float64[:]), ("edges_right", numba.float64[:]), @@ -34,6 +41,7 @@ class NumbaTreeSequence: def __init__( self, + num_trees, num_edges, sequence_length, edges_left, @@ -54,6 +62,7 @@ def __init__( mutations_time, breakpoints, ): + self.num_trees = num_trees self.num_edges = num_edges self.sequence_length = sequence_length self.edges_left = edges_left @@ -75,56 +84,143 @@ def __init__( self.breakpoints = breakpoints def tree_position(self): - return NumbaTreePosition(self, (0, 0), (0, 0), (0, 0)) + return NumbaTreePosition(self) + + +edge_range_spec = [ + ("start", numba.int32), + ("stop", numba.int32), + ("order", numba.int32[:]), +] + + +@numba.experimental.jitclass(edge_range_spec) +class NumbaEdgeRange: + def __init__(self, start, stop, order): + self.start = start + self.stop = stop + self.order = order tree_position_spec = [ ("ts", NumbaTreeSequence.class_type.instance_type), + ("index", numba.int32), + ("direction", numba.int32), ("interval", numba.types.UniTuple(numba.float64, 2)), - ("edges_in_index_range", numba.types.UniTuple(numba.int32, 2)), - ("edges_out_index_range", numba.types.UniTuple(numba.int32, 2)), + ("in_range", NumbaEdgeRange.class_type.instance_type), + ("out_range", NumbaEdgeRange.class_type.instance_type), ] @numba.experimental.jitclass(tree_position_spec) class NumbaTreePosition: - def __init__(self, ts, interval, edges_in_index_range, edges_out_index_range): + def __init__(self, ts): self.ts = ts - self.interval = interval - self.edges_in_index_range = edges_in_index_range - self.edges_out_index_range = edges_out_index_range + self.index = -1 + self.direction = 0 + self.interval = (0, 0) + self.in_range = NumbaEdgeRange(0, 0, np.zeros(0, dtype=numba.int32)) + self.out_range = NumbaEdgeRange(0, 0, np.zeros(0, dtype=numba.int32)) + + def set_null(self): + self.index = -1 + self.interval = (0, 0) def next(self): # noqa: A003 M = self.ts.num_edges - edges_left = self.ts.edges_left - edges_right = self.ts.edges_right - in_order = self.ts.indexes_edge_insertion_order - out_order = self.ts.indexes_edge_removal_order + breakpoints = self.ts.breakpoints + left_coords = self.ts.edges_left + left_order = self.ts.indexes_edge_insertion_order + right_coords = self.ts.edges_right + right_order = self.ts.indexes_edge_removal_order + + if self.index == -1: + self.interval = (self.interval[0], 0) + self.out_range.stop = 0 + self.in_range.stop = 0 + self.direction = FORWARD + + if self.direction == FORWARD: + left_current_index = self.in_range.stop + right_current_index = self.out_range.stop + else: + left_current_index = self.out_range.stop + 1 + right_current_index = self.in_range.stop + 1 left = self.interval[1] - j = self.edges_in_index_range[1] - k = self.edges_out_index_range[1] - while k < M and edges_right[out_order[k]] == left: - k += 1 - while j < M and edges_left[in_order[j]] == left: + j = right_current_index + self.out_range.start = j + while j < M and right_coords[right_order[j]] == left: j += 1 + self.out_range.stop = j + self.out_range.order = right_order - self.edges_in_index_range = (self.edges_in_index_range[1], j) - self.edges_out_index_range = (self.edges_out_index_range[1], k) - - right = self.ts.sequence_length - if j < M: - right = min(right, edges_left[in_order[j]]) - if k < M: - right = min(right, edges_right[out_order[k]]) - - self.interval = (left, right) - return j < M or left < self.ts.sequence_length + j = left_current_index + self.in_range.start = j + while j < M and left_coords[left_order[j]] == left: + j += 1 + self.in_range.stop = j + self.in_range.order = left_order + + self.direction = FORWARD + self.index += 1 + if self.index == self.ts.num_trees: + self.set_null() + else: + self.interval = (left, breakpoints[self.index + 1]) + return self.index != -1 + + def prev(self): + M = self.ts.num_edges + breakpoints = self.ts.breakpoints + right_coords = self.ts.edges_right + right_order = self.ts.indexes_edge_removal_order + left_coords = self.ts.edges_left + left_order = self.ts.indexes_edge_insertion_order + + if self.index == -1: + self.index = self.ts.num_trees + self.interval = (self.ts.sequence_length, self.interval[1]) + self.in_range.stop = M - 1 + self.out_range.stop = M - 1 + self.direction = REVERSE + + if self.direction == REVERSE: + left_current_index = self.out_range.stop + right_current_index = self.in_range.stop + else: + left_current_index = self.in_range.stop - 1 + right_current_index = self.out_range.stop - 1 + + right = self.interval[0] + + j = left_current_index + self.out_range.start = j + while j >= 0 and left_coords[left_order[j]] == right: + j -= 1 + self.out_range.stop = j + self.out_range.order = left_order + + j = right_current_index + self.in_range.start = j + while j >= 0 and right_coords[right_order[j]] == right: + j -= 1 + self.in_range.stop = j + self.in_range.order = right_order + + self.direction = REVERSE + self.index -= 1 + if self.index == -1: + self.set_null() + else: + self.interval = (breakpoints[self.index], right) + return self.index != -1 def numba_tree_sequence(ts): return NumbaTreeSequence( + num_trees=ts.num_trees, num_edges=ts.num_edges, sequence_length=ts.sequence_length, edges_left=ts.edges_left, From 89d45132c84a8ee61d36558ac148fcf9f536b690 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Tue, 1 Jul 2025 14:13:34 +0100 Subject: [PATCH 6/9] More tests --- python/tests/test_jit.py | 90 +++++++++++++++++++++++++++++++++++++--- 1 file changed, 84 insertions(+), 6 deletions(-) diff --git a/python/tests/test_jit.py b/python/tests/test_jit.py index b422377a69..ea9629fba3 100644 --- a/python/tests/test_jit.py +++ b/python/tests/test_jit.py @@ -63,9 +63,7 @@ def test_correct_trees_backwards(ts): def test_using_from_jit_function(): - """ - Test that we can use the numba jit function from the tskit.jit module. - """ + # Test we can use from a numba jitted function import tskit.jit.numba as jit_numba ts = msprime.sim_ancestry( @@ -110,9 +108,6 @@ def coalescent_nodes_python(ts): def test_numba_tree_sequence_properties(ts_fixture): - """ - Test that NumbaTreeSequence properties have correct contents and dtypes. - """ ts = ts_fixture import tskit.jit.numba as jit_numba @@ -159,3 +154,86 @@ def test_numba_tree_sequence_properties(ts_fixture): assert numba_ts.indexes_edge_removal_order.dtype == np.int32 assert numba_ts.breakpoints.dtype == np.float64 np.testing.assert_array_equal(numba_ts.breakpoints, ts.breakpoints(as_array=True)) + + +def test_numba_edge_range(): + import tskit.jit.numba as jit_numba + + order = np.array([1, 3, 2, 0], dtype=np.int32) + edge_range = jit_numba.NumbaEdgeRange(start=1, stop=3, order=order) + + assert edge_range.start == 1 + assert edge_range.stop == 3 + np.testing.assert_array_equal(edge_range.order, order) + + +def test_numba_tree_position_set_null(ts_fixture): + import tskit.jit.numba as jit_numba + + ts = msprime.sim_ancestry( + samples=5, sequence_length=10, recombination_rate=0.1, random_seed=42 + ) + numba_ts = jit_numba.numba_tree_sequence(ts_fixture) + tree_pos = numba_ts.tree_position() + + # Move to a valid position first + tree_pos.next() + initial_interval = tree_pos.interval + assert tree_pos.index != -1 + assert initial_interval != (0, 0) + + # Test set_null + tree_pos.set_null() + assert tree_pos.index == -1 + assert tree_pos.interval == (0, 0) + + +def test_numba_tree_position_constants(ts_fixture): + import tskit.jit.numba as jit_numba + + ts = msprime.sim_ancestry( + samples=5, sequence_length=10, recombination_rate=0.1, random_seed=42 + ) + numba_ts = jit_numba.numba_tree_sequence(ts_fixture) + tree_pos = numba_ts.tree_position() + + # Initial direction should be 0 + assert tree_pos.direction == 0 + + # After next(), direction should be FORWARD + tree_pos.next() + assert tree_pos.direction == jit_numba.FORWARD + assert tree_pos.direction == 1 + + # After prev(), direction should be REVERSE + tree_pos.prev() + assert tree_pos.direction == jit_numba.REVERSE + assert tree_pos.direction == -1 + + +def test_numba_tree_position_edge_cases(): + import tskit.jit.numba as jit_numba + + # Test with empty tree sequence + tables = tskit.TableCollection(sequence_length=1.0) + empty_ts = tables.tree_sequence() + numba_ts = jit_numba.numba_tree_sequence(empty_ts) + tree_pos = numba_ts.tree_position() + + # Should have exactly one tree + assert tree_pos.next() + assert tree_pos.index == 0 + assert tree_pos.interval == (0.0, 1.0) + assert not tree_pos.next() # No more trees + assert tree_pos.index == -1 + + # Test with single tree (with edges) + ts = msprime.sim_ancestry(samples=2, random_seed=42) # No recombination + numba_ts = jit_numba.numba_tree_sequence(ts) + tree_pos = numba_ts.tree_position() + + # Should have exactly one tree + assert tree_pos.next() + assert tree_pos.index == 0 + assert not tree_pos.next() # No more trees + assert tree_pos.index == -1 From e5e152a5bf0e6e4a25bd31ba997401839320b8bd Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Wed, 2 Jul 2025 11:59:31 +0100 Subject: [PATCH 7/9] Attempt to get coverage --- .github/workflows/tests.yml | 4 ++-- python/tests/test_jit.py | 26 ++++++++++---------------- 2 files changed, 12 insertions(+), 18 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 30b1c1b5f3..391edd900a 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -121,7 +121,7 @@ jobs: source ~/.profile conda activate anaconda-client-env if [[ "${{ matrix.os }}" == "windows-latest" ]]; then - python -m pytest -x --skip-slow --cov=tskit --cov-report=xml --cov-branch -n2 --durations=20 tests + python -m pytest -x --skip-slow --cov=tskit --cov=tskit.jit --cov-report=xml --cov-branch -n2 --durations=20 tests else python -m pytest -x --cov=tskit --cov-report=xml --cov-branch -n2 --durations=20 tests fi @@ -170,7 +170,7 @@ jobs: - name: Run tests with numpy 1.x working-directory: python run: | - python -m pytest -x --cov=tskit --cov-report=xml --cov-branch -n2 tests/test_lowlevel.py tests/test_highlevel.py + python -m pytest -x --cov=tskit --cov=tskit.jit --cov-report=xml --cov-branch -n2 tests/test_lowlevel.py tests/test_highlevel.py - name: Upload coverage to Codecov uses: codecov/codecov-action@v5.4.0 diff --git a/python/tests/test_jit.py b/python/tests/test_jit.py index ea9629fba3..52c49b26ef 100644 --- a/python/tests/test_jit.py +++ b/python/tests/test_jit.py @@ -161,7 +161,7 @@ def test_numba_edge_range(): order = np.array([1, 3, 2, 0], dtype=np.int32) edge_range = jit_numba.NumbaEdgeRange(start=1, stop=3, order=order) - + assert edge_range.start == 1 assert edge_range.stop == 3 np.testing.assert_array_equal(edge_range.order, order) @@ -170,18 +170,15 @@ def test_numba_edge_range(): def test_numba_tree_position_set_null(ts_fixture): import tskit.jit.numba as jit_numba - ts = msprime.sim_ancestry( - samples=5, sequence_length=10, recombination_rate=0.1, random_seed=42 - ) numba_ts = jit_numba.numba_tree_sequence(ts_fixture) tree_pos = numba_ts.tree_position() - + # Move to a valid position first tree_pos.next() initial_interval = tree_pos.interval assert tree_pos.index != -1 assert initial_interval != (0, 0) - + # Test set_null tree_pos.set_null() assert tree_pos.index == -1 @@ -191,21 +188,18 @@ def test_numba_tree_position_set_null(ts_fixture): def test_numba_tree_position_constants(ts_fixture): import tskit.jit.numba as jit_numba - ts = msprime.sim_ancestry( - samples=5, sequence_length=10, recombination_rate=0.1, random_seed=42 - ) numba_ts = jit_numba.numba_tree_sequence(ts_fixture) tree_pos = numba_ts.tree_position() - + # Initial direction should be 0 assert tree_pos.direction == 0 - + # After next(), direction should be FORWARD tree_pos.next() assert tree_pos.direction == jit_numba.FORWARD assert tree_pos.direction == 1 - - # After prev(), direction should be REVERSE + + # After prev(), direction should be REVERSE tree_pos.prev() assert tree_pos.direction == jit_numba.REVERSE assert tree_pos.direction == -1 @@ -219,19 +213,19 @@ def test_numba_tree_position_edge_cases(): empty_ts = tables.tree_sequence() numba_ts = jit_numba.numba_tree_sequence(empty_ts) tree_pos = numba_ts.tree_position() - + # Should have exactly one tree assert tree_pos.next() assert tree_pos.index == 0 assert tree_pos.interval == (0.0, 1.0) assert not tree_pos.next() # No more trees assert tree_pos.index == -1 - + # Test with single tree (with edges) ts = msprime.sim_ancestry(samples=2, random_seed=42) # No recombination numba_ts = jit_numba.numba_tree_sequence(ts) tree_pos = numba_ts.tree_position() - + # Should have exactly one tree assert tree_pos.next() assert tree_pos.index == 0 From 57522649d0d27ce37ece1b4e339aec62fe0a4017 Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Wed, 2 Jul 2025 16:36:03 +0100 Subject: [PATCH 8/9] Docs first pass --- docs/_toc.yml | 1 + docs/numba.md | 150 ++++++++++++++++++++++++++++++ python/tskit/jit/numba.py | 189 +++++++++++++++++++++++++++++++++++++- 3 files changed, 338 insertions(+), 2 deletions(-) create mode 100644 docs/numba.md diff --git a/docs/_toc.yml b/docs/_toc.yml index 168e7e8dd8..fb1a70735c 100644 --- a/docs/_toc.yml +++ b/docs/_toc.yml @@ -20,6 +20,7 @@ parts: - caption: Interfaces chapters: - file: python-api + - file: numba - file: c-api - file: cli - file: file-formats diff --git a/docs/numba.md b/docs/numba.md new file mode 100644 index 0000000000..8bd7dde473 --- /dev/null +++ b/docs/numba.md @@ -0,0 +1,150 @@ +--- +jupytext: + text_representation: + extension: .md + format_name: myst + format_version: 0.12 + jupytext_version: 1.9.1 +kernelspec: + display_name: Python 3 + language: python + name: python3 +--- + +```{currentmodule} tskit.jit.numba +``` + +(sec_numba)= + +# Numba Integration + +The `tskit.jit.numba` module provides classes for working with tree sequences +from [Numba](https://numba.pydata.org/) jit-compiled Python code. Such code can run +upto hundreds of times faster than normal Python, yet avoids the difficulties of writing +C or other low-level code. + +:::{note} +Numba is not a direct dependency of tskit, so will not be avaliable unless installed: + +```bash +pip install numba +``` + +or + +```bash +conda install numba +``` +::: + +## Overview + +The numba integration provides: + +- **{class}`NumbaTreeSequence`**: A Numba-compatible representation of tree sequence data +- **{class}`NumbaTreePosition`**: An class for efficient tree traversal +- **{class}`NumbaEdgeRange`**: Container class for edge ranges during traversal + +These classes are designed to work within Numba's `@njit` decorated functions, +allowing you to write high-performance tree sequence analysis code. + +## Basic Usage + +The ``tskit.jit.numba`` module is not imported with normal `tskit` so must be imported explicitly: +```{code-cell} python +import tskit +import tskit.jit.numba as tskit_numba +``` + +Normal third-party classes such as {class}`tskit.TreeSequence` can't be used in `numba.njit` compiled +functions so the {class}`tskit.TreeSequence` must be wrapped in a {class}`NumbaTreeSequence` by +{meth}`numba_tree_sequence`. This must be done outside `njit` code: + +```{code-cell} python +import msprime + +ts = msprime.sim_ancestry( + samples=50000, + sequence_length=100000, + recombination_rate=0.1, + random_seed=42 +) +numba_ts = tskit_numba.numba_tree_sequence(ts) +print(type(numba_ts)) +``` + +## Tree Traversal + +Tree traversal can be performed using the {class}`NumbaTreePosition` class. +This class provides `next()` and `prev()` methods for forward and backward iteration through the trees in a tree sequence. It's `in_range` and `out_range` attributes provide the edges that must be added or removed to form the current +tree from the previous tree. + +A `NumbaTreePosition` instance can be obtained from a `NumbaTreeSequence` using the `tree_position()` method. The initial state of this is of a "null" tree outside the range of the tree sequence, the first call to `next()` or `prev()`will be to the first, or last tree sequence tree respectively. After that, the `in_range` and `out_range` attributes will provide the edges that must be added or removed to form the current tree from the previous tree. For example +`in_range.order[in_range.start:in_range.stop]` will give the edge ids that are new in the current tree, and `out_range.order[out_range.start:out_range.stop]` will give the edge ids that are no longer present in the current tree. + +As a simple example we can calulate the number of edges in each tree in a tree sequence: + +```{code-cell} python +import numba + +@numba.njit +def edges_per_tree(numba_ts): + tree_pos = numba_ts.tree_position() + current_num_edges = 0 + num_edges = [] + + # Traverse trees forward + while tree_pos.next(): + # Access current tree information + in_range = tree_pos.in_range + out_range = tree_pos.out_range + + current_num_edges -= (out_range.stop - out_range.start) + current_num_edges += (in_range.stop - in_range.start) + num_edges.append(current_num_edges) + return num_edges +``` + +```{code-cell} python +:tags: [hide-cell] +# Warm up the JIT compiler +edges = edges_per_tree(numba_ts) +``` + + +```{code-cell} python +import time + +t = time.time() +jit_num_edges = edges_per_tree(numba_ts) +print(f"JIT Time taken: {time.time() - t:.4f} seconds") +``` + +Doing the same thing with the normal `tskit` API would be much slower: + +```{code-cell} python +t = time.time() +python_num_edges = [] +for tree in ts.trees(): + python_num_edges.append(tree.num_edges) +print(f"Normal Time taken: {time.time() - t:.4f} seconds") + +assert jit_num_edges == python_num_edges, "JIT and normal results do not match!" +``` + +## API Reference + +```{eval-rst} +.. currentmodule:: tskit.jit.numba + +.. autofunction:: numba_tree_sequence + +.. autoclass:: NumbaTreeSequence + :members: + +.. autoclass:: NumbaTreePosition + :members: + +.. autoclass:: NumbaEdgeRange + :members: +``` \ No newline at end of file diff --git a/python/tskit/jit/numba.py b/python/tskit/jit/numba.py index bc6b238126..8ee0a7c91c 100644 --- a/python/tskit/jit/numba.py +++ b/python/tskit/jit/numba.py @@ -9,8 +9,8 @@ ) -FORWARD = 1 -REVERSE = -1 +FORWARD = 1 #: Direction constant for forward tree traversal +REVERSE = -1 #: Direction constant for reverse tree traversal tree_sequence_spec = [ @@ -39,6 +39,60 @@ @numba.experimental.jitclass(tree_sequence_spec) class NumbaTreeSequence: + """ + A Numba-compatible representation of a tree sequence. + + This class provides access a tree sequence class that can be used + from within Numba "njit" compiled functions, as it is a Numba + "jitclass". :meth:`numba_tree_sequence` should be used to + create this class from a :class:`tskit.TreeSequence` object, + before it is passed to a Numba function. + + Attributes + ---------- + num_trees : int32 + Number of trees in the tree sequence. + num_edges : int32 + Number of edges in the tree sequence. + sequence_length : float64 + Total sequence length of the tree sequence. + edges_left : float64[] + Left coordinates of edges. + edges_right : float64[] + Right coordinates of edges. + edges_parent : int32[] + Parent node IDs for each edge. + edges_child : int32[] + Child node IDs for each edge. + nodes_time : float64[] + Time values for each node. + nodes_flags : uint32[] + Flag values for each node. + nodes_population : int32[] + Population IDs for each node. + nodes_individual : int32[] + Individual IDs for each node. + individuals_flags : uint32[] + Flag values for each individual. + sites_position : float64[] + Positions of sites along the sequence. + mutations_site : int32[] + Site IDs for each mutation. + mutations_node : int32[] + Node IDs for each mutation. + mutations_parent : int32[] + Parent mutation IDs. + mutations_time : float64[] + Time values for each mutation. + breakpoints : float64[] + Genomic positions where trees change. + indexes_edge_insertion_order : int32[] + Order in which edges are inserted during tree building. + indexes_edge_removal_order : int32[] + Order in which edges are removed during tree building. + + """ + def __init__( self, num_trees, @@ -84,6 +138,22 @@ def __init__( self.breakpoints = breakpoints def tree_position(self): + """ + Create a :class:`NumbaTreePosition` for traversing this tree sequence. + + Returns + ------- + NumbaTreePosition + A new tree position initialized to the null tree. + Use next() or prev() to move to actual tree positions. + + Examples + -------- + >>> tree_pos = numba_ts.tree_position() + >>> while tree_pos.next(): + ... # Process current tree at tree_pos.index + ... print(f"Tree {tree_pos.index}: {tree_pos.interval}") + """ return NumbaTreePosition(self) @@ -96,6 +166,25 @@ def tree_position(self): @numba.experimental.jitclass(edge_range_spec) class NumbaEdgeRange: + """ + Represents a range of edges during tree traversal. + + This class encapsulates information about a contiguous range of edges + that are either being removed or added to step from one tree to another + The ``start`` and ``stop`` indices, when applied to the order array, + define the ids of edges to process. + + Attributes + ---------- + start : int32 + Starting index of the edge range (inclusive). + stop : int32 + Stopping index of the edge range (exclusive). + order : int32[] + Array containing edge IDs in the order they should be processed. + The edge ids in this range are order[start:stop]. + """ + def __init__(self, start, stop, order): self.start = start self.stop = stop @@ -114,6 +203,39 @@ def __init__(self, start, stop, order): @numba.experimental.jitclass(tree_position_spec) class NumbaTreePosition: + """ + Traverse trees in a numba compatible tree sequence. + + This class provides efficient forward and backward iteration through + the trees in a tree sequence. It tracks the current position and interval, + providing edge changes between trees. + + + Attributes + ---------- + ts : NumbaTreeSequence + Reference to the tree sequence being traversed. + index : int32 + Current tree index. -1 indicates no current tree (null state). + direction : int32 + Traversal direction: tskit.FORWARD or tskit.REVERSE. tskit.NULL if uninitialised. + interval : tuple of float64 + Genomic interval (left, right) covered by the current tree. + in_range : NumbaEdgeRange + Edges being added to form this current tree, relative to the last state + out_range : NumbaEdgeRange + Edges being removed to form this current tree, relative to the last state + + Example + -------- + >>> tree_pos = numba_ts.tree_position() + >>> num_edges + >>> while tree_pos.next(): + num_edges += (tree_pos.in_range.stop - tree_pos.in_range.start) + num_edges -= (tree_pos.out_range.stop - tree_pos.out_range.start) + print(f"Tree {tree_pos.index}: {num_edges} edges") + """ + def __init__(self, ts): self.ts = ts self.index = -1 @@ -123,10 +245,34 @@ def __init__(self, ts): self.out_range = NumbaEdgeRange(0, 0, np.zeros(0, dtype=numba.int32)) def set_null(self): + """ + Reset the tree position to null state. + """ self.index = -1 self.interval = (0, 0) def next(self): # noqa: A003 + """ + Move to the next tree in forward direction. + + Updates the tree position to the next tree in the sequence, + computing the edges that need to be added and removed to + transform from the previous tree to the current tree, storing + them in self.in_range and self.out_range. + + Returns + ------- + bool + True if successfully moved to next tree, False if the end + of the tree sequence is reached. + When False is returned, the iterator is in null state (index=-1). + + Notes + ----- + On the first call, this initializes the iterator and moves to tree 0. + The in_range and out_range attributes are updated to reflect the + edge changes needed for the current tree. + """ M = self.ts.num_edges breakpoints = self.ts.breakpoints left_coords = self.ts.edges_left @@ -172,6 +318,28 @@ def next(self): # noqa: A003 return self.index != -1 def prev(self): + """ + Move to the previous tree in reverse direction. + + Updates the tree position to the previous tree in the sequence, + computing the edges that need to be added and removed to + transform from the next tree to the current tree, storing them + in self.in_range and self.out_range + + Returns + ------- + bool + True if successfully moved to previous tree, False if the beginning + of the tree sequence is reached. + When False is returned, the iterator is in null state (index=-1). + + Notes + ----- + On the first call, this initializes the iterator and moves to the most + rightward tree. + The in_range and out_range attributes are updated to reflect the + edge changes needed for the current tree when traversing backward. + """ M = self.ts.num_edges breakpoints = self.ts.breakpoints right_coords = self.ts.edges_right @@ -219,6 +387,23 @@ def prev(self): def numba_tree_sequence(ts): + """ + Convert a TreeSequence to a Numba-compatible format. + + Creates a NumbaTreeSequence object that can be used within + Numba-compiled functions. + + Parameters + ---------- + ts : tskit.TreeSequence + The tree sequence to convert. + + Returns + ------- + NumbaTreeSequence + A Numba-compatible representation of the input tree sequence. + Contains all necessary data arrays and metadata for tree traversal. + """ return NumbaTreeSequence( num_trees=ts.num_trees, num_edges=ts.num_edges, From 3fbac2707b7b8d330b9d3cded29da8ac9a7d786d Mon Sep 17 00:00:00 2001 From: Ben Jeffery Date: Thu, 3 Jul 2025 10:45:14 +0100 Subject: [PATCH 9/9] Remove mentions of traversal --- docs/numba.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/numba.md b/docs/numba.md index 8bd7dde473..d3813f3cf3 100644 --- a/docs/numba.md +++ b/docs/numba.md @@ -42,8 +42,8 @@ conda install numba The numba integration provides: - **{class}`NumbaTreeSequence`**: A Numba-compatible representation of tree sequence data -- **{class}`NumbaTreePosition`**: An class for efficient tree traversal -- **{class}`NumbaEdgeRange`**: Container class for edge ranges during traversal +- **{class}`NumbaTreePosition`**: An class for efficient tree iteration +- **{class}`NumbaEdgeRange`**: Container class for edge ranges during iteration These classes are designed to work within Numba's `@njit` decorated functions, allowing you to write high-performance tree sequence analysis code. @@ -73,9 +73,9 @@ numba_ts = tskit_numba.numba_tree_sequence(ts) print(type(numba_ts)) ``` -## Tree Traversal +## Tree Iteration -Tree traversal can be performed using the {class}`NumbaTreePosition` class. +Tree iteration can be performed using the {class}`NumbaTreePosition` class. This class provides `next()` and `prev()` methods for forward and backward iteration through the trees in a tree sequence. It's `in_range` and `out_range` attributes provide the edges that must be added or removed to form the current tree from the previous tree. @@ -93,7 +93,7 @@ def edges_per_tree(numba_ts): current_num_edges = 0 num_edges = [] - # Traverse trees forward + # Move forward through the trees while tree_pos.next(): # Access current tree information in_range = tree_pos.in_range