Skip to content

Commit cd2edc9

Browse files
committed
add filtering back to plots
1 parent a0d07df commit cd2edc9

File tree

7 files changed

+132
-87
lines changed

7 files changed

+132
-87
lines changed

packages/app/src/components/data/plot/plots/box-plot.tsx

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import { Group } from "@visx/group";
33
import { Plot, PlotDefinition } from "@common/db/schema/plot";
44
import { DataPoint } from "./common";
55
import { usePlotTheme } from "@/hooks/use-plot-theme";
6-
import { getColorScale, MARGIN } from "./common";
6+
import { MARGIN } from "./common";
77
import { CommonPlotProps } from "./common";
88
import { BaseCartesianPlot } from "./base-plot";
99
import { PlotWrapper } from "./plot-wrapper";
@@ -128,13 +128,15 @@ function BaseBoxPlot({
128128
width,
129129
height,
130130
data,
131+
colorMapping,
131132
plot,
132133
showTooltip,
133134
hideTooltip,
134135
}: {
135136
width: number;
136137
height: number;
137138
data: DataPoint[];
139+
colorMapping?: Map<string, string>;
138140
plot: Props["plot"];
139141
showTooltip: (_args: {
140142
tooltipLeft: number;
@@ -145,7 +147,6 @@ function BaseBoxPlot({
145147
}) {
146148
const theme = usePlotTheme();
147149
const { xAxis, yAxis } = plot.definition;
148-
const colorScale = getColorScale(data);
149150

150151
// Group data by x-axis values
151152
const groupedData = new Map<string | number, number[]>();
@@ -225,7 +226,8 @@ function BaseBoxPlot({
225226
const groupPoints = data.filter((point) => point.x === x);
226227
const firstPoint = groupPoints[0];
227228
const color = firstPoint?.originalColor
228-
? colorScale(String(firstPoint.originalColor))
229+
? colorMapping?.get(String(firstPoint.originalColor)) ||
230+
theme.primary
229231
: theme.primary;
230232

231233
// Debug each box plot
@@ -350,6 +352,7 @@ export function BoxPlot(props: Props) {
350352
{...props}
351353
renderContent={({
352354
processedData,
355+
colorMapping,
353356
width,
354357
height,
355358
showTooltip,
@@ -358,6 +361,7 @@ export function BoxPlot(props: Props) {
358361
<BaseBoxPlot
359362
plot={props.plot}
360363
data={processedData}
364+
colorMapping={colorMapping}
361365
width={width}
362366
height={height}
363367
showTooltip={showTooltip}

packages/app/src/components/data/plot/plots/common.tsx

Lines changed: 32 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import {
33
type PlotDefinition,
44
} from "@common/db/schema/plot";
55

6-
import { scaleLinear, scaleOrdinal } from "@visx/scale";
6+
import { scaleLinear } from "@visx/scale";
77
import { extent } from "@visx/vendor/d3-array";
88
import { usePlotTheme } from "@/hooks/use-plot-theme";
99
import {
@@ -37,6 +37,7 @@ export const OBSERVABLE_COLORS = [
3737
export interface CommonPlotProps {
3838
plot: DbPlot & { definition: PlotDefinition };
3939
data: Record<string, unknown>[];
40+
allData?: Record<string, unknown>[];
4041
className?: string;
4142
height?: number;
4243
width?: number | string;
@@ -250,38 +251,53 @@ export function assignAbbreviationColorsAndSymbols<T extends DataPoint>(
250251

251252
let defaultColorIndex = 0;
252253

253-
points.forEach((point) => {
254-
if (colorColumn && point.color && !processedColors.has(point.color)) {
255-
processedColors.add(point.color);
254+
// Get unique colors in sorted order for consistent assignment
255+
const uniqueColors = colorColumn
256+
? Array.from(new Set(points.map((p) => p.color).filter(Boolean))).sort()
257+
: [];
258+
259+
uniqueColors.forEach((color) => {
260+
if (colorColumn && color && !processedColors.has(color)) {
261+
processedColors.add(color);
262+
263+
// Find the point with this color to check for abbreviation color
264+
const pointWithColor = points.find((p) => p.color === color);
256265

257266
// Check if this point has abbreviation color information
258-
if (point.colorObject?.color) {
259-
colorMap.set(point.color, point.colorObject.color);
267+
if (pointWithColor?.colorObject?.color) {
268+
colorMap.set(color, pointWithColor.colorObject.color);
260269
} else {
261270
// Use default color if no abbreviation color is set
262271
colorMap.set(
263-
point.color,
272+
color,
264273
OBSERVABLE_COLORS[defaultColorIndex % OBSERVABLE_COLORS.length],
265274
);
266275
defaultColorIndex++;
267276
}
268277
}
269278
});
270279

271-
// Get unique values for symbol and line columns
280+
// Get unique values for symbol and line columns (sorted for consistent ordering)
272281
const uniqueSymbols = symbolColumn
273-
? Array.from(new Set(points.map((p) => p.symbol).filter(Boolean)))
282+
? Array.from(new Set(points.map((p) => p.symbol).filter(Boolean))).sort()
274283
: [];
275284
const uniqueLines = lineColumn
276-
? Array.from(new Set(points.map((p) => p.line).filter(Boolean)))
285+
? Array.from(new Set(points.map((p) => p.line).filter(Boolean))).sort()
277286
: [];
278287

279-
const symbolMap = new Map(
280-
uniqueSymbols.map((symbol, index) => [
281-
symbol,
282-
AVAILABLE_SYMBOLS[index % AVAILABLE_SYMBOLS.length],
283-
]),
284-
);
288+
// Create symbol mappings using the same approach as colors
289+
const symbolMap = new Map<string, string>();
290+
let defaultSymbolIndex = 0;
291+
292+
uniqueSymbols.forEach((symbol) => {
293+
if (symbolColumn && symbol) {
294+
symbolMap.set(
295+
symbol,
296+
AVAILABLE_SYMBOLS[defaultSymbolIndex % AVAILABLE_SYMBOLS.length],
297+
);
298+
defaultSymbolIndex++;
299+
}
300+
});
285301

286302
const lineMap = new Map(uniqueLines.map((line, index) => [line, index]));
287303

@@ -425,48 +441,6 @@ export const VISX_SYMBOLS = {
425441

426442
export type SymbolType = keyof typeof VISX_SYMBOLS;
427443

428-
export function getColorScale(data: DataPoint[]) {
429-
// First, collect all unique color values and their custom colors
430-
const colorMap = new Map<string, string>();
431-
const processedColors = new Set<string>();
432-
let defaultColorIndex = 0;
433-
434-
data.forEach((point) => {
435-
if (point.originalColor && !processedColors.has(point.originalColor)) {
436-
processedColors.add(point.originalColor);
437-
438-
// If the point has a custom color from abbreviation, use it
439-
if (point.colorObject?.color) {
440-
colorMap.set(point.originalColor, point.colorObject.color);
441-
} else {
442-
// Otherwise, use the next color from our default palette
443-
colorMap.set(
444-
point.originalColor,
445-
OBSERVABLE_COLORS[defaultColorIndex % OBSERVABLE_COLORS.length],
446-
);
447-
defaultColorIndex++;
448-
}
449-
}
450-
});
451-
452-
// Create a scale that uses our color map
453-
return scaleOrdinal({
454-
domain: Array.from(colorMap.keys()),
455-
range: Array.from(colorMap.values()),
456-
});
457-
}
458-
459-
export function getSymbolScale(data: DataPoint[]) {
460-
const uniqueSymbols = Array.from(
461-
new Set(data.map((d) => d.originalSymbol).filter(Boolean)),
462-
);
463-
464-
return scaleOrdinal({
465-
domain: uniqueSymbols.filter((s): s is string => typeof s === "string"),
466-
range: Object.keys(VISX_SYMBOLS),
467-
});
468-
}
469-
470444
export function PlotContainer({
471445
children,
472446
width,

packages/app/src/components/data/plot/plots/histogram-plot.tsx

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import { Bar } from "@visx/shape";
22
import { Plot, PlotDefinition } from "@common/db/schema/plot";
33
import { DataPoint } from "./common";
44
import { usePlotTheme } from "@/hooks/use-plot-theme";
5-
import { getColorScale, MARGIN } from "./common";
5+
import { MARGIN } from "./common";
66
import { CommonPlotProps } from "./common";
77
import { BaseCartesianPlot } from "./base-plot";
88
import { PlotWrapper } from "./plot-wrapper";
@@ -64,13 +64,15 @@ function BaseHistogramPlot({
6464
width,
6565
height,
6666
data,
67+
colorMapping,
6768
plot,
6869
showTooltip,
6970
hideTooltip,
7071
}: {
7172
width: number;
7273
height: number;
7374
data: DataPoint[];
75+
colorMapping?: Map<string, string>;
7476
plot: Props["plot"];
7577
showTooltip: (_args: {
7678
tooltipLeft: number;
@@ -81,7 +83,6 @@ function BaseHistogramPlot({
8183
}) {
8284
const theme = usePlotTheme();
8385
const { xAxis, grouping } = plot.definition;
84-
const colorScale = getColorScale(data);
8586

8687
// Create base scales
8788
const xValues = data.map((d) => Number(d.x)).filter((x) => !isNaN(x));
@@ -200,7 +201,7 @@ function BaseHistogramPlot({
200201
fill={
201202
group.color === theme.primary
202203
? theme.primary
203-
: colorScale(group.color)
204+
: colorMapping?.get(group.color) || theme.primary
204205
}
205206
onMouseEnter={(event) => {
206207
const coords = localPoint(
@@ -237,6 +238,7 @@ export function HistogramPlot(props: Props) {
237238
{...props}
238239
renderContent={({
239240
processedData,
241+
colorMapping,
240242
width,
241243
height,
242244
showTooltip,
@@ -245,6 +247,7 @@ export function HistogramPlot(props: Props) {
245247
<BaseHistogramPlot
246248
plot={props.plot}
247249
data={processedData}
250+
colorMapping={colorMapping}
248251
width={width}
249252
height={height}
250253
showTooltip={showTooltip}

packages/app/src/components/data/plot/plots/line-plot.tsx

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import { LinePath } from "@visx/shape";
22
import { Plot, PlotDefinition } from "@common/db/schema/plot";
33
import { DataPoint } from "./common";
44
import { usePlotTheme } from "@/hooks/use-plot-theme";
5-
import { getScaleConfig, getColorScale } from "./common";
5+
import { getScaleConfig } from "./common";
66
import { CommonPlotProps } from "./common";
77
import { BaseCartesianPlot } from "./base-plot";
88
import { PlotWrapper } from "./plot-wrapper";
@@ -32,13 +32,15 @@ function BaseLinePlot({
3232
width,
3333
height,
3434
data,
35+
colorMapping,
3536
plot,
3637
showTooltip,
3738
hideTooltip,
3839
}: {
3940
width: number;
4041
height: number;
4142
data: DataPoint[];
43+
colorMapping?: Map<string, string>;
4244
plot: Props["plot"];
4345
showTooltip: (_args: {
4446
tooltipLeft: number;
@@ -51,7 +53,6 @@ function BaseLinePlot({
5153
const { grouping } = plot.definition;
5254
const scaleConfig = getScaleConfig(plot.definition, data, width, height);
5355
const { xAxis, yAxis } = scaleConfig;
54-
const colorScale = getColorScale(data);
5556

5657
// Group data by color if color grouping is enabled
5758
const dataByColor: LineGroup[] = grouping?.color?.column
@@ -82,7 +83,7 @@ function BaseLinePlot({
8283
data={points}
8384
x={(d) => xAxis.scale(Number(d.x)) ?? 0}
8485
y={(d) => yAxis.scale(Number(d.y)) ?? 0}
85-
stroke={colorScale(color)}
86+
stroke={colorMapping?.get(color) || theme.primary}
8687
strokeWidth={2}
8788
curve={
8889
plot.definition.curve === "natural" ? curveNatural : undefined
@@ -112,7 +113,8 @@ function BaseLinePlot({
112113
r={3}
113114
fill={
114115
point.originalColor
115-
? colorScale(String(point.originalColor))
116+
? colorMapping?.get(String(point.originalColor)) ||
117+
theme.primary
116118
: theme.primary
117119
}
118120
onMouseOver={(event) => {
@@ -140,6 +142,7 @@ export function LinePlot(props: Props) {
140142
{...props}
141143
renderContent={({
142144
processedData,
145+
colorMapping,
143146
width,
144147
height,
145148
showTooltip,
@@ -148,6 +151,7 @@ export function LinePlot(props: Props) {
148151
<BaseLinePlot
149152
plot={props.plot}
150153
data={processedData}
154+
colorMapping={colorMapping}
151155
width={width}
152156
height={height}
153157
showTooltip={showTooltip}

packages/app/src/components/data/plot/plots/plot-wrapper.tsx

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import {
88
assignAbbreviationColorsAndSymbols,
99
PlotTitle,
1010
getUniqueValues,
11+
useFilteredData,
1112
DataPoint,
1213
OBSERVABLE_COLORS,
1314
} from "./common";
@@ -24,6 +25,8 @@ export interface PlotWrapperProps extends CommonPlotProps {
2425
};
2526
renderContent: (_props: {
2627
processedData: DataPoint[];
28+
colorMapping?: Map<string, string>;
29+
symbolMapping?: Map<string, string>;
2730
width: number;
2831
height: number;
2932
showTooltip: (_args: {
@@ -84,22 +87,31 @@ export function PlotWrapper({
8487
xAxisType,
8588
});
8689

87-
// Assign colors and symbols
90+
// Assign colors and symbols to all points first (for stable mappings)
8891
const processedPoints = assignAbbreviationColorsAndSymbols(
8992
points as DataPoint[],
9093
colorColumn,
9194
symbolColumn,
9295
);
9396

94-
// Get unique values for legends
97+
// Apply filtering to the processed points
98+
const {
99+
filteredPoints,
100+
filteredColors,
101+
filteredSymbols,
102+
handleColorClick,
103+
handleSymbolClick,
104+
} = useFilteredData(processedPoints);
105+
106+
// Get unique values for legends (from all data, not filtered)
95107
const uniqueColors = colorColumn
96108
? getUniqueValues(processedPoints, "color")
97109
: [];
98110
const uniqueSymbols = symbolColumn
99111
? getUniqueValues(processedPoints, "symbol")
100112
: [];
101113

102-
// Get color and symbol mappings for legend
114+
// Get color and symbol mappings for legend (based on all data, not filtered)
103115
const colorMapping = colorColumn
104116
? new Map(
105117
processedPoints
@@ -249,7 +261,9 @@ export function PlotWrapper({
249261
transform={`translate(${(effectiveWidth - actualPlotWidth) / 2}, 0)`}
250262
>
251263
{renderContent({
252-
processedData: processedPoints,
264+
processedData: filteredPoints,
265+
colorMapping,
266+
symbolMapping,
253267
width: actualPlotWidth,
254268
height: plotHeight,
255269
showTooltip,
@@ -273,6 +287,10 @@ export function PlotWrapper({
273287
colorMapping={colorMapping}
274288
symbolMapping={symbolMapping}
275289
width={naturalLegendWidth}
290+
onColorClick={handleColorClick}
291+
onSymbolClick={handleSymbolClick}
292+
filteredColors={filteredColors}
293+
filteredSymbols={filteredSymbols}
276294
/>
277295
</Group>
278296
</Group>

0 commit comments

Comments
 (0)