Skip to content
Open
201 changes: 175 additions & 26 deletions vamb/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,12 +360,15 @@ class GeneralOptions:
@classmethod
def from_args(cls, args: argparse.Namespace):
return cls(
typeasserted(args.outdir, Path),
typeasserted(Path(args.outdir), Path),
typeasserted(args.nthreads, int),
typeasserted(args.seed, int),
typeasserted(args.cuda, bool),
)

def training_args_assertions(cls, args: argparse.Namespace):
None

def __init__(
self,
out_dir: Path,
Expand Down Expand Up @@ -807,7 +810,64 @@ def __init__(
self.comp = comp
self.abundance = abundance
self.clustering = clustering
self.output = output


class TrainingCommonOptions:
def __init__(
self, general: GeneralOptions, comp: CompositionPath, abundance: AbundancePath
):
self.general = general
self.comp = comp
self.abundance = abundance


class PartialTrainingOptions:
def __init__(
self,
general: GeneralOptions,
common: TrainingCommonOptions,
comp: CompositionPath,
abundance: AbundancePath,
min_contig_length: MinContigLength,
vae: VAEOptions,
outdir: Path,
):
self.general = general
self.common = common
self.comp = comp
self.abundance = abundance
self.min_contig_length = min_contig_length
self.vae = vae
self.outdir = outdir

@classmethod
def from_args(cls, args: argparse.Namespace):
general = GeneralOptions.from_args(args)
comp = CompositionPath(Path(args.composition_file))
abundance_path = Path(args.abundance_file)
min_contig_length = MinContigLength.from_args(args)
basic = BasicTrainingOptions.from_args_vae(args)
vae = VAEOptions.from_args(basic, args)
outdir = Path(args.outdir)

abundance = AbundanceOptions(
bampaths=None,
bamdir=None,
abundance_tsv=None,
abundancepath=abundance_path,
min_alignment_id=0.0,
min_contig_length=min_contig_length,
refcheck=False,
)
common = TrainingCommonOptions(general, comp, abundance)
return cls(general, common, comp, abundance, min_contig_length, vae, outdir)

def validate_comp_is_npz(self) -> Path:
if not isinstance(self.comp, CompositionPath):
raise TypeError(
"Training-only mode requires a CompositionPath (precomputed .npz)"
)
return self.comp


class BinDefaultOptions:
Expand Down Expand Up @@ -950,21 +1010,23 @@ def __init__(


def calc_tnf(
options: CompositionOptions | PartialCompositionOptions,
options: CompositionOptions | PartialCompositionOptions | PartialTrainingOptions,
outdir: Path,
binsplitter: Optional[vamb.vambtools.BinSplitter],
binsplitter: Optional[vamb.vambtools.BinSplitter] = None,
train_only: bool = False,
) -> vamb.parsecontigs.Composition:
begintime = time.time()
logger.info("Loading TNF")
logger.info(f"\tMinimum sequence length: {options.min_contig_length.n}")

path = options.path

if not train_only:
logger.info(f"\tMinimum sequence length: {options.min_contig_length.n}")
path = options.path
else:
path = options.validate_comp_is_npz()
if isinstance(path, CompositionPath):
logger.info(f'\tLoading composition from npz at: "{path.path}"')
composition = vamb.parsecontigs.Composition.load(path.path)
composition.filter_min_length(options.min_contig_length.n)
else:
elif not train_only:
assert isinstance(path, FASTAPath)
logger.info(f"\tLoading data from FASTA file {path.path}")
with vamb.vambtools.Reader(path.path) as file:
Expand All @@ -973,11 +1035,16 @@ def calc_tnf(
)
assert outdir is not None
composition.save(outdir.joinpath("composition.npz"))
else:
raise TypeError(
"In training-only mode, path must be a CompositionPath with a valid .npz file"
)

# Initialize binsplitter on the identifiers. Only done if we actually need to binsplit
# later.
if binsplitter is not None:
binsplitter.initialize(composition.metadata.identifiers)
if not train_only:
# Initialize binsplitter on the identifiers. Only done if we actually need to binsplit
# later.
if binsplitter is not None:
binsplitter.initialize(composition.metadata.identifiers)

if composition.nseqs < MINIMUM_SEQS:
err = (
Expand Down Expand Up @@ -1088,6 +1155,7 @@ def load_composition_and_abundance(
vamb_options.out_dir.path,
binsplitter,
)

abundance = calc_abundance(
abundance_options,
vamb_options.out_dir.path,
Expand All @@ -1097,6 +1165,20 @@ def load_composition_and_abundance(
return (composition, abundance)


def load_composition_and_abundance_train_only(
opt: PartialTrainingOptions,
) -> Tuple[vamb.parsecontigs.Composition, vamb.parsebam.Abundance]:
composition = calc_tnf(opt, opt.general.out_dir.path, train_only=True)

abundance = calc_abundance(
opt.abundance,
opt.general.out_dir.path,
composition.metadata,
opt.general.n_threads,
)
return (composition, abundance)


def load_markers(
options: MarkerOptions,
comp_metadata: vamb.parsecontigs.CompositionMetaData,
Expand Down Expand Up @@ -1405,13 +1487,9 @@ def run_partial_abundance(opt: PartialAbundanceOptions):
)


def run_bin_default(opt: BinDefaultOptions):
composition, abundance = load_composition_and_abundance(
vamb_options=opt.common.general,
comp_options=opt.common.comp,
abundance_options=opt.common.abundance,
binsplitter=opt.common.output.binsplitter,
)
def run_train_vae(
opt: BinDefaultOptions, composition: CompositionPath, abundance: AbundancePath
):
data_loader = vamb.encode.make_dataloader(
abundance.matrix,
composition.matrix,
Expand All @@ -1430,6 +1508,16 @@ def run_bin_default(opt: BinDefaultOptions):
del composition, abundance
assert comp_metadata.nseqs == len(latent)

logger.info(f"Saved latent.npz to {opt.outdir}")

return latent


def run_cluster_and_write_files(
latent, opt: BinDefaultOptions, composition: CompositionPath
):
comp_metadata = composition.metadata
assert comp_metadata.nseqs == len(latent)
cluster_and_write_files(
opt.common.clustering,
opt.common.output.binsplitter,
Expand All @@ -1442,6 +1530,17 @@ def run_bin_default(opt: BinDefaultOptions):
FastaOutput.try_from_common(opt.common),
None,
)


def run_bin_default(opt: BinDefaultOptions):
composition, abundance = load_composition_and_abundance(
vamb_options=opt.common.general,
comp_options=opt.common.comp,
abundance_options=opt.common.abundance,
binsplitter=opt.common.output.binsplitter,
)
latent = run_train_vae(opt, composition, abundance)
run_cluster_and_write_files(latent, opt, composition)
del latent


Expand Down Expand Up @@ -2017,6 +2116,30 @@ def add_abundance_arguments(subparser: argparse.ArgumentParser):
return subparser


def add_training_arguments(subparser: argparse.ArgumentParser):
trainingos = subparser.add_argument_group(title="Training options")
add_minlength(trainingos)
trainingos.add_argument("--print_test", type=str, help="Print test output")
trainingos.add_argument(
"-p",
dest="nthreads",
metavar="",
type=int,
default=DEFAULT_THREADS,
help="number of threads to use where customizable",
)
trainingos.add_argument(
"--seed",
metavar="",
type=int,
default=int.from_bytes(os.urandom(7), "little"),
help="Random seed (determinism not guaranteed)",
)
trainingos.add_argument(
"--cuda", help="Use GPU to train & cluster [False]", action="store_true"
)


def add_taxonomy_arguments(subparser: argparse.ArgumentParser, taxonomy_only=False):
taxonomys = subparser.add_argument_group(title="Taxonomy input")
taxonomys.add_argument(
Expand Down Expand Up @@ -2125,7 +2248,12 @@ def add_vae_arguments(subparser: argparse.ArgumentParser):
trainos = subparser.add_argument_group(title="Training options", description=None)

trainos.add_argument(
"-e", dest="nepochs", metavar="", type=int, default=300, help=argparse.SUPPRESS
"-e",
dest="nepochs",
metavar="",
type=int,
default=70,
help=argparse.SUPPRESS,
)
trainos.add_argument(
"-t",
Expand Down Expand Up @@ -2383,8 +2511,9 @@ def main():
""",
add_help=False,
)
add_help_arguments(vaevae_parserbin_parser)
subparsers_model = vaevae_parserbin_parser.add_subparsers(dest="model_subcommand")
subparsers_model = vaevae_parserbin_parser.add_subparsers(
dest="model_subcommand", required=True
)

vae_parser = subparsers_model.add_parser(
VAMB,
Expand All @@ -2393,7 +2522,7 @@ def main():
default binner based on a variational autoencoder.
See the paper 'Improved metagenome binning and assembly using deep variational autoencoders'""",
add_help=False,
usage="%(prog)s [options]",
# usage="%(prog)s [options]",
description="""Bin using a VAE that merges composition and abundance information.

Required arguments: Outdir, at least one composition input and at least one abundance input""",
Expand All @@ -2413,7 +2542,7 @@ def main():
taxonomy informed binner based on a bi-modal variational autoencoder.
See the paper 'TaxVAMB: taxonomic annotations improve metagenome binning'""",
add_help=False,
usage="%(prog)s [options]",
# usage="%(prog)s [options]",
description="""Bin using a semi-supervised VAEVAE model that merges composition, abundance and taxonomic information.

Required arguments: Outdir, taxonomy, at least one composition input and at least one abundance input""",
Expand All @@ -2432,7 +2561,7 @@ def main():
AVAMB,
help=argparse.SUPPRESS,
add_help=False,
usage="%(prog)s [options]",
# usage="%(prog)s [options]",
)
general_group = add_general_arguments(vaeaae_parser)
add_minlength(general_group)
Expand Down Expand Up @@ -2511,6 +2640,14 @@ def main():
add_composition_npz_argument(abundance_parser)
add_abundance_args_nonpz(abundance_parser)

train_parser = partial_part.add_parser(
"train", help="Do training without clustering", add_help=False
)
general_group = add_training_arguments(train_parser)
add_vae_arguments(train_parser)
train_parser.add_argument("--abundance_file", type=str, help="Input filename")
train_parser.add_argument("--composition_file", type=str, help="Input filename")
train_parser.add_argument("--outdir", type=str, help="Output directory")
args = parser.parse_args()

if args.subcommand == TAXOMETER:
Expand Down Expand Up @@ -2549,6 +2686,16 @@ def main():
opt = PartialAbundanceOptions.from_args(args)
runner = partial(run_partial_abundance, opt)
run(runner, opt.general)
elif args.partial_part == "train":
starting_time = time.time()
opt = PartialTrainingOptions.from_args(args)
os.makedirs(args.outdir, exist_ok=False)
composition, abundance = load_composition_and_abundance_train_only(opt)
run_train_vae(opt, composition, abundance)
ending_time = time.time()
elapsed = ending_time - starting_time
logger.info(f"Completed training in {elapsed:.2f} seconds")

else:
# TODO: Add abundance
# TODO: Add encoding w. VAE
Expand All @@ -2562,3 +2709,5 @@ def main():

if __name__ == "__main__":
main()