diff --git a/scikitplot/metrics.py b/scikitplot/metrics.py index 08ec693..492078d 100644 --- a/scikitplot/metrics.py +++ b/scikitplot/metrics.py @@ -34,67 +34,52 @@ def plot_confusion_matrix(y_true, y_pred, labels=None, true_labels=None, pred_labels=None, title=None, normalize=False, hide_zeros=False, hide_counts=False, x_tick_rotation=0, ax=None, figsize=None, cmap='Blues', title_fontsize="large", - text_fontsize="medium"): + text_fontsize="medium", colorbar=True): """Generates confusion matrix plot from predictions and true labels - Args: y_true (array-like, shape (n_samples)): Ground truth (correct) target values. - y_pred (array-like, shape (n_samples)): Estimated targets as returned by a classifier. - labels (array-like, shape (n_classes), optional): List of labels to index the matrix. This may be used to reorder or select a subset of labels. If none is given, those that appear at least once in ``y_true`` or ``y_pred`` are used in sorted order. (new in v0.2.5) - true_labels (array-like, optional): The true labels to display. If none is given, then all of the labels are used. - pred_labels (array-like, optional): The predicted labels to display. If none is given, then all of the labels are used. - title (string, optional): Title of the generated plot. Defaults to "Confusion Matrix" if `normalize` is True. Else, defaults to "Normalized Confusion Matrix. - normalize (bool, optional): If True, normalizes the confusion matrix before plotting. Defaults to False. - hide_zeros (bool, optional): If True, does not plot cells containing a value of zero. Defaults to False. - hide_counts (bool, optional): If True, doe not overlay counts. Defaults to False. - x_tick_rotation (int, optional): Rotates x-axis tick labels by the specified angle. This is useful in cases where there are numerous categories and the labels overlap each other. - ax (:class:`matplotlib.axes.Axes`, optional): The axes upon which to plot the curve. If None, the plot is drawn on a new set of axes. - figsize (2-tuple, optional): Tuple denoting figure size of the plot e.g. (6, 6). Defaults to ``None``. - cmap (string or :class:`matplotlib.colors.Colormap` instance, optional): Colormap used for plotting the projection. View Matplotlib Colormap documentation for available options. https://matplotlib.org/users/colormaps.html - title_fontsize (string or int, optional): Matplotlib-style fontsizes. Use e.g. "small", "medium", "large" or integer-values. Defaults to "large". - text_fontsize (string or int, optional): Matplotlib-style fontsizes. Use e.g. "small", "medium", "large" or integer-values. Defaults to "medium". - + colorbar (bool, optional): If False, does not add colour bar. + Defaults to True. Returns: ax (:class:`matplotlib.axes.Axes`): The axes on which the plot was drawn. - Example: >>> import scikitplot as skplt >>> rf = RandomForestClassifier() @@ -103,7 +88,6 @@ def plot_confusion_matrix(y_true, y_pred, labels=None, true_labels=None, >>> skplt.metrics.plot_confusion_matrix(y_test, y_pred, normalize=True) >>> plt.show() - .. image:: _static/examples/plot_confusion_matrix.png :align: center :alt: Confusion matrix @@ -153,7 +137,10 @@ def plot_confusion_matrix(y_true, y_pred, labels=None, true_labels=None, ax.set_title('Confusion Matrix', fontsize=title_fontsize) image = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.get_cmap(cmap)) - plt.colorbar(mappable=image) + + if colorbar == True: + plt.colorbar(mappable=image) + x_tick_marks = np.arange(len(pred_classes)) y_tick_marks = np.arange(len(true_classes)) ax.set_xticks(x_tick_marks)