|
| 1 | +from __future__ import annotations |
| 2 | + |
1 | 3 | from dataclasses import dataclass |
2 | 4 | from functools import partial |
3 | 5 | from typing import Any |
4 | 6 |
|
5 | 7 | import ipywidgets |
6 | 8 | import numpy as np |
7 | 9 | import xarray as xr |
8 | | -from lonboard import Map |
| 10 | +from lonboard import BaseLayer, Map |
9 | 11 |
|
10 | 12 |
|
11 | 13 | def on_slider_change(change, container): |
@@ -39,7 +41,115 @@ def render(self): |
39 | 41 | # add any additional control widgets here |
40 | 42 | control_box = ipywidgets.HBox([self.dimension_sliders]) |
41 | 43 |
|
42 | | - return ipywidgets.VBox([self.map, control_box]) |
| 44 | + return MapWithSliders( |
| 45 | + [self.map, control_box], layout=ipywidgets.Layout(width="100%") |
| 46 | + ) |
| 47 | + |
| 48 | + |
| 49 | +def extract_maps(obj: MapGrid | MapWithSliders | Map): |
| 50 | + if isinstance(obj, Map): |
| 51 | + return obj |
| 52 | + |
| 53 | + return getattr(obj, "maps", (obj.map,)) |
| 54 | + |
| 55 | + |
| 56 | +class MapGrid(ipywidgets.GridBox): |
| 57 | + def __init__( |
| 58 | + self, |
| 59 | + maps: MapWithSliders | Map = None, |
| 60 | + n_columns: int = 2, |
| 61 | + synchronize: bool = False, |
| 62 | + ): |
| 63 | + self.n_columns = n_columns |
| 64 | + self.synchronize = synchronize |
| 65 | + |
| 66 | + column_width = 100 // n_columns |
| 67 | + layout = ipywidgets.Layout( |
| 68 | + width="100%", grid_template_columns=f"repeat({n_columns}, {column_width}%)" |
| 69 | + ) |
| 70 | + |
| 71 | + if maps is None: |
| 72 | + maps = [] |
| 73 | + |
| 74 | + if synchronize and maps: |
| 75 | + all_maps = [getattr(m, "map", m) for m in maps] |
| 76 | + |
| 77 | + first = all_maps[0] |
| 78 | + for second in all_maps[1:]: |
| 79 | + ipywidgets.jslink((first, "view_state"), (second, "view_state")) |
| 80 | + |
| 81 | + super().__init__(maps, layout=layout) |
| 82 | + |
| 83 | + def _replace_maps(self, maps): |
| 84 | + return type(self)(maps, n_columns=self.n_columns, synchronize=self.synchronize) |
| 85 | + |
| 86 | + def add_map(self, map_: MapWithSliders | Map): |
| 87 | + return self._replace_maps(self.maps + (map_,)) |
| 88 | + |
| 89 | + @property |
| 90 | + def maps(self): |
| 91 | + return self.children |
| 92 | + |
| 93 | + def __or__(self, other: MapGrid | MapWithSliders | Map): |
| 94 | + other_maps = extract_maps(other) |
| 95 | + |
| 96 | + return self._replace_maps(self.maps + other_maps) |
| 97 | + |
| 98 | + def __ror__(self, other: MapWithSliders | Map): |
| 99 | + other_maps = extract_maps(other) |
| 100 | + |
| 101 | + return self._replace_maps(self.maps + other_maps) |
| 102 | + |
| 103 | + |
| 104 | +class MapWithSliders(ipywidgets.VBox): |
| 105 | + def change_layout(self, layout): |
| 106 | + return type(self)(self.children, layout=layout) |
| 107 | + |
| 108 | + @property |
| 109 | + def sliders(self) -> list: |
| 110 | + return list(self.children[1:]) if len(self.children) > 1 else [] |
| 111 | + |
| 112 | + @property |
| 113 | + def map(self) -> Map: |
| 114 | + return self.children[0] |
| 115 | + |
| 116 | + @property |
| 117 | + def layers(self) -> list[BaseLayer]: |
| 118 | + return self.map.layers |
| 119 | + |
| 120 | + def __or__(self, other: MapWithSliders | Map): |
| 121 | + [other_map] = extract_maps(other) |
| 122 | + |
| 123 | + return MapGrid([self, other], synchronize=True) |
| 124 | + |
| 125 | + def _merge(self, layers, sliders): |
| 126 | + all_layers = list(self.map.layers) + list(layers) |
| 127 | + new_map = Map(all_layers) |
| 128 | + |
| 129 | + slider_widgets = [] |
| 130 | + if self.sliders: |
| 131 | + slider_widgets.extend(self.sliders) |
| 132 | + if sliders: |
| 133 | + slider_widgets.extend(sliders) |
| 134 | + |
| 135 | + widgets = [new_map] |
| 136 | + if slider_widgets: |
| 137 | + widgets.append(ipywidgets.HBox(slider_widgets)) |
| 138 | + |
| 139 | + return type(self)(widgets, layout=self.layout) |
| 140 | + |
| 141 | + def add_layer(self, layer: BaseLayer): |
| 142 | + self.map.add_layer(layer) |
| 143 | + |
| 144 | + def __and__(self, other: MapWithSliders | Map | BaseLayer): |
| 145 | + if isinstance(other, BaseLayer): |
| 146 | + layers = [other] |
| 147 | + sliders = [] |
| 148 | + else: |
| 149 | + layers = other.layers |
| 150 | + sliders = getattr(other, "sliders", []) |
| 151 | + |
| 152 | + return self._merge(layers, sliders) |
43 | 153 |
|
44 | 154 |
|
45 | 155 | def create_arrow_table(polygons, arr, coords=None): |
|
0 commit comments