diff --git a/python/requirements/CI-complete/requirements.txt b/python/requirements/CI-complete/requirements.txt index 8920fa33ca..99465a9c5c 100644 --- a/python/requirements/CI-complete/requirements.txt +++ b/python/requirements/CI-complete/requirements.txt @@ -6,6 +6,7 @@ lshmm==0.0.8 msgpack==1.1.0 msprime==1.3.3 networkx==3.2.1 +numba==0.61.2 portion==2.6.0 pytest==8.3.5 pytest-cov==6.0.0 diff --git a/python/requirements/CI-tests-pip/requirements.txt b/python/requirements/CI-tests-pip/requirements.txt index 2e162f68d7..a72ec4227c 100644 --- a/python/requirements/CI-tests-pip/requirements.txt +++ b/python/requirements/CI-tests-pip/requirements.txt @@ -11,4 +11,5 @@ networkx==3.2.1 msgpack==1.1.0 newick==1.10.0 kastore==0.3.3 -jsonschema==4.23.0 \ No newline at end of file +jsonschema==4.23.0 +numba>=0.60.0 \ No newline at end of file diff --git a/python/tests/test_balance_metrics.py b/python/tests/test_balance_metrics.py index eed20477b2..f554ec745f 100644 --- a/python/tests/test_balance_metrics.py +++ b/python/tests/test_balance_metrics.py @@ -29,7 +29,7 @@ import tests import tskit -from tests.test_highlevel import get_example_tree_sequences +from tests.tsutil import get_example_tree_sequences # ↑ See https://github.com/tskit-dev/tskit/issues/1804 for when # we can remove this. diff --git a/python/tests/test_divmat.py b/python/tests/test_divmat.py index ac66f43f1e..892c48c347 100644 --- a/python/tests/test_divmat.py +++ b/python/tests/test_divmat.py @@ -32,7 +32,7 @@ import tskit from tests import tsutil -from tests.test_highlevel import get_example_tree_sequences +from tests.tsutil import get_example_tree_sequences # ↑ See https://github.com/tskit-dev/tskit/issues/1804 for when # we can remove this. diff --git a/python/tests/test_extend_haplotypes.py b/python/tests/test_extend_haplotypes.py index 0d4e40a7f8..51ef7094e7 100644 --- a/python/tests/test_extend_haplotypes.py +++ b/python/tests/test_extend_haplotypes.py @@ -6,7 +6,7 @@ import tests.test_wright_fisher as wf import tskit from tests import tsutil -from tests.test_highlevel import get_example_tree_sequences +from tests.tsutil import get_example_tree_sequences # ↑ See https://github.com/tskit-dev/tskit/issues/1804 for when # we can remove this. diff --git a/python/tests/test_genotypes.py b/python/tests/test_genotypes.py index e2a0920376..036ae8db04 100644 --- a/python/tests/test_genotypes.py +++ b/python/tests/test_genotypes.py @@ -37,7 +37,7 @@ import tests.test_wright_fisher as wf import tests.tsutil as tsutil import tskit -from tests.test_highlevel import get_example_tree_sequences +from tests.tsutil import get_example_tree_sequences from tskit import exceptions from tskit.genotypes import allele_remap diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index 7d98d2184a..47477865bc 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -26,7 +26,6 @@ import collections import dataclasses import decimal -import functools import inspect import io import itertools @@ -142,270 +141,6 @@ def traversal_timedesc(tree, root=None): } -def insert_uniform_mutations(tables, num_mutations, nodes): - """ - Returns n evenly mutations over the specified list of nodes. - """ - for j in range(num_mutations): - tables.sites.add_row( - position=j * (tables.sequence_length / num_mutations), - ancestral_state="0", - metadata=json.dumps({"index": j}).encode(), - ) - tables.mutations.add_row( - site=j, - derived_state="1", - node=nodes[j % len(nodes)], - metadata=json.dumps({"index": j}).encode(), - ) - - -def get_table_collection_copy(tables, sequence_length): - """ - Returns a copy of the specified table collection with the specified - sequence length. - """ - table_dict = tables.asdict() - table_dict["sequence_length"] = sequence_length - return tskit.TableCollection.fromdict(table_dict) - - -def insert_gap(ts, position, length): - """ - Inserts a gap of the specified size into the specified tree sequence. - This involves: (1) breaking all edges that intersect with this point; - and (2) shifting all coordinates greater than this value up by the - gap length. - """ - new_edges = [] - for e in ts.edges(): - if e.left < position < e.right: - new_edges.append([e.left, position, e.parent, e.child]) - new_edges.append([position, e.right, e.parent, e.child]) - else: - new_edges.append([e.left, e.right, e.parent, e.child]) - - # Now shift up all coordinates. - for e in new_edges: - # Left coordinates == position get shifted - if e[0] >= position: - e[0] += length - # Right coordinates == position do not get shifted - if e[1] > position: - e[1] += length - tables = ts.dump_tables() - L = ts.sequence_length + length - tables = get_table_collection_copy(tables, L) - tables.edges.clear() - tables.sites.clear() - tables.mutations.clear() - for left, right, parent, child in new_edges: - tables.edges.add_row(left, right, parent, child) - tables.sort() - # Throw in a bunch of mutations over the whole sequence on the samples. - insert_uniform_mutations(tables, 100, list(ts.samples())) - return tables.tree_sequence() - - -@functools.lru_cache -def get_gap_examples(custom_max=None): - """ - Returns example tree sequences that contain gaps within the list of - edges. - """ - ret = [] - if custom_max is None: - n_list = [20, 10] - else: - n_list = [custom_max, custom_max // 2] - - ts = msprime.simulate(n_list[0], random_seed=56, recombination_rate=1) - - assert ts.num_trees > 1 - - gap = 0.0125 - for x in [0, 0.1, 0.5, 0.75]: - ts = insert_gap(ts, x, gap) - found = False - for t in ts.trees(): - if t.interval.left == x: - assert t.interval.right == x + gap - assert len(t.parent_dict) == 0 - found = True - assert found - ret.append((f"gap_{x}", ts)) - # Give an example with a gap at the end. - ts = msprime.simulate(n_list[1], random_seed=5, recombination_rate=1) - tables = get_table_collection_copy(ts.dump_tables(), 2) - tables.sites.clear() - tables.mutations.clear() - insert_uniform_mutations(tables, 100, list(ts.samples())) - ret.append(("gap_at_end", tables.tree_sequence())) - return ret - - -@functools.lru_cache -def get_internal_samples_examples(): - """ - Returns example tree sequences with internal samples. - """ - ret = [] - n = 5 - ts = msprime.simulate(n, random_seed=10, mutation_rate=5) - assert ts.num_mutations > 0 - tables = ts.dump_tables() - nodes = tables.nodes - flags = nodes.flags - # Set all nodes to be samples. - flags[:] = tskit.NODE_IS_SAMPLE - nodes.flags = flags - ret.append(("all_nodes_samples", tables.tree_sequence())) - - # Set just internal nodes to be samples. - flags[:] = 0 - flags[n:] = tskit.NODE_IS_SAMPLE - nodes.flags = flags - ret.append(("internal_nodes_samples", tables.tree_sequence())) - - # Set a mixture of internal and leaf samples. - flags[:] = 0 - flags[n // 2 : n + n // 2] = tskit.NODE_IS_SAMPLE - nodes.flags = flags - ret.append(("mixed_internal_leaf_samples", tables.tree_sequence())) - return ret - - -@functools.lru_cache -def get_decapitated_examples(custom_max=None): - """ - Returns example tree sequences in which the oldest edges have been removed. - """ - ret = [] - if custom_max is None: - n_list = [10, 20] - else: - n_list = [custom_max // 2, custom_max] - ts = msprime.simulate(n_list[0], random_seed=1234) - # yield ts.decapitate(ts.tables.nodes.time[-1] / 2) - ts = msprime.simulate(n_list[1], recombination_rate=1, random_seed=1234) - assert ts.num_trees > 2 - ret.append(("decapitate_recomb", ts.decapitate(ts.tables.nodes.time[-1] / 4))) - return ret - - -def get_bottleneck_examples(custom_max=None): - """ - Returns an iterator of example tree sequences with nonbinary trees. - """ - bottlenecks = [ - msprime.SimpleBottleneck(0.01, 0, proportion=0.05), - msprime.SimpleBottleneck(0.02, 0, proportion=0.25), - msprime.SimpleBottleneck(0.03, 0, proportion=1), - ] - if custom_max is None: - n_list = [3, 10, 100] - else: - n_list = [i * custom_max // 3 for i in range(1, 4)] - for n in n_list: - ts = msprime.simulate( - n, - length=100, - recombination_rate=1, - demographic_events=bottlenecks, - random_seed=n, - ) - yield (f"bottleneck_n={n}", ts) - - -def get_back_mutation_examples(): - """ - Returns an iterator of example tree sequences with nonbinary trees. - """ - ts = msprime.simulate(10, random_seed=1) - for j in [1, 2, 3]: - yield tsutil.insert_branch_mutations(ts, mutations_per_branch=j) - for ts in get_bottleneck_examples(): - yield tsutil.insert_branch_mutations(ts) - - -def make_example_tree_sequences(custom_max=None): - yield from get_decapitated_examples(custom_max=custom_max) - yield from get_gap_examples(custom_max=custom_max) - yield from get_internal_samples_examples() - seed = 1 - if custom_max is None: - n_list = [2, 3, 10, 100] - else: - n_list = [i * custom_max // 4 for i in range(1, 5)] - for n in n_list: - for m in [1, 2, 32]: - for rho in [0, 0.1, 0.5]: - recomb_map = msprime.RecombinationMap.uniform_map(m, rho, num_loci=m) - ts = msprime.simulate( - recombination_map=recomb_map, - mutation_rate=0.1, - random_seed=seed, - population_configurations=[ - msprime.PopulationConfiguration(n), - msprime.PopulationConfiguration(0), - ], - migration_matrix=[[0, 1], [1, 0]], - ) - ts = tsutil.insert_random_ploidy_individuals(ts, 4, seed=seed) - yield ( - f"n={n}_m={m}_rho={rho}", - tsutil.add_random_metadata(ts, seed=seed), - ) - seed += 1 - for name, ts in get_bottleneck_examples(custom_max=custom_max): - yield ( - f"{name}_mutated", - msprime.mutate( - ts, - rate=0.1, - random_seed=seed, - model=msprime.InfiniteSites(msprime.NUCLEOTIDES), - ), - ) - ts = tskit.Tree.generate_balanced(8).tree_sequence - yield ("rev_node_order", ts.subset(np.arange(ts.num_nodes - 1, -1, -1))) - ts = msprime.sim_ancestry( - 8, sequence_length=40, recombination_rate=0.1, random_seed=seed - ) - tables = ts.dump_tables() - tables.populations.metadata_schema = tskit.MetadataSchema(None) - ts = tables.tree_sequence() - assert ts.num_trees > 1 - yield ( - "back_mutations", - tsutil.insert_branch_mutations(ts, mutations_per_branch=2), - ) - ts = tsutil.insert_multichar_mutations(ts) - yield ("multichar", ts) - yield ("multichar_no_metadata", tsutil.add_random_metadata(ts)) - tables = ts.dump_tables() - tables.nodes.flags = np.zeros_like(tables.nodes.flags) - yield ("no_samples", tables.tree_sequence()) # no samples - tables = ts.dump_tables() - tables.edges.clear() - yield ("empty_tree", tables.tree_sequence()) # empty tree - yield ( - "empty_ts", - tskit.TableCollection(sequence_length=1).tree_sequence(), - ) # empty tree seq - yield ("all_fields", tsutil.all_fields_ts()) - - -_examples = tuple(make_example_tree_sequences(custom_max=None)) - - -def get_example_tree_sequences(pytest_params=True, custom_max=None): - if pytest_params: - return [pytest.param(ts, id=name) for name, ts in _examples] - else: - return [ts for _, ts in _examples] - - def simple_get_pairwise_diversity(haplotypes): """ Returns the value of pi for the specified haplotypes. @@ -517,7 +252,7 @@ def test_returned_types(self, order): for u in lst: assert isinstance(u, int) - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) @pytest.mark.parametrize("order", list(traversal_map.keys())) def test_traversals_virtual_root(self, ts, order): tree = ts.first() @@ -526,7 +261,7 @@ def test_traversals_virtual_root(self, ts, order): assert tree.virtual_root in node_list1 assert node_list1 == node_list2 - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) @pytest.mark.parametrize("order", list(traversal_map.keys())) def test_traversals(self, ts, order): tree = next(ts.trees()) @@ -1419,7 +1154,7 @@ class TestTreeSequence(HighLevelTestCase): Tests for the tree sequence object. """ - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_row_getter(self, ts): for table_name, table in ts.tables_dict.items(): sequence = getattr(ts, table_name)() @@ -1446,7 +1181,7 @@ def test_bad_row_getter(self, index, simple_degree2_ts_fixture): with pytest.raises(TypeError, match=match): element_accessor(index) - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_discrete_genome(self, ts): def is_discrete(a): return np.all(np.floor(a) == a) @@ -1462,7 +1197,7 @@ def is_discrete(a): ) assert ts.discrete_genome == discrete_genome - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_discrete_time(self, ts): def is_discrete(a): return np.all(np.logical_or(np.floor(a) == a, tskit.is_unknown_time(a))) @@ -1475,11 +1210,11 @@ def is_discrete(a): ) assert ts.discrete_time == discrete_time - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_trees(self, ts): self.verify_trees(ts) - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_mutations(self, ts): self.verify_mutations(ts) @@ -1504,7 +1239,7 @@ def verify_pairwise_diversity(self, ts): assert not math.isnan(pi1) @pytest.mark.slow - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_pairwise_diversity(self, ts): self.verify_pairwise_diversity(ts) @@ -1514,7 +1249,7 @@ def test_bad_node_iteration_order(self, order): with pytest.raises(ValueError, match="order"): ts.nodes(order=order) - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_node_iteration_order(self, ts): order = [n.id for n in ts.nodes()] assert order == list(range(ts.num_nodes)) @@ -1566,17 +1301,17 @@ def verify_edgesets(self, ts): assert len(squashed) == len(edges) assert edges == squashed - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_edge_ids(self, ts): for index, edge in enumerate(ts.edges()): assert edge.id == index - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_edge_span_property(self, ts): for edge in ts.edges(): assert edge.span == edge.right - edge.left - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_edge_interval_property(self, ts): for edge in ts.edges(): assert edge.interval == (edge.left, edge.right) @@ -1587,14 +1322,14 @@ def test_edge_interval_property(self, ts): def test_edgesets(self): tested = False # We manual loop in this test to test the example tree sequences are working - for ts in get_example_tree_sequences(pytest_params=False): + for ts in tsutil.get_example_tree_sequences(pytest_params=False): # Can't get edgesets with metadata if ts.tables.edges.metadata_schema == tskit.MetadataSchema(None): self.verify_edgesets(ts) tested = True assert tested - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_breakpoints(self, ts): breakpoints = ts.breakpoints(as_array=True) assert breakpoints.shape == (ts.num_trees + 1,) @@ -1622,11 +1357,11 @@ def verify_coalescence_records(self, ts): assert parent.time == record.time assert parent.population == record.population - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_coalescence_records(self, ts): self.verify_coalescence_records(ts) - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_compute_mutation_parent(self, ts): tables = ts.dump_tables() before = tables.mutations.parent[:] @@ -1634,7 +1369,7 @@ def test_compute_mutation_parent(self, ts): parent = ts.tables.mutations.parent assert np.array_equal(parent, before) - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_compute_mutation_time(self, ts): tables = ts.dump_tables() python_time = tsutil.compute_mutation_times(ts) @@ -1643,7 +1378,7 @@ def test_compute_mutation_time(self, ts): # Check we have valid times tables.tree_sequence() - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_tracked_samples(self, ts): # Should be empty list by default. for tree in ts.trees(): @@ -1670,7 +1405,7 @@ def test_tracked_samples_is_first_arg(self): tree = next(ts.trees(samples)) assert tree.num_tracked_samples() == 3 - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_deprecated_sample_aliases(self, ts): # Ensure that we get the same results from the various combinations # of leaf_lists, sample_lists etc. @@ -1706,7 +1441,7 @@ def verify_samples(self, ts): samples2.append(list(t.samples())) assert samples1 == samples2 - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_samples(self, ts): self.verify_samples(ts) pops = {node.population for node in ts.nodes()} @@ -1722,7 +1457,7 @@ def test_samples(self, ts): with pytest.raises(ValueError): ts.samples(population=0, population_id=0) - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_first_last(self, ts): for kwargs in [{}, {"tracked_samples": ts.samples()}]: t1 = ts.first(**kwargs) @@ -1764,7 +1499,7 @@ def test_trees_interface(self): assert t.get_num_tracked_samples(0) == 0 assert list(t.samples(0)) == [0] - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_get_pairwise_diversity(self, ts): with pytest.raises(ValueError, match="at least one element"): ts.get_pairwise_diversity([]) @@ -1782,7 +1517,7 @@ def test_get_pairwise_diversity(self, ts): def test_populations(self): more_than_zero = False - for ts in get_example_tree_sequences(pytest_params=False): + for ts in tsutil.get_example_tree_sequences(pytest_params=False): N = ts.num_populations if N > 0: more_than_zero = True @@ -1796,7 +1531,7 @@ def test_populations(self): def test_individuals(self): more_than_zero = False mapped_to_nodes = False - for ts in get_example_tree_sequences(pytest_params=False): + for ts in tsutil.get_example_tree_sequences(pytest_params=False): ind_node_map = collections.defaultdict(list) for node in ts.nodes(): if node.individual != tskit.NULL: @@ -1819,7 +1554,7 @@ def test_individuals(self): assert more_than_zero assert mapped_to_nodes - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_get_population(self, ts): # Deprecated interface for ts.node(id).population N = ts.get_num_nodes() @@ -1832,7 +1567,7 @@ def test_get_population(self, ts): for node in range(N): assert ts.get_population(node) == ts.node(node).population - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_get_time(self, ts): # Deprecated interface for ts.node(id).time N = ts.get_num_nodes() @@ -1845,7 +1580,7 @@ def test_get_time(self, ts): for u in range(N): assert ts.get_time(u) == ts.node(u).time - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_max_root_time(self, ts): oldest = None for tree in ts.trees(): @@ -1907,7 +1642,7 @@ def test_deprecated_apis(self): def test_sites(self): some_sites = False - for ts in get_example_tree_sequences(pytest_params=False): + for ts in tsutil.get_example_tree_sequences(pytest_params=False): tables = ts.dump_tables() sites = tables.sites mutations = tables.mutations @@ -1962,7 +1697,7 @@ def verify_mutations(self, ts): for mut, other_mut in zip(mutations, other_mutations): assert mut == other_mut - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_sites_mutations(self, ts): # Check that the mutations iterator returns the correct values. self.verify_mutations(ts) @@ -2051,12 +1786,12 @@ def test_migrations(self): assert migration.right == 1 assert 0 <= migration.node < ts.num_nodes - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_len_trees(self, ts): tree_iter = ts.trees() assert len(tree_iter) == ts.num_trees - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_list(self, ts): for kwargs in [{}, {"tracked_samples": ts.samples()}]: tree_list = ts.aslist(**kwargs) @@ -2074,7 +1809,7 @@ def test_list(self, ts): assert t1.num_tracked_samples() == 0 assert t2.num_tracked_samples() == 0 - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_reversed_trees(self, ts): index = ts.num_trees - 1 tree_list = ts.aslist() @@ -2085,7 +1820,7 @@ def test_reversed_trees(self, ts): assert tree.parent_dict == t2.parent_dict index -= 1 - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_at_index(self, ts): for kwargs in [{}, {"tracked_samples": ts.samples()}]: tree_list = ts.aslist(**kwargs) @@ -2100,7 +1835,7 @@ def test_at_index(self, ts): else: assert t2.num_tracked_samples() == 0 - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_at(self, ts): for kwargs in [{}, {"tracked_samples": ts.samples()}]: tree_list = ts.aslist(**kwargs) @@ -2123,7 +1858,7 @@ def test_at(self, ts): else: assert t2.num_tracked_samples() == 0 - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_sequence_iteration(self, ts): for table_name in ts.tables_dict.keys(): sequence = getattr(ts, table_name)() @@ -2157,7 +1892,7 @@ def test_sequence_iteration(self, ts): if i is not None: assert n.id == 0 - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_load_tables(self, ts): tables = ts.dump_tables() tables.drop_index() @@ -2185,7 +1920,7 @@ def test_load_tables(self, ts): # Tables in tc, and rebuilt assert tskit.TreeSequence.load_tables(tables).dump_tables().has_index() - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_html_repr(self, ts): html = ts._repr_html_() # Parse to check valid @@ -2223,7 +1958,7 @@ def test_html_repr_limit(self, ts_fixture): assert "... and 20 more" in ts._repr_html_() assert "NN..." in ts._repr_html_() - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_str(self, ts): s = str(ts) assert len(s) > 999 @@ -2358,7 +2093,7 @@ def modify(ts, func): assert t1.equals(t2) assert t2.equals(t1) - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_tree_node_edges(self, ts): edge_visited = np.zeros(ts.num_edges, dtype=bool) for tree in ts.trees(): @@ -2663,7 +2398,7 @@ def test_arrays_equal_to_tables(self, ts_fixture): ts.indexes_edge_removal_order, tables.indexes.edge_removal_order ) - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_impute_unknown_mutations_time(self, ts): # Tests for method='min' imputed_time = ts.impute_unknown_mutations_time(method="min") @@ -2811,13 +2546,13 @@ def verify_tables_api_equality(self, ts): ts.simplify(samples=samples).tables, ignore_timestamps=True ) - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_simplify_tables_equality(self, ts): # Can't simplify edges with metadata if ts.tables.edges.metadata_schema == tskit.MetadataSchema(schema=None): self.verify_tables_api_equality(ts) - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_simplify_provenance(self, ts): # Can't simplify edges with metadata if ts.tables.edges.metadata_schema == tskit.MetadataSchema(schema=None): @@ -2827,7 +2562,7 @@ def test_simplify_provenance(self, ts): # test them independently. A way of getting a random-ish subset of samples # from the pytest param would be useful. @pytest.mark.slow - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_simplify(self, ts): # Can't simplify edges with metadata if ts.tables.edges.metadata_schema == tskit.MetadataSchema(schema=None): @@ -2883,7 +2618,7 @@ def test_simplify_migrations_fails(self): with pytest.raises(_tskit.LibraryError): ts.simplify() - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_no_update_sample_flags_no_filter_nodes(self, ts): # Can't simplify edges with metadata if ts.tables.edges.metadata_schema == tskit.MetadataSchema(schema=None): @@ -2981,7 +2716,7 @@ def test_k_mutations(self, k): class TestEdgeDiffs: - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_correct_trees_forward(self, ts): parent = np.full(ts.num_nodes + 1, tskit.NULL, dtype=np.int32) for edge_diff, tree in itertools.zip_longest(ts.edge_diffs(), ts.trees()): @@ -2992,7 +2727,7 @@ def test_correct_trees_forward(self, ts): parent[edge.child] = edge.parent assert_array_equal(parent, tree.parent_array) - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_correct_trees_reverse(self, ts): parent = np.full(ts.num_nodes + 1, tskit.NULL, dtype=np.int32) iterator = itertools.zip_longest( @@ -3032,7 +2767,7 @@ def test_edge_properties(self, direction, simple_degree2_ts_fixture): assert ts.edge(edge.id) == edge assert edge_ids == set(range(ts.num_edges)) - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) @pytest.mark.parametrize("direction", [tskit.FORWARD, tskit.REVERSE]) def test_include_terminal(self, ts, direction): edges = set() @@ -3514,7 +3249,7 @@ def convert(v): assert repr(migration.metadata) == splits[6] @pytest.mark.parametrize(("precision", "base64_metadata"), [(2, True), (7, False)]) - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_output_format(self, precision, base64_metadata, ts): nodes_file = io.StringIO() edges_file = io.StringIO() @@ -3622,7 +3357,7 @@ def verify_approximate_equality(self, ts1, ts2): check += 1 assert check == ts1.get_num_trees() - @pytest.mark.parametrize("ts1", get_example_tree_sequences()) + @pytest.mark.parametrize("ts1", tsutil.get_example_tree_sequences()) def test_text_record_round_trip(self, ts1): # Can't round trip without the schema if ts1.tables.nodes.metadata_schema == tskit.MetadataSchema(None): @@ -3840,7 +3575,7 @@ def test_ancestors_empty(self): for u in ts.samples(): assert len(list(tree.ancestors(u))) == 0 - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_virtual_root_semantics(self, ts): for tree in ts.trees(): assert math.isinf(tree.time(tree.virtual_root)) @@ -3852,7 +3587,7 @@ def test_virtual_root_semantics(self, ts): def test_root_properties(self): tested = set() - for ts in get_example_tree_sequences(pytest_params=False): + for ts in tsutil.get_example_tree_sequences(pytest_params=False): for tree in ts.trees(): if tree.has_single_root: tested.add("single") @@ -3871,7 +3606,7 @@ def test_root_properties(self): assert len(tested) == 3 def test_as_dict_of_dicts(self): - for ts in get_example_tree_sequences(pytest_params=False): + for ts in tsutil.get_example_tree_sequences(pytest_params=False): tree = next(ts.trees()) adj_dod = tree.as_dict_of_dicts() g = nx.DiGraph(adj_dod) @@ -4794,7 +4529,7 @@ def test_seek_0_from_3(self): t2.seek(0) assert_trees_identical(t1, t2) - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_seek_mid_null_and_middle(self, ts): breakpoints = ts.breakpoints(as_array=True) mid = breakpoints[:-1] + np.diff(breakpoints) / 2 @@ -4813,7 +4548,7 @@ def test_seek_mid_null_and_middle(self, ts): assert t1.index == t2.index assert np.all(t1.parent_array == t2.parent_array) - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_seek_last_then_prev(self, ts): t1 = tskit.Tree(ts) t1.seek(ts.sequence_length - 0.00001) @@ -4827,7 +4562,7 @@ def test_seek_last_then_prev(self, ts): class TestSeek: - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_new_seek_breakpoints(self, ts): breakpoints = ts.breakpoints(as_array=True) for index, left in enumerate(breakpoints[:-1]): @@ -4835,7 +4570,7 @@ def test_new_seek_breakpoints(self, ts): tree.seek(left) assert tree.index == index - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_new_seek_mid(self, ts): breakpoints = ts.breakpoints(as_array=True) mid = breakpoints[:-1] + np.diff(breakpoints) / 2 @@ -4844,7 +4579,7 @@ def test_new_seek_mid(self, ts): tree.seek(left) assert tree.index == index - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_same_seek_breakpoints(self, ts): breakpoints = ts.breakpoints(as_array=True) tree = tskit.Tree(ts) @@ -4852,7 +4587,7 @@ def test_same_seek_breakpoints(self, ts): tree.seek(left) assert tree.index == index - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_new_seek_breakpoints_reversed(self, ts): breakpoints = ts.breakpoints(as_array=True) for index, left in reversed(list(enumerate(breakpoints[:-1]))): @@ -4860,7 +4595,7 @@ def test_new_seek_breakpoints_reversed(self, ts): tree.seek(left) assert tree.index == index - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_same_seek_breakpoints_reversed(self, ts): breakpoints = ts.breakpoints(as_array=True) tree = tskit.Tree(ts) @@ -5288,7 +5023,7 @@ def num_lineages_definition(tree, t): class TestNumLineages: - @pytest.mark.parametrize("ts", get_example_tree_sequences()) + @pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) def test_tree_midpoint_definition(self, ts): t = 0 if ts.num_nodes > 0: diff --git a/python/tests/test_ibd.py b/python/tests/test_ibd.py index 39210051d2..4b13ad6db9 100644 --- a/python/tests/test_ibd.py +++ b/python/tests/test_ibd.py @@ -10,7 +10,7 @@ import tests.ibd as ibd import tests.test_wright_fisher as wf import tskit -from tests.test_highlevel import get_example_tree_sequences +from tests.tsutil import get_example_tree_sequences """ Tests of IBD finding algorithms. diff --git a/python/tests/test_jit.py b/python/tests/test_jit.py new file mode 100644 index 0000000000..7759c8949c --- /dev/null +++ b/python/tests/test_jit.py @@ -0,0 +1,86 @@ +import itertools +import sys +from unittest.mock import patch + +import msprime +import numba +import numpy as np +import pytest + +import tests.tsutil as tsutil + + +def test_numba_import_error(): + # Mock numba as not available + with patch.dict(sys.modules, {"numba": None}): + with pytest.raises(ImportError, match="pip install numba"): + import tskit.jit.numba # noqa: F401 + + +@pytest.mark.parametrize("ts", tsutil.get_example_tree_sequences()) +def test_correct_trees_forward(ts): + import tskit.jit.numba as jit_numba + + numba_ts = jit_numba.numba_tree_sequence(ts) + in_index = ts.indexes_edge_insertion_order + out_index = ts.indexes_edge_removal_order + tree_pos = numba_ts.tree_position() + ts_edge_diffs = ts.edge_diffs() + while tree_pos.next(): + edge_diff = next(ts_edge_diffs) + assert edge_diff.interval == tree_pos.interval + for edge_in_index, edge in itertools.zip_longest( + range(*tree_pos.edges_in_index_range), edge_diff.edges_in + ): + assert edge.id == in_index[edge_in_index] + for edge_out_index, edge in itertools.zip_longest( + range(*tree_pos.edges_out_index_range), edge_diff.edges_out + ): + assert edge.id == out_index[edge_out_index] + + +def test_using_from_jit_function(): + """ + Test that we can use the numba jit function from the tskit.jit module. + """ + import tskit.jit.numba as jit_numba + + ts = msprime.sim_ancestry( + samples=10, sequence_length=100, recombination_rate=1, random_seed=42 + ) + + @numba.njit + def _coalescent_nodes_numba(numba_ts, num_nodes, edges_parent): + is_coalescent = np.zeros(num_nodes, dtype=np.int8) + num_children = np.zeros(num_nodes, dtype=np.int64) + tree_pos = numba_ts.tree_position() + while tree_pos.next(): + for j in range(*tree_pos.edges_out_index_range): + e = numba_ts.indexes_edge_removal_order[j] + num_children[edges_parent[e]] -= 1 + for j in range(*tree_pos.edges_in_index_range): + e = numba_ts.indexes_edge_insertion_order[j] + p = edges_parent[e] + num_children[p] += 1 + if num_children[p] == 2: + is_coalescent[p] = True + return is_coalescent + + def coalescent_nodes_python(ts): + is_coalescent = np.zeros(ts.num_nodes, dtype=bool) + num_children = np.zeros(ts.num_nodes, dtype=int) + for _, edges_out, edges_in in ts.edge_diffs(): + for e in edges_out: + num_children[e.parent] -= 1 + for e in edges_in: + num_children[e.parent] += 1 + if num_children[e.parent] == 2: + # Num_children will always be exactly two once, even arity is greater + is_coalescent[e.parent] = True + return is_coalescent + + numba_ts = jit_numba.numba_tree_sequence(ts) + C1 = coalescent_nodes_python(ts) + C2 = _coalescent_nodes_numba(numba_ts, ts.num_nodes, ts.edges_parent) + + np.testing.assert_array_equal(C1, C2) diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index 8b474192dd..137afbd7fb 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -41,7 +41,7 @@ import tskit from tests import tsutil -from tests.test_highlevel import get_example_tree_sequences +from tests.tsutil import get_example_tree_sequences @contextlib.contextmanager diff --git a/python/tests/test_phylo_formats.py b/python/tests/test_phylo_formats.py index 2125f506e9..5a24dce0b7 100644 --- a/python/tests/test_phylo_formats.py +++ b/python/tests/test_phylo_formats.py @@ -36,7 +36,7 @@ import tests import tskit -from tests.test_highlevel import get_example_tree_sequences +from tests.tsutil import get_example_tree_sequences # ↑ See https://github.com/tskit-dev/tskit/issues/1804 for when # we can remove this. The example_ts here is intended to be the diff --git a/python/tests/test_relatedness_vector.py b/python/tests/test_relatedness_vector.py index 03e88a643d..6982ce8290 100644 --- a/python/tests/test_relatedness_vector.py +++ b/python/tests/test_relatedness_vector.py @@ -28,7 +28,7 @@ import tskit from tests import tsutil -from tests.test_highlevel import get_example_tree_sequences +from tests.tsutil import get_example_tree_sequences # ↑ See https://github.com/tskit-dev/tskit/issues/1804 for when # we can remove this. diff --git a/python/tests/test_table_transforms.py b/python/tests/test_table_transforms.py index 35285fc4dc..672de341cb 100644 --- a/python/tests/test_table_transforms.py +++ b/python/tests/test_table_transforms.py @@ -33,7 +33,7 @@ import tests import tskit import tskit.util as util -from tests.test_highlevel import get_example_tree_sequences +from tests.tsutil import get_example_tree_sequences # ↑ See https://github.com/tskit-dev/tskit/issues/1804 for when # we can remove this. diff --git a/python/tests/test_tree_positioning.py b/python/tests/test_tree_positioning.py index 39f0b5ccab..a6095b573e 100644 --- a/python/tests/test_tree_positioning.py +++ b/python/tests/test_tree_positioning.py @@ -30,7 +30,7 @@ import tests import tskit from tests import tsutil -from tests.test_highlevel import get_example_tree_sequences +from tests.tsutil import get_example_tree_sequences # ↑ See https://github.com/tskit-dev/tskit/issues/1804 for when # we can remove this. diff --git a/python/tests/tsutil.py b/python/tests/tsutil.py index 091e1d0df5..fe60510002 100644 --- a/python/tests/tsutil.py +++ b/python/tests/tsutil.py @@ -34,6 +34,7 @@ import msprime import numpy as np +import pytest import tskit import tskit.provenance as provenance @@ -2288,3 +2289,269 @@ def all_fields_ts(edge_metadata=True, migrations=True): tables.provenances.add_row(record="A", timestamp=str(i)) return tables.tree_sequence() + + +def insert_uniform_mutations(tables, num_mutations, nodes): + """ + Returns n evenly mutations over the specified list of nodes. + """ + for j in range(num_mutations): + tables.sites.add_row( + position=j * (tables.sequence_length / num_mutations), + ancestral_state="0", + metadata=json.dumps({"index": j}).encode(), + ) + tables.mutations.add_row( + site=j, + derived_state="1", + node=nodes[j % len(nodes)], + metadata=json.dumps({"index": j}).encode(), + ) + + +def get_table_collection_copy(tables, sequence_length): + """ + Returns a copy of the specified table collection with the specified + sequence length. + """ + table_dict = tables.asdict() + table_dict["sequence_length"] = sequence_length + return tskit.TableCollection.fromdict(table_dict) + + +def insert_gap(ts, position, length): + """ + Inserts a gap of the specified size into the specified tree sequence. + This involves: (1) breaking all edges that intersect with this point; + and (2) shifting all coordinates greater than this value up by the + gap length. + """ + new_edges = [] + for e in ts.edges(): + if e.left < position < e.right: + new_edges.append([e.left, position, e.parent, e.child]) + new_edges.append([position, e.right, e.parent, e.child]) + else: + new_edges.append([e.left, e.right, e.parent, e.child]) + + # Now shift up all coordinates. + for e in new_edges: + # Left coordinates == position get shifted + if e[0] >= position: + e[0] += length + # Right coordinates == position do not get shifted + if e[1] > position: + e[1] += length + tables = ts.dump_tables() + L = ts.sequence_length + length + tables = get_table_collection_copy(tables, L) + tables.edges.clear() + tables.sites.clear() + tables.mutations.clear() + for left, right, parent, child in new_edges: + tables.edges.add_row(left, right, parent, child) + tables.sort() + # Throw in a bunch of mutations over the whole sequence on the samples. + insert_uniform_mutations(tables, 100, list(ts.samples())) + return tables.tree_sequence() + + +@functools.lru_cache +def get_decapitated_examples(custom_max=None): + """ + Returns example tree sequences in which the oldest edges have been removed. + """ + ret = [] + if custom_max is None: + n_list = [10, 20] + else: + n_list = [custom_max // 2, custom_max] + ts = msprime.simulate(n_list[0], random_seed=1234) + # yield ts.decapitate(ts.tables.nodes.time[-1] / 2) + ts = msprime.simulate(n_list[1], recombination_rate=1, random_seed=1234) + assert ts.num_trees > 2 + ret.append(("decapitate_recomb", ts.decapitate(ts.tables.nodes.time[-1] / 4))) + return ret + + +@functools.lru_cache +def get_gap_examples(custom_max=None): + """ + Returns example tree sequences that contain gaps within the list of + edges. + """ + ret = [] + if custom_max is None: + n_list = [20, 10] + else: + n_list = [custom_max, custom_max // 2] + + ts = msprime.simulate(n_list[0], random_seed=56, recombination_rate=1) + + assert ts.num_trees > 1 + + gap = 0.0125 + for x in [0, 0.1, 0.5, 0.75]: + ts = insert_gap(ts, x, gap) + found = False + for t in ts.trees(): + if t.interval.left == x: + assert t.interval.right == x + gap + assert len(t.parent_dict) == 0 + found = True + assert found + ret.append((f"gap_{x}", ts)) + # Give an example with a gap at the end. + ts = msprime.simulate(n_list[1], random_seed=5, recombination_rate=1) + tables = get_table_collection_copy(ts.dump_tables(), 2) + tables.sites.clear() + tables.mutations.clear() + insert_uniform_mutations(tables, 100, list(ts.samples())) + ret.append(("gap_at_end", tables.tree_sequence())) + return ret + + +@functools.lru_cache +def get_internal_samples_examples(): + """ + Returns example tree sequences with internal samples. + """ + ret = [] + n = 5 + ts = msprime.simulate(n, random_seed=10, mutation_rate=5) + assert ts.num_mutations > 0 + tables = ts.dump_tables() + nodes = tables.nodes + flags = nodes.flags + # Set all nodes to be samples. + flags[:] = tskit.NODE_IS_SAMPLE + nodes.flags = flags + ret.append(("all_nodes_samples", tables.tree_sequence())) + + # Set just internal nodes to be samples. + flags[:] = 0 + flags[n:] = tskit.NODE_IS_SAMPLE + nodes.flags = flags + ret.append(("internal_nodes_samples", tables.tree_sequence())) + + # Set a mixture of internal and leaf samples. + flags[:] = 0 + flags[n // 2 : n + n // 2] = tskit.NODE_IS_SAMPLE + nodes.flags = flags + ret.append(("mixed_internal_leaf_samples", tables.tree_sequence())) + return ret + + +@functools.lru_cache +def get_bottleneck_examples(custom_max=None): + """ + Returns an iterator of example tree sequences with nonbinary trees. + """ + bottlenecks = [ + msprime.SimpleBottleneck(0.01, 0, proportion=0.05), + msprime.SimpleBottleneck(0.02, 0, proportion=0.25), + msprime.SimpleBottleneck(0.03, 0, proportion=1), + ] + if custom_max is None: + n_list = [3, 10, 100] + else: + n_list = [i * custom_max // 3 for i in range(1, 4)] + for n in n_list: + ts = msprime.simulate( + n, + length=100, + recombination_rate=1, + demographic_events=bottlenecks, + random_seed=n, + ) + yield (f"bottleneck_n={n}", ts) + + +@functools.lru_cache +def get_back_mutation_examples(): + """ + Returns an iterator of example tree sequences with nonbinary trees. + """ + ts = msprime.simulate(10, random_seed=1) + for j in [1, 2, 3]: + yield insert_branch_mutations(ts, mutations_per_branch=j) + for ts in get_bottleneck_examples(): + yield insert_branch_mutations(ts) + + +def make_example_tree_sequences(custom_max=None): + yield from get_decapitated_examples(custom_max=custom_max) + yield from get_gap_examples(custom_max=custom_max) + yield from get_internal_samples_examples() + seed = 1 + if custom_max is None: + n_list = [2, 3, 10, 100] + else: + n_list = [i * custom_max // 4 for i in range(1, 5)] + for n in n_list: + for m in [1, 2, 32]: + for rho in [0, 0.1, 0.5]: + recomb_map = msprime.RecombinationMap.uniform_map(m, rho, num_loci=m) + ts = msprime.simulate( + recombination_map=recomb_map, + mutation_rate=0.1, + random_seed=seed, + population_configurations=[ + msprime.PopulationConfiguration(n), + msprime.PopulationConfiguration(0), + ], + migration_matrix=[[0, 1], [1, 0]], + ) + ts = insert_random_ploidy_individuals(ts, 4, seed=seed) + yield ( + f"n={n}_m={m}_rho={rho}", + add_random_metadata(ts, seed=seed), + ) + seed += 1 + for name, ts in get_bottleneck_examples(custom_max=custom_max): + yield ( + f"{name}_mutated", + msprime.mutate( + ts, + rate=0.1, + random_seed=seed, + model=msprime.InfiniteSites(msprime.NUCLEOTIDES), + ), + ) + ts = tskit.Tree.generate_balanced(8).tree_sequence + yield ("rev_node_order", ts.subset(np.arange(ts.num_nodes - 1, -1, -1))) + ts = msprime.sim_ancestry( + 8, sequence_length=40, recombination_rate=0.1, random_seed=seed + ) + tables = ts.dump_tables() + tables.populations.metadata_schema = tskit.MetadataSchema(None) + ts = tables.tree_sequence() + assert ts.num_trees > 1 + yield ( + "back_mutations", + insert_branch_mutations(ts, mutations_per_branch=2), + ) + ts = insert_multichar_mutations(ts) + yield ("multichar", ts) + yield ("multichar_no_metadata", add_random_metadata(ts)) + tables = ts.dump_tables() + tables.nodes.flags = np.zeros_like(tables.nodes.flags) + yield ("no_samples", tables.tree_sequence()) # no samples + tables = ts.dump_tables() + tables.edges.clear() + yield ("empty_tree", tables.tree_sequence()) # empty tree + yield ( + "empty_ts", + tskit.TableCollection(sequence_length=1).tree_sequence(), + ) # empty tree seq + yield ("all_fields", all_fields_ts()) + + +_examples = tuple(make_example_tree_sequences(custom_max=None)) + + +def get_example_tree_sequences(pytest_params=True, custom_max=None): + if pytest_params: + return [pytest.param(ts, id=name) for name, ts in _examples] + else: + return [ts for _, ts in _examples] diff --git a/python/tskit/jit/__init__.py b/python/tskit/jit/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/tskit/jit/numba.py b/python/tskit/jit/numba.py new file mode 100644 index 0000000000..5ffd23f790 --- /dev/null +++ b/python/tskit/jit/numba.py @@ -0,0 +1,90 @@ +from dataclasses import dataclass + + +try: + import numba +except ImportError: + raise ImportError( + "Numba is not installed. Please install it with `pip install numba` " + "or `conda install numba` to use the tskit.jit.numba module." + ) + + +# Decorator that makes a jited dataclass by removing certain methods +# that are not compatible with Numba's JIT compilation. +def jitdataclass(cls): + dc_cls = dataclass(cls, eq=False) + del dc_cls.__dataclass_params__ + del dc_cls.__dataclass_fields__ + del dc_cls.__repr__ + try: + del dc_cls.__replace__ + except AttributeError: + # __replace__ is not available in Python < 3.10 + pass + try: + del dc_cls.__match_args__ + except AttributeError: + # __match_args__ is not available in Python < 3.10 + pass + return numba.experimental.jitclass(dc_cls) + + +@jitdataclass +class NumbaTreeSequence: + num_edges: numba.int64 + sequence_length: numba.float64 + edges_left: numba.float64[:] + edges_right: numba.float64[:] + indexes_edge_insertion_order: numba.int32[:] + indexes_edge_removal_order: numba.int32[:] + + def tree_position(self): + return NumbaTreePosition(self, (0, 0), (0, 0), (0, 0)) + + +@jitdataclass +class NumbaTreePosition: + ts: NumbaTreeSequence + interval: numba.types.UniTuple(numba.float64, 2) + edges_in_index_range: numba.types.UniTuple(numba.int32, 2) + edges_out_index_range: numba.types.UniTuple(numba.int32, 2) + + def next(self): # noqa: A003 + M = self.ts.num_edges + edges_left = self.ts.edges_left + edges_right = self.ts.edges_right + in_order = self.ts.indexes_edge_insertion_order + out_order = self.ts.indexes_edge_removal_order + + left = self.interval[1] + j = self.edges_in_index_range[1] + k = self.edges_out_index_range[1] + + while k < M and edges_right[out_order[k]] == left: + k += 1 + while j < M and edges_left[in_order[j]] == left: + j += 1 + + self.edges_in_index_range = (self.edges_in_index_range[1], j) + self.edges_out_index_range = (self.edges_out_index_range[1], k) + + right = self.ts.sequence_length + if j < M: + right = min(right, edges_left[in_order[j]]) + if k < M: + right = min(right, edges_right[out_order[k]]) + + self.interval = (left, right) + return j < M or left < self.ts.sequence_length + + +def numba_tree_sequence(ts): + return NumbaTreeSequence( + num_edges=ts.num_edges, + sequence_length=ts.sequence_length, + edges_left=ts.edges_left, + edges_right=ts.edges_right, + indexes_edge_insertion_order=ts.indexes_edge_insertion_order, + indexes_edge_removal_order=ts.indexes_edge_removal_order, + )