Skip to content

Commit 49bf87b

Browse files
Add is_reference option to AncestralState config
When is_reference is true, the REF allele (variant_allele[:, 0]) from the VCZ store is used as the ancestral allele instead of reading a named field. Useful for simulated data where REF equals the ancestral allele. Exactly one of field or is_reference must be set; both default to None. The genotype roundtrip tests now use is_reference=True throughout, since all test inputs originate from tree sequences where REF is ancestral.
1 parent 76289b0 commit 49bf87b

File tree

7 files changed

+79
-21
lines changed

7 files changed

+79
-21
lines changed

docs/config.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ Specifies where to read the ancestral allele for each variant position.
3333
| Field | Type | Default | Description |
3434
|-------|------|---------|-------------|
3535
| `path` | string | (required) | Path to VCZ containing ancestral alleles |
36-
| `field` | string | (required) | Array name in the store (e.g. `"variant_AA"`) |
36+
| `field` | string || Array name in the store (e.g. `"variant_AA"`). Required unless `is_reference` is set. |
37+
| `is_reference` | bool | `false` | Use the REF allele (`variant_allele[:, 0]`) as the ancestral state. Useful for simulations. `field` must not be set when this is `true`. |
3738

3839

3940
## `[[ancestors]]`

docs/quickstart.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@ vcf2zarr convert mydata.vcf.gz mydata.vcz
2828
Each site used for inference requires a known **ancestral allele**. If your VCF
2929
has an `AA` INFO field, `vcf2zarr` stores it as `variant_AA` in the `.vcz`
3030
store and you can reference it directly in the config. Alternatively, ancestral
31-
alleles can come from a separate VCZ store. See the
32-
{ref}`config reference <sec_config_reference>` for details.
31+
alleles can come from a separate VCZ store. For simulated data where the REF
32+
allele is the ancestral allele, set `is_reference = true` instead of specifying
33+
a field. See the {ref}`config reference <sec_config_reference>` for details.
3334

3435

3536
## Writing the config

example_config.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ path = "data/1kgp_chr20.vcz"
5959
path = "data/homo_sapiens-chr20.vcz"
6060
field = "variant_AA"
6161

62+
# For simulated data where REF is the ancestral allele, use:
63+
# [ancestral_state]
64+
# path = "data/simulated.vcz"
65+
# is_reference = true
66+
6267

6368
# ============================================================================
6469
# Ancestors

tests/test_genotype_roundtrip.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,13 @@
3939
# ---------------------------------------------------------------------------
4040

4141

42-
def _anc_state(store):
43-
return config.AncestralState(path=store, field="variant_ancestral_allele")
44-
45-
4642
def _run_pipeline(sample_store):
47-
"""Build config, run full pipeline, return output tree sequence."""
43+
"""Build config, run full pipeline, return output tree sequence.
44+
45+
Uses ``is_reference=True`` so the REF allele (variant_allele[:, 0])
46+
is treated as the ancestral allele — the natural choice when input
47+
data originates from a tree sequence.
48+
"""
4849
src = config.Source(path=sample_store, name="test")
4950
anc_src = config.Source(path=None, name="ancestors", sample_time="sample_time")
5051
cfg = config.Config(
@@ -62,7 +63,7 @@ def _run_pipeline(sample_store):
6263
output="output.trees",
6364
),
6465
post_process=config.PostProcessConfig(),
65-
ancestral_state=_anc_state(sample_store),
66+
ancestral_state=config.AncestralState(path=sample_store, is_reference=True),
6667
)
6768
return pipeline.run(cfg)
6869

@@ -103,7 +104,7 @@ def _run_pipeline_with_augment(sample_store, augment_store=None, ann_store=None)
103104
),
104105
post_process=config.PostProcessConfig(),
105106
augment_sites=config.AugmentSitesConfig(sources=["augment"]),
106-
ancestral_state=_anc_state(ann_store),
107+
ancestral_state=config.AncestralState(path=ann_store, is_reference=True),
107108
)
108109
return pipeline.run(cfg)
109110

tests/test_pipeline.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,35 @@ def test_hand_constructed(self):
355355
assert out_ts.num_sites == 2
356356

357357

358+
# ---------------------------------------------------------------------------
359+
# TestIsReferenceConfig
360+
# ---------------------------------------------------------------------------
361+
362+
363+
class TestIsReferenceConfig:
364+
def test_is_reference_with_field_raises(self):
365+
"""is_reference=True with field set raises ValueError."""
366+
with pytest.raises(ValueError, match="field must not be set"):
367+
config.AncestralState(path="x.vcz", field="variant_AA", is_reference=True)
368+
369+
def test_neither_field_nor_is_reference_raises(self):
370+
"""Neither field nor is_reference raises ValueError."""
371+
with pytest.raises(ValueError, match="requires either"):
372+
config.AncestralState(path="x.vcz")
373+
374+
def test_valid_field(self):
375+
"""field without is_reference is valid."""
376+
a = config.AncestralState(path="x.vcz", field="variant_AA")
377+
assert a.field == "variant_AA"
378+
assert a.is_reference is None
379+
380+
def test_valid_is_reference(self):
381+
"""is_reference=True without field is valid."""
382+
a = config.AncestralState(path="x.vcz", is_reference=True)
383+
assert a.is_reference is True
384+
assert a.field is None
385+
386+
358387
# ---------------------------------------------------------------------------
359388
# TestNodeMetadata
360389
# ---------------------------------------------------------------------------

tsinfer/config.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,28 @@ def __post_init__(self):
113113

114114
@dataclasses.dataclass
115115
class AncestralState:
116-
"""Specifies where to read the ancestral allele for each variant position."""
116+
"""Specifies where to read the ancestral allele for each variant position.
117+
118+
Exactly one of *field* or *is_reference* must be set.
119+
120+
If *is_reference* is ``True``, the REF allele (``variant_allele[:, 0]``)
121+
from the store at *path* is used as the ancestral allele. Otherwise
122+
*field* names the array to read.
123+
"""
117124

118125
path: str | pathlib.Path
119-
field: str
126+
field: str | None = None
127+
is_reference: bool | None = None
128+
129+
def __post_init__(self):
130+
if self.is_reference is None and self.field is None:
131+
raise ValueError(
132+
"[ancestral_state] requires either 'field' or 'is_reference = true'"
133+
)
134+
if self.is_reference is True and self.field is not None:
135+
raise ValueError(
136+
"[ancestral_state] field must not be set when is_reference is true"
137+
)
120138

121139

122140
@dataclasses.dataclass
@@ -356,7 +374,7 @@ def from_toml(cls, path: str | pathlib.Path) -> Config:
356374
"sample_time",
357375
}
358376

359-
_KNOWN_ANCESTRAL_STATE_KEYS = {"path", "field"}
377+
_KNOWN_ANCESTRAL_STATE_KEYS = {"path", "field", "is_reference"}
360378

361379
_KNOWN_ANCESTORS_KEYS = {
362380
"name",
@@ -429,13 +447,13 @@ def _parse_ancestral_state(raw: dict) -> AncestralState:
429447
if entry is None:
430448
raise ValueError("Config must contain an [ancestral_state] section")
431449
_check_unknown_keys("ancestral_state", entry, _KNOWN_ANCESTRAL_STATE_KEYS)
432-
try:
433-
return AncestralState(
434-
path=_resolve_path(entry["path"]),
435-
field=entry["field"],
436-
)
437-
except KeyError as e:
438-
raise ValueError(f"[ancestral_state] missing required key: {e}") from e
450+
if "path" not in entry:
451+
raise ValueError("[ancestral_state] missing required key: 'path'")
452+
return AncestralState(
453+
path=_resolve_path(entry["path"]),
454+
field=entry.get("field"),
455+
is_reference=entry.get("is_reference"),
456+
)
439457

440458

441459
def _parse_one_ancestor(entry: dict) -> AncestorsConfig:

tsinfer/vcz.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2368,7 +2368,10 @@ def __init__(
23682368
# --- Ancestral state ---
23692369
ann_store = open_store(ancestral_state.path)
23702370
ann_positions = np.asarray(ann_store["variant_position"][:], dtype=np.int32)
2371-
ann_values = np.asarray(ann_store[ancestral_state.field][:])
2371+
if ancestral_state.is_reference:
2372+
ann_values = np.asarray(ann_store["variant_allele"][:, 0])
2373+
else:
2374+
ann_values = np.asarray(ann_store[ancestral_state.field][:])
23722375

23732376
# --- Unified site set (all numpy) ---
23742377
valid_per_source = []

0 commit comments

Comments
 (0)