Skip to content
Open
173 changes: 157 additions & 16 deletions vamb/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ class GeneralOptions:
@classmethod
def from_args(cls, args: argparse.Namespace):
return cls(
typeasserted(args.outdir, Path),
typeasserted(Path(args.outdir), Path),
typeasserted(args.nthreads, int),
typeasserted(args.seed, int),
typeasserted(args.cuda, bool),
Expand Down Expand Up @@ -810,6 +810,64 @@ def __init__(
self.output = output


class TrainingCommonOptions:
def __init__(
self, general: GeneralOptions, comp: CompositionPath, abundance: AbundancePath
):
self.general = general
self.comp = comp
self.abundance = abundance


class PartialTrainingOptions:
def __init__(
self,
general: GeneralOptions,
common: TrainingCommonOptions,
comp: CompositionPath,
abundance: AbundancePath,
min_contig_length: MinContigLength,
vae: VAEOptions,
outdir: Path,
):
self.general = general
self.common = common
self.comp = comp
self.abundance = abundance
self.min_contig_length = min_contig_length
self.vae = vae
self.outdir = outdir

@classmethod
def from_args(cls, args: argparse.Namespace):
general = GeneralOptions.from_args(args)
comp = CompositionPath(Path(args.composition_file))
abundance_path = Path(args.abundance_file)
min_contig_length = MinContigLength.from_args(args)
basic = BasicTrainingOptions.from_args_vae(args)
vae = VAEOptions.from_args(basic, args)
outdir = Path(args.outdir)

abundance = AbundanceOptions(
bampaths=None,
bamdir=None,
abundance_tsv=None,
abundancepath=abundance_path,
min_alignment_id=0.0,
min_contig_length=min_contig_length,
refcheck=False,
)
common = TrainingCommonOptions(general, comp, abundance)
return cls(general, common, comp, abundance, min_contig_length, vae, outdir)

def validate_comp_is_npz(self) -> Path:
if not isinstance(self.comp, CompositionPath):
raise TypeError(
"Training-only mode requires a CompositionPath (precomputed .npz)"
)
return self.comp


class BinDefaultOptions:
@classmethod
def from_args(cls, args: argparse.Namespace):
Expand Down Expand Up @@ -1088,6 +1146,7 @@ def load_composition_and_abundance(
vamb_options.out_dir.path,
binsplitter,
)

abundance = calc_abundance(
abundance_options,
vamb_options.out_dir.path,
Expand Down Expand Up @@ -1405,13 +1464,9 @@ def run_partial_abundance(opt: PartialAbundanceOptions):
)


def run_bin_default(opt: BinDefaultOptions):
composition, abundance = load_composition_and_abundance(
vamb_options=opt.common.general,
comp_options=opt.common.comp,
abundance_options=opt.common.abundance,
binsplitter=opt.common.output.binsplitter,
)
def run_train_vae(
opt: BinDefaultOptions, composition: CompositionPath, abundance: AbundancePath
):
data_loader = vamb.encode.make_dataloader(
abundance.matrix,
composition.matrix,
Expand All @@ -1430,6 +1485,14 @@ def run_bin_default(opt: BinDefaultOptions):
del composition, abundance
assert comp_metadata.nseqs == len(latent)

return latent


def run_cluster_and_write_files(
latent, opt: BinDefaultOptions, composition: CompositionPath
):
comp_metadata = composition.metadata
assert comp_metadata.nseqs == len(latent)
cluster_and_write_files(
opt.common.clustering,
opt.common.output.binsplitter,
Expand All @@ -1442,7 +1505,21 @@ def run_bin_default(opt: BinDefaultOptions):
FastaOutput.try_from_common(opt.common),
None,
)
del latent


def load_train_bin(opt: BinDefaultOptions, partial_mode: str = "default", latent=None):
composition, abundance = load_composition_and_abundance(
vamb_options=opt.common.general,
comp_options=opt.common.comp,
abundance_options=opt.common.abundance,
binsplitter=opt.common.output.binsplitter,
)
if partial_mode == "default" or partial_mode == "train":
latent = run_train_vae(opt, composition, abundance)

if partial_mode == "default" or partial_mode == "cluster":
run_cluster_and_write_files(latent, opt, composition)
del latent


def run_bin_aae(opt: BinAvambOptions):
Expand Down Expand Up @@ -2125,7 +2202,12 @@ def add_vae_arguments(subparser: argparse.ArgumentParser):
trainos = subparser.add_argument_group(title="Training options", description=None)

trainos.add_argument(
"-e", dest="nepochs", metavar="", type=int, default=300, help=argparse.SUPPRESS
"-e",
dest="nepochs",
metavar="",
type=int,
default=70,
help=argparse.SUPPRESS,
)
trainos.add_argument(
"-t",
Expand Down Expand Up @@ -2206,6 +2288,20 @@ def add_predictor_arguments(subparser: argparse.ArgumentParser):
return subparser


def add_cluster_only_args(subparser: argparse.ArgumentParser):
c_only_arg = subparser.add_argument_group(
title="Clustering options", description=None
)
c_only_arg.add_argument(
"--latent",
dest="latent_file",
required=True,
metavar="",
type=lambda p: np.load(Path(p)),
help="Path to latent.npz file",
)


def add_clustering_arguments(subparser: argparse.ArgumentParser):
# Clustering arguments
clusto = subparser.add_argument_group(title="Clustering options", description=None)
Expand Down Expand Up @@ -2383,8 +2479,9 @@ def main():
""",
add_help=False,
)
add_help_arguments(vaevae_parserbin_parser)
subparsers_model = vaevae_parserbin_parser.add_subparsers(dest="model_subcommand")
subparsers_model = vaevae_parserbin_parser.add_subparsers(
dest="model_subcommand", required=True
)

vae_parser = subparsers_model.add_parser(
VAMB,
Expand All @@ -2393,7 +2490,7 @@ def main():
default binner based on a variational autoencoder.
See the paper 'Improved metagenome binning and assembly using deep variational autoencoders'""",
add_help=False,
usage="%(prog)s [options]",
# usage="%(prog)s [options]",
description="""Bin using a VAE that merges composition and abundance information.

Required arguments: Outdir, at least one composition input and at least one abundance input""",
Expand All @@ -2413,7 +2510,7 @@ def main():
taxonomy informed binner based on a bi-modal variational autoencoder.
See the paper 'TaxVAMB: taxonomic annotations improve metagenome binning'""",
add_help=False,
usage="%(prog)s [options]",
# usage="%(prog)s [options]",
description="""Bin using a semi-supervised VAEVAE model that merges composition, abundance and taxonomic information.

Required arguments: Outdir, taxonomy, at least one composition input and at least one abundance input""",
Expand All @@ -2432,7 +2529,7 @@ def main():
AVAMB,
help=argparse.SUPPRESS,
add_help=False,
usage="%(prog)s [options]",
# usage="%(prog)s [options]",
)
general_group = add_general_arguments(vaeaae_parser)
add_minlength(general_group)
Expand Down Expand Up @@ -2511,6 +2608,36 @@ def main():
add_composition_npz_argument(abundance_parser)
add_abundance_args_nonpz(abundance_parser)

train_parser = partial_part.add_parser(
"train", help="Do training without clustering", add_help=False
)

train_parser.set_defaults(model_subcommand=VAMB)

general_group = add_general_arguments(train_parser)
add_minlength(general_group)
add_composition_arguments(train_parser)
add_abundance_arguments(train_parser)
add_taxonomy_arguments(train_parser)
add_bin_output_arguments(train_parser)
add_vae_arguments(train_parser)
add_clustering_arguments(train_parser)

cluster_parser = partial_part.add_parser(
"cluster", help="Cluster after training", add_help=False
)
cluster_parser.set_defaults(model_subcommand=VAMB)

general_group = add_general_arguments(cluster_parser)
add_minlength(general_group)
add_composition_arguments(cluster_parser)
add_abundance_arguments(cluster_parser)
add_taxonomy_arguments(cluster_parser)
add_bin_output_arguments(cluster_parser)
add_vae_arguments(cluster_parser)
add_clustering_arguments(cluster_parser)
add_cluster_only_args(cluster_parser)

args = parser.parse_args()

if args.subcommand == TAXOMETER:
Expand All @@ -2524,7 +2651,7 @@ def main():
sys.exit(1)
if model == VAMB:
opt = BinDefaultOptions.from_args(args)
runner = partial(run_bin_default, opt)
runner = partial(load_train_bin, opt)
run(runner, opt.common.general)
elif model == TAXVAMB:
opt = BinTaxVambOptions.from_args(args)
Expand All @@ -2549,6 +2676,20 @@ def main():
opt = PartialAbundanceOptions.from_args(args)
runner = partial(run_partial_abundance, opt)
run(runner, opt.general)
elif args.partial_part == "train":
opt = BinDefaultOptions.from_args(args)
runner = partial(load_train_bin, opt, partial_mode="train")
run(runner, opt.common.general)
elif args.partial_part == "cluster":
opt = BinDefaultOptions.from_args(args)
runner = partial(
load_train_bin,
opt,
partial_mode="cluster",
latent=args.latent_file["arr_0"],
)
run(runner, opt.common.general)

else:
# TODO: Add abundance
# TODO: Add encoding w. VAE
Expand Down