diff --git a/pyani/pyani_graphics/__init__.py b/pyani/pyani_graphics/__init__.py index 04c5c0f1..77c501b3 100644 --- a/pyani/pyani_graphics/__init__.py +++ b/pyani/pyani_graphics/__init__.py @@ -54,6 +54,8 @@ from . import mpl # noqa: F401 # matplotlib wrappers from . import sns # noqa: F401 # seaborn wrappers +from . import tree + # Specify matplotlib backend. This *must* be done before pyplot import, but # raises errors with flake8 etc. So we comment out the specific error matplotlib.use("Agg") diff --git a/pyani/pyani_graphics/mpl/__init__.py b/pyani/pyani_graphics/mpl/__init__.py index a2ae3137..6613a7f2 100644 --- a/pyani/pyani_graphics/mpl/__init__.py +++ b/pyani/pyani_graphics/mpl/__init__.py @@ -292,7 +292,7 @@ def add_colorscale(fig, heatmap_gs, ax_map, params, title=None): # Generate Matplotlib heatmap output -def heatmap(dfr, outfilename=None, title=None, params=None): +def heatmap(dfr, outfilename=None, title=None, params=None, format=None, args=None): """Return matplotlib heatmap with cluster dendrograms. :param dfr: pandas DataFrame with relevant data @@ -357,7 +357,13 @@ def heatmap(dfr, outfilename=None, title=None, params=None): heatmap_gs.tight_layout(fig, h_pad=0.1, w_pad=0.5) if outfilename: fig.savefig(outfilename) - return fig + + # Tree + newicks = None + if args.tree: + pass + + return fig, newicks def scatter( diff --git a/pyani/pyani_graphics/sns/__init__.py b/pyani/pyani_graphics/sns/__init__.py index fdcddea5..5990f95d 100644 --- a/pyani/pyani_graphics/sns/__init__.py +++ b/pyani/pyani_graphics/sns/__init__.py @@ -42,10 +42,19 @@ import matplotlib # pylint: disable=C0411 import pandas as pd import seaborn as sns +import logging +from pathlib import Path +from scipy.cluster import hierarchy +from ete3 import ClusterTree +from ete3 import Tree, TreeStyle, faces, AttrFace, PhyloTree matplotlib.use("Agg") import matplotlib.pyplot as plt # noqa: E402,E501 # pylint: disable=wrong-import-position,wrong-import-order,ungrouped-imports +LABEL_DICT = {} + +logger = logging.getLogger(__name__) + # Add classes colorbar to Seaborn plot def get_colorbar(dfr, classes): @@ -98,6 +107,28 @@ def add_labels(fig, params): return fig +def build_label_dict(fig, axis, params): + """Label info for tree plots. + + :param fig: a Seaborn clustermap instance + :param axis: one of {'row', 'col'} + :param params: plot parameters; this is where the labels come from + + """ + if axis == "col": + for idx, _ in zip( + fig.dendrogram_col.reordered_ind, fig.ax_heatmap.get_yticklabels() + ): + LABEL_DICT[str(idx + 1)] = params.labels.get(_, _.get_text()) + elif axis == "row": + for idx, _ in zip( + fig.dendrogram_row.reordered_ind, fig.ax_heatmap.get_xticklabels() + ): + LABEL_DICT[str(idx + 1)] = params.labels.get(_, _.get_text()) + logger.debug(f"{LABEL_DICT}") + return LABEL_DICT + + # Return a clustermap def get_clustermap(dfr, params, title=None, annot=True): """Return a Seaborn clustermap for the passed dataframe. @@ -151,7 +182,7 @@ def get_clustermap(dfr, params, title=None, annot=True): # Generate Seaborn heatmap output -def heatmap(dfr, outfilename=None, title=None, params=None): +def heatmap(dfr, outfilename=None, title=None, params=None, format=None, args=None): """Return seaborn heatmap with cluster dendrograms. :param dfr: pandas DataFrame with relevant data @@ -185,8 +216,13 @@ def heatmap(dfr, outfilename=None, title=None, params=None): if outfilename: fig.savefig(outfilename) + # Tree + newicks = None + if args.tree: + newicks = tree(dfr, fig, title, format, params, args) + # Return clustermap - return fig + return fig, newicks def distribution(dfr, outfilename, matname, title=None): @@ -284,3 +320,114 @@ def scatter( # Return clustermap return fig + + +def get_newick(node, parentdist, leaf_names, newick=""): + """Generates a newick formatted file from a tree, + using recursion to traverse it. + + :param node: a (portion of a) tree to be traversed + :param parentdist: distance from the parent node + :param leaf_names: lables that will be attached to the terminal nodes + :param newick: the current newick-formatted tree structure + + """ + # logger = logging.getLogger(__name__) + # logger.debug(f"{type(parentdist)}, {parentdist}") + # logger.debug(f"{type(node.dist)}, {node.dist}") + diff = parentdist - node.dist + if node.is_leaf(): + return f"{leaf_names[node.id]}:{diff:.2f}{newick}" + else: + if len(newick) > 0: + newick = f"):{diff:.2f}{newick}" + else: + newick = ");" + newick = get_newick(node.get_left(), node.dist, leaf_names, newick) + newick = get_newick(node.get_right(), node.dist, leaf_names, f",{newick}") + newick = f"({newick}" + return newick + + +def tree(dfr, fig, title, format, params, args): + """Generate a newick file and dendrogram plot for the given dataframe. + + :param dfr: a dataframe + :param fig: a figure produced by sns.clustermap + :param title: name of the matrix plot + :param format: image file format being used + :param params: matrix plot parameters; including labels + :param args: Namespace + + """ + logger = logging.getLogger(__name__) + + # Get matrix name and run_id from the plot title + matname, run_id = title.split("_", 1)[-1].rsplit("_", 1) + + # Dictionary to allow abstraction over axes + sides = { + "col": { + "axis": fig.dendrogram_col, + "names": dfr.columns, # fig.dendrogram_col.reordered_ind, + }, + "row": { + "axis": fig.dendrogram_row, + "names": dfr.index, # fig.dendrogram_row.reordered_ind, + }, + } + + # Create a linkage dendrogram and newick string for both rows and columns + newicks = {} + + for axis in sides.keys(): + # Generate newick format + tree = hierarchy.to_tree(sides[axis]["axis"].linkage, False) + logger.debug(f"Names: {sides[axis]['names']}") + newick = get_newick(tree, tree.dist, sides[axis]["names"], "") + newicks.update({f"[{axis}_newick_{matname}_{run_id}]": newick}) + + # Generate dendrogram + # if 'dendrogram' in args.tree: + # if args.tree: + build_label_dict(fig, axis, params) + # figtree = ClusterTree(newick, text_array=matrix) + figtree = PhyloTree(newick) + figtree.set_species_naming_function(get_species_name) + figtree_file = Path(args.outdir) / f"{axis}_tree_{matname}_{run_id}.{format}" + logger.debug(f"{figtree}") + figtree.render(str(figtree_file), layout=tree_layout) + + # Return the newick strings so we can save them in the database (eventually) + return newicks + + +def tree_layout(node): + + # Add taxonomy to nodes, and align to right + if node.is_leaf(): + # if node.name == "F962_00589": + # faces.add_face_to_node( + # AttrFace("name", fgcolor="white"), + # node, + # column=0, + # position="branch-right", + # ) + # faces.add_face_to_node( + # AttrFace("species", fgcolor="white"), node, column=0, position="aligned" + # ) + # node.img_style["bgcolor"] == "darkred" + # else: + + faces.add_face_to_node( + AttrFace("name", fgcolor="black"), + node, + column=0, + position="branch-right", + ) + faces.add_face_to_node(AttrFace("species"), node, column=0, position="aligned") + + +def get_species_name(node_name_string): + """Return `Genus species` (where known) for a node.""" + return LABEL_DICT[node_name_string] diff --git a/pyani/pyani_graphics/tree/__init__.py b/pyani/pyani_graphics/tree/__init__.py new file mode 100644 index 00000000..b25f9854 --- /dev/null +++ b/pyani/pyani_graphics/tree/__init__.py @@ -0,0 +1,166 @@ +import logging +from pyani import pyani_graphics +from scipy.cluster import hierarchy +from ete3 import ClusterTree, Tree, TreeStyle, faces, AttrFace, PhyloTree +from pathlib import Path +import sys +import seaborn as sns + +LABEL_DICT = {} + + +def build_label_dict(fig, axis, params): + """Label info for tree plots. + + :param fig: a Seaborn clustermap instance + :param axis: one of {'row', 'col'} + :param params: plot parameters; this is where the labels come from + + """ + logger = logging.getLogger(__name__) + if axis == "col": + for idx, _ in zip( + fig.dendrogram_col.reordered_ind, fig.ax_heatmap.get_yticklabels() + ): + LABEL_DICT[str(idx + 1)] = params.labels.get(_, _.get_text()) + elif axis == "row": + for idx, _ in zip( + fig.dendrogram_row.reordered_ind, fig.ax_heatmap.get_xticklabels() + ): + LABEL_DICT[str(idx + 1)] = params.labels.get(_, _.get_text()) + logger.debug(f"Label dict: {LABEL_DICT}") + return LABEL_DICT + + +def get_newick(node, parentdist, leaf_names, newick=""): + """Generates a newick formatted file from a tree, + using recursion to traverse it. + + :param node: a (portion of a) tree to be traversed + :param parentdist: distance from the parent node + :param leaf_names: lables that will be attached to the terminal nodes + :param newick: the current newick-formatted tree structure + + """ + # logger = logging.getLogger(__name__) + # logger.debug(f"{type(parentdist)}, {parentdist}") + # logger.debug(f"{type(node.dist)}, {node.dist}") + diff = parentdist - node.dist + if node.is_leaf(): + return f"{leaf_names[node.id]}:{diff:.2f}{newick}" + else: + if len(newick) > 0: + newick = f"):{diff:.2f}{newick}" + else: + newick = ");" + newick = get_newick(node.get_left(), node.dist, leaf_names, newick) + newick = get_newick(node.get_right(), node.dist, leaf_names, f",{newick}") + newick = f"({newick}" + return newick + + +def tree(dfr, outfname, title, params, format, args): + """Generate a newick file and dendrogram plot for the given dataframe. + + :param dfr: a dataframe + # :param fig: a figure produced by sns.clustermap + :param title: name of the matrix plot + :param format: image file format being used + :param params: matrix plot parameters; including labels + :param args: Namespace + + """ + logger = logging.getLogger(__name__) + + # Get matrix name and run_id from the plot title + matname, run_id = title.split("_", 1)[-1].rsplit("_", 1) + + maxfigsize = 120 + calcfigsize = dfr.shape[0] * 1.1 + figsize = min(max(8, calcfigsize), maxfigsize) + if figsize == maxfigsize: + scale = maxfigsize / calcfigsize + sns.set_context("notebook", font_scale=scale) + + # Add a colorbar? + if params.classes is None: + col_cb = None + else: + col_cb = pyani_graphics.sns.get_colorbar(dfr, params.classes) + + params.colorbar = col_cb + params.figsize = figsize + params.linewidths = 0.25 + + fig = pyani_graphics.sns.get_clustermap(dfr, params) + + # Dictionary to allow abstraction over axes + sides = { + "columns": { + "axis": fig.dendrogram_col, + "names": dfr.columns, # fig.dendrogram_col.reordered_ind, + }, + "rows": { + "axis": fig.dendrogram_row, + "names": dfr.index, # fig.dendrogram_row.reordered_ind, + }, + } + + # Create a linkage dendrogram and newick string for both rows and columns + newicks = {} + + for axis in args.axes: + # Generate newick format + tree = hierarchy.to_tree(sides[axis]["axis"].linkage, False) + logger.debug(f"Names: {sides[axis]['names']}") + + newick = get_newick(tree, tree.dist, sides[axis]["names"], "") + newicks.update({f"[{axis}_newick_{matname}_{run_id}]": newick}) + + # Generate dendrogram + # if 'dendrogram' in args.tree: + # if args.tree: + build_label_dict(fig, axis, params) + sys.stderr.write(f"Label dict: {LABEL_DICT}\n") + # figtree = ClusterTree(newick, text_array=matrix) + figtree = PhyloTree(newick) + figtree.set_species_naming_function(get_species_name) + figtree_file = Path(args.outdir) / f"{axis}_tree_{matname}_{run_id}.{format}" + logger.debug(f"{figtree}") + + # Write the tree to file + figtree.render(str(figtree_file), layout=tree_layout) + + # Return the newick strings so we can save them in the database (eventually) + return newicks + + +def tree_layout(node): + + # Add taxonomy to nodes, and align to right + if node.is_leaf(): + # if node.name == "F962_00589": + # faces.add_face_to_node( + # AttrFace("name", fgcolor="white"), + # node, + # column=0, + # position="branch-right", + # ) + # faces.add_face_to_node( + # AttrFace("species", fgcolor="white"), node, column=0, position="aligned" + # ) + # node.img_style["bgcolor"] == "darkred" + # else: + + faces.add_face_to_node( + AttrFace("name", fgcolor="black"), + node, + column=0, + position="branch-right", + ) + faces.add_face_to_node(AttrFace("species"), node, column=0, position="aligned") + + +def get_species_name(node_name_string): + """Return `Genus species` (where known) for a node.""" + return LABEL_DICT[node_name_string] diff --git a/pyani/scripts/average_nucleotide_identity.py b/pyani/scripts/average_nucleotide_identity.py index 0fc0dcbe..9f028e95 100755 --- a/pyani/scripts/average_nucleotide_identity.py +++ b/pyani/scripts/average_nucleotide_identity.py @@ -819,11 +819,11 @@ def draw(args: Namespace, filestems: List[str], gformat: str) -> None: ) if args.gmethod == "mpl": pyani_graphics.mpl.heatmap( - dfm, outfilename=outfilename, title=filestem, params=params + dfm, outfilename=outfilename, title=filestem, params=params, args=args ) elif args.gmethod == "seaborn": pyani_graphics.sns.heatmap( - dfm, outfilename=outfilename, title=filestem, params=params + dfm, outfilename=outfilename, title=filestem, params=params, args=args ) diff --git a/pyani/scripts/parsers/__init__.py b/pyani/scripts/parsers/__init__.py index e1fc790c..e6eb8a3d 100644 --- a/pyani/scripts/parsers/__init__.py +++ b/pyani/scripts/parsers/__init__.py @@ -58,6 +58,7 @@ common_parser, run_common_parser, listdeps_parser, + tree_parser, ) @@ -133,6 +134,7 @@ def parse_cmdline(argv: Optional[List] = None) -> Namespace: ) report_parser.build(subparsers, parents=[parser_common]) plot_parser.build(subparsers, parents=[parser_common]) + tree_parser.build(subparsers, parents=[parser_common]) classify_parser.build(subparsers, parents=[parser_common]) listdeps_parser.build(subparsers, parents=[parser_common]) diff --git a/pyani/scripts/parsers/plot_parser.py b/pyani/scripts/parsers/plot_parser.py index b35f8e01..b31136e8 100644 --- a/pyani/scripts/parsers/plot_parser.py +++ b/pyani/scripts/parsers/plot_parser.py @@ -120,4 +120,12 @@ def build( help="Number of worker processes for multiprocessing " "(default zero, meaning use all available cores)", ) + parser.add_argument( + "--tree", + dest="tree", + action="store_true", + default=False, + help="tree formats to generate", + # choices=["newick", "dendrogram"] + ) parser.set_defaults(func=subcommands.subcmd_plot) diff --git a/pyani/scripts/parsers/tree_parser.py b/pyani/scripts/parsers/tree_parser.py new file mode 100644 index 00000000..99242f91 --- /dev/null +++ b/pyani/scripts/parsers/tree_parser.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- +# (c) The James Hutton Institute 2016-2019 +# (c) University of Strathclyde 2019-2020 +# Author: Leighton Pritchard +# +# Contact: +# leighton.pritchard@strath.ac.uk +# +# Leighton Pritchard, +# Strathclyde Institute for Pharmacy and Biomedical Sciences, +# Cathedral Street, +# Glasgow, +# G4 0RE +# Scotland, +# UK +# +# The MIT License +# +# Copyright (c) 2016-2019 The James Hutton Institute +# Copyright (c) 2019-2020 University of Strathclyde +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +"""Provides parser for plot subcommand.""" + +from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, _SubParsersAction +from pathlib import Path +from typing import List, Optional + +from pyani.scripts import subcommands + + +def get_tree_list(tree_string: str): + possible_trees = { + "i": "identity", + "c": "coverage", + "a": "aln_lengths", + "s": "sim_errors", + "h": "hadamard", + } + return [possible_trees[_] for _ in tree_string] + + +def get_axes_list(axes_string: str): + axes = {"c": "columns", "r": "rows"} + return [axes[_] for _ in axes_string] + + +def build( + subps: _SubParsersAction, parents: Optional[List[ArgumentParser]] = None +) -> None: + """Return a command-line parser for the plot subcommand. + + :param subps: collection of subparsers in main parser + :param parents: parsers from which arguments are inherited + + The plot subcommand takes specific arguments: + + --method (graphics method to use) + """ + parser = subps.add_parser( + "tree", parents=parents, formatter_class=ArgumentDefaultsHelpFormatter + ) + # Required arguments: output directory and run ID + parser.add_argument( + "-o", + "--outdir", + action="store", + dest="outdir", + default=None, + type=Path, + help="output directory", + required=True, + ) + parser.add_argument( + "--run_ids", + action="store", + dest="run_ids", + default=None, + metavar="RUN_ID", + nargs="+", + help="run IDs to plot", + required=True, + ) + # Other optional arguments + parser.add_argument( + "--dbpath", + action="store", + dest="dbpath", + default=Path(".pyani/pyanidb"), + type=Path, + help="path to pyani database", + ) + # Graphics methods and formats + parser.add_argument( + "--formats", + dest="formats", + action="store", + default=["png"], + metavar="FORMAT", + nargs="+", + choices=["pdf", "png", "svg", "jpg"], + help="graphics output format; options: (pdf, png, svg, jpg)", + ) + parser.add_argument( + "--method", + dest="method", + action="store", + default="ete3", # "seaborn", + metavar="METHOD", + nargs=1, + choices=["ete3"], # ["seaborn", "mpl", "plotly"], + help="graphics method to use for plotting; options (ete3)", # "(seaborn, mpl, plotly)", + ) + parser.add_argument( + "--workers", + dest="workers", + action="store", + default=None, + type=int, + help="Number of worker processes for multiprocessing " + "(default zero, meaning use all available cores)", + ) + parser.add_argument( + "--trees", + dest="trees", + # action="store_true", + # default=False, + type=get_tree_list, + metavar="TREES", + help="A string (such as: icash, cah, shi) specifying which trees to generate, made up of their initials: {'i': 'identity', 'c': 'coverage', 'a': 'aln_length', 's': 'sim_errors', 'h': 'hadamard'}", + ) + parser.add_argument( + "--axes", + dest="axes", + default="cr", + type=get_axes_list, + metavar="AXES", + help="A string indicating which axes to plot. One of (c, r, cr); c = columns, r = rows, cr = both", + ) + parser.set_defaults(func=subcommands.subcmd_tree) diff --git a/pyani/scripts/subcommands/__init__.py b/pyani/scripts/subcommands/__init__.py index 7e5e92ef..f301a8a2 100644 --- a/pyani/scripts/subcommands/__init__.py +++ b/pyani/scripts/subcommands/__init__.py @@ -48,4 +48,5 @@ from .subcmd_listdeps import subcmd_listdeps from .subcmd_plot import subcmd_plot from .subcmd_report import subcmd_report +from .subcmd_tree import subcmd_tree from .subcmd_fastani import subcmd_fastani diff --git a/pyani/scripts/subcommands/subcmd_plot.py b/pyani/scripts/subcommands/subcmd_plot.py index 4e3f1ab0..93bd84b1 100644 --- a/pyani/scripts/subcommands/subcmd_plot.py +++ b/pyani/scripts/subcommands/subcmd_plot.py @@ -56,12 +56,15 @@ # Distribution dictionary of matrix graphics methods GMETHODS = {"mpl": pyani_graphics.mpl.heatmap, "seaborn": pyani_graphics.sns.heatmap} SMETHODS = {"mpl": pyani_graphics.mpl.scatter, "seaborn": pyani_graphics.sns.scatter} +# TMETHODS = {"seaborn": pyani_graphics.seaborn.} # Distribution dictionary of distribution graphics methods DISTMETHODS = { "mpl": pyani_graphics.mpl.distribution, "seaborn": pyani_graphics.sns.distribution, } +NEWICKS = {} + def subcmd_plot(args: Namespace) -> int: """Produce graphical output for an analysis. @@ -94,6 +97,10 @@ def subcmd_plot(args: Namespace) -> int: for run_id in run_ids: write_run_plots(run_id, session, outfmts, args) + if NEWICKS: + write_newicks(args, run_id) + NEWICKS.clear() + return 0 @@ -163,8 +170,8 @@ def write_run_plots(run_id: int, session, outfmts: List[str], args: Namespace) - # Run the plotting commands logger.debug("Running plotting commands") for func, options in plotting_commands: - logger.debug("Running %s with options %s", func, options) - pool.apply_async(func, args=options) + result = pool.apply_async(func, options, {}, callback=logger.debug) + result.get() # Close worker pool pool.close() @@ -187,7 +194,7 @@ def write_distribution( for fmt in outfmts: outfname = Path(args.outdir) / f"distribution_{matdata.name}_run{run_id}.{fmt}" logger.debug("\tWriting graphics to %s", outfname) - DISTMETHODS[args.method[0]]( + DISTMETHODS[args.method]( matdata.data, outfname, matdata.name, @@ -220,19 +227,28 @@ def write_heatmap( logger.info("Writing %s matrix heatmaps", matdata.name) cmap = pyani_config.get_colormap(matdata.data, matdata.name) for fmt in outfmts: - outfname = Path(args.outdir) / f"matrix_{matdata.name}_run{run_id}.{fmt}" + outfname = ( + Path(args.outdir) / f"matrix_{matdata.name}_run{run_id}_{args.method}.{fmt}" + ) logger.debug("\tWriting graphics to %s", outfname) params = pyani_graphics.Params(cmap, result_labels, result_classes) # Draw heatmap - GMETHODS[args.method[0]]( + _, newicks = GMETHODS[args.method]( matdata.data, outfname, title=f"matrix_{matdata.name}_run{run_id}", params=params, + format=fmt, + args=args, ) + # If Newick strings were generated, add them to NEWICKS. + if newicks: + NEWICKS.update(newicks) + # Be tidy with matplotlib caches plt.close("all") + return def write_scatter( @@ -266,7 +282,7 @@ def write_scatter( logger.debug("\tWriting graphics to %s", outfname) params = pyani_graphics.Params(cmap, result_labels, result_classes) # Draw scatterplot - SMETHODS[args.method[0]]( + SMETHODS[args.method]( matdata1.data, matdata2.data, outfname, @@ -278,3 +294,11 @@ def write_scatter( # Be tidy with matplotlib caches plt.close("all") + + +def write_newicks(args: Namespace, run_id): + # If Newick strings were generated, write them out. + newick_file = Path(args.outdir) / f"newicks_run{run_id}.nw" + with open(newick_file, "w") as nfh: + for name, nw in NEWICKS.items(): + nfh.write(f"{name}\t{nw}\n") diff --git a/pyani/scripts/subcommands/subcmd_tree.py b/pyani/scripts/subcommands/subcmd_tree.py new file mode 100644 index 00000000..8b86bdb6 --- /dev/null +++ b/pyani/scripts/subcommands/subcmd_tree.py @@ -0,0 +1,159 @@ +import logging +import os +import sys +import multiprocessing + +from argparse import Namespace +from pathlib import Path +from typing import Dict, List +import pandas as pd + +from pyani import pyani_config, pyani_orm, pyani_graphics +from pyani.pyani_tools import termcolor, MatrixData + +# TREEMETHODS = {} +TREEMETHODS = {"ete3": pyani_graphics.tree.tree} + +NEWICKS = {} + + +def subcmd_tree(args: Namespace) -> int: + """Produce tree output for an analysis. + + :param args: Namespace of command-line arguments + + This is graphical output for representing the ANI analysis results, and + takes the form of a tree, or dendrogram. + """ + logger = logging.getLogger(__name__) + + # Announce what's going on to the user + logger.info(termcolor("Generating tree output for analyses", "red")) + logger.info("Writing output to: %s", args.outdir) + os.makedirs(args.outdir, exist_ok=True) + logger.info("Rendering method: %s", args.method) + + # Connect to database session + logger.debug("Activating session for database: %s", args.dbpath) + session = pyani_orm.get_session(args.dbpath) + + # Parse output formats + outfmts = args.formats + logger.debug("Requested output formats: %s", outfmts) + logger.debug("Type of formats variable: %s", type(outfmts)) + + # Work on each run: + run_ids = args.run_ids + logger.debug("Generating trees for runs: %s", run_ids) + for run_id in run_ids: + write_run_trees(run_id, session, outfmts, args) + + if NEWICKS: + write_newicks(args, run_id) + NEWICKS.clear() + + return 0 + + +def write_run_trees( + run_id: int, + session, + outfmts: List[str], + args: Namespace, +) -> None: + """Write tree plots for each matrix type. + + :param run_id: int, run_id for this run + :param matdata: MatrixData object for this distribution plot + :param args: Namespace for command-line arguments + :param outfmts: list of output formats for files + """ + logger = logging.getLogger(__name__) + logger.debug("Retrieving results matrices for run %s", run_id) + + results = ( + session.query(pyani_orm.Run).filter(pyani_orm.Run.run_id == run_id).first() + ) + result_label_dict = pyani_orm.get_matrix_labels_for_run(session, run_id) + result_class_dict = pyani_orm.get_matrix_classes_for_run(session, run_id) + logger.debug( + f"Have {len(result_label_dict)} labels and {len(result_class_dict)} classes" + ) + + # Create worker pool and empty command list + pool = multiprocessing.Pool(processes=args.workers) + plotting_commands = [] + + # Build and collect the plotting commands + for matdata in [ + MatrixData(*_) + for _ in [ + ("identity", pd.read_json(results.df_identity), {}), + ("coverage", pd.read_json(results.df_coverage), {}), + ("aln_lengths", pd.read_json(results.df_alnlength), {}), + ("sim_errors", pd.read_json(results.df_simerrors), {}), + ("hadamard", pd.read_json(results.df_hadamard), {}), + ] + if _[0] in args.trees + ]: + logger.info("Writing tree plot for %s matrix", matdata.name) + plotting_commands.append( + ( + write_tree, + [run_id, matdata, result_label_dict, result_class_dict, outfmts, args], + ) + ) + + sys.stdout.write(str(plotting_commands)) + + # Run the plotting commands + for func, options in plotting_commands: + result = pool.apply_async(func, options, {}, callback=logger.debug) + result.get() + + # Close worker pool + pool.close() + pool.join() + + +def write_tree( + run_id: int, + matdata: MatrixData, + result_labels: Dict, + result_classes: Dict, + outfmts: List[str], + args: Namespace, +) -> None: + """Write a single tree for a pyani run. + + :param run_id: int, run_id for this run + :param matdata: MatrixData object for this heatmap + :param result_labels: dict of result labels + :param result_classes: dict of result classes + :param args: Namespace for command-line arguments + :param outfmts: list of output formats for files + """ + # logger = logging.getLogger(__name__) + cmap = pyani_config.get_colormap(matdata.data, matdata.name) + + for fmt in outfmts: + outfname = Path(args.outdir) / f"distribution_{matdata.name}_run{run_id}.{fmt}" + + params = pyani_graphics.Params(cmap, result_labels, result_classes) + + TREEMETHODS[args.method]( + matdata.data, + outfname, + title=f"matrix_{matdata.name}_run{run_id}", + params=params, + format=fmt, + args=args, + ) + + +def write_newicks(args: Namespace, run_id): + # If Newick strings were generated, write them out. + newick_file = Path(args.outdir) / f"newicks_run{run_id}.nw" + with open(newick_file, "w") as nfh: + for name, nw in NEWICKS.items(): + nfh.write(f"{name}\t{nw}\n") diff --git a/requirements-pip.txt b/requirements-pip.txt index 5467acdc..93098c89 100644 --- a/requirements-pip.txt +++ b/requirements-pip.txt @@ -1,2 +1,3 @@ pytest-ordering sphinx-rtd-theme +PyQt5 diff --git a/tests/conftest.py b/tests/conftest.py index 4b88f323..14769b09 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -46,6 +46,7 @@ from pathlib import Path from typing import NamedTuple +from argparse import Namespace import pytest @@ -476,6 +477,12 @@ def path_fna_all(dir_seq): return [_ for _ in dir_seq.iterdir() if _.is_file() and _.suffix == ".fna"] +@pytest.fixture +def plot_namespace_no_tree(): + """A namespace for plotting tests with no trees.""" + return Namespace(tree=False) + + @pytest.fixture(autouse=True) def skip_by_unavailable_executable( request, blastall_available, blastn_available, nucmer_available diff --git a/tests/test_graphics.py b/tests/test_graphics.py index f229a443..3cb3f557 100644 --- a/tests/test_graphics.py +++ b/tests/test_graphics.py @@ -51,6 +51,7 @@ from pathlib import Path from typing import Dict, NamedTuple +from argparse import Namespace import pytest @@ -65,15 +66,17 @@ class GraphicsTestInputs(NamedTuple): filename: Path labels: Dict[str, str] classes: Dict[str, str] + args: Namespace @pytest.fixture -def graphics_inputs(dir_graphics_in): +def graphics_inputs(dir_graphics_in, plot_namespace_no_tree): """Returns namedtuple of graphics inputs.""" return GraphicsTestInputs( dir_graphics_in / "ANIm_percentage_identity.tab", get_labels(dir_graphics_in / "labels.tab"), get_labels(dir_graphics_in / "classes.tab"), + plot_namespace_no_tree, ) @@ -89,7 +92,11 @@ def draw_format_method(fmt, mth, graphics_inputs, tmp_path): graphics_inputs.classes, ) fn[mth]( - df, tmp_path / f"{mth}.{fmt}", title=f"{mth}:{fmt} test", params=method_params + df, + tmp_path / f"{mth}.{fmt}", + title=f"{mth}:{fmt} test", + params=method_params, + args=graphics_inputs.args, ) sc[mth]( df, diff --git a/tests/test_legacy_scripts.py b/tests/test_legacy_scripts.py index 72966f60..32ab5a7a 100644 --- a/tests/test_legacy_scripts.py +++ b/tests/test_legacy_scripts.py @@ -102,6 +102,7 @@ def legacy_ani_namespace(path_fixtures_base, tmp_path): subsample=None, seed=None, jobprefix="ANI", + tree=False, ) diff --git a/tests/test_subcmd_06_plot.py b/tests/test_subcmd_06_plot.py index d7652d26..c3311283 100644 --- a/tests/test_subcmd_06_plot.py +++ b/tests/test_subcmd_06_plot.py @@ -79,65 +79,73 @@ def setUp(self): outdir=self.outdir / "mpl", run_ids=self.run_id, dbpath=self.dbpath, - formats="pdf", + formats=["pdf"], method="mpl", workers=None, + tree=False, ), "mpl_png": Namespace( outdir=self.outdir / "mpl", run_ids=self.run_id, dbpath=self.dbpath, - formats="png", + formats=["png"], method="mpl", workers=None, + tree=False, ), "mpl_svg": Namespace( outdir=self.outdir / "mpl", run_ids=self.run_id, dbpath=self.dbpath, - formats="svg", + formats=["svg"], method="mpl", workers=None, + tree=False, ), "mpl_jpg": Namespace( outdir=self.outdir / "mpl", run_ids=self.run_id, dbpath=self.dbpath, - formats="jpg", + formats=["jpg"], method="mpl", workers=None, + tree=False, ), "seaborn_pdf": Namespace( outdir=self.outdir / "seaborn", run_ids=self.run_id, dbpath=self.dbpath, - formats="pdf", + formats=["pdf"], method="seaborn", workers=None, + tree=False, ), "seaborn_png": Namespace( outdir=self.outdir / "seaborn", run_ids=self.run_id, dbpath=self.dbpath, - formats="png", + formats=["png"], method="seaborn", workers=None, + tree=False, ), "seaborn_svg": Namespace( outdir=self.outdir / "seaborn", run_ids=self.run_id, dbpath=self.dbpath, - formats="svg", + formats=["svg"], method="seaborn", workers=None, + tree=False, ), "seaborn_jpg": Namespace( outdir=self.outdir / "seaborn", run_ids=self.run_id, dbpath=self.dbpath, - formats="jpg", + formats=["jpg"], method="seaborn", workers=None, + tree=False, ), }