Skip to content
Open
210 changes: 190 additions & 20 deletions vamb/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,12 +360,13 @@ class GeneralOptions:
@classmethod
def from_args(cls, args: argparse.Namespace):
return cls(
typeasserted(args.outdir, Path),
typeasserted(Path(args.outdir), Path),
typeasserted(args.nthreads, int),
typeasserted(args.seed, int),
typeasserted(args.cuda, bool),
)

def training_args_assertions(cls, args: argparse.Namespace):
None
def __init__(
self,
out_dir: Path,
Expand Down Expand Up @@ -793,7 +794,6 @@ def from_args(cls, args: argparse.Namespace):
ClusterOptions.from_args(args),
BinOutputOptions.from_args(comp, args),
)

# We do not have BasicTrainingOptions because that is model-specific
def __init__(
self,
Expand All @@ -809,6 +809,52 @@ def __init__(
self.clustering = clustering
self.output = output

class TrainingCommonOptions:
def __init__(self, general: GeneralOptions, comp: CompositionPath, abundance: AbundancePath):
self.general = general
self.comp = comp
self.abundance = abundance

class PartialTrainingOptions:
def __init__(
self,
general: GeneralOptions,
common: TrainingCommonOptions,
comp: CompositionPath,
abundance: AbundancePath,
min_contig_length: MinContigLength,
vae: VAEOptions
):
self.general = general
self.common = common
self.comp = comp
self.abundance = abundance
self.min_contig_length = min_contig_length
self.vae = vae
@classmethod
def from_args(cls, args: argparse.Namespace):
general = GeneralOptions.from_args(args)
comp = CompositionPath(Path(args.composition_file))
abundance_path = Path(args.abundance_file)
min_contig_length = MinContigLength.from_args(args)
basic = BasicTrainingOptions.from_args_vae(args)
vae = VAEOptions.from_args(basic, args)


abundance = AbundanceOptions(
bampaths=None,
bamdir=None,
abundance_tsv=None,
abundancepath=abundance_path,
min_alignment_id=0.0,
min_contig_length=min_contig_length,
refcheck=False
)

common = TrainingCommonOptions(general, comp, abundance)
return cls(general, common, comp, abundance, min_contig_length, vae)



class BinDefaultOptions:
@classmethod
Expand Down Expand Up @@ -948,6 +994,51 @@ def __init__(
self.output = output
self.algorithm = algorithm

def calc_tnf_train_only(
options: PartialTrainingOptions,
outdir: Path,
) -> vamb.parsecontigs.Composition:
begintime = time.time()
logger.info("Loading TNF")
#logger.info(f"\tMinimum sequence length: {options.min_contig_length.n}")

path = options.comp
if isinstance(path, CompositionPath):
logger.info(f'\tLoading composition from npz at: "{path.path}"')
composition = vamb.parsecontigs.Composition.load(path.path)
composition.filter_min_length(options.min_contig_length.n)
else:
raise TypeError("Training-only mode requires a CompositionPath (precomputed .npz)")
if composition.nseqs < MINIMUM_SEQS:
err = (
f"Found only {composition.nseqs} contigs, but Vamb currently requires at least "
f"{MINIMUM_SEQS} to work correctly. "
"If you have this few sequences in a metagenomic assembly, "
"it's probably an error somewhere in your workflow."
)
logger.error(err)
raise ValueError(err)

# Warn the user if any contigs have been observed, which is smaller
# than the threshold.
if not np.all(composition.metadata.mask):
n_removed = len(composition.metadata.mask) - np.sum(composition.metadata.mask)
message = (
f"The minimum sequence length has been set to {options.min_contig_length.n}, "
f"but {n_removed} sequences fell below this threshold and was filtered away."
"\nBetter results are obtained if the sequence file is filtered to the minimum "
"sequence length before mapping.\n"
)
logger.opt(raw=True).info("\n")
logger.warning(message)

elapsed = round(time.time() - begintime, 2)
logger.info(
f"\tKept {composition.count_bases()} bases in {composition.nseqs} sequences"
)
logger.info(f"\tProcessed TNF in {elapsed} seconds.\n")

return composition

def calc_tnf(
options: CompositionOptions | PartialCompositionOptions,
Expand Down Expand Up @@ -1076,7 +1167,6 @@ def calc_abundance(

return abundance


def load_composition_and_abundance(
vamb_options: GeneralOptions,
comp_options: CompositionOptions,
Expand All @@ -1088,6 +1178,7 @@ def load_composition_and_abundance(
vamb_options.out_dir.path,
binsplitter,
)

abundance = calc_abundance(
abundance_options,
vamb_options.out_dir.path,
Expand All @@ -1096,6 +1187,21 @@ def load_composition_and_abundance(
)
return (composition, abundance)

def load_composition_and_abundance_train_only(
opt: PartialTrainingOptions,
) -> Tuple[vamb.parsecontigs.Composition, vamb.parsebam.Abundance]:
composition = calc_tnf_train_only(
opt,
opt.general.out_dir.path
)

abundance = calc_abundance(
opt.abundance,
opt.general.out_dir.path,
composition.metadata,
opt.general.n_threads,
)
return (composition, abundance)

def load_markers(
options: MarkerOptions,
Expand Down Expand Up @@ -1403,15 +1509,7 @@ def run_partial_abundance(opt: PartialAbundanceOptions):
calc_abundance(
opt, opt.general.out_dir.path, composition.metadata, opt.general.n_threads
)


def run_bin_default(opt: BinDefaultOptions):
composition, abundance = load_composition_and_abundance(
vamb_options=opt.common.general,
comp_options=opt.common.comp,
abundance_options=opt.common.abundance,
binsplitter=opt.common.output.binsplitter,
)
def run_train_vae(opt: BinDefaultOptions, composition: CompositionPath, abundance: AbundancePath):
data_loader = vamb.encode.make_dataloader(
abundance.matrix,
composition.matrix,
Expand All @@ -1430,6 +1528,12 @@ def run_bin_default(opt: BinDefaultOptions):
del composition, abundance
assert comp_metadata.nseqs == len(latent)

return latent

def run_cluster_and_write_files(latent, opt: BinDefaultOptions, composition: CompositionPath):

comp_metadata = composition.metadata
assert comp_metadata.nseqs == len(latent)
cluster_and_write_files(
opt.common.clustering,
opt.common.output.binsplitter,
Expand All @@ -1442,6 +1546,15 @@ def run_bin_default(opt: BinDefaultOptions):
FastaOutput.try_from_common(opt.common),
None,
)
def run_bin_default(opt: BinDefaultOptions):
composition, abundance = load_composition_and_abundance(
vamb_options=opt.common.general,
comp_options=opt.common.comp,
abundance_options=opt.common.abundance,
binsplitter=opt.common.output.binsplitter,
)
latent = run_train_vae(opt, composition, abundance)
run_cluster_and_write_files(latent, opt, composition)
del latent


Expand Down Expand Up @@ -2016,6 +2129,32 @@ def add_abundance_arguments(subparser: argparse.ArgumentParser):
)
return subparser

def add_training_arguments(subparser: argparse.ArgumentParser):
trainingos = subparser.add_argument_group(title="Training options")
add_minlength(trainingos)
trainingos.add_argument(
"--print_test",
type=str,
help="Print test output"
)
trainingos.add_argument(
"-p",
dest="nthreads",
metavar="",
type=int,
default=DEFAULT_THREADS,
help="number of threads to use where customizable",
)
trainingos.add_argument(
"--seed",
metavar="",
type=int,
default=int.from_bytes(os.urandom(7), "little"),
help="Random seed (determinism not guaranteed)",
)
trainingos.add_argument(
"--cuda", help="Use GPU to train & cluster [False]", action="store_true"
)

def add_taxonomy_arguments(subparser: argparse.ArgumentParser, taxonomy_only=False):
taxonomys = subparser.add_argument_group(title="Taxonomy input")
Expand Down Expand Up @@ -2125,7 +2264,11 @@ def add_vae_arguments(subparser: argparse.ArgumentParser):
trainos = subparser.add_argument_group(title="Training options", description=None)

trainos.add_argument(
"-e", dest="nepochs", metavar="", type=int, default=300, help=argparse.SUPPRESS
"-e", dest="nepochs",
metavar="",
type=int,
default=70,
help=argparse.SUPPRESS,
)
trainos.add_argument(
"-t",
Expand Down Expand Up @@ -2383,8 +2526,7 @@ def main():
""",
add_help=False,
)
add_help_arguments(vaevae_parserbin_parser)
subparsers_model = vaevae_parserbin_parser.add_subparsers(dest="model_subcommand")
subparsers_model = vaevae_parserbin_parser.add_subparsers(dest="model_subcommand", required=True)

vae_parser = subparsers_model.add_parser(
VAMB,
Expand All @@ -2393,7 +2535,7 @@ def main():
default binner based on a variational autoencoder.
See the paper 'Improved metagenome binning and assembly using deep variational autoencoders'""",
add_help=False,
usage="%(prog)s [options]",
#usage="%(prog)s [options]",
description="""Bin using a VAE that merges composition and abundance information.

Required arguments: Outdir, at least one composition input and at least one abundance input""",
Expand All @@ -2413,7 +2555,7 @@ def main():
taxonomy informed binner based on a bi-modal variational autoencoder.
See the paper 'TaxVAMB: taxonomic annotations improve metagenome binning'""",
add_help=False,
usage="%(prog)s [options]",
#usage="%(prog)s [options]",
description="""Bin using a semi-supervised VAEVAE model that merges composition, abundance and taxonomic information.

Required arguments: Outdir, taxonomy, at least one composition input and at least one abundance input""",
Expand All @@ -2432,7 +2574,7 @@ def main():
AVAMB,
help=argparse.SUPPRESS,
add_help=False,
usage="%(prog)s [options]",
#usage="%(prog)s [options]",
)
general_group = add_general_arguments(vaeaae_parser)
add_minlength(general_group)
Expand Down Expand Up @@ -2511,9 +2653,26 @@ def main():
add_composition_npz_argument(abundance_parser)
add_abundance_args_nonpz(abundance_parser)

train_parser = partial_part.add_parser(
"train", help="Do training without clustering", add_help=False
)
general_group = add_training_arguments(train_parser)
train_parser.add_argument('--abundance_file', type=str, help='Input filename')
train_parser.add_argument('--composition_file', type=str, help='Input filename')
train_parser.add_argument('--outdir', type=str, help='Output directory')
train_parser.add_argument('--nepochs', type=int, default=70, help='Number of training epochs (default: 70)')
train_parser.add_argument('--batchsize', type=int, default=64, help='Batchsize')
train_parser.add_argument('--batchsteps', type=int, nargs='+', default=[], help='Epochs at which to update batch statistics (default: none)')
train_parser.add_argument('--nhiddens', type=int, nargs='*', default=None, help='List of hidden layer sizes for the VAE (e.g., --nhiddens 512 256)')
parser.add_argument('--nlatent', type=int, default=32, help='Size of the VAE latent space (default: 32)')
parser.add_argument('--alpha', type=float, default=None, help='Beta-VAE alpha parameter (optional, default: None)')
parser.add_argument('--beta', type=float, default=1.0, help='KL-divergence weight for Beta-VAE (default: 1.0)')
parser.add_argument('--dropout', type=float, default=None, help='Dropout rate for VAE layers (optional, default: None)')


args = parser.parse_args()

if args.subcommand == TAXOMETER:
if args.subcommand == TAXOMETER:
opt = TaxometerOptions.from_args(args)
runner = partial(run_taxonomy_predictor, opt)
run(runner, opt.general)
Expand Down Expand Up @@ -2549,6 +2708,17 @@ def main():
opt = PartialAbundanceOptions.from_args(args)
runner = partial(run_partial_abundance, opt)
run(runner, opt.general)
elif args.partial_part == "train":
starting_time = time.time()
opt = PartialTrainingOptions.from_args(args)
os.makedirs(args.outdir, exist_ok=False)
composition, abundance = load_composition_and_abundance_train_only(opt)
run_train_vae(opt, composition, abundance)
logger.info(f"Saved latent.npz to /{args.outdir}")
ending_time = time.time()
elapsed = ending_time - starting_time
logger.info(f"Completed training in {elapsed:.2f} seconds")

else:
# TODO: Add abundance
# TODO: Add encoding w. VAE
Expand Down