Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More new plots #18

Draft
wants to merge 1 commit into
base: develop
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 200 additions & 0 deletions genetools/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
import pandas as pd
import seaborn as sns

# from https://matplotlib.org/stable/gallery/color/named_colors.html
from matplotlib.patches import Rectangle
import matplotlib.colors as mcolors
import textwrap


def savefig(fig, *args, **kwargs):
"""
Expand Down Expand Up @@ -257,6 +262,201 @@ def horizontal_stacked_bar_plot(
return fig, ax


def make_axis_limits_consistent(axarr):
"""
make x and y axis limits match across array of axes. set limits to outermost extents encountered in any axes.
"""
left = min([ax.get_xlim()[0] for ax in axarr])
right = max([ax.get_xlim()[1] for ax in axarr])

bottom = min([ax.get_ylim()[0] for ax in axarr])
top = max([ax.get_ylim()[1] for ax in axarr])

for ax in axarr:
ax.set_xlim(left, right)
ax.set_ylim(bottom, top)

return axarr, (left, right), (bottom, top)


# def _verify_or_create_palette(palette, data, hue_key):
# """Verify that a particular discrete color palette has the right number of colors (subsetting if necessary).
# Or if no color palette was provided by the user, return a default color palette.

# :param palette: color palette for discrete hues if the user has supplied one, otherwise None
# :type palette: matplotlib palette name, list of colors, or dict mapping hue values to colors, or None
# :param data: dataframe containing observations and associated hues
# :type data: pandas.DataFrame
# :param hue_key: name of column in dataframe that lists hues
# :type hue_key: str
# :raises ValueError: if user-supplied palette has fewer colors than the number of hues in the data
# :return: a color palette for plotting
# :rtype: list of colors
# """
# n_colors = data[hue_key].nunique()

# if not palette:
# # create colors
# palette = sns.color_palette("Spectral", n_colors=n_colors)

# # confirm number of colors
# if len(palette) < n_colors:
# raise ValueError("Not enough colors in palette")

# # subset to exact number of colors we need (otherwise seaborn throws error)
# # TODO: make this work with matplotlib palette names or dicts mapping hue values to colors
# return palette[:n_colors]


def plot_colortable(colors, title, sort_colors=True, emptycols=0):

cell_width = 212
cell_height = 22
swatch_width = 48
margin = 12
topmargin = 40

# Sort colors by hue, saturation, value and name.
if sort_colors is True:
by_hsv = sorted(
(tuple(mcolors.rgb_to_hsv(mcolors.to_rgb(color))), name)
for name, color in colors.items()
)
names = [name for hsv, name in by_hsv]
else:
names = list(colors)

n = len(names)
ncols = 4 - emptycols
nrows = n // ncols + int(n % ncols > 0)

width = cell_width * 4 + 2 * margin
height = cell_height * nrows + margin + topmargin
dpi = 72

fig, ax = plt.subplots(figsize=(width / dpi, height / dpi), dpi=dpi)
fig.subplots_adjust(
margin / width,
margin / height,
(width - margin) / width,
(height - topmargin) / height,
)
ax.set_xlim(0, cell_width * 4)
ax.set_ylim(cell_height * (nrows - 0.5), -cell_height / 2.0)
ax.yaxis.set_visible(False)
ax.xaxis.set_visible(False)
ax.set_axis_off()
ax.set_title(title, fontsize=24, loc="left", pad=10)

for i, name in enumerate(names):
row = i % nrows
col = i // nrows
y = row * cell_height

swatch_start_x = cell_width * col
text_pos_x = cell_width * col + swatch_width + 7

ax.text(
text_pos_x,
y,
name,
fontsize=14,
horizontalalignment="left",
verticalalignment="center",
)

ax.add_patch(
Rectangle(
xy=(swatch_start_x, y - 9),
width=swatch_width,
height=18,
facecolor=colors[name],
edgecolor="0.7",
)
)

return fig


####


def make_confusion_matrix(
y_true, y_pred, true_label, pred_label, label_order=None
) -> pd.DataFrame:
# rows ground truth label - columns predicted label
cm = pd.crosstab(y_true, y_pred, rownames=[true_label], colnames=[pred_label])

# reorder so columns and index match
if label_order is None:
label_order = cm.index.union(cm.columns)

resulting_row_order, resulting_col_order = [
pd.Index(label_order).intersection(source_list).tolist()
for source_list in [cm.index, cm.columns]
]

cm = cm.loc[resulting_row_order][resulting_col_order]

return cm


def plot_confusion_matrix(
df, figsize=(6, 4), outside_borders=True, inside_border_width=0.5
):
# TODO: write test case with: df = pd.crosstab(pd.Series(['a', 'a', 'b', 'b', 'c']), pd.Series([1, 2, 3, 4, 1]))
with sns.axes_style("white"):
fig, ax = plt.subplots(figsize=figsize)
# add text with numeric values (annot=True), but without scientific notation (overriding fmt with "g" or "d")
sns.heatmap(
df, annot=True, fmt="g", cmap="Blues", ax=ax, linewidth=inside_border_width
)
plt.setp(ax.get_yticklabels(), rotation="horizontal", va="center")
plt.setp(ax.get_xticklabels(), rotation="horizontal")

if outside_borders:
# Activate outside borders
for _, spine in ax.spines.items():
spine.set_visible(True)

return fig, ax


def wrap_tick_labels(ax, wrap_x_axis=True, wrap_y_axis=True, wrap_amount=20):
"""Add text wrapping to tick labels on x and/or y axes."""
# TODO: Breaks when run on numerical, non-categorical axes. Fix.

def wrap_labels(labels):
for label in labels:
label.set_text("\n".join(textwrap.wrap(label.get_text(), wrap_amount)))
return labels

if wrap_x_axis:
# Wrap x-axis text labels
ax.set_xticklabels(wrap_labels(ax.get_xticklabels()))

if wrap_y_axis:
# Wrap y-axis text labels
ax.set_yticklabels(wrap_labels(ax.get_yticklabels()))

return ax


def add_sample_size_to_labels(labels, data, hue_key):
"""
Return new categorical labels with sample size for these groups added. We look up sample size of each label in data[hue_key] column.

Example usage:
ax.set_xticklabels(add_sample_size_to_labels(ax.get_xticklabels(), df, "Patient Type"))
"""

def _make_label(hue_value):
sample_size = data[data[hue_key] == hue_value].shape[0]
return f"{hue_value}\n($n={sample_size}$)"

return [_make_label(label.get_text()) for label in labels]


####


Expand Down