Skip to content

Commit 06c55bd

Browse files
committed
ruff
1 parent 9dc49fe commit 06c55bd

File tree

3 files changed

+77
-57
lines changed

3 files changed

+77
-57
lines changed

allocator/core/algorithms.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
try:
1414
from sklearn.cluster import KMeans
1515
from sklearn.utils.validation import check_array
16+
1617
HAS_SKLEARN = True
1718
except ImportError:
1819
HAS_SKLEARN = False
@@ -70,7 +71,14 @@ class CustomKMeans(KMeans if HAS_SKLEARN else object):
7071
including haversine, OSRM, and Google Maps API distances.
7172
"""
7273

73-
def __init__(self, n_clusters=8, distance_method="euclidean", max_iter=300, random_state=None, **distance_kwargs):
74+
def __init__(
75+
self,
76+
n_clusters=8,
77+
distance_method="euclidean",
78+
max_iter=300,
79+
random_state=None,
80+
**distance_kwargs,
81+
):
7482
if HAS_SKLEARN:
7583
# Initialize sklearn KMeans with all parameters
7684
super().__init__(n_clusters=n_clusters, max_iter=max_iter, random_state=random_state)
@@ -81,12 +89,14 @@ def __init__(self, n_clusters=8, distance_method="euclidean", max_iter=300, rand
8189
def _transform(self, X):
8290
"""Override sklearn's distance calculation to use custom metrics."""
8391
if not HAS_SKLEARN:
84-
raise ImportError("sklearn is required for CustomKMeans. Install with: pip install 'allocator[algorithms]'")
92+
raise ImportError(
93+
"sklearn is required for CustomKMeans. Install with: pip install 'allocator[algorithms]'"
94+
)
8595

8696
# Use our custom distance factory instead of sklearn's euclidean
87-
distances = get_distance_matrix(X, self.cluster_centers_,
88-
method=self.distance_method,
89-
**self.distance_kwargs)
97+
distances = get_distance_matrix(
98+
X, self.cluster_centers_, method=self.distance_method, **self.distance_kwargs
99+
)
90100
return distances
91101

92102
def _update_centroids(self, X, labels):
@@ -111,17 +121,17 @@ def fit(self, X, y=None, sample_weight=None):
111121
# Fallback to original implementation if sklearn not available
112122
return self._fit_custom_implementation(X)
113123

114-
X = check_array(X, accept_sparse='csr', dtype=[np.float64, np.float32])
124+
X = check_array(X, accept_sparse="csr", dtype=[np.float64, np.float32])
115125

116126
# Initialize using sklearn's initialization logic
117127
super().fit(X)
118128

119129
# Now run our custom iterations
120130
for iteration in range(self.max_iter):
121131
# Calculate distances using custom metric
122-
distances = get_distance_matrix(X, self.cluster_centers_,
123-
method=self.distance_method,
124-
**self.distance_kwargs)
132+
distances = get_distance_matrix(
133+
X, self.cluster_centers_, method=self.distance_method, **self.distance_kwargs
134+
)
125135

126136
# Assign points to nearest centroids
127137
labels = np.argmin(distances, axis=1)
@@ -145,9 +155,9 @@ def fit(self, X, y=None, sample_weight=None):
145155

146156
def _fit_custom_implementation(self, X):
147157
"""Fallback to original implementation when sklearn is not available."""
148-
result = _kmeans_cluster_original(X, self.n_clusters,
149-
distance_method=self.distance_method,
150-
**self.distance_kwargs)
158+
result = _kmeans_cluster_original(
159+
X, self.n_clusters, distance_method=self.distance_method, **self.distance_kwargs
160+
)
151161
self.cluster_centers_ = result["centroids"]
152162
self.labels_ = result["labels"]
153163
self.n_iter_ = result["iterations"]
@@ -184,7 +194,7 @@ def kmeans_cluster(
184194
distance_method=distance_method,
185195
max_iter=max_iter,
186196
random_state=random_state,
187-
**distance_kwargs
197+
**distance_kwargs,
188198
)
189199
kmeans.fit(X)
190200

@@ -196,7 +206,9 @@ def kmeans_cluster(
196206
}
197207

198208
# Fall back to original implementation
199-
return _kmeans_cluster_original(X, n_clusters, distance_method, max_iter, random_state, **distance_kwargs)
209+
return _kmeans_cluster_original(
210+
X, n_clusters, distance_method, max_iter, random_state, **distance_kwargs
211+
)
200212

201213

