Skip to content

Commit

Permalink
Correct hex lattice and implement hex u-matrix
Browse files Browse the repository at this point in the history
Changed generate_hex_lattice() from sompy/codebook.py such as the
coordinates correspond to a regular hexagon grid (which wasn't the
case previously).

Changed sompy/visualization/umatrix.py such that if the som object has
a hexagonal lattice, then the u-matrix is plotted on a hexagonal grid.
All additional plot features (e.g. contour, blob, etc.) are compatible
with the hexagonal grid. Also added colorbar to the plot.
  • Loading branch information
dfhssilva committed Feb 5, 2021
1 parent 659a5a0 commit 6bd2ced
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 64 deletions.
8 changes: 4 additions & 4 deletions sompy/codebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ class InvalidMapsizeError(Exception):
pass

def generate_hex_lattice(n_rows, n_columns):
x_coord = []
y_coord = []
w, h = np.sqrt(3), 2
x_coord, y_coord = [], []
for i in range(n_rows):
for j in range(n_columns):
x_coord.append(i*1.5)
y_coord.append(np.sqrt(2/3)*(2*j+(1+i)%2))
x_coord.append(((j + 1) * w) - ((i % 2) * (0.5 * w)))
y_coord.append((i * h * (3 / 4)) + (0.5 * h))
coordinates = np.column_stack([x_coord, y_coord])
return coordinates

Expand Down
186 changes: 126 additions & 60 deletions sompy/visualization/umatrix.py
Original file line number Diff line number Diff line change
@@ -1,95 +1,161 @@
from .view import MatplotView
from .mapview import MapView
from .plot_tools import plot_hex_map
from matplotlib import pyplot as plt
from pylab import imshow, contour
from matplotlib import cm
from mpl_toolkits.axes_grid1 import make_axes_locatable
from skimage.feature import blob_log

from warnings import warn
from math import sqrt
import numpy as np
import scipy


class UMatrixView(MatplotView):
def rectxy_to_hexaxy(coord, X, Y):
"""Convert rectangular grid xy coordinates to hexagonal grid xy coordinates.
Useful for plotting additional data on top of hexagonal grid.
Args:
coord (array): array with rectangular grid xy coordinates
X (array): mapsize shaped array with hexagonal grid x coordinates
Y (array): mapsize shaped array with hexagonal grid y coordinates
Returns:
[array]: array of coord's shape with hexagonal grid xy coordinates
"""
out = np.vstack(([X[tuple(i)] for i in coord], [Y[tuple(i)] for i in coord])).T
return out

class UMatrixView(MapView):
def build_u_matrix(self, som, distance=1, row_normalized=False):
UD2 = som.calculate_map_dist()
Umatrix = np.zeros((som.codebook.nnodes, 1))
codebook = som.codebook.matrix
if row_normalized:
vector = som._normalizer.normalize_by(codebook.T, codebook.T,
method='var').T
vector = som._normalizer.normalize_by(codebook.T, codebook.T).T
else:
vector = codebook

for i in range(som.codebook.nnodes):
codebook_i = vector[i][np.newaxis, :]
neighborbor_ind = UD2[i][0:] <= distance
neighborbor_codebooks = vector[neighborbor_ind]
Umatrix[i] = scipy.spatial.distance_matrix(
codebook_i, neighborbor_codebooks).mean()
neighborbor_dists = scipy.spatial.distance_matrix(
codebook_i, neighborbor_codebooks)
Umatrix[i] = np.sum(neighborbor_dists) / (neighborbor_dists.shape[1] - 1)

return Umatrix.reshape(som.codebook.mapsize)

def show(self, som, distance2=1, row_normalized=False, show_data=True,
contooor=True, blob=False, labels=False):
umat = self.build_u_matrix(som, distance=distance2,
row_normalized=row_normalized)
def _set_contour(self, umat, ax, X=None, Y=None, hex=False):
mn = np.min(umat.flatten())
md = np.median(umat.flatten())
if hex:
ax.contour(X, Y, umat, np.linspace(mn, md, 15),
linewidths=0.7, cmap=plt.cm.get_cmap('Blues'))
else:
ax.contour(umat, np.linspace(mn, md, 15),
linewidths=0.7, cmap=plt.cm.get_cmap('Blues'))

def _set_show_data(self, X, Y, ax):
ax.scatter(X, Y, s=2, alpha=1., c='Gray',
marker='o', cmap='jet', linewidths=3, edgecolor='Gray')

def _set_labels(self, labels, X, Y, ax):
for label, x, y in zip(labels, X, Y):
ax.annotate(str(label), xy=(x, y),
horizontalalignment='center',
verticalalignment='center')

def _set_blob(self, umat, coord, ax, X=None, Y=None, hex=False):
# 'Laplacian of Gaussian'
image = 1 / umat
blobs = blob_log(image, max_sigma=5, num_sigma=4, threshold=.152)
blobs[:, 2] = blobs[:, 2] * sqrt(2)
if hex:
blobs[:, :2] = rectxy_to_hexaxy(blobs[:, :2].astype(int), X, Y)
sel_points = list()

for blob in blobs:
row, col, r = blob
c = plt.Circle((row, col), r, color='red', linewidth=2,
fill=False)
ax.add_patch(c)

dist = scipy.spatial.distance_matrix(
coord, np.array([row, col])[np.newaxis, :])
sel_point = dist <= r
ax.plot(coord[:, 0][sel_point[:, 0]],
coord[:, 1][sel_point[:, 0]], '.r')
sel_points.append(sel_point[:, 0])
if hex:
ax.set_xlim([-0.5, umat.shape[1]])
ax.set_ylim([0.6, -((umat.shape[0] - 1) * sqrt(3)/2 + 0.6)])
else:
ax.set_xlim([-0.5, umat.shape[1] - 0.5])
ax.set_ylim([-0.5, umat.shape[0] - 0.5])


def show(self, som, distance=1, row_normalized=False, show_data=False,
contour=False, blob=False, labels=False):
# Setting figure parameters
org_w = self.width
org_h = self.height
(self.width, self.height, indtoshow, no_row_in_plot, no_col_in_plot,
axis_num) = self._calculate_figure_params(som, 1, 1)
self.width /= (self.width/org_w) if self.width > self.height else (self.height/org_h)
self.height /= (self.width / org_w) if self.width > self.height else (self.height / org_h)
self.prepare()
plt.rc('figure', titlesize=self.text_size)
colormap = plt.get_cmap('RdYlBu_r')

# Setting figure data
if som.codebook.lattice == "hexa" and distance < sqrt(3):
warn("For hexagonal lattice, distance < sqrt(3) produces a null U-matrix.")
umat = self.build_u_matrix(som, distance=distance, row_normalized=row_normalized)
msz = som.codebook.mapsize
proj = som.project_data(som.data_raw)
coord = som.bmu_ind_to_xy(proj)

self._fig, ax = plt.subplots(1, 1)
imshow(umat, cmap=plt.cm.get_cmap('RdYlBu_r'), alpha=1)
coord = som.bmu_ind_to_xy(proj)[:, :2]
sel_points = list()

if contooor:
mn = np.min(umat.flatten())
mx = np.max(umat.flatten())
std = np.std(umat.flatten())
md = np.median(umat.flatten())
mx = md + 0*std
contour(umat, np.linspace(mn, mx, 15), linewidths=0.7,
cmap=plt.cm.get_cmap('Blues'))
if som.codebook.lattice == "rect":
ax = self._fig.add_subplot(111)
ax.imshow(umat, cmap=colormap, alpha=1)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(cm.ScalarMappable(cmap=colormap), cax=cax, orientation='vertical')
coord = np.flip(coord, axis=1)

if contour:
self._set_contour(umat, ax, hex=False)

if blob:
self._set_blob(umat, coord, ax, hex=False)
elif som.codebook.lattice == "hexa":
ax, cents = plot_hex_map(umat, colormap=colormap, fig=self._fig, colorbar=True)
X = np.flip(np.array(cents)[:, 0].reshape(msz[0], msz[1]), axis=1)
Y = np.flip(np.array(cents)[:, 1].reshape(msz[0], msz[1]), axis=1)
coord = rectxy_to_hexaxy(coord, X, Y)

if contour:
self._set_contour(umat, ax, X, Y, hex=True)

if blob:
self._set_blob(umat, coord, ax, X, Y, hex=True)
else:
raise ValueError(
'lattice argument of SOM object should be either "rect" or "hexa".')

if show_data:
plt.scatter(coord[:, 1], coord[:, 0], s=2, alpha=1., c='Gray',
marker='o', cmap='jet', linewidths=3, edgecolor='Gray')
plt.axis('off')
self._set_show_data(coord[:, 0], coord[:, 1], ax)

if labels:
if labels is True:
labels = som.build_data_labels()
for label, x, y in zip(labels, coord[:, 1], coord[:, 0]):
plt.annotate(str(label), xy=(x, y),
horizontalalignment='center',
verticalalignment='center')

labels = som.build_data_labels()
self._set_labels(labels, coord[:, 0], coord[:, 1], ax)

ratio = float(msz[0])/(msz[0]+msz[1])
self._fig.set_size_inches((1-ratio)*15, ratio*15)
plt.tight_layout()
plt.subplots_adjust(hspace=.00, wspace=.000)
sel_points = list()

if blob:
from skimage.color import rgb2gray
from skimage.feature import blob_log

image = 1 / umat
rgb2gray(image)

# 'Laplacian of Gaussian'
blobs = blob_log(image, max_sigma=5, num_sigma=4, threshold=.152)
blobs[:, 2] = blobs[:, 2] * sqrt(2)
imshow(umat, cmap=plt.cm.get_cmap('RdYlBu_r'), alpha=1)
sel_points = list()

for blob in blobs:
row, col, r = blob
c = plt.Circle((col, row), r, color='red', linewidth=2,
fill=False)
ax.add_patch(c)
dist = scipy.spatial.distance_matrix(
coord[:, :2], np.array([row, col])[np.newaxis, :])
sel_point = dist <= r
plt.plot(coord[:, 1][sel_point[:, 0]],
coord[:, 0][sel_point[:, 0]], '.r')
sel_points.append(sel_point[:, 0])
plt.subplots_adjust(top=0.90, hspace=.00, wspace=.000)

plt.show()
return sel_points, umat

0 comments on commit 6bd2ced

Please sign in to comment.