Skip to content

Commit c5faedb

Browse files
Feature/label dots reiinakano#111
1 parent 21cd5c6 commit c5faedb

File tree

2 files changed

+44
-11
lines changed

2 files changed

+44
-11
lines changed

scikitplot/decomposition.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
properties shared by scikit-learn estimators. The specific requirements are
66
documented per function.
77
"""
8-
from __future__ import absolute_import, division, print_function, \
9-
unicode_literals
8+
from __future__ import (
9+
absolute_import, division, print_function, unicode_literals
10+
)
1011

1112
import matplotlib.pyplot as plt
1213
import numpy as np
@@ -97,7 +98,8 @@ def plot_pca_component_variance(clf, title='PCA Component Explained Variances',
9798
def plot_pca_2d_projection(clf, X, y, title='PCA 2-D Projection',
9899
biplot=False, feature_labels=None,
99100
ax=None, figsize=None, cmap='Spectral',
100-
title_fontsize="large", text_fontsize="medium"):
101+
title_fontsize="large", text_fontsize="medium",
102+
dimensions=[0, 1], label_dots=False, ):
101103
"""Plots the 2-dimensional projection of PCA on a given dataset.
102104
103105
Args:
@@ -163,21 +165,27 @@ def plot_pca_2d_projection(clf, X, y, title='PCA 2-D Projection',
163165
fig, ax = plt.subplots(1, 1, figsize=figsize)
164166

165167
ax.set_title(title, fontsize=title_fontsize)
166-
classes = np.unique(np.array(y))
168+
# Get unique classes from y, preserving order of class occurence in y
169+
_, class_indexes = np.unique(np.array(y), return_index=True)
170+
classes = np.array(y)[np.sort(class_indexes)]
167171

168172
colors = plt.cm.get_cmap(cmap)(np.linspace(0, 1, len(classes)))
169173

170174
for label, color in zip(classes, colors):
171-
ax.scatter(transformed_X[y == label, 0], transformed_X[y == label, 1],
175+
ax.scatter(transformed_X[y == label, dimensions[0]], transformed_X[y == label, dimensions[1]],
172176
alpha=0.8, lw=2, label=label, color=color)
173177

178+
if label_dots:
179+
for dot in transformed_X[y == label][:, dimensions]:
180+
ax.text(*dot, label)
181+
174182
if biplot:
175-
xs = transformed_X[:, 0]
176-
ys = transformed_X[:, 1]
177-
vectors = np.transpose(clf.components_[:2, :])
183+
xs = transformed_X[:, dimensions[0]]
184+
ys = transformed_X[:, dimensions[1]]
185+
vectors = np.transpose(clf.components_[dimensions, :])
178186
vectors_scaled = vectors * [xs.max(), ys.max()]
179187
for i in range(vectors.shape[0]):
180-
ax.annotate("", xy=(vectors_scaled[i, 0], vectors_scaled[i, 1]),
188+
ax.annotate("", xy=(vectors_scaled[i, dimensions[0]], vectors_scaled[i, dimensions[1]]),
181189
xycoords='data', xytext=(0, 0), textcoords='data',
182190
arrowprops={'arrowstyle': '-|>', 'ec': 'r'})
183191

@@ -187,8 +195,8 @@ def plot_pca_2d_projection(clf, X, y, title='PCA 2-D Projection',
187195

188196
ax.legend(loc='best', shadow=False, scatterpoints=1,
189197
fontsize=text_fontsize)
190-
ax.set_xlabel('First Principal Component', fontsize=text_fontsize)
191-
ax.set_ylabel('Second Principal Component', fontsize=text_fontsize)
198+
ax.set_xlabel(f'Principal Component {dimensions[0]+1}', fontsize=text_fontsize)
199+
ax.set_ylabel(f'Principal Component {dimensions[1]+1}', fontsize=text_fontsize)
192200
ax.tick_params(labelsize=text_fontsize)
193201

194202
return ax

scikitplot/tests/test_decomposition.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import numpy as np
88
import matplotlib.pyplot as plt
99

10+
import scikitplot
1011
from scikitplot.decomposition import plot_pca_component_variance
1112
from scikitplot.decomposition import plot_pca_2d_projection
1213

@@ -81,3 +82,27 @@ def test_biplot(self):
8182
clf.fit(self.X)
8283
ax = plot_pca_2d_projection(clf, self.X, self.y, biplot=True,
8384
feature_labels=load_data().feature_names)
85+
86+
def test_label_order(self):
87+
'''
88+
Plot labels should be in the same order as the classes in the provided y-array
89+
'''
90+
np.random.seed(0)
91+
clf = PCA()
92+
clf.fit(self.X)
93+
94+
# define y such that the first entry is 1
95+
y = np.copy(self.y)
96+
y[0] = 1 # load_iris is be default orderer (i.e.: 0 0 0 ... 1 1 1 ... 2 2 2)
97+
98+
# test with len(y) == X.shape[0] with multiple rows belonging to the same class
99+
ax = plot_pca_2d_projection(clf, self.X, y, cmap='Spectral')
100+
legend_labels = ax.get_legend_handles_labels()[1]
101+
self.assertListEqual(['1', '0', '2'], legend_labels)
102+
103+
# test with len(y) == #classes with each row belonging to an individual class
104+
y = list(range(len(y)))
105+
np.random.shuffle(y)
106+
ax = plot_pca_2d_projection(clf, self.X, y, cmap='Spectral')
107+
legend_labels = ax.get_legend_handles_labels()[1]
108+
self.assertListEqual([str(v) for v in y], legend_labels)

0 commit comments

Comments
 (0)