diff --git a/tests/test_formats.py b/tests/test_formats.py index 6a16d4c4..42ede54c 100644 --- a/tests/test_formats.py +++ b/tests/test_formats.py @@ -191,7 +191,10 @@ def encode_metadata(metadata, schema): assert input_file.num_populations == ts.num_populations for pop in ts.populations(): if pop.metadata: - assert pop_metadata[pop.id] == pop.metadata + for key in pop.metadata: + # Everything in pop.metadata should also be in pop_metadata[pop.id] + # (pop_metadata[pop.id] might have an extra name + description) + assert pop_metadata[pop.id][key] == pop.metadata[key] if ts.num_individuals == 0: assert input_file.num_individuals == ts.num_samples else: @@ -206,7 +209,7 @@ def encode_metadata(metadata, schema): IS_WINDOWS, reason="windows simultaneous file permissions issue" ) def test_defaults_with_path(self): - ts = tsutil.get_example_ts(10, 10) + ts = tsutil.get_example_ts(10) with tempfile.TemporaryDirectory(prefix="tsinf_format_test") as tempdir: filename = os.path.join(tempdir, "samples.tmp") input_file = formats.SampleData( @@ -284,19 +287,19 @@ def test_inference_not_supported(self): sample_data.add_site(0.1, [1, 1], inference=False) def test_defaults_no_path(self): - ts = tsutil.get_example_ts(10, 10) + ts = tsutil.get_example_ts(10) with formats.SampleData(sequence_length=ts.sequence_length) as sample_data: self.verify_data_round_trip(ts, sample_data) for _, array in sample_data.arrays(): assert array.compressor == formats.DEFAULT_COMPRESSOR def test_with_metadata_and_individuals(self): - ts = tsutil.get_example_individuals_ts_with_metadata(5, 2, 10, 1) + ts = tsutil.get_example_individuals_ts_with_metadata(5, ploidy=2) with formats.SampleData(sequence_length=ts.sequence_length) as sample_data: self.verify_data_round_trip(ts, sample_data) def test_access_individuals(self): - ts = tsutil.get_example_individuals_ts_with_metadata(5, 2, 10, 1) + ts = tsutil.get_example_individuals_ts_with_metadata(5, ploidy=2) sd = tsinfer.SampleData.from_tree_sequence(ts) assert sd.num_individuals > 0 has_some_metadata = False @@ -312,7 +315,7 @@ def test_access_individuals(self): assert i == sd.num_individuals - 1 def test_access_populations(self): - ts = tsutil.get_example_individuals_ts_with_metadata(5, 2, 10, 1) + ts = tsutil.get_example_individuals_ts_with_metadata(5, ploidy=2) sd = tsinfer.SampleData.from_tree_sequence(ts) assert sd.num_individuals > 0 has_some_metadata = False @@ -327,7 +330,7 @@ def test_from_tree_sequence_bad_times(self): n_individuals = 4 ploidy = 2 individual_times = np.arange(n_individuals) # Diploids - ts = tsutil.get_example_historical_sampled_ts(individual_times, ploidy, 10) + ts = tsutil.get_example_historical_sampled_ts(individual_times, ploidy) tables = ts.dump_tables() # Associate nodes at different times with a single individual nodes_time = tables.nodes.time @@ -342,7 +345,7 @@ def test_from_tree_sequence_bad_times(self): def test_from_tree_sequence_bad_populations(self): n_individuals = 4 - ts = tsutil.get_example_ts(n_individuals * 2, 10, 1) # Diploids + ts = tsutil.get_example_ts(n_individuals * 2) # Diploids tables = ts.dump_tables() # Associate each sample node with a new population for _ in range(n_individuals * 2): @@ -363,14 +366,14 @@ def test_from_tree_sequence_bad_populations(self): formats.SampleData.from_tree_sequence(bad_ts) def test_from_tree_sequence_simple(self): - ts = tsutil.get_example_ts(10, 10, 1) + ts = tsutil.get_example_ts(10) sd1 = formats.SampleData(sequence_length=ts.sequence_length) self.verify_data_round_trip(ts, sd1) sd2 = formats.SampleData.from_tree_sequence(ts, use_sites_time=True) assert sd1.data_equal(sd2) def test_from_tree_sequence_variable_allele_number(self): - ts = tsutil.get_example_ts(10, 10) + ts = tsutil.get_example_ts(10) # Create > 2 alleles by scattering mutations on the tree nodes at the first site tables = ts.dump_tables() # We can't have mixed know and unknown times. @@ -402,7 +405,7 @@ def test_from_tree_sequence_variable_allele_number(self): assert sd1.data_equal(sd2) def test_from_tree_sequence_with_metadata(self): - ts = tsutil.get_example_individuals_ts_with_metadata(5, 2, 10) + ts = tsutil.get_example_individuals_ts_with_metadata(5, 2) # Remove individuals tables = ts.dump_tables() tables.individuals.clear() @@ -418,7 +421,7 @@ def test_from_tree_sequence_with_metadata(self): assert sd1.data_equal(sd2) def test_from_tree_sequence_with_metadata_and_individuals(self): - ts = tsutil.get_example_individuals_ts_with_metadata(5, 3, 10) + ts = tsutil.get_example_individuals_ts_with_metadata(5, ploidy=3) sd1 = formats.SampleData(sequence_length=ts.sequence_length) self.verify_data_round_trip(ts, sd1) sd2 = formats.SampleData.from_tree_sequence(ts, use_sites_time=True) @@ -428,7 +431,7 @@ def test_from_historical_tree_sequence_with_times(self): n_indiv = 5 ploidy = 2 individual_times = np.arange(n_indiv) - ts = tsutil.get_example_historical_sampled_ts(individual_times, ploidy, 10) + ts = tsutil.get_example_historical_sampled_ts(individual_times, ploidy) # Test on a tree seq containing an individual with no nodes keep_samples = [u for i in ts.individuals() for u in i.nodes if i.id < n_indiv] ts = ts.simplify(samples=keep_samples, filter_individuals=False) @@ -446,7 +449,7 @@ def test_from_tree_sequence_no_times(self): n_indiv = 5 ploidy = 2 individual_times = np.arange(n_indiv + 1) - ts = tsutil.get_example_historical_sampled_ts(individual_times, ploidy, 10) + ts = tsutil.get_example_historical_sampled_ts(individual_times, ploidy) # Test on a tree seq containing an individual with no nodes keep_samples = [u for i in ts.individuals() for u in i.nodes if i.id < n_indiv] ts = ts.simplify(samples=keep_samples, filter_individuals=False) @@ -457,7 +460,7 @@ def test_from_tree_sequence_no_times(self): def test_from_tree_sequence_time_incompatibilities(self): ploidy = 2 individual_times = np.arange(5) - ts = tsutil.get_example_historical_sampled_ts(individual_times, ploidy, 10) + ts = tsutil.get_example_historical_sampled_ts(individual_times, ploidy) with pytest.raises(ValueError, match="Incompatible timescales"): _ = formats.SampleData.from_tree_sequence(ts, use_individuals_time=True) # Similar error if no individuals in the TS @@ -471,7 +474,7 @@ def test_from_tree_sequence_time_incompatibilities(self): _ = formats.SampleData.from_tree_sequence(ts, use_individuals_time=True) def test_chunk_size(self): - ts = tsutil.get_example_ts(4, 2) + ts = tsutil.get_example_ts(4, mutation_rate=0.005) assert ts.num_sites > 50 for chunk_size in [1, 2, 3, ts.num_sites - 1, ts.num_sites, ts.num_sites + 1]: input_file = formats.SampleData( @@ -484,7 +487,7 @@ def test_chunk_size(self): assert array.chunks[1] == chunk_size def test_filename(self): - ts = tsutil.get_example_ts(14, 15) + ts = tsutil.get_example_ts(14) with tempfile.TemporaryDirectory(prefix="tsinf_format_test") as tempdir: filename = os.path.join(tempdir, "samples.tmp") input_file = formats.SampleData( @@ -504,7 +507,7 @@ def test_filename(self): other_input_file.close() def test_chunk_size_file_equal(self): - ts = tsutil.get_example_ts(13, 15) + ts = tsutil.get_example_ts(13) with tempfile.TemporaryDirectory(prefix="tsinf_format_test") as tempdir: files = [] for chunk_size in [5, 7]: @@ -525,7 +528,7 @@ def test_chunk_size_file_equal(self): @pytest.mark.slow def test_compressor(self): - ts = tsutil.get_example_ts(11, 17) + ts = tsutil.get_example_ts(11, random_seed=123) compressors = [ None, formats.DEFAULT_COMPRESSOR, @@ -545,7 +548,7 @@ def test_compressor(self): assert array.compressor == compressor def test_multichar_alleles(self): - ts = tsutil.get_example_ts(5, 17) + ts = tsutil.get_example_ts(5) t = ts.dump_tables() t.sites.clear() t.mutations.clear() @@ -562,13 +565,14 @@ def test_multichar_alleles(self): self.verify_data_round_trip(ts, input_file) def test_str(self): - ts = tsutil.get_example_ts(5, 3) + ts = tsutil.get_example_ts(5, random_seed=2) input_file = formats.SampleData(sequence_length=ts.sequence_length) self.verify_data_round_trip(ts, input_file) assert len(str(input_file)) > 0 def test_eq(self): - ts = tsutil.get_example_ts(5, 3) + ts = tsutil.get_example_ts(5, random_seed=3) + print(ts.num_sites) input_file = formats.SampleData(sequence_length=ts.sequence_length) self.verify_data_round_trip(ts, input_file) assert input_file == input_file @@ -576,7 +580,7 @@ def test_eq(self): assert not ({} == input_file) def test_provenance(self): - ts = tsutil.get_example_ts(4, 3) + ts = tsutil.get_example_ts(4, random_seed=10) input_file = formats.SampleData(sequence_length=ts.sequence_length) self.verify_data_round_trip(ts, input_file) assert input_file.num_provenances == 1 @@ -591,7 +595,7 @@ def test_provenance(self): assert a[0][1] == record def test_clear_provenance(self): - ts = tsutil.get_example_ts(4, 3) + ts = tsutil.get_example_ts(4, random_seed=6) input_file = formats.SampleData(sequence_length=ts.sequence_length) self.verify_data_round_trip(ts, input_file) assert input_file.num_provenances == 1 @@ -717,21 +721,27 @@ def test_population_metadata(self): sample_data.add_site(position=0, genotypes=[0, 1]) sample_data.finalise() - assert sample_data.populations_metadata[0] == {"a": 1} - assert sample_data.populations_metadata[1] == {"b": 2} + assert sample_data.populations_metadata[0] == { + "a": 1, + "name": tsinfer.AUTOGENERATED_POP_NAME_PREFIX + "0", + "description": tsinfer.AUTOGENERATED_POP_DESCRIPTION, + } + assert sample_data.populations_metadata[1] == { + "b": 2, + "name": tsinfer.AUTOGENERATED_POP_NAME_PREFIX + "1", + "description": tsinfer.AUTOGENERATED_POP_DESCRIPTION, + } assert sample_data.individuals_population[0] == 0 assert sample_data.individuals_population[1] == 1 def test_individual_metadata(self): sample_data = formats.SampleData(sequence_length=10) - sample_data.add_population({"a": 1}) - sample_data.add_population({"b": 2}) - sample_data.add_individual(population=0) - sample_data.add_individual(population=1) + sample_data.add_individual(metadata={"a": 1}) + sample_data.add_individual(metadata={"b": 2}) sample_data.add_site(0, [0, 0]) sample_data.finalise() - assert sample_data.populations_metadata[0] == {"a": 1} - assert sample_data.populations_metadata[1] == {"b": 2} + assert sample_data.individuals_metadata[0] == {"a": 1} + assert sample_data.individuals_metadata[1] == {"b": 2} def test_add_individual_time(self): sample_data = formats.SampleData(sequence_length=10) @@ -815,7 +825,7 @@ def test_add_site_return(self): assert sid == 1 def test_sites(self): - ts = tsutil.get_example_ts(11, 15) + ts = tsutil.get_example_ts(11) assert ts.num_sites > 1 input_file = formats.SampleData.from_tree_sequence(ts) @@ -829,7 +839,7 @@ def test_sites(self): assert next(all_sites, None) is None, None def test_sites_subset(self): - ts = tsutil.get_example_ts(11, 15) + ts = tsutil.get_example_ts(11) assert ts.num_sites > 1 input_file = formats.SampleData.from_tree_sequence(ts, use_sites_time=True) assert list(input_file.sites([])) == [] @@ -847,7 +857,7 @@ def test_sites_subset(self): list(input_file.sites([10000])) def test_variants(self): - ts = tsutil.get_example_ts(11, 15) + ts = tsutil.get_example_ts(11) assert ts.num_sites > 1 input_file = formats.SampleData.from_tree_sequence(ts) @@ -862,7 +872,7 @@ def test_variants(self): assert next(all_variants, None) is None, None def test_variants_subset_sites(self): - ts = tsutil.get_example_ts(4, 2) + ts = tsutil.get_example_ts(4, mutation_rate=0.004) assert ts.num_sites > 50 for chunk_size in [1, 2, 3, ts.num_sites - 1, ts.num_sites, ts.num_sites + 1]: input_file = formats.SampleData( @@ -888,7 +898,7 @@ def test_variants_subset_sites(self): assert every_variant == next(v) def test_all_haplotypes(self): - ts = tsutil.get_example_ts(13, 12) + ts = tsutil.get_example_ts(13, random_seed=111) assert ts.num_sites > 1 input_file = formats.SampleData.from_tree_sequence(ts) @@ -915,7 +925,7 @@ def test_all_haplotypes(self): assert j == ts.num_samples def test_haplotypes_index_errors(self): - ts = tsutil.get_example_ts(13, 12) + ts = tsutil.get_example_ts(13, random_seed=19) assert ts.num_sites > 1 input_file = formats.SampleData.from_tree_sequence(ts) with pytest.raises(ValueError): @@ -934,7 +944,7 @@ def test_haplotypes_index_errors(self): list(input_file.haplotypes([3, 14])) def test_haplotypes_subsets(self): - ts = tsutil.get_example_ts(25, 12) + ts = tsutil.get_example_ts(25) assert ts.num_sites > 1 input_file = formats.SampleData.from_tree_sequence(ts) @@ -970,7 +980,7 @@ def test_haplotypes_subsets(self): assert j == len(subset) def test_ts_with_invariant_sites(self): - ts = tsutil.get_example_ts(5, 3) + ts = tsutil.get_example_ts(5) t = ts.dump_tables() positions = {site.position for site in ts.sites()} for j in range(10): @@ -987,7 +997,7 @@ def test_ts_with_invariant_sites(self): assert len(str(input_file)) > 0 def test_ts_with_root_mutations(self): - ts = tsutil.get_example_ts(5, 3) + ts = tsutil.get_example_ts(5) t = ts.dump_tables() positions = {site.position for site in ts.sites()} for tree in ts.trees(): @@ -1189,10 +1199,11 @@ def test_num_alleles_with_missing(self): assert np.all(sd.num_alleles() == np.array([0, 1, 2, 2])) def test_append_sites(self): - ts = tsutil.get_example_individuals_ts_with_metadata(4, 2, 10) - sd1 = tsinfer.SampleData.from_tree_sequence(ts.keep_intervals([[0, 2]])) - sd2 = tsinfer.SampleData.from_tree_sequence(ts.keep_intervals([[2, 5]])) - sd3 = tsinfer.SampleData.from_tree_sequence(ts.keep_intervals([[5, 10]])) + pos = [0, 2000, 5000, 10000] + ts = tsutil.get_example_individuals_ts_with_metadata(4, sequence_length=pos[-1]) + sd1 = tsinfer.SampleData.from_tree_sequence(ts.keep_intervals([pos[0:2]])) + sd2 = tsinfer.SampleData.from_tree_sequence(ts.keep_intervals([pos[1:3]])) + sd3 = tsinfer.SampleData.from_tree_sequence(ts.keep_intervals([pos[2:]])) sd = sd1.copy() # put into write mode sd.append_sites(sd2, sd3) sd.finalise() @@ -1204,20 +1215,22 @@ def test_append_sites(self): sd_full.assert_data_equal(tsinfer.SampleData.from_tree_sequence(ts)) def test_append_sites_bad_order(self): - ts = tsutil.get_example_individuals_ts_with_metadata(4, 2, 10) - sd1 = tsinfer.SampleData.from_tree_sequence(ts.keep_intervals([[0, 2]])) - sd2 = tsinfer.SampleData.from_tree_sequence(ts.keep_intervals([[2, 5]])) - sd3 = tsinfer.SampleData.from_tree_sequence(ts.keep_intervals([[5, 10]])) + pos = [0, 2000, 5000, 10000] + ts = tsutil.get_example_individuals_ts_with_metadata(4, sequence_length=pos[-1]) + sd1 = tsinfer.SampleData.from_tree_sequence(ts.keep_intervals([pos[0:2]])) + sd2 = tsinfer.SampleData.from_tree_sequence(ts.keep_intervals([pos[1:3]])) + sd3 = tsinfer.SampleData.from_tree_sequence(ts.keep_intervals([pos[2:]])) sd = sd1.copy() # put into write mode with pytest.raises(ValueError, match="ascending"): sd.append_sites(sd3, sd2) def test_append_sites_incompatible_files(self): - ts = tsutil.get_example_individuals_ts_with_metadata(4, 2, 10) - sd1 = tsinfer.SampleData.from_tree_sequence(ts.keep_intervals([[0, 2]])) - mid_ts = ts.keep_intervals([[2, 5]]) + pos = [0, 2000, 5000, 10000] + ts = tsutil.get_example_individuals_ts_with_metadata(4, sequence_length=pos[-1]) + sd1 = tsinfer.SampleData.from_tree_sequence(ts.keep_intervals([pos[0:2]])) + mid_ts = ts.keep_intervals([pos[1:3]]) sd2 = tsinfer.SampleData.from_tree_sequence(mid_ts) - sd3 = tsinfer.SampleData.from_tree_sequence(ts.keep_intervals([[5, 10]])) + sd3 = tsinfer.SampleData.from_tree_sequence(ts.keep_intervals([pos[2:]])) # Fails if altered SD is not in write mode with pytest.raises(ValueError, match="build"): sd1.append_sites(sd2, sd3) @@ -1276,8 +1289,10 @@ def test_metadata_schemas_default(self): with formats.SampleData() as sample_data: sample_data.add_site(0, [0, 0]) assert sample_data.metadata_schema == tsinfer.permissive_json_schema() - # Tables default to None for backward compatibility. - assert sample_data.populations_metadata_schema is None + assert ( + sample_data.populations_metadata_schema == tsinfer.permissive_json_schema() + ) + # other tables default to None for backward compatibility. assert sample_data.individuals_metadata_schema is None assert sample_data.sites_metadata_schema is None @@ -1313,16 +1328,24 @@ def test_set_top_level_metadata_schema(self): assert sample_data.metadata_schema == tsinfer.permissive_json_schema() def test_set_population_metadata_schema(self): + # By default we already use tsinfer.permissive_json_schema() for populations example_schema = tsinfer.permissive_json_schema() + example_schema["properties"]["xyz"] = {"type": "string"} with formats.SampleData() as sample_data: - assert sample_data.populations_metadata_schema is None + assert ( + sample_data.populations_metadata_schema + == tsinfer.permissive_json_schema() + ) sample_data.populations_metadata_schema = example_schema assert sample_data.populations_metadata_schema == example_schema sample_data.add_site(0, [0, 0]) assert sample_data.populations_metadata_schema == example_schema with formats.SampleData() as sample_data: - assert sample_data.populations_metadata_schema is None + assert ( + sample_data.populations_metadata_schema + == tsinfer.permissive_json_schema() + ) sample_data.populations_metadata_schema = example_schema assert sample_data.populations_metadata_schema == example_schema sample_data.add_site(0, [0, 0]) @@ -1370,7 +1393,7 @@ class TestSampleDataSubset: """ def test_no_arguments(self): - ts = tsutil.get_example_ts(10, 10, 1) + ts = tsutil.get_example_ts(10) sd1 = formats.SampleData.from_tree_sequence(ts) # No arguments gives the same data subset = sd1.subset() @@ -1419,7 +1442,7 @@ def verify_subset_data(self, source, individuals, sites): assert j == len(sites) def test_simple_case(self): - ts = tsutil.get_example_ts(10, 10, 1) + ts = tsutil.get_example_ts(10) sd1 = formats.SampleData.from_tree_sequence(ts) G1 = ts.genotype_matrix() # Because this is a haploid tree sequence we can use the @@ -1432,7 +1455,7 @@ def test_simple_case(self): self.verify_subset_data(sd1, cols, rows) def test_reordering_individuals(self): - ts = tsutil.get_example_ts(10, 10, 1) + ts = tsutil.get_example_ts(10) sd = formats.SampleData.from_tree_sequence(ts) ind = np.arange(sd.num_individuals)[::-1] subset = sd.subset(individuals=ind) @@ -1440,7 +1463,7 @@ def test_reordering_individuals(self): assert np.array_equal(sd.sites_genotypes[:][:, ind], subset.sites_genotypes[:]) def test_mixed_diploid_metadata(self): - ts = tsutil.get_example_individuals_ts_with_metadata(10, 2, 10) + ts = tsutil.get_example_individuals_ts_with_metadata(10, ploidy=2) sd = formats.SampleData.from_tree_sequence(ts) N = sd.num_individuals M = sd.num_sites @@ -1453,7 +1476,7 @@ def test_mixed_diploid_metadata(self): self.verify_subset_data(sd, [0, N - 1], range(M)) def test_mixed_triploid_metadata(self): - ts = tsutil.get_example_individuals_ts_with_metadata(18, 3, 10) + ts = tsutil.get_example_individuals_ts_with_metadata(18, ploidy=3) sd = formats.SampleData.from_tree_sequence(ts) N = sd.num_individuals M = sd.num_sites @@ -1466,7 +1489,7 @@ def test_mixed_triploid_metadata(self): self.verify_subset_data(sd, [0, N - 1], range(M)) def test_errors(self): - ts = tsutil.get_example_ts(10, 10, 1) + ts = tsutil.get_example_ts(10) sd1 = formats.SampleData.from_tree_sequence(ts) with pytest.raises(ValueError): sd1.subset(sites=[]) @@ -1495,7 +1518,7 @@ def test_errors(self): def test_file_kwargs(self): # Make sure we pass kwards on to the SampleData constructor as # required. - ts = tsutil.get_example_ts(10, 10, 1) + ts = tsutil.get_example_ts(10) sd1 = formats.SampleData.from_tree_sequence(ts) with tempfile.TemporaryDirectory() as tmpdir: path = os.path.join(tmpdir, "sample-data") @@ -1511,7 +1534,7 @@ class TestSampleDataMerge: """ def test_finalised(self): - ts1 = tsutil.get_example_ts(2, 2, 1) + ts1 = tsutil.get_example_ts(2) sd1 = formats.SampleData.from_tree_sequence(ts1) sd1_copy = sd1.copy() with pytest.raises(ValueError, match="not finalised"): @@ -1520,16 +1543,16 @@ def test_finalised(self): sd1.merge(sd1_copy) def test_different_sequence_lengths(self): - ts1 = tsutil.get_example_ts(2, 2, 1) + ts1 = tsutil.get_example_ts(2, sequence_length=10000) sd1 = formats.SampleData.from_tree_sequence(ts1) - ts2 = tsutil.get_example_ts(2, 3, 1) + ts2 = tsutil.get_example_ts(2, sequence_length=10001) sd2 = formats.SampleData.from_tree_sequence(ts2) with pytest.raises(ValueError): sd1.merge(sd2) def test_mismatch_ancestral_state(self): # Difference ancestral states - ts = tsutil.get_example_ts(2, 2, 1) + ts = tsutil.get_example_ts(2) sd1 = formats.SampleData.from_tree_sequence(ts) tables = ts.dump_tables() tables.sites.ancestral_state += 2 @@ -1537,8 +1560,8 @@ def test_mismatch_ancestral_state(self): with pytest.raises(ValueError): sd1.merge(sd2) - def verify(self, sd1, sd2): - sd3 = sd1.merge(sd2) + def verify(self, sd1, sd2, use_population_names=False): + sd3 = sd1.merge(sd2, use_population_names=use_population_names) n1 = sd1.num_samples n2 = sd2.num_samples assert sd3.num_samples == n1 + n2 @@ -1576,7 +1599,13 @@ def verify(self, sd1, sd2): assert len(new_pops) == len(old_pops) for new_pop, old_pop in zip(new_pops, old_pops): assert new_pop.id == old_pop.id + sd1.num_populations - assert new_pop.metadata == old_pop.metadata + if use_population_names: + assert new_pop.metadata == old_pop.metadata + else: + # Name could have been changed: just check other metadata + del old_pop.metadata["name"] + del new_pop.metadata["name"] + assert new_pop.metadata == old_pop.metadata sd1_sites = set(sd1.sites_position) sd2_sites = set(sd2.sites_position) @@ -1615,7 +1644,7 @@ def verify(self, sd1, sd2): def test_merge_identical(self): n = 10 - ts = tsutil.get_example_ts(n, 10, 1) + ts = tsutil.get_example_ts(n) sd1 = formats.SampleData.from_tree_sequence(ts, use_sites_time=True) sd2 = sd1.merge(sd1) assert sd2.num_sites == sd1.num_sites @@ -1629,9 +1658,9 @@ def test_merge_identical(self): def test_merge_distinct(self): n = 10 - ts = tsutil.get_example_ts(n, 10, 1, random_seed=1) + ts = tsutil.get_example_ts(n, random_seed=1) sd1 = formats.SampleData.from_tree_sequence(ts) - ts = tsutil.get_example_ts(n, 10, 1, random_seed=2) + ts = tsutil.get_example_ts(n, random_seed=2) sd2 = formats.SampleData.from_tree_sequence(ts) assert len(set(sd1.sites_position) & set(sd2.sites_position)) == 0 @@ -1661,7 +1690,7 @@ def test_merge_distinct(self): self.verify(sd2, sd1) def test_merge_overlapping_sites(self): - ts = tsutil.get_example_ts(4, 10, 1, random_seed=1) + ts = tsutil.get_example_ts(4, random_seed=1) sd1 = formats.SampleData.from_tree_sequence(ts) tables = ts.dump_tables() # Change the position of the first and last sites to we have @@ -1679,24 +1708,28 @@ def test_merge_overlapping_sites(self): self.verify(sd2, sd1) def test_individuals_metadata_identical(self): - ts = tsutil.get_example_individuals_ts_with_metadata(5, 2, 10, 1) + ts = tsutil.get_example_individuals_ts_with_metadata(5) sd1 = formats.SampleData.from_tree_sequence(ts) self.verify(sd1, sd1) def test_individuals_metadata_distinct(self): - ts = tsutil.get_example_individuals_ts_with_metadata(5, 2, 10, 1) + ts = tsutil.get_example_individuals_ts_with_metadata( + 5, ploidy=2, discrete_genome=False + ) # Use discrete_genome to put sites at different positions sd1 = formats.SampleData.from_tree_sequence(ts) - ts = tsutil.get_example_individuals_ts_with_metadata(3, 3, 10, 1) + ts = tsutil.get_example_individuals_ts_with_metadata( + 3, ploidy=3, discrete_genome=False + ) sd2 = formats.SampleData.from_tree_sequence(ts) assert len(set(sd1.sites_position) & set(sd2.sites_position)) == 0 self.verify(sd1, sd2) self.verify(sd2, sd1) def test_different_alleles_same_sites(self): - ts = tsutil.get_example_individuals_ts_with_metadata(5, 2, 10, 1) + ts = tsutil.get_example_ts(5, mutation_model=msprime.BinaryMutationModel()) sd1 = formats.SampleData.from_tree_sequence(ts) tables = ts.dump_tables() - tables.mutations.derived_state += 1 + tables.mutations.derived_state += 1 # "0" -> "1", "1"-> "2", etc sd2 = formats.SampleData.from_tree_sequence(tables.tree_sequence()) self.verify(sd1, sd2) self.verify(sd2, sd1) @@ -1733,7 +1766,7 @@ def test_missing_data(self): def test_file_kwargs(self): # Make sure we pass kwards on to the SampleData constructor as # required. - ts = tsutil.get_example_ts(10, 10, 1) + ts = tsutil.get_example_ts(10) sd1 = formats.SampleData.from_tree_sequence(ts) with tempfile.TemporaryDirectory() as tmpdir: path = os.path.join(tmpdir, "sample-data") @@ -1749,7 +1782,7 @@ class TestMinSiteTimes: """ def test_no_historical(self): - ts = tsutil.get_example_ts(10, 10, 1) + ts = tsutil.get_example_ts(10) sd1 = formats.SampleData.from_tree_sequence(ts, use_sites_time=True) # No arguments and individuals_only=True should give array of zeros bounds_individuals_only = sd1.min_site_times(individuals_only=True) @@ -1782,7 +1815,7 @@ def test_simple_case(self): assert np.array_equal(time_bound, sd1.sites_time[:]) def test_errors(self): - ts = tsutil.get_example_ts(10, 10, 1) + ts = tsutil.get_example_ts(10) sd1 = formats.SampleData.from_tree_sequence(ts) individuals_time = sd1.individuals_time[:] neg_times_sd1 = sd1.copy() @@ -1804,8 +1837,8 @@ def get_example_data( ): ts = msprime.simulate( sample_size, - recombination_rate=1, mutation_rate=10, + recombination_rate=1, length=sequence_length, random_seed=100, ) @@ -1934,7 +1967,7 @@ def test_provenance(self): def test_chunk_size(self): N = 20 for chunk_size in [1, 2, 3, N - 1, N, N + 1]: - sample_data, ancestors = self.get_example_data(6, 1, N) + sample_data, ancestors = self.get_example_data(6, 1, num_ancestors=N) ancestor_data = tsinfer.AncestorData(sample_data, chunk_size=chunk_size) self.verify_data_round_trip(sample_data, ancestor_data, ancestors) assert ancestor_data.ancestors_haplotype.chunks == (chunk_size,) diff --git a/tests/test_inference.py b/tests/test_inference.py index f986fe6f..07f70329 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -889,13 +889,13 @@ def test_site_metadata(self, use_schema): assert value == tsinfer.INFERENCE_FULL assert decoded_metadata == all_metadata[site.id] - @pytest.mark.parametrize("use_schema", [True, False]) - def test_population_metadata(self, use_schema): + @pytest.mark.parametrize("no_schema", [True, False]) + def test_population_metadata(self, no_schema): ts = msprime.simulate(12, mutation_rate=5, random_seed=16) assert ts.num_sites > 2 sample_data = tsinfer.SampleData(sequence_length=1) - if use_schema: - sample_data.populations_metadata_schema = tsinfer.permissive_json_schema() + if no_schema: + sample_data.populations_metadata_schema = None rng = random.Random(32) all_metadata = [] for j in range(ts.num_samples): @@ -914,7 +914,7 @@ def test_population_metadata(self, use_schema): assert all_metadata[j] == metadata output_ts = tsinfer.infer(sample_data) output_metadata = [ - population.metadata if use_schema else json.loads(population.metadata) + json.loads(population.metadata) if no_schema else population.metadata for population in output_ts.populations() ] assert all_metadata == output_metadata @@ -1079,15 +1079,20 @@ def test_from_standard_tree_sequence(self): """ n_indiv = 5 ploidy = 2 # Diploids - seq_len = 10 ts = tsutil.get_example_individuals_ts_with_metadata( - n_indiv, ploidy, seq_len, 1, skip_last=False + n_indiv, ploidy, skip_last=False ) ts_inferred = tsinfer.infer(tsinfer.SampleData.from_tree_sequence(ts)) assert ts.sequence_length == ts_inferred.sequence_length assert ts.metadata_schema.schema == ts_inferred.metadata_schema.schema assert ts.metadata == ts_inferred.metadata - assert ts.tables.populations == ts_inferred.tables.populations + assert ts.tables.populations.equals( + ts_inferred.tables.populations, ignore_metadata=True + ) + # Check all the metadata in pops is in inferred pops (in addition to names etc) + for pop, pop_inferred in zip(ts.populations(), ts_inferred.populations()): + for key in pop.metadata.keys(): + assert pop.metadata[key] == pop_inferred.metadata[key] assert ts.num_individuals == ts_inferred.num_individuals for i1, i2 in zip(ts.individuals(), ts_inferred.individuals()): assert list(i1.location) == list(i2.location) @@ -1113,9 +1118,8 @@ def test_from_historical_tree_sequence(self): """ n_indiv = 5 ploidy = 2 # Diploids - seq_len = 10 individual_times = np.arange(n_indiv) - ts = tsutil.get_example_historical_sampled_ts(individual_times, ploidy, seq_len) + ts = tsutil.get_example_historical_sampled_ts(individual_times, ploidy) ts_inferred = tsinfer.infer( tsinfer.SampleData.from_tree_sequence( ts, use_sites_time=True, use_individuals_time=True @@ -2977,10 +2981,11 @@ def verify(self, samples): t2, node_id_map = tsinfer.extract_ancestors(samples, ts) assert len(t2.provenances) == len(t1.provenances) + 2 - # Population data isn't carried through in ancestors tree sequences + # Population isn't carried through in ancestors tree sequences # for now. t2.populations.clear() - assert t1.equals(t2, ignore_provenance=True, ignore_ts_metadata=True) + + t1.assert_equals(t2, ignore_provenance=True, ignore_ts_metadata=True) for node in ts.nodes(): if node_id_map[node.id] != -1: diff --git a/tests/tsutil.py b/tests/tsutil.py index 7738c54e..2423a43b 100644 --- a/tests/tsutil.py +++ b/tests/tsutil.py @@ -48,14 +48,13 @@ def add_default_schemas(ts): """ tables = ts.dump_tables() schema = tskit.MetadataSchema(tsinfer.permissive_json_schema()) - # Make sure we're not overwriting existing metadata. This will probably - # fail when msprime 1.0 comes along, but we can fix it then. assert len(tables.metadata) == 0 tables.metadata_schema = schema tables.metadata = {} tables.populations.metadata_schema = schema - assert len(tables.populations.metadata) == 0 - tables.populations.packset_metadata([b"{}"] * ts.num_populations) + # msprime 1.0 fills the population metadata, so put it back in here + for pop in ts.populations(): + tables.populations[pop.id] = pop tables.individuals.metadata_schema = schema assert len(tables.individuals.metadata) == 0 tables.individuals.packset_metadata([b"{}"] * ts.num_individuals) @@ -65,19 +64,36 @@ def add_default_schemas(ts): return tables.tree_sequence() -def get_example_ts(sample_size, sequence_length, mutation_rate=10, random_seed=100): - ts = msprime.simulate( +def get_example_ts( + sample_size, + sequence_length=10000, + mutation_rate=0.0005, + mutation_model=None, + discrete_genome=True, + random_seed=100, +): + ts = msprime.sim_ancestry( sample_size, - recombination_rate=1, - mutation_rate=mutation_rate, - length=sequence_length, + ploidy=1, + sequence_length=sequence_length, + recombination_rate=mutation_rate * 0.1, + discrete_genome=discrete_genome, random_seed=random_seed, ) + ts = msprime.sim_mutations( + ts, rate=mutation_rate, model=mutation_model, random_seed=random_seed + ) return add_default_schemas(ts) def get_example_individuals_ts_with_metadata( - n, ploidy, length, mutation_rate=1, *, skip_last=True + n, + ploidy=2, + sequence_length=10000, + mutation_rate=0.0002, + *, + discrete_genome=True, + skip_last=True, ): """ For testing only, create a ts with lots of arbitrary metadata attached to sites, @@ -88,18 +104,22 @@ def get_example_individuals_ts_with_metadata( For testing purposes, we can set ``skip_last`` to check what happens if we have some samples that are not associated with an individual in the tree sequence. """ - ts = msprime.simulate( - n * ploidy, - recombination_rate=1, - mutation_rate=mutation_rate, - length=length, + ts = msprime.sim_ancestry( + n, + ploidy=ploidy, + recombination_rate=mutation_rate * 0.1, + sequence_length=sequence_length, random_seed=100, + discrete_genome=discrete_genome, + ) + ts = msprime.sim_mutations( + ts, rate=mutation_rate, discrete_genome=discrete_genome, random_seed=100 ) ts = add_default_schemas(ts) tables = ts.dump_tables() tables.metadata = {f"a_{j}": j for j in range(n)} - tables.populations.clear() + tables.individuals.clear() for i in range(n): location = [i, i] individual_meta = {} @@ -107,9 +127,7 @@ def get_example_individuals_ts_with_metadata( if i % 2 == 0: # Add unicode metadata to every other individual: 8544+i = Roman numerals individual_meta = {"unicode id": chr(8544 + i)} - # TODO: flags should use np.iinfo(np.uint32).max. Change after solving issue - # https://github.com/tskit-dev/tskit/issues/1027 - individual_flags = np.random.randint(0, np.iinfo(np.int32).max) + individual_flags = np.random.randint(0, np.iinfo(np.uint32).max) # Also for populations: chr(127462) + chr(127462+i) give emoji flags pop_meta = {"utf": chr(127462) + chr(127462 + i)} tables.populations.add_row(metadata=pop_meta) # One pop for each individual @@ -145,22 +163,28 @@ def get_example_individuals_ts_with_metadata( return tables.tree_sequence() -def get_example_historical_sampled_ts(individual_times, ploidy=2, sequence_length=1): +def get_example_historical_sampled_ts( + individual_times, + ploidy=2, + sequence_length=10000, + mutation_rate=0.0002, +): samples = [ - msprime.Sample(population=0, time=t) + msprime.SampleSet(1, population=0, time=t, ploidy=ploidy) for t in individual_times - for _ in range(ploidy) ] - ts = msprime.simulate( + ts = msprime.sim_ancestry( samples=samples, - recombination_rate=1, - mutation_rate=10, - length=sequence_length, + ploidy=ploidy, + recombination_rate=mutation_rate * 0.1, + sequence_length=sequence_length, random_seed=100, ) + ts = msprime.sim_mutations(ts, rate=mutation_rate, random_seed=100) ts = add_default_schemas(ts) tables = ts.dump_tables() # Add individuals + tables.individuals.clear() nodes_individual = tables.nodes.individual individual_ids = [] for _ in individual_times: diff --git a/tsinfer/formats.py b/tsinfer/formats.py index 3df416ed..251d51b1 100644 --- a/tsinfer/formats.py +++ b/tsinfer/formats.py @@ -49,6 +49,8 @@ FORMAT_NAME_KEY = "format_name" FORMAT_VERSION_KEY = "format_version" FINALISED_KEY = "finalised" +AUTOGENERATED_POP_NAME_PREFIX = "pop_" +AUTOGENERATED_POP_DESCRIPTION = "name autogenerated by tsinfer" # We use the zstd compressor because it allows for compression of buffers # bigger than 2GB, which can occur in a larger instances. @@ -950,10 +952,11 @@ def __init__(self, sequence_length=0, **kwargs): dtype=object, object_codec=self._metadata_codec, ) - populations_group.attrs["metadata_schema"] = None + populations_group.attrs["metadata_schema"] = permissive_json_schema() self._populations_writer = BufferedItemWriter( {"metadata": metadata}, num_threads=self._num_flush_threads ) + self._population_names = set() individuals_group = self.data.create_group("individuals") individuals_group.attrs["metadata_schema"] = None @@ -1584,22 +1587,60 @@ def _alloc_site_writer(self): arrays, num_threads=self._num_flush_threads ) - def add_population(self, metadata=None): + def add_population(self, metadata=None, name=None, description=None): """ Adds a new :ref:`sec_inference_data_model_population` to this - :class:`.SampleData` and returns its ID. + :class:`.SampleData` and returns its ID. Metadata including (at a minimum) + a name and a description, will be associated with each population. All calls to this method must be made **before** individuals or sites are defined. :param dict metadata: A JSON encodable dict-like object containing - metadata that is to be associated with this population. + metadata that is to be associated with this population. Keys called + "name" or "description" may have their values overwritten using the + `name` and `description` parameters below. + :param str name: A unique name for this population (of length > 0). If given, + this will override any name defined in `metadata`. If not given, and + no "name" key exists in `metadata`, a suitable name will be autogenerated: + population 0 will be given the name `pop_0`, population 1 `pop_1`, etc. + :param str description: A description for this population (default: None). If not + None, this will override any description provided in `metadata`. Otherwise + (i.e. no "description" key exists in `metadata` and `description` is None) + if the population name was name was autogenerated the description will be + set to a default string (given by `tsinfer.AUTOGENERATED_POP_DESCRIPTION`), + whereas if the name was not autogenerated the description will default + to the empty string. + :return: The ID of the newly added population. :rtype: int + :raises: ValueError if the name has already been used for another population. """ self._check_build_mode() if self._build_state != self.ADDING_POPULATIONS: raise ValueError("Cannot add populations after adding samples or sites") + if metadata is None: + metadata = {} + + if name == "": + raise ValueError("Name cannot be blank") + if name is not None: + metadata["name"] = name + if "name" not in metadata: + metadata["name"] = AUTOGENERATED_POP_NAME_PREFIX + str( + len(self._population_names) + ) + if description not in metadata and description is None: + metadata["description"] = AUTOGENERATED_POP_DESCRIPTION + if metadata["name"] in self._population_names: + raise ValueError(f"Another population is already named {metadata['name']}") + self._population_names.add(metadata["name"]) + + if description is not None: + metadata["description"] = description + if "description" not in metadata: + metadata["description"] = "" + return self._populations_writer.add(metadata=self._check_metadata(metadata)) def add_individual( @@ -1883,19 +1924,24 @@ def __insert_individuals(self, other, pop_id_map=None): # Read mode #################################### - def merge(self, other, **kwargs): + def merge(self, other, *, use_population_names=False, **kwargs): """ - Returns a copy of this SampleData file merged with the specified - other SampleData file. Subsequent keyword arguments are passed + Returns a copy of this SampleData instance merged with the specified + other SampleData instance. Subsequent keyword arguments are passed to the SampleData constructor for the returned merged dataset. The datasets are merged by following process: - 1. We add the populations from this dataset to the result, followed - by the populations from other. Population references from the two - datasets are updated accordingly. + 1. We add the populations from this dataset to the result. If + ``use_population_names`` is False, all populations in ``other`` will + then be added as new, distinct populations (and named accordingly if + necessary, see note below). Otherwise populations with the same name + in the two datasets will be merged and only populations with novel + names will be added to the result. In either case, references to the + population IDs are adjusted as necessary. 2. We add individual data from this dataset to the result, followed - by the individuals from the other dataset. + by the individuals from the other dataset, and adjust references to + individual IDs as necessary. 3. We merge the variant data from the two datasets by comparing sites by their position. If two sites in the datasets have the same position we combine the genotype data. The alleles from this dataset @@ -1910,9 +1956,24 @@ def merge(self, other, **kwargs): :param SampleData other: The other :class:`SampleData` instance to to merge. + :param bool use_population_names: If ``False`` (default) always treat + populations in ``other`` as distinct populations. If ``True``, + populations in ``other`` that have the same metadata name are + treated as the same population. :return: A new SampleData instance which contains the merged data from the two datasets. :rtype: .SampleData + :raises: ValueError if ``other`` has a different sequence length. + :raises: ValueError if during auto-generation of names, a population in + `other` is allocated a new name which is the same as the name of a + population in the newly created dataset (see note). + + .. note:: + In order to ensure names are unique, if `use_population_names` is False, + and a population in `other` has the same name as one in the dataset being + generated, a new name will be auto-generated for the population, + in the same format as described in :meth:`add_population`. + """ self._check_finalised() other._check_finalised() @@ -1920,19 +1981,30 @@ def merge(self, other, **kwargs): raise ValueError("Sample data files must have the same sequence length") with SampleData(sequence_length=self.sequence_length, **kwargs) as result: # Keep the same population IDs from self. - for population in self.populations(): - result.add_population(population.metadata) - # TODO we could avoid duplicate populations here by keying on the - # metadata. It's slightly complicated by the case where the - # metadata is all empty, but we could fall through to just - # adding in all the populations as is, then. - other_pop_map = {-1: -1} - for population in other.populations(): - pid = result.add_population(population.metadata) - other_pop_map[population.id] = pid + pop_name_map = {} + num_populations = 0 + for p in self.populations(): + result.add_population(p.metadata) + num_populations += 1 + pop_name_map[p.metadata["name"]] = p.id + pop_id_map = {-1: -1} + for p in other.populations(): + if p.metadata["name"] in pop_name_map: + if use_population_names: + # Merge the populations; don't create a new one + pop_id_map[p.id] = pop_name_map[p.metadata["name"]] + continue + else: + # Name clash: keep the population but allocate a new name + p.metadata["name"] = AUTOGENERATED_POP_NAME_PREFIX + str( + num_populations + ) + pop_id_map[p.id] = result.add_population(p.metadata) + pop_name_map[p.metadata["name"]] = None # Add name to avoid later dups + num_populations += 1 result.__insert_individuals(self) - result.__insert_individuals(other, other_pop_map) + result.__insert_individuals(other, pop_id_map) for variant in merge_variants(self, other): result.add_site( diff --git a/tsinfer/inference.py b/tsinfer/inference.py index 7da92646..f652d18a 100644 --- a/tsinfer/inference.py +++ b/tsinfer/inference.py @@ -1436,6 +1436,13 @@ def get_ancestors_tree_sequence(self): child=child, ) + # Add the schema to the ancestors TS: this does not persist through to the + # final TS, but could be useful for adding populations associated with + # historical individuals + schema = self.sample_data.populations_metadata_schema + if schema is not None: + tables.populations.metadata_schema = tskit.MetadataSchema(schema) + self.convert_inference_mutations(tables) logger.debug("Sorting ancestors tree sequence")