Skip to content

Commit 60c22a9

Browse files
committed
Raise default threads for BAM parsing 8->32, BLAS 8->16
It was originally capped at 8 under the belief that reading 8 BAM files in parallel would saturate the disk, so there would be no benefit of going higher. However, my laptop can read at 4 GB/s, and decompress BAM files perhaps 40 times slower, so it's CPU bottlenecked even with 32 threads. This change is significant, because users have reported slow BAM file parsing. However, it will potentially quadruple the memory usage of the BAM parsing step. Will be benchmarked before merging. The BLAS change is simply because I think 8 CPUs is too conservative.
1 parent f37b5fd commit 60c22a9

6 files changed

Lines changed: 27 additions & 25 deletions

File tree

vamb/__main__.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@
2222
import pandas as pd
2323

2424
_ncpu = os.cpu_count()
25-
DEFAULT_THREADS = 8 if _ncpu is None else min(_ncpu, 8)
25+
DEFAULT_BLAS_THREADS = 16 if _ncpu is None else min(_ncpu, 16)
2626

2727
# These MUST be set before importing numpy
2828
# I know this is a shitty hack, see https://github.com/numpy/numpy/issues/11826
29-
os.environ["MKL_NUM_THREADS"] = str(DEFAULT_THREADS)
30-
os.environ["NUMEXPR_NUM_THREADS"] = str(DEFAULT_THREADS)
31-
os.environ["OMP_NUM_THREADS"] = str(DEFAULT_THREADS)
29+
os.environ["MKL_NUM_THREADS"] = str(DEFAULT_BLAS_THREADS)
30+
os.environ["NUMEXPR_NUM_THREADS"] = str(DEFAULT_BLAS_THREADS)
31+
os.environ["OMP_NUM_THREADS"] = str(DEFAULT_BLAS_THREADS)
3232

3333
# Append vamb to sys.path to allow vamb import even if vamb was not installed
3434
# using pip
@@ -771,9 +771,11 @@ def cluster_and_write_files(
771771
print(
772772
str(i + 1),
773773
None if cluster.radius is None else round(cluster.radius, 3),
774-
None
775-
if cluster.observed_pvr is None
776-
else round(cluster.observed_pvr, 2),
774+
(
775+
None
776+
if cluster.observed_pvr is None
777+
else round(cluster.observed_pvr, 2)
778+
),
777779
cluster.kind_str,
778780
sum(sequence_lens[i] for i in cluster.members),
779781
len(cluster.members),
@@ -1686,9 +1688,11 @@ def add_input_output_arguments(subparser):
16861688
dest="nthreads",
16871689
metavar="",
16881690
type=int,
1689-
default=DEFAULT_THREADS,
1691+
default=vamb.parsebam.DEFAULT_BAM_THREADS,
16901692
help=(
1691-
"number of threads to use " "[min(" + str(DEFAULT_THREADS) + ", nbamfiles)]"
1693+
"number of threads to read BAM files [min("
1694+
+ str(vamb.parsebam.DEFAULT_BAM_THREADS)
1695+
+ ", nbamfiles)]"
16921696
),
16931697
)
16941698
inputos.add_argument(

vamb/aamb_encode.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Adversarial autoencoders (AAE) for metagenomics binning, this files contains the implementation of the AAE"""
22

3-
43
import numpy as np
54
from math import log, isfinite
65
import time

vamb/parsebam.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414
from typing import Optional, TypeVar, Union, IO, Sequence, Iterable
1515
from pathlib import Path
1616
import shutil
17-
18-
_ncpu = _os.cpu_count()
19-
DEFAULT_THREADS = 8 if _ncpu is None else _ncpu
17+
import os
2018

2119
A = TypeVar("A", bound="Abundance")
2220

21+
_ncpu = os.cpu_count()
22+
DEFAULT_BAM_THREADS = 32 if _ncpu is None else min(_ncpu, 32)
23+
2324

2425
class Abundance:
2526
"Object representing contig abundance. Contains a matrix and refhash."
@@ -115,10 +116,10 @@ def from_files(
115116

116117
chunksize = min(nthreads, len(paths))
117118

118-
# We cap it to 16 threads, max. This will prevent pycoverm from consuming a huge amount
119+
# We cap it to DEFAULT_BAM_THREADS threads, max. This will prevent pycoverm from consuming a huge amount
119120
# of memory if given a crapload of threads, and most programs will probably be IO bound
120-
# when reading 16 files at a time.
121-
chunksize = min(chunksize, 16)
121+
# when reading DEFAULT_BAM_THREADS files at a time.
122+
chunksize = min(chunksize, DEFAULT_BAM_THREADS)
122123

123124
# If it can be done in memory, do so
124125
if chunksize >= len(paths):
@@ -134,7 +135,7 @@ def from_files(
134135
else:
135136
if cache_directory is None:
136137
raise ValueError(
137-
"If min(16, nthreads) < len(paths), cache_directory must not be None"
138+
"If min(DEFAULT_BAM_THREADS, nthreads) < len(paths), cache_directory must not be None"
138139
)
139140
return cls.chunkwise_loading(
140141
paths,

vamb/semisupervised_encode.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Semisupervised multimodal VAEs for metagenomics binning, this files contains the implementation of the VAEVAE for MMSEQ predictions"""
22

3-
43
__cmd_doc__ = """Encode depths and TNF using a VAE to latent representation"""
54

65
import numpy as _np

vamb/taxvamb_encode.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Hierarchical loss for the labels suggested in https://arxiv.org/abs/2210.10929"""
22

3-
43
__cmd_doc__ = """Hierarchical loss for the labels"""
54

65

workflow_avamb/src/rip_bins.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,9 @@ def remove_meaningless_edges_from_pairs(
183183
contig_length,
184184
)
185185
print("Cluster ripped because of a meaningless edge ", cluster_updated)
186-
clusters_changed_but_not_intersecting_contigs[
187-
cluster_updated
188-
] = cluster_contigs[cluster_updated]
186+
clusters_changed_but_not_intersecting_contigs[cluster_updated] = (
187+
cluster_contigs[cluster_updated]
188+
)
189189

190190
components: list[set[str]] = list()
191191
for component in nx.connected_components(graph_clusters):
@@ -295,9 +295,9 @@ def make_all_components_pair(
295295
contig_length,
296296
)
297297
print("Cluster ripped because of a pairing component ", cluster_updated)
298-
clusters_changed_but_not_intersecting_contigs[
299-
cluster_updated
300-
] = cluster_contigs[cluster_updated]
298+
clusters_changed_but_not_intersecting_contigs[cluster_updated] = (
299+
cluster_contigs[cluster_updated]
300+
)
301301
component_len = max(
302302
[
303303
len(nx.node_connected_component(graph_clusters, node_i))

0 commit comments

Comments
 (0)