202214
def _kmeans_cluster_original(

allocator/viz/plotting.py

Lines changed: 47 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@
1212
try:
1313
import folium
1414
from folium import plugins
15+
1516
HAS_FOLIUM = True
1617
except ImportError:
1718
HAS_FOLIUM = False
1819

1920
try:
2021
import polyline
22+
2123
HAS_POLYLINE = True
2224
except ImportError:
2325
HAS_POLYLINE = False
@@ -277,18 +279,26 @@ def plot_clusters_interactive(
277279
zoom_start = 8
278280

279281
# Create base map
280-
m = folium.Map(
281-
location=[center_lat, center_lon],
282-
zoom_start=zoom_start,
283-
tiles='OpenStreetMap'
284-
)
282+
m = folium.Map(location=[center_lat, center_lon], zoom_start=zoom_start, tiles="OpenStreetMap")
285283

286284
# Color palette for clusters
287285
n_clusters = len(np.unique(labels))
288286
color_palette = [
289-
'#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FECA57',
290-
'#FF9FF3', '#54A0FF', '#5F27CD', '#00D2D3', '#FF9F43',
291-
'#A55EEA', '#26DE81', '#FD79A8', '#FDCB6E', '#6C5CE7'
287+
"#FF6B6B",
288+
"#4ECDC4",
289+
"#45B7D1",
290+
"#96CEB4",
291+
"#FECA57",
292+
"#FF9FF3",
293+
"#54A0FF",
294+
"#5F27CD",
295+
"#00D2D3",
296+
"#FF9F43",
297+
"#A55EEA",
298+
"#26DE81",
299+
"#FD79A8",
300+
"#FDCB6E",
301+
"#6C5CE7",
292302
]
293303

294304
# Extend palette if needed
@@ -304,10 +314,10 @@ def plot_clusters_interactive(
304314
location=[lat, lon],
305315
radius=6,
306316
popup=f"Point {i}<br>Cluster: {cluster_id}<br>Coords: ({lon:.4f}, {lat:.4f})",
307-
color='white',
317+
color="white",
308318
weight=1,
309319
fillColor=color,
310-
fillOpacity=0.8
320+
fillOpacity=0.8,
311321
).add_to(m)
312322

313323
# Add centroids if provided
@@ -317,8 +327,8 @@ def plot_clusters_interactive(
317327
folium.Marker(
318328
location=[lat, lon],
319329
popup=f"Centroid {k}<br>Coords: ({lon:.4f}, {lat:.4f})",
320-
icon=folium.Icon(color='black', icon='star', prefix='fa'),
321-
tooltip=f"Cluster {k} Centroid"
330+
icon=folium.Icon(color="black", icon="star", prefix="fa"),
331+
tooltip=f"Cluster {k} Centroid",
322332
).add_to(m)
323333

324334
# Add legend
@@ -337,7 +347,7 @@ def plot_clusters_interactive(
337347
legend_html += f'<p><span style="color:{color};">●</span> Cluster {k}</p>'
338348

339349
if n_clusters > 8:
340-
legend_html += f'<p>... and {n_clusters - 8} more</p>'
350+
legend_html += f"<p>... and {n_clusters - 8} more</p>"
341351

342352
legend_html += "</div>"
343353
m.get_root().html.add_child(folium.Element(legend_html))
@@ -403,11 +413,7 @@ def plot_route_interactive(
403413
zoom_start = 8
404414

405415
# Create base map
406-
m = folium.Map(
407-
location=[center_lat, center_lon],
408-
zoom_start=zoom_start,
409-
tiles='OpenStreetMap'
410-
)
416+
m = folium.Map(location=[center_lat, center_lon], zoom_start=zoom_start, tiles="OpenStreetMap")
411417

412418
# Add route line
413419
if route_geometry and HAS_POLYLINE:
@@ -418,11 +424,7 @@ def plot_route_interactive(
418424
route_coords = [[lat, lon] for lat, lon in decoded_coords]
419425

420426
folium.PolyLine(
421-
locations=route_coords,
422-
color='blue',
423-
weight=4,
424-
opacity=0.8,
425-
popup="Optimized Route"
427+
locations=route_coords, color="blue", weight=4, opacity=0.8, popup="Optimized Route"
426428
).add_to(m)
427429
except Exception:
428430
# Fall back to straight line connections if decoding fails
@@ -437,23 +439,23 @@ def plot_route_interactive(
437439

438440
# Color-code start, end, and intermediate points
439441
if i == 0:
440-
icon_color = 'green'
441-
icon_symbol = 'play'
442-
label = 'Start'
442+
icon_color = "green"
443+
icon_symbol = "play"
444+
label = "Start"
443445
elif i == len(route_order) - 1:
444-
icon_color = 'red'
445-
icon_symbol = 'stop'
446-
label = 'End'
446+
icon_color = "red"
447+
icon_symbol = "stop"
448+
label = "End"
447449
else:
448-
icon_color = 'blue'
449-
icon_symbol = f'{i}'
450-
label = f'Stop {i}'
450+
icon_color = "blue"
451+
icon_symbol = f"{i}"
452+
label = f"Stop {i}"
451453

452454
folium.Marker(
453455
location=[lat, lon],
454456
popup=f"{label}<br>Point {point_idx}<br>Coords: ({lon:.4f}, {lat:.4f})",
455-
icon=folium.Icon(color=icon_color, icon=icon_symbol, prefix='fa'),
456-
tooltip=f"{label} (Point {point_idx})"
457+
icon=folium.Icon(color=icon_color, icon=icon_symbol, prefix="fa"),
458+
tooltip=f"{label} (Point {point_idx})",
457459
).add_to(m)
458460

459461
# Calculate route statistics
@@ -467,7 +469,7 @@ def plot_route_interactive(
467469
font-size:14px; padding: 10px">
468470
<h4>{title}</h4>
469471
<p><strong>Points:</strong> {len(route_points)}</p>
470-
<p><strong>Route Type:</strong> {'API Route' if route_geometry else 'Direct Lines'}</p>
472+
<p><strong>Route Type:</strong> {"API Route" if route_geometry else "Direct Lines"}</p>
471473
<p><strong>Est. Distance:</strong> {total_distance:.2f} km</p>
472474
<hr>
473475
<p><span style="color:green;">●</span> Start Point</p>
@@ -487,7 +489,9 @@ def plot_route_interactive(
487489
return m
488490

489491

490-
def _add_straight_line_route(m: folium.Map, route_points: np.ndarray, route_order: list[int]) -> None:
492+
def _add_straight_line_route(
493+
m: folium.Map, route_points: np.ndarray, route_order: list[int]
494+
) -> None:
491495
"""Add straight line connections between route points."""
492496
ordered_points = route_points[route_order]
493497

@@ -497,10 +501,10 @@ def _add_straight_line_route(m: folium.Map, route_points: np.ndarray, route_orde
497501

498502
folium.PolyLine(
499503
locations=route_coords,
500-
color='blue',
504+
color="blue",
501505
weight=3,
502506
opacity=0.7,
503-
popup="Direct Route (Straight Lines)"
507+
popup="Direct Route (Straight Lines)",
504508
).add_to(m)
505509

506510

@@ -516,7 +520,7 @@ def haversine_distance(lat1: float, lon1: float, lat2: float, lon2: float) -> fl
516520
# Haversine formula
517521
dlat = lat2 - lat1
518522
dlon = lon2 - lon1
519-
a = sin(dlat/2)**2 + cos(lat1) * cos(lat2) * sin(dlon/2)**2
523+
a = sin(dlat / 2) ** 2 + cos(lat1) * cos(lat2) * sin(dlon / 2) ** 2
520524
c = 2 * asin(sqrt(a))
521525

522526
# Earth radius in kilometers
@@ -532,8 +536,10 @@ def haversine_distance(lat1: float, lon1: float, lat2: float, lon2: float) -> fl
532536
next_point = ordered_points[(i + 1) % len(ordered_points)] # Return to start
533537

534538
distance = haversine_distance(
535-
current_point[1], current_point[0], # lat1, lon1
536-
next_point[1], next_point[0] # lat2, lon2
539+
current_point[1],
540+
current_point[0], # lat1, lon1
541+
next_point[1],
542+
next_point[0], # lat2, lon2
537543
)
538544
total_distance += distance
539545

tests/api/test_cluster_api.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,10 @@ def test_kmeans_reproducibility(self):
9191
labels_identical = np.array_equal(result1.labels, result2.labels)
9292
labels_flipped = np.array_equal(result1.labels, 1 - result2.labels)
9393

94-
self.assertTrue(labels_identical or labels_flipped,
95-
"Clustering results should be reproducible (labels may be flipped)")
94+
self.assertTrue(
95+
labels_identical or labels_flipped,
96+
"Clustering results should be reproducible (labels may be flipped)",
97+
)
9698

9799
# Centroids should be the same (possibly in different order)
98100
centroids1_sorted = np.sort(result1.centroids.flatten())

0 commit comments

Comments
 (0)