Skip to content

Commit

Permalink
hyperparameter estimation implemented
Browse files Browse the repository at this point in the history
  • Loading branch information
kcajj committed Nov 25, 2024
1 parent ecfa585 commit f1879f9
Show file tree
Hide file tree
Showing 4 changed files with 221 additions and 23 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
data/
__pycache__/
results/
results*
log/
.snakemake/
tmp/
81 changes: 59 additions & 22 deletions rules/plots.smk
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import os

plot_config = config["plots"]
input_replicates = ""
for replicate in HMM["replicates"]:
input_replicates += replicate + ","
input_timesteps = ""
for timestep in HMM["timesteps"]:
input_timesteps += timestep + ","
Expand All @@ -14,15 +17,15 @@ for replicate in HMM["replicates"]:

rule plot_coverage_dynamics:
input:
hybrid_ref=rules.hybrid_ref.output.hybrid_ref,
coverage_folder=directory(out_fld + "/coverage_arrays/{replicate}/"),
wait=rules.HMM_all.output.finish,
hybrid_ref = rules.hybrid_ref.output.hybrid_ref,
coverage_folder = directory(out_fld + "/coverage_arrays/{replicate}/"),
wait = rules.HMM_all.output.finish,
output:
plots=out_fld + "/plots/coverage_dynamics/coverage_{replicate}.pdf",
plots = out_fld + "/plots/coverage_dynamics/coverage_{replicate}.pdf",
params:
timesteps=input_timesteps,
references=input_references,
coverage_threshold=plot_config["coverage_threshold"],
timesteps = input_timesteps,
references = input_references,
coverage_threshold = plot_config["coverage_threshold"],
conda:
"../conda_envs/sci_py.yml"
shell:
Expand All @@ -39,15 +42,15 @@ rule plot_coverage_dynamics:

rule plot_recombination_dynamics:
input:
recombination_folder=directory(out_fld + "/genomewide_recombination/{replicate}/"),
coverage_folder=directory(out_fld + "/coverage_arrays/{replicate}/"),
wait=rules.HMM_all.output.finish,
recombination_folder = directory(out_fld + "/genomewide_recombination/{replicate}/"),
coverage_folder = directory(out_fld + "/coverage_arrays/{replicate}/"),
wait = rules.HMM_all.output.finish,
output:
plots=out_fld + "/plots/recombination_dynamics/recombination_{replicate}.pdf",
plots = out_fld + "/plots/recombination_dynamics/recombination_{replicate}.pdf",
params:
timesteps=input_timesteps,
references=input_references,
coverage_threshold=plot_config["coverage_threshold"],
timesteps = input_timesteps,
references = input_references,
coverage_threshold = plot_config["coverage_threshold"],
conda:
"../conda_envs/sci_py.yml"
shell:
Expand All @@ -64,16 +67,16 @@ rule plot_recombination_dynamics:

rule unique_plot:
input:
hybrid_ref=rules.hybrid_ref.output.hybrid_ref,
recombination_folder=directory(out_fld + "/genomewide_recombination/{replicate}/"),
coverage_folder=directory(out_fld + "/coverage_arrays/{replicate}/"),
wait=rules.HMM_all.output.finish,
hybrid_ref = rules.hybrid_ref.output.hybrid_ref,
recombination_folder = directory(out_fld + "/genomewide_recombination/{replicate}/"),
coverage_folder = directory(out_fld + "/coverage_arrays/{replicate}/"),
wait = rules.HMM_all.output.finish,
output:
plots=out_fld + "/plots/unique_plots/unique_{replicate}.pdf",
plots = out_fld + "/plots/unique_plots/unique_{replicate}.pdf",
params:
timesteps=input_timesteps,
references=input_references,
coverage_threshold=plot_config["coverage_threshold"],
timesteps = input_timesteps,
references = input_references,
coverage_threshold = plot_config["coverage_threshold"],
conda:
"../conda_envs/sci_py.yml"
shell:
Expand All @@ -88,11 +91,45 @@ rule unique_plot:
--out {output.plots}
"""

HMM_config = config["HMM_parameters"]

rule optimize_recombination_parameter:
input:
msa = rules.msa.output.msa,
evidences_folder = directory(out_fld + "/evidence_arrays/"),
wait=rules.HMM_all.output.finish,
output:
plot = out_fld + "/plots/parameter_optimization.pdf"
conda:
'../conda_envs/sci_py.yml'
params:
replicates = input_replicates,
timesteps = input_timesteps,
cores = HMM_config["cores"],
initial_probability = HMM_config["initial_probability"]["A"]+","+HMM_config["initial_probability"]["B"],
transition_probability = config["optimization_recombination_parameter"]["values"],
emission_probability = HMM_config["emission_probability"]["A"][0]+","+HMM_config["emission_probability"]["A"][1]+","+HMM_config["emission_probability"]["A"][2]+"/"+HMM_config["emission_probability"]["B"][0]+","+HMM_config["emission_probability"]["B"][1]+","+HMM_config["emission_probability"]["B"][2],
subsample = config["optimization_recombination_parameter"]["subsample"],
shell:
"""
python scripts/optimize_recombination_parameter.py \
--replicates {params.replicates} \
--timesteps {params.timesteps} \
--evidences {input.evidences_folder} \
--out {output.plot} \
--cores {params.cores} \
--initial_p {params.initial_probability} \
--transition_p {params.transition_probability} \
--emission_p {params.emission_probability} \
--subsample {params.subsample}
"""

rule plot_all:
input:
coverage_dynamics=expand(rules.plot_coverage_dynamics.output.plots, replicate=HMM["replicates"]),
recombination_dynamics=expand(rules.plot_recombination_dynamics.output.plots, replicate=HMM["replicates"]),
unique_plots=expand(rules.unique_plot.output.plots, replicate=HMM["replicates"]),
parameter_optimization=rules.optimize_recombination_parameter.output.plot,
finish=rules.HMM_all.output.finish,
shell:
"""
Expand Down
8 changes: 8 additions & 0 deletions run_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ run_config:
- 'D6'
- 'D8'
- 'D10'

alignments:
length_threshold: 5000

HMM_parameters:
cores: 20
initial_probability:
Expand All @@ -39,5 +41,11 @@ HMM_parameters:
0: "0.967"
1: "0.003"
2: "0.03"

optimization_recombination_parameter:
flag: true
values: '0.00001,0.00002,0.00003,0.00004,0.00005,0.00006,0.00008,0.0001' #comma separated values
subsample: 10000

plots:
coverage_threshold: 0
153 changes: 153 additions & 0 deletions scripts/optimize_recombination_parameter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import numpy as np
from viterbi import viterbi_algorithm
import csv
import sys
from multiprocessing import Pool
from itertools import repeat
from array_compression import compress_array, decompress_array, retrive_compressed_array_from_str
from collections import defaultdict
from matplotlib import pyplot as plt


def build_matrix(input):
matrix = []
rows = input.split("/")
for i in rows:
row = i.split(",")
float_row = [float(r) for r in row]
matrix.append(float_row)
return np.array(matrix)


def get_evidence_arrays(evidences_file):
csv.field_size_limit(sys.maxsize)

ancestral_names = []
evidence_arrays = []
mapping_starts = []
mapping_ends = []

c_reads = 0
with open(evidences_file) as file:
tsv_file = csv.reader(file, delimiter="\t")
for line in tsv_file:
ancestral_name = line[0]
mapping_start = int(line[1])
mapping_end = int(line[2])

compressed_evidence_array = retrive_compressed_array_from_str(line[3])
evidence_array = decompress_array(compressed_evidence_array)

ancestral_names.append(ancestral_name)
evidence_arrays.append(evidence_array)
mapping_starts.append(mapping_start)
mapping_ends.append(mapping_end)
c_reads += 1

return ancestral_names, evidence_arrays, mapping_starts, mapping_ends, c_reads


def write_prediction_arrays(output_path, results, read_names, mapping_starts, mapping_ends):
with open(output_path, "w", newline="") as tsvfile:
writer = csv.writer(tsvfile, delimiter="\t", lineterminator="\n")
for i in range(len(results)):
hmm_prediction = results[i][0]
log_lik = results[i][1]
read_name = read_names[i]
mapping_start = mapping_starts[i]
mapping_end = mapping_ends[i]

compressed_hmm_prediction = compress_array(hmm_prediction)

np.set_printoptions(threshold=np.inf, linewidth=np.inf)
writer.writerow([read_name, mapping_start, mapping_end, log_lik, compressed_hmm_prediction])


if __name__ == "__main__":

import argparse

parser = argparse.ArgumentParser(
description="Makes a prediction on each evidence array and summarises the information in a single recombination array",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--replicates", help="replicates names")
parser.add_argument("--timesteps", help="list of timesteps")
parser.add_argument("--evidences", help="path of the folder containing the evidence arrays")
parser.add_argument("--out", help="output path of the plot with the optimization")
parser.add_argument("--cores", help="number of cores to use", type=int)
parser.add_argument("--initial_p", help="initial probabilities of the HMM states")
parser.add_argument("--transition_p", help="transition probabilities of the HMM states")
parser.add_argument("--emission_p", help="emission probabilities of the HMM states")
parser.add_argument("--subsample", help="number of reads to subsample", type=int)

args = parser.parse_args()
replicates = args.replicates.split(",")[:-1]
timesteps = args.timesteps.split(",")[:-1]
evidences_folder = args.evidences
output_path = args.out
cores = args.cores
initial_probability = args.initial_p
transition_probabilities = [float(prob) for prob in args.transition_p.split(",")]
emission_probability = args.emission_p
subsample = args.subsample

initial_probability_matrix = build_matrix(initial_probability)
emission_probability_matrix = build_matrix(emission_probability)
transition_probabilities.sort()

log_liks = defaultdict(dict) # log likelihoods saved for each clone and population
"""
keys: probability of recombination
values: keys: population_clone
values: log likelihood
"""
for replicate in replicates:
for timestep in timesteps:

evidences_file = f"{evidences_folder}/{replicate}/{timestep}.tsv"

read_names, evidence_arrays, mapping_starts, mapping_ends, c_reads = get_evidence_arrays(evidences_file)

if subsample<c_reads:
idx = np.random.choice(c_reads, subsample, replace=False)
read_names = [read_names[i] for i in idx]
evidence_arrays = [evidence_arrays[i] for i in idx]
mapping_starts = [mapping_starts[i] for i in idx]
mapping_ends = [mapping_ends[i] for i in idx]

for prob in transition_probabilities:

transition_probability_matrix = np.array([[1 - prob, prob], [prob, 1 - prob]])

with Pool(cores) as p:
results = p.starmap(
viterbi_algorithm,
zip(
evidence_arrays,
repeat(transition_probability_matrix),
repeat(emission_probability_matrix),
repeat(initial_probability_matrix),
),
)

tot_log_lik = 0
for i in range(len(results)):
hmm_prediction = results[i][0]
log_lik = results[i][1]
tot_log_lik += log_lik

log_liks[prob][f"{replicate}_{timestep}"] = tot_log_lik

mean_log_liks = []
for prob, samples in log_liks.items():
prob_mean = []
for sample, log_lik in samples.items():
prob_mean.append(log_lik)
mean_log_liks.append(np.mean(prob_mean))

plt.plot(transition_probabilities, mean_log_liks)
plt.xlabel("Recombination probability")
plt.ylabel("log likelihood")
plt.title("Optimization of the recombination parameter by log likelihood maximisation")
plt.savefig(output_path, bbox_inches="tight")

0 comments on commit f1879f9

Please sign in to comment.