diff --git a/scikitplot/metrics.py b/scikitplot/metrics.py index c82675d..017d3a7 100644 --- a/scikitplot/metrics.py +++ b/scikitplot/metrics.py @@ -22,6 +22,7 @@ from sklearn.metrics import silhouette_score from sklearn.metrics import silhouette_samples from sklearn.calibration import calibration_curve +from sklearn.utils import deprecated from scipy import interp @@ -31,7 +32,7 @@ def plot_confusion_matrix(y_true, y_pred, labels=None, true_labels=None, pred_labels=None, title=None, normalize=False, - hide_zeros=False, x_tick_rotation=0, ax=None, + hide_zeros=False, hide_counts=False, x_tick_rotation=0, ax=None, figsize=None, cmap='Blues', title_fontsize="large", text_fontsize="medium"): """Generates confusion matrix plot from predictions and true labels @@ -64,6 +65,9 @@ def plot_confusion_matrix(y_true, y_pred, labels=None, true_labels=None, hide_zeros (bool, optional): If True, does not plot cells containing a value of zero. Defaults to False. + hide_counts (bool, optional): If True, doe not overlay counts. + Defaults to False. + x_tick_rotation (int, optional): Rotates x-axis tick labels by the specified angle. This is useful in cases where there are numerous categories and the labels overlap each other. @@ -159,21 +163,25 @@ def plot_confusion_matrix(y_true, y_pred, labels=None, true_labels=None, ax.set_yticklabels(true_classes, fontsize=text_fontsize) thresh = cm.max() / 2. - for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): - if not (hide_zeros and cm[i, j] == 0): - ax.text(j, i, cm[i, j], - horizontalalignment="center", - verticalalignment="center", - fontsize=text_fontsize, - color="white" if cm[i, j] > thresh else "black") + + if not hide_counts: + for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): + if not (hide_zeros and cm[i, j] == 0): + ax.text(j, i, cm[i, j], + horizontalalignment="center", + verticalalignment="center", + fontsize=text_fontsize, + color="white" if cm[i, j] > thresh else "black") ax.set_ylabel('True label', fontsize=text_fontsize) ax.set_xlabel('Predicted label', fontsize=text_fontsize) - ax.grid('off') + ax.grid(False) return ax +@deprecated('This will be removed in v0.5.0. Please use ' + 'scikitplot.metrics.plot_roc instead.') def plot_roc_curve(y_true, y_probas, title='ROC Curves', curves=('micro', 'macro', 'each_class'), ax=None, figsize=None, cmap='nipy_spectral', @@ -321,9 +329,144 @@ def plot_roc_curve(y_true, y_probas, title='ROC Curves', return ax +def plot_roc(y_true, y_probas, title='ROC Curves', + plot_micro=True, plot_macro=True, classes_to_plot=None, + ax=None, figsize=None, cmap='nipy_spectral', + title_fontsize="large", text_fontsize="medium", digits=2): + """Generates the ROC curves from labels and predicted scores/probabilities + + Args: + y_true (array-like, shape (n_samples)): + Ground truth (correct) target values. + + y_probas (array-like, shape (n_samples, n_classes)): + Prediction probabilities for each class returned by a classifier. + + title (string, optional): Title of the generated plot. Defaults to + "ROC Curves". + + plot_micro (boolean, optional): Plot the micro average ROC curve. + Defaults to ``True``. + + plot_macro (boolean, optional): Plot the macro average ROC curve. + Defaults to ``True``. + + classes_to_plot (list-like, optional): Classes for which the ROC + curve should be plotted. e.g. [0, 'cold']. If given class does not exist, + it will be ignored. If ``None``, all classes will be plotted. Defaults to + ``None`` + + ax (:class:`matplotlib.axes.Axes`, optional): The axes upon which to + plot the curve. If None, the plot is drawn on a new set of axes. + + figsize (2-tuple, optional): Tuple denoting figure size of the plot + e.g. (6, 6). Defaults to ``None``. + + cmap (string or :class:`matplotlib.colors.Colormap` instance, optional): + Colormap used for plotting the projection. View Matplotlib Colormap + documentation for available options. + https://matplotlib.org/users/colormaps.html + + title_fontsize (string or int, optional): Matplotlib-style fontsizes. + Use e.g. "small", "medium", "large" or integer-values. Defaults to + "large". + + text_fontsize (string or int, optional): Matplotlib-style fontsizes. + Use e.g. "small", "medium", "large" or integer-values. Defaults to + "medium". + + digits (int, optional): Number of digits for formatting output floating point values. + Use e.g. 2 or 4. Defaults to 2. + + Returns: + ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was + drawn. + + Example: + >>> import scikitplot as skplt + >>> nb = GaussianNB() + >>> nb = nb.fit(X_train, y_train) + >>> y_probas = nb.predict_proba(X_test) + >>> skplt.metrics.plot_roc(y_test, y_probas) + + >>> plt.show() + + .. image:: _static/examples/plot_roc_curve.png + :align: center + :alt: ROC Curves + """ + y_true = np.array(y_true) + y_probas = np.array(y_probas) + + classes = np.unique(y_true) + probas = y_probas + + if classes_to_plot is None: + classes_to_plot = classes + + if ax is None: + fig, ax = plt.subplots(1, 1, figsize=figsize) + + ax.set_title(title, fontsize=title_fontsize) + + fpr_dict = dict() + tpr_dict = dict() + + indices_to_plot = np.in1d(classes, classes_to_plot) + for i, to_plot in enumerate(indices_to_plot): + fpr_dict[i], tpr_dict[i], _ = roc_curve(y_true, probas[:, i], + pos_label=classes[i]) + if to_plot: + roc_auc = auc(fpr_dict[i], tpr_dict[i]) + color = plt.cm.get_cmap(cmap)(float(i) / len(classes)) + ax.plot(fpr_dict[i], tpr_dict[i], lw=2, color=color, + label='ROC curve of class {0} (area = {1:.{digits}f})' + ''.format(classes[i], roc_auc, digits=digits)) + + if plot_micro: + binarized_y_true = label_binarize(y_true, classes=classes) + if len(classes) == 2: + binarized_y_true = np.hstack( + (1 - binarized_y_true, binarized_y_true)) + fpr, tpr, _ = roc_curve(binarized_y_true.ravel(), probas.ravel()) + roc_auc = auc(fpr, tpr) + ax.plot(fpr, tpr, + label='micro-average ROC curve ' + '(area = {0:.{digits}f})'.format(roc_auc, digits=digits), + color='deeppink', linestyle=':', linewidth=4) + + if plot_macro: + # Compute macro-average ROC curve and ROC area + # First aggregate all false positive rates + all_fpr = np.unique(np.concatenate([fpr_dict[x] for x in range(len(classes))])) + + # Then interpolate all ROC curves at this points + mean_tpr = np.zeros_like(all_fpr) + for i in range(len(classes)): + mean_tpr += interp(all_fpr, fpr_dict[i], tpr_dict[i]) + + # Finally average it and compute AUC + mean_tpr /= len(classes) + roc_auc = auc(all_fpr, mean_tpr) + + ax.plot(all_fpr, mean_tpr, + label='macro-average ROC curve ' + '(area = {0:.{digits}f})'.format(roc_auc, digits=digits), + color='navy', linestyle=':', linewidth=4) + + ax.plot([0, 1], [0, 1], 'k--', lw=2) + ax.set_xlim([0.0, 1.0]) + ax.set_ylim([0.0, 1.05]) + ax.set_xlabel('False Positive Rate', fontsize=text_fontsize) + ax.set_ylabel('True Positive Rate', fontsize=text_fontsize) + ax.tick_params(labelsize=text_fontsize) + ax.legend(loc='lower right', fontsize=text_fontsize) + return ax + + def plot_ks_statistic(y_true, y_probas, title='KS Statistic Plot', ax=None, figsize=None, title_fontsize="large", - text_fontsize="medium"): + text_fontsize="medium", digits=2): """Generates the KS Statistic plot from labels and scores/probabilities Args: @@ -351,6 +494,9 @@ def plot_ks_statistic(y_true, y_probas, title='KS Statistic Plot', Use e.g. "small", "medium", "large" or integer-values. Defaults to "medium". + digits (int, optional): Number of digits for formatting output floating point values. + Use e.g. 2 or 4. Defaults to 2. + Returns: ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was drawn. @@ -391,8 +537,8 @@ def plot_ks_statistic(y_true, y_probas, title='KS Statistic Plot', ax.plot(thresholds, pct2, lw=3, label='Class {}'.format(classes[1])) idx = np.where(thresholds == max_distance_at)[0][0] ax.axvline(max_distance_at, *sorted([pct1[idx], pct2[idx]]), - label='KS Statistic: {:.3f} at {:.3f}'.format(ks_statistic, - max_distance_at), + label='KS Statistic: {:.{digits}f} at {:.{digits}f}'.format(ks_statistic, + max_distance_at, digits=digits), linestyle=':', lw=3, color='black') ax.set_xlim([0.0, 1.0]) @@ -406,6 +552,8 @@ def plot_ks_statistic(y_true, y_probas, title='KS Statistic Plot', return ax +@deprecated('This will be removed in v0.5.0. Please use ' + 'scikitplot.metrics.plot_precision_recall instead.') def plot_precision_recall_curve(y_true, y_probas, title='Precision-Recall Curve', curves=('micro', 'each_class'), ax=None, @@ -531,10 +679,132 @@ def plot_precision_recall_curve(y_true, y_probas, return ax +def plot_precision_recall(y_true, y_probas, + title='Precision-Recall Curve', + plot_micro=True, + classes_to_plot=None, ax=None, + figsize=None, cmap='nipy_spectral', + title_fontsize="large", + text_fontsize="medium", + digits=2): + """Generates the Precision Recall Curve from labels and probabilities + + Args: + y_true (array-like, shape (n_samples)): + Ground truth (correct) target values. + + y_probas (array-like, shape (n_samples, n_classes)): + Prediction probabilities for each class returned by a classifier. + + title (string, optional): Title of the generated plot. Defaults to + "Precision-Recall curve". + + plot_micro (boolean, optional): Plot the micro average ROC curve. + Defaults to ``True``. + + classes_to_plot (list-like, optional): Classes for which the precision-recall + curve should be plotted. e.g. [0, 'cold']. If given class does not exist, + it will be ignored. If ``None``, all classes will be plotted. Defaults to + ``None``. + + ax (:class:`matplotlib.axes.Axes`, optional): The axes upon which to + plot the curve. If None, the plot is drawn on a new set of axes. + + figsize (2-tuple, optional): Tuple denoting figure size of the plot + e.g. (6, 6). Defaults to ``None``. + + cmap (string or :class:`matplotlib.colors.Colormap` instance, optional): + Colormap used for plotting the projection. View Matplotlib Colormap + documentation for available options. + https://matplotlib.org/users/colormaps.html + + title_fontsize (string or int, optional): Matplotlib-style fontsizes. + Use e.g. "small", "medium", "large" or integer-values. Defaults to + "large". + + text_fontsize (string or int, optional): Matplotlib-style fontsizes. + Use e.g. "small", "medium", "large" or integer-values. Defaults to + "medium". + + digits (int, optional): Number of digits for formatting output floating point values. + Use e.g. 2 or 4. Defaults to 2. + + Returns: + ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was + drawn. + + Example: + >>> import scikitplot as skplt + >>> nb = GaussianNB() + >>> nb.fit(X_train, y_train) + >>> y_probas = nb.predict_proba(X_test) + >>> skplt.metrics.plot_precision_recall(y_test, y_probas) + + >>> plt.show() + + .. image:: _static/examples/plot_precision_recall_curve.png + :align: center + :alt: Precision Recall Curve + """ + y_true = np.array(y_true) + y_probas = np.array(y_probas) + + classes = np.unique(y_true) + probas = y_probas + + if classes_to_plot is None: + classes_to_plot = classes + + binarized_y_true = label_binarize(y_true, classes=classes) + if len(classes) == 2: + binarized_y_true = np.hstack( + (1 - binarized_y_true, binarized_y_true)) + + if ax is None: + fig, ax = plt.subplots(1, 1, figsize=figsize) + + ax.set_title(title, fontsize=title_fontsize) + + indices_to_plot = np.in1d(classes, classes_to_plot) + for i, to_plot in enumerate(indices_to_plot): + if to_plot: + average_precision = average_precision_score( + binarized_y_true[:, i], + probas[:, i]) + precision, recall, _ = precision_recall_curve( + y_true, probas[:, i], pos_label=classes[i]) + color = plt.cm.get_cmap(cmap)(float(i) / len(classes)) + ax.plot(recall, precision, lw=2, + label='Precision-recall curve of class {0} ' + '(area = {1:.{digits}f})'.format(classes[i], + average_precision, + digits=digits), + color=color) + + if plot_micro: + precision, recall, _ = precision_recall_curve( + binarized_y_true.ravel(), probas.ravel()) + average_precision = average_precision_score(binarized_y_true, + probas, + average='micro') + ax.plot(recall, precision, + label='micro-average Precision-recall curve ' + '(area = {0:.{digits}f})'.format(average_precision, digits=digits), + color='navy', linestyle=':', linewidth=4) + + ax.set_xlim([0.0, 1.0]) + ax.set_ylim([0.0, 1.05]) + ax.set_xlabel('Recall') + ax.set_ylabel('Precision') + ax.tick_params(labelsize=text_fontsize) + ax.legend(loc='best', fontsize=text_fontsize) + return ax + + def plot_silhouette(X, cluster_labels, title='Silhouette Analysis', metric='euclidean', copy=True, ax=None, figsize=None, cmap='nipy_spectral', title_fontsize="large", - text_fontsize="medium"): + text_fontsize="medium", digits=2): """Plots silhouette analysis of clusters provided. Args: @@ -576,6 +846,9 @@ def plot_silhouette(X, cluster_labels, title='Silhouette Analysis', Use e.g. "small", "medium", "large" or integer-values. Defaults to "medium". + digits (int, optional): Number of digits for formatting output floating point values. + Use e.g. 2 or 4. Defaults to 2. + Returns: ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was drawn. @@ -638,7 +911,7 @@ def plot_silhouette(X, cluster_labels, title='Silhouette Analysis', y_lower = y_upper + 10 ax.axvline(x=silhouette_avg, color="red", linestyle="--", - label='Silhouette score: {0:0.3f}'.format(silhouette_avg)) + label='Silhouette score: {0:.{digits}f}'.format(silhouette_avg, digits=2)) ax.set_yticks([]) # Clear the y-axis labels / ticks ax.set_xticks(np.arange(-0.1, 1.0, 0.2)) @@ -784,7 +1057,8 @@ def plot_calibration_curve(y_true, probas_list, clf_names=None, n_bins=10, def plot_cumulative_gain(y_true, y_probas, title='Cumulative Gains Curve', - ax=None, figsize=None, title_fontsize="large", + classes_to_plot=None, plot_micro=True, plot_macro=True, + ax=None, figsize=None, title_fontsize="large", cmap='nipy_spectral', text_fontsize="medium"): """Generates the Cumulative Gains Plot from labels and scores/probabilities @@ -803,6 +1077,17 @@ def plot_cumulative_gain(y_true, y_probas, title='Cumulative Gains Curve', title (string, optional): Title of the generated plot. Defaults to "Cumulative Gains Curve". + classes_to_plot (list-like, optional): Classes for which the Cumulative Gain + curve should be plotted. e.g. [0, 'cold']. If given class does not exist, + it will be ignored. If ``None``, all classes will be plotted. Defaults to + ``None`` + + plot_micro (boolean, optional): Plot the micro average ROC curve. + Defaults to ``True``. + + plot_macro (boolean, optional): Plot the macro average ROC curve. + Defaults to ``True``. + ax (:class:`matplotlib.axes.Axes`, optional): The axes upon which to plot the learning curve. If None, the plot is drawn on a new set of axes. @@ -814,6 +1099,11 @@ def plot_cumulative_gain(y_true, y_probas, title='Cumulative Gains Curve', Use e.g. "small", "medium", "large" or integer-values. Defaults to "large". + cmap (string or :class:`matplotlib.colors.Colormap` instance, optional): + Colormap used for plotting the projection. View Matplotlib Colormap + documentation for available options. + https://matplotlib.org/users/colormaps.html + text_fontsize (string or int, optional): Matplotlib-style fontsizes. Use e.g. "small", "medium", "large" or integer-values. Defaults to "medium". @@ -839,28 +1129,56 @@ def plot_cumulative_gain(y_true, y_probas, title='Cumulative Gains Curve', y_probas = np.array(y_probas) classes = np.unique(y_true) - if len(classes) != 2: - raise ValueError('Cannot calculate Cumulative Gains for data with ' - '{} category/ies'.format(len(classes))) - # Compute Cumulative Gain Curves - percentages, gains1 = cumulative_gain_curve(y_true, y_probas[:, 0], - classes[0]) - percentages, gains2 = cumulative_gain_curve(y_true, y_probas[:, 1], - classes[1]) + if classes_to_plot is None: + classes_to_plot = classes if ax is None: fig, ax = plt.subplots(1, 1, figsize=figsize) ax.set_title(title, fontsize=title_fontsize) - ax.plot(percentages, gains1, lw=3, label='Class {}'.format(classes[0])) - ax.plot(percentages, gains2, lw=3, label='Class {}'.format(classes[1])) + perc_dict = dict() + gain_dict = dict() + + indices_to_plot = np.isin(classes, classes_to_plot) + # Loop for all classes to get different class gain + for i, to_plot in enumerate(indices_to_plot): + perc_dict[i], gain_dict[i] = cumulative_gain_curve(y_true, y_probas[:, i], pos_label=classes[i]) + + if to_plot: + color = plt.cm.get_cmap(cmap)(float(i) / len(classes)) + ax.plot(perc_dict[i], gain_dict[i], lw=2, color=color, + label='Class {} Cumulative Gain curve'.format(classes[i])) + + # Whether or to plot macro or micro + if plot_micro: + binarized_y_true = label_binarize(y_true, classes=classes) + if len(classes) == 2: + binarized_y_true = np.hstack((1 - binarized_y_true, binarized_y_true)) + + perc, gain = cumulative_gain_curve(binarized_y_true.ravel(), y_probas.ravel()) + ax.plot(perc, gain, label='micro-average Cumulative Gain curve', + color='deeppink', linestyle=':', linewidth=4) + + if plot_macro: + # First aggregate all percentages + all_perc = np.unique(np.concatenate([perc_dict[x] for x in range(len(classes))])) + + # Then interpolate all cumulative gain + mean_gain = np.zeros_like(all_perc) + for i in range(len(classes)): + mean_gain += interp(all_perc, perc_dict[i], gain_dict[i]) + + mean_gain /= len(classes) + + ax.plot(all_perc, mean_gain, label='macro-average Cumulative Gain curve', + color='navy', linestyle=':', linewidth=4) ax.set_xlim([0.0, 1.0]) ax.set_ylim([0.0, 1.0]) - ax.plot([0, 1], [0, 1], 'k--', lw=2, label='Baseline') + ax.plot([0, 1], [0, 1], 'k--', lw=2) ax.set_xlabel('Percentage of sample', fontsize=text_fontsize) ax.set_ylabel('Gain', fontsize=text_fontsize)