Skip to content
14 changes: 1 addition & 13 deletions simpeg_drivers/utils/nested.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,21 +535,9 @@ def tile_locations(
from sklearn.cluster import KMeans

kmeans = KMeans(n_clusters=n_tiles, random_state=0, n_init="auto")
cluster_size = int(np.ceil(grid_locs.shape[0] / n_tiles))
kmeans.fit(grid_locs)

if labels is not None:
cluster_id = kmeans.labels_
else:
# Redistribute cluster centers to even out the number of points
centers = kmeans.cluster_centers_
centers = (
centers.reshape(-1, 1, grid_locs.shape[1])
.repeat(cluster_size, 1)
.reshape(-1, grid_locs.shape[1])
)
distance_matrix = cdist(grid_locs, centers)
cluster_id = linear_sum_assignment(distance_matrix)[1] // cluster_size
cluster_id = kmeans.labels_
Comment thread
domfournier marked this conversation as resolved.

tiles = []
for tid in set(cluster_id):
Expand Down
58 changes: 30 additions & 28 deletions tests/locations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,34 +113,36 @@ def test_filter(tmp_path: Path):
assert np.all(filtered_data["key"] == [2, 3, 4])


def test_tile_locations(tmp_path: Path):
with Workspace.create(tmp_path / f"{__name__}.geoh5") as ws:
grid_x, grid_y = np.meshgrid(np.arange(100), np.arange(100))
choices = np.c_[grid_x.ravel(), grid_y.ravel(), np.zeros(grid_x.size)]
inds = np.random.randint(0, 10000, 1000)
pts = Points.create(
ws,
name="test-points",
vertices=choices[inds],
)
tiles = tile_locations(pts.vertices[:, :2], n_tiles=8)

values = np.zeros(pts.n_vertices)
pop = []
for ind, tile in enumerate(tiles):
values[tile] = ind
pop.append(len(tile))

pts.add_data(
{
"values": {
"values": values,
}
}
)
assert np.std(pop) / np.mean(pop) < 0.02, (
"Population of tiles are not almost equal."
)
# TODO Find a scalable algo better than linear_sum_assignment to do even split
# The tiling strategy should yield even "densities" (area x n_receivers)
# def test_tile_locations(tmp_path: Path):
# with Workspace.create(tmp_path / f"{__name__}.geoh5") as ws:
# grid_x, grid_y = np.meshgrid(np.arange(100), np.arange(100))
# choices = np.c_[grid_x.ravel(), grid_y.ravel(), np.zeros(grid_x.size)]
# inds = np.random.randint(0, 10000, 1000)
# pts = Points.create(
# ws,
# name="test-points",
# vertices=choices[inds],
# )
# tiles = tile_locations(pts.vertices[:, :2], n_tiles=8)
#
# values = np.zeros(pts.n_vertices)
# pop = []
# for ind, tile in enumerate(tiles):
# values[tile] = ind
# pop.append(len(tile))
#
# pts.add_data(
# {
# "values": {
# "values": values,
# }
# }
# )
# assert np.std(pop) / np.mean(pop) < 0.02, (
# "Population of tiles are not almost equal {}."
# )


Comment thread
domfournier marked this conversation as resolved.
Outdated
def test_tile_locations_labels(tmp_path: Path):
Expand Down
Loading