From e91332e98e81a568056be0a71e36f1431ade9bcb Mon Sep 17 00:00:00 2001 From: Dieter Blomme Date: Sun, 26 Oct 2025 02:52:36 +0100 Subject: [PATCH 1/5] Add filtering and sorting for custom fields --- client/src/components/column.tsx | 47 +++- client/src/components/dataProvider.ts | 21 ++ client/src/utils/filtering.ts | 81 ++++++- client/src/utils/queryFields.ts | 18 ++ client/src/utils/sorting.ts | 37 ++- spoolman/api/v1/filament.py | 12 +- spoolman/api/v1/spool.py | 12 +- spoolman/api/v1/vendor.py | 12 +- spoolman/database/filament.py | 72 +++++- spoolman/database/spool.py | 118 +++++++--- spoolman/database/utils.py | 222 +++++++++++++++++- spoolman/database/vendor.py | 62 ++++- .../tests/fields/test_filter_sort.py | 176 ++++++++++++++ 13 files changed, 821 insertions(+), 69 deletions(-) create mode 100644 tests_integration/tests/fields/test_filter_sort.py diff --git a/client/src/components/column.tsx b/client/src/components/column.tsx index 059b607f0..7f5103048 100644 --- a/client/src/components/column.tsx +++ b/client/src/components/column.tsx @@ -40,7 +40,7 @@ export interface Action { interface BaseColumnProps { id: string | string[]; - dataId?: keyof Obj & string; + dataId?: keyof Obj & string | string; // Allow string values for custom fields i18ncat?: string; i18nkey?: string; title?: string; @@ -389,13 +389,56 @@ export function NumberRangeColumn(props: NumberColumnProps { + filters.push({ + text: choice, + value: `"${choice}"`, // Exact match + }); + }); + } + + // For boolean fields, add true/false options + if (field.field_type === FieldType.boolean) { + filters.push( + { text: "Yes", value: "true" }, + { text: "No", value: "false" } + ); + } + + // Add empty option for all field types + filters.push({ + text: "", + value: "", + }); + + return filters; +} + export function CustomFieldColumn(props: Omit, "id"> & { field: Field }) { const field = props.field; + const fieldId = `extra.${field.key}`; + + // Get filtered values for this field + const typedFilters = typeFilters(props.tableState.filters); + const filteredValue = getFiltersForField(typedFilters, fieldId); + + // Create filters based on field type + const filters = createCustomFieldFilters(field); + const commonProps = { ...props, id: ["extra", field.key], title: field.name, - sorter: false, + sorter: true, // Enable sorting for custom fields + dataId: fieldId, // Set the dataId for sorting + filters: filters, // Add filters + filteredValue: filteredValue, // Set filtered values transform: (value: unknown) => { if (value === null || value === undefined) { return undefined; diff --git a/client/src/components/dataProvider.ts b/client/src/components/dataProvider.ts index 95efd97a9..d2c439af5 100644 --- a/client/src/components/dataProvider.ts +++ b/client/src/components/dataProvider.ts @@ -2,6 +2,9 @@ import { DataProvider } from "@refinedev/core"; import { axiosInstance } from "@refinedev/simple-rest"; import { AxiosInstance } from "axios"; import { stringify } from "query-string"; +import { getCustomFieldFilters } from "../utils/filtering"; +import { isCustomField } from "../utils/queryFields"; +import { getCustomFieldSorters, isCustomFieldSorter } from "../utils/sorting"; type MethodTypes = "get" | "delete" | "head" | "options"; type MethodTypesWithBody = "post" | "put" | "patch"; @@ -25,20 +28,30 @@ const dataProvider = ( } if (sorters && sorters.length > 0) { + // Map all sorters, including custom field sorters queryParams["sort"] = sorters .map((sort) => { const field = sort.field; + // Custom field sorters are already in the correct format (extra.field_key) return `${field}:${sort.order}`; }) .join(","); } if (filters && filters.length > 0) { + // Process regular filters filters.forEach((filter) => { if (!("field" in filter)) { throw Error("Filter must be a LogicalFilter."); } + const field = filter.field; + + // Skip custom fields, they'll be handled separately + if (typeof field === 'string' && isCustomField(field)) { + return; + } + if (filter.value.length > 0) { const filterValueArray = Array.isArray(filter.value) ? filter.value : [filter.value]; @@ -54,6 +67,14 @@ const dataProvider = ( queryParams[field] = filterValue; } }); + + // Process custom field filters + const customFieldFilters = getCustomFieldFilters(filters); + Object.entries(customFieldFilters).forEach(([key, values]) => { + if (values.length > 0) { + queryParams[`extra.${key}`] = values.join(","); + } + }); } const { data, headers } = await httpClient[requestMethod](`${url}`, { diff --git a/client/src/utils/filtering.ts b/client/src/utils/filtering.ts index da99b43c4..177ab96a7 100644 --- a/client/src/utils/filtering.ts +++ b/client/src/utils/filtering.ts @@ -1,7 +1,8 @@ import { CrudFilter, CrudOperators } from "@refinedev/core"; +import { Field, FieldType, getCustomFieldKey, isCustomField } from "./queryFields"; interface TypedCrudFilter { - field: keyof Obj; + field: keyof Obj | string; operator: Exclude; value: string[]; } @@ -16,9 +17,9 @@ export function typeFilters(filters: CrudFilter[]): TypedCrudFilter[] * @param field The field to get the filter values for. * @returns An array of filter values for the given field. */ -export function getFiltersForField( +export function getFiltersForField( filters: TypedCrudFilter[], - field: Field, + field: Field | string, ): string[] { const filterValues: string[] = []; filters.forEach((filter) => { @@ -29,6 +30,80 @@ export function getFiltersForField( return filterValues; } +/** + * Creates a filter value for a custom field based on its type + * @param field The custom field definition + * @param value The value to filter by + * @returns The formatted filter value + */ +export function formatCustomFieldFilterValue(field: Field, value: any): string { + switch (field.field_type) { + case FieldType.text: + case FieldType.choice: + // For text and choice fields, we can use the value directly + // If it's an exact match, surround with quotes + if (typeof value === "string" && !value.startsWith('"') && !value.endsWith('"')) { + // Check if we need an exact match (no wildcards) + if (!value.includes("*") && !value.includes("?")) { + return `"${value}"`; + } + } + return value; + + case FieldType.integer: + case FieldType.float: + // For numeric fields, we can use the value directly + return value.toString(); + + case FieldType.boolean: + // For boolean fields, convert to "true" or "false" + return value ? "true" : "false"; + + case FieldType.datetime: + // For datetime fields, format as ISO string + if (value instanceof Date) { + return value.toISOString(); + } + return value; + + case FieldType.integer_range: + case FieldType.float_range: + // For range fields, format as min:max + if (Array.isArray(value) && value.length === 2) { + return `${value[0] ?? ""}:${value[1] ?? ""}`; + } + return value; + + default: + return value; + } +} + +/** + * Extracts all custom field filters from a list of filters + * @param filters The list of filters + * @returns An object with custom field keys and their filter values + */ +export function getCustomFieldFilters( + filters: CrudFilter[] | TypedCrudFilter[] +): Record { + const customFieldFilters: Record = {}; + + filters.forEach((filter) => { + if (!("field" in filter)) { + return; // Skip non-field filters + } + + const field = filter.field.toString(); + if (isCustomField(field)) { + const key = getCustomFieldKey(field); + customFieldFilters[key] = filter.value as string[]; + } + }); + + return customFieldFilters; +} + /** * Function that returns an array with all undefined values removed. */ diff --git a/client/src/utils/queryFields.ts b/client/src/utils/queryFields.ts index 7cde38c05..f7a048a74 100644 --- a/client/src/utils/queryFields.ts +++ b/client/src/utils/queryFields.ts @@ -110,6 +110,24 @@ export function useSetField(entity_type: EntityType) { }); } +/** + * Checks if a field is a custom field (starts with "extra.") + * @param field The field to check + * @returns True if the field is a custom field + */ +export function isCustomField(field: string): boolean { + return field.startsWith("extra."); +} + +/** + * Extracts the key from a custom field (removes the "extra." prefix) + * @param field The custom field + * @returns The key of the custom field + */ +export function getCustomFieldKey(field: string): string { + return field.substring(6); // Remove "extra." prefix +} + export function useDeleteField(entity_type: EntityType) { const queryClient = useQueryClient(); diff --git a/client/src/utils/sorting.ts b/client/src/utils/sorting.ts index 543d72546..c75603d2a 100644 --- a/client/src/utils/sorting.ts +++ b/client/src/utils/sorting.ts @@ -1,8 +1,9 @@ import { CrudSort } from "@refinedev/core"; import { SortOrder } from "antd/es/table/interface"; +import { getCustomFieldKey, isCustomField } from "./queryFields"; interface TypedCrudSort { - field: keyof Obj; + field: keyof Obj | string; order: "asc" | "desc"; } @@ -12,9 +13,9 @@ interface TypedCrudSort { * @param field The field to get the sort order for. * @returns The sort order for the given field, or undefined if the field is not being sorted. */ -export function getSortOrderForField( +export function getSortOrderForField( sorters: TypedCrudSort[], - field: Field, + field: Field | string, ): SortOrder | undefined { const sorter = sorters.find((s) => s.field === field); if (sorter) { @@ -26,3 +27,33 @@ export function getSortOrderForField( export function typeSorters(sorters: CrudSort[]): TypedCrudSort[] { return sorters as TypedCrudSort[]; // <-- Unsafe cast } + +/** + * Checks if a sorter is for a custom field + * @param sorter The sorter to check + * @returns True if the sorter is for a custom field + */ +export function isCustomFieldSorter(sorter: TypedCrudSort | CrudSort): boolean { + return typeof sorter.field === 'string' && isCustomField(sorter.field); +} + +/** + * Extracts all custom field sorters from a list of sorters + * @param sorters The list of sorters + * @returns An object with custom field keys and their sort orders + */ +export function getCustomFieldSorters( + sorters: TypedCrudSort[] | CrudSort[] +): Record { + const customFieldSorters: Record = {}; + + sorters.forEach((sorter) => { + if (isCustomFieldSorter(sorter)) { + const field = sorter.field.toString(); + const key = getCustomFieldKey(field); + customFieldSorters[key] = sorter.order; + } + }); + + return customFieldSorters; +} diff --git a/spoolman/api/v1/filament.py b/spoolman/api/v1/filament.py index 3e3f859af..e10963fb7 100644 --- a/spoolman/api/v1/filament.py +++ b/spoolman/api/v1/filament.py @@ -4,7 +4,7 @@ import logging from typing import Annotated -from fastapi import APIRouter, Depends, Query, WebSocket, WebSocketDisconnect +from fastapi import APIRouter, Depends, Query, Request, WebSocket, WebSocketDisconnect from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse from pydantic import BaseModel, Field, field_validator, model_validator @@ -201,6 +201,7 @@ def prevent_none(cls: type["FilamentUpdateParameters"], v: float | None) -> floa ) async def find( *, + request: Request, db: Annotated[AsyncSession, Depends(get_db_session)], vendor_name_old: Annotated[ str | None, @@ -342,6 +343,14 @@ async def find( else: filter_by_ids = None + # Extract custom field filters from query parameters + extra_field_filters = {} + query_params = request.query_params + for key, value in query_params.items(): + if key.startswith("extra."): + field_key = key[6:] # Remove "extra." prefix + extra_field_filters[field_key] = value + db_items, total_count = await filament.find( db=db, ids=filter_by_ids, @@ -351,6 +360,7 @@ async def find( material=material, article_number=article_number, external_id=external_id, + extra_field_filters=extra_field_filters if extra_field_filters else None, sort_by=sort_by, limit=limit, offset=offset, diff --git a/spoolman/api/v1/spool.py b/spoolman/api/v1/spool.py index 8f667e3da..ef659a327 100644 --- a/spoolman/api/v1/spool.py +++ b/spoolman/api/v1/spool.py @@ -5,7 +5,7 @@ from datetime import datetime from typing import Annotated -from fastapi import APIRouter, Depends, Query, WebSocket, WebSocketDisconnect +from fastapi import APIRouter, Depends, Query, Request, WebSocket, WebSocketDisconnect from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse from pydantic import BaseModel, Field, field_validator @@ -127,6 +127,7 @@ class SpoolMeasureParameters(BaseModel): ) async def find( *, + request: Request, db: Annotated[AsyncSession, Depends(get_db_session)], filament_name_old: Annotated[ str | None, @@ -285,6 +286,14 @@ async def find( else: filament_vendor_ids = None + # Extract custom field filters from query parameters + extra_field_filters = {} + query_params = request.query_params + for key, value in query_params.items(): + if key.startswith("extra."): + field_key = key[6:] # Remove "extra." prefix + extra_field_filters[field_key] = value + db_items, total_count = await spool.find( db=db, filament_name=filament_name if filament_name is not None else filament_name_old, @@ -295,6 +304,7 @@ async def find( location=location, lot_nr=lot_nr, allow_archived=allow_archived, + extra_field_filters=extra_field_filters if extra_field_filters else None, sort_by=sort_by, limit=limit, offset=offset, diff --git a/spoolman/api/v1/vendor.py b/spoolman/api/v1/vendor.py index 9216fba30..54601228a 100644 --- a/spoolman/api/v1/vendor.py +++ b/spoolman/api/v1/vendor.py @@ -3,7 +3,7 @@ import asyncio from typing import Annotated -from fastapi import APIRouter, Depends, Query, WebSocket, WebSocketDisconnect +from fastapi import APIRouter, Depends, Query, Request, WebSocket, WebSocketDisconnect from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse from pydantic import BaseModel, Field, field_validator @@ -79,6 +79,7 @@ def prevent_none(cls: type["VendorUpdateParameters"], v: str | None) -> str | No }, ) async def find( + request: Request, db: Annotated[AsyncSession, Depends(get_db_session)], name: Annotated[ str | None, @@ -124,10 +125,19 @@ async def find( field, direction = sort_item.split(":") sort_by[field] = SortOrder[direction.upper()] + # Extract custom field filters from query parameters + extra_field_filters = {} + query_params = request.query_params + for key, value in query_params.items(): + if key.startswith("extra."): + field_key = key[6:] # Remove "extra." prefix + extra_field_filters[field_key] = value + db_items, total_count = await vendor.find( db=db, name=name, external_id=external_id, + extra_field_filters=extra_field_filters if extra_field_filters else None, sort_by=sort_by, limit=limit, offset=offset, diff --git a/spoolman/database/filament.py b/spoolman/database/filament.py index e2d742758..8818cec70 100644 --- a/spoolman/database/filament.py +++ b/spoolman/database/filament.py @@ -12,14 +12,7 @@ from spoolman.api.v1.models import EventType, Filament, FilamentEvent, MultiColorDirection from spoolman.database import models, vendor -from spoolman.database.utils import ( - SortOrder, - add_where_clause_int_in, - add_where_clause_int_opt, - add_where_clause_str, - add_where_clause_str_opt, - parse_nested_field, -) +from spoolman.database.utils import SortOrder from spoolman.exceptions import ItemDeleteError, ItemNotFoundError from spoolman.math import delta_e, hex_to_rgb, rgb_to_lab from spoolman.ws import websocket_manager @@ -102,6 +95,7 @@ async def find( material: str | None = None, article_number: str | None = None, external_id: str | None = None, + extra_field_filters: dict[str, str] | None = None, sort_by: dict[str, SortOrder] | None = None, limit: int | None = None, offset: int = 0, @@ -113,6 +107,17 @@ async def find( Returns a tuple containing the list of items and the total count of matching items. """ + # Import here to avoid circular imports + from spoolman.database.utils import ( + add_where_clause_int_in, + add_where_clause_int_opt, + add_where_clause_str, + add_where_clause_str_opt, + add_where_clause_extra_field, + add_order_by_extra_field, + parse_nested_field, + ) + stmt = ( select(models.Filament) .options(contains_eager(models.Filament.vendor)) @@ -135,13 +140,54 @@ async def find( stmt = stmt.offset(offset).limit(limit) + # Apply extra field filters if provided + if extra_field_filters: + # Get all extra fields for filaments + from spoolman.extra_fields import EntityType, get_extra_fields + + extra_fields = await get_extra_fields(db, EntityType.filament) + extra_fields_dict = {field.key: field for field in extra_fields} + + for field_key, value in extra_field_filters.items(): + if field_key in extra_fields_dict: + field = extra_fields_dict[field_key] + stmt = add_where_clause_extra_field( + stmt, + models.Filament, + EntityType.filament, + field_key, + field.field_type, + value, + field.multi_choice if field.field_type == "choice" else None + ) + if sort_by is not None: for fieldstr, order in sort_by.items(): - field = parse_nested_field(models.Filament, fieldstr) - if order == SortOrder.ASC: - stmt = stmt.order_by(field.asc()) - elif order == SortOrder.DESC: - stmt = stmt.order_by(field.desc()) + # Check if this is a custom field sort + if fieldstr.startswith("extra."): + field_key = fieldstr[6:] # Remove "extra." prefix + + # Get the field definition + from spoolman.extra_fields import EntityType, get_extra_fields + + extra_fields = await get_extra_fields(db, EntityType.filament) + extra_field = next((f for f in extra_fields if f.key == field_key), None) + + if extra_field: + stmt = add_order_by_extra_field( + stmt, + models.Filament, + EntityType.filament, + field_key, + extra_field.field_type, + order + ) + else: + field = parse_nested_field(models.Filament, fieldstr) + if order == SortOrder.ASC: + stmt = stmt.order_by(field.asc()) + elif order == SortOrder.DESC: + stmt = stmt.order_by(field.desc()) rows = await db.execute( stmt, diff --git a/spoolman/database/spool.py b/spoolman/database/spool.py index 5c190ce65..a4541b72e 100644 --- a/spoolman/database/spool.py +++ b/spoolman/database/spool.py @@ -13,14 +13,7 @@ from spoolman.api.v1.models import EventType, Spool, SpoolEvent from spoolman.database import filament, models -from spoolman.database.utils import ( - SortOrder, - add_where_clause_int, - add_where_clause_int_opt, - add_where_clause_str, - add_where_clause_str_opt, - parse_nested_field, -) +from spoolman.database.utils import SortOrder from spoolman.exceptions import ItemCreateError, ItemNotFoundError, SpoolMeasureError from spoolman.math import weight_from_length from spoolman.ws import websocket_manager @@ -122,6 +115,7 @@ async def find( # noqa: C901, PLR0912 location: str | None = None, lot_nr: str | None = None, allow_archived: bool = False, + extra_field_filters: dict[str, str] | None = None, sort_by: dict[str, SortOrder] | None = None, limit: int | None = None, offset: int = 0, @@ -133,6 +127,17 @@ async def find( # noqa: C901, PLR0912 Returns a tuple containing the list of items and the total count of matching items. """ + # Import here to avoid circular imports + from spoolman.database.utils import ( + add_where_clause_int, + add_where_clause_int_opt, + add_where_clause_str, + add_where_clause_str_opt, + add_where_clause_extra_field, + add_order_by_extra_field, + parse_nested_field, + ) + stmt = ( sqlalchemy.select(models.Spool) .join(models.Spool.filament, isouter=True) @@ -165,37 +170,78 @@ async def find( # noqa: C901, PLR0912 stmt = stmt.offset(offset).limit(limit) + # Apply extra field filters if provided + if extra_field_filters: + # Get all extra fields for spools + from spoolman.extra_fields import EntityType, get_extra_fields + + extra_fields = await get_extra_fields(db, EntityType.spool) + extra_fields_dict = {field.key: field for field in extra_fields} + + for field_key, value in extra_field_filters.items(): + if field_key in extra_fields_dict: + field = extra_fields_dict[field_key] + stmt = add_where_clause_extra_field( + stmt, + models.Spool, + EntityType.spool, + field_key, + field.field_type, + value, + field.multi_choice if field.field_type == "choice" else None + ) + if sort_by is not None: for fieldstr, order in sort_by.items(): - sorts = [] - if fieldstr == "remaining_weight": - sorts.append(coalesce(models.Spool.initial_weight, models.Filament.weight) - models.Spool.used_weight) - elif fieldstr == "remaining_length": - # Simplified weight -> length formula. Absolute value is not correct but the proportionality is still - # kept, which means the sort order is correct. - sorts.append( - (coalesce(models.Spool.initial_weight, models.Filament.weight) - models.Spool.used_weight) - / models.Filament.density - / (models.Filament.diameter * models.Filament.diameter), - ) - elif fieldstr == "used_length": - sorts.append( - models.Spool.used_weight - / models.Filament.density - / (models.Filament.diameter * models.Filament.diameter), - ) - elif fieldstr == "filament.combined_name": - sorts.append(models.Vendor.name) - sorts.append(models.Filament.name) - elif fieldstr == "price": - sorts.append(coalesce(models.Spool.price, models.Filament.price)) + # Check if this is a custom field sort + if fieldstr.startswith("extra."): + field_key = fieldstr[6:] # Remove "extra." prefix + + # Get the field definition + from spoolman.extra_fields import EntityType, get_extra_fields + + extra_fields = await get_extra_fields(db, EntityType.spool) + extra_field = next((f for f in extra_fields if f.key == field_key), None) + + if extra_field: + stmt = add_order_by_extra_field( + stmt, + models.Spool, + EntityType.spool, + field_key, + extra_field.field_type, + order + ) else: - sorts.append(parse_nested_field(models.Spool, fieldstr)) - - if order == SortOrder.ASC: - stmt = stmt.order_by(*(f.asc() for f in sorts)) - elif order == SortOrder.DESC: - stmt = stmt.order_by(*(f.desc() for f in sorts)) + sorts = [] + if fieldstr == "remaining_weight": + sorts.append(coalesce(models.Spool.initial_weight, models.Filament.weight) - models.Spool.used_weight) + elif fieldstr == "remaining_length": + # Simplified weight -> length formula. Absolute value is not correct but the proportionality is still + # kept, which means the sort order is correct. + sorts.append( + (coalesce(models.Spool.initial_weight, models.Filament.weight) - models.Spool.used_weight) + / models.Filament.density + / (models.Filament.diameter * models.Filament.diameter), + ) + elif fieldstr == "used_length": + sorts.append( + models.Spool.used_weight + / models.Filament.density + / (models.Filament.diameter * models.Filament.diameter), + ) + elif fieldstr == "filament.combined_name": + sorts.append(models.Vendor.name) + sorts.append(models.Filament.name) + elif fieldstr == "price": + sorts.append(coalesce(models.Spool.price, models.Filament.price)) + else: + sorts.append(parse_nested_field(models.Spool, fieldstr)) + + if order == SortOrder.ASC: + stmt = stmt.order_by(*(f.asc() for f in sorts)) + elif order == SortOrder.DESC: + stmt = stmt.order_by(*(f.desc() for f in sorts)) rows = await db.execute( stmt, diff --git a/spoolman/database/utils.py b/spoolman/database/utils.py index 2d8776c00..888d9deaf 100644 --- a/spoolman/database/utils.py +++ b/spoolman/database/utils.py @@ -1,14 +1,20 @@ """Utility functions for the database module.""" from collections.abc import Sequence +import json from enum import Enum -from typing import Any, TypeVar +from typing import Any, Dict, Tuple, Type, TypeVar import sqlalchemy -from sqlalchemy import Select -from sqlalchemy.orm import attributes +from sqlalchemy import Select, and_, cast, func, or_, text +from sqlalchemy.orm import attributes, aliased +from sqlalchemy.sql import expression from spoolman.database import models +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from spoolman.extra_fields import EntityType, ExtraField, ExtraFieldType class SortOrder(Enum): @@ -129,3 +135,213 @@ def add_where_clause_int_in( if value is not None: stmt = stmt.where(field.in_(value)) return stmt + + +def get_field_table_for_entity(entity_type: Any) -> Type[models.Base]: + """Get the field table class for a given entity type.""" + # Import here to avoid circular imports + from spoolman.extra_fields import EntityType + + if entity_type == EntityType.spool: + return models.SpoolField + elif entity_type == EntityType.filament: + return models.FilamentField + elif entity_type == EntityType.vendor: + return models.VendorField + else: + raise ValueError(f"Unknown entity type: {entity_type}") + + +def get_entity_id_column(field_table: Type[models.Base]) -> attributes.InstrumentedAttribute[int]: + """Get the entity ID column for a given field table.""" + if field_table == models.SpoolField: + return models.SpoolField.spool_id + elif field_table == models.FilamentField: + return models.FilamentField.filament_id + elif field_table == models.VendorField: + return models.VendorField.vendor_id + else: + raise ValueError(f"Unknown field table: {field_table}") + + +def add_where_clause_extra_field( + stmt: Select, + base_obj: Type[models.Base], + entity_type: Any, + field_key: str, + field_type: Any, + value: str, + multi_choice: bool | None = None, +) -> Select: + """Add a where clause to a select statement for an extra field. + Args: + stmt: The select statement to add the where clause to + base_obj: The base object type (Spool, Filament, Vendor) + entity_type: The entity type + field_key: The key of the extra field + field_type: The type of the extra field + value: The value to filter by + multi_choice: Whether the field is a multi-choice field (only for choice fields) + Returns: + The modified select statement + """ + # Import here to avoid circular imports + from spoolman.extra_fields import ExtraFieldType + + field_table = get_field_table_for_entity(entity_type) + entity_id_column = get_entity_id_column(field_table) + + value_parts = value.split(",") + + # Handle filtering for empty values + if any(p == "" or len(p) == 0 for p in value_parts): + # An item is considered "empty" if: + # A) A row exists in the field table, and its value is null, 'null', or 'false' for booleans. + # B) No row exists in the field table for this item and field_key. + + # Condition A subquery + empty_conditions = [ + field_table.value.is_(None), + field_table.value == "null", + ] + if field_type == ExtraFieldType.boolean: + empty_conditions.append(field_table.value == json.dumps(False)) + + subq_a = sqlalchemy.select(entity_id_column).where( + sqlalchemy.and_(field_table.key == field_key, sqlalchemy.or_(*empty_conditions)) + ) + + # Condition B subquery + subq_b = sqlalchemy.select(base_obj.id).where( + getattr(base_obj, "id").not_in(sqlalchemy.select(entity_id_column).where(field_table.key == field_key)) + ) + + return stmt.where( + sqlalchemy.or_( + getattr(base_obj, "id").in_(subq_a), + getattr(base_obj, "id").in_(subq_b), + ) + ) + + # Handle filtering for specific values + conditions = [] + for value_part in value_parts: + exact_match = value_part.startswith('"') and value_part.endswith('"') + if exact_match: + value_part = value_part[1:-1] + + if field_type == ExtraFieldType.text: + if exact_match: + conditions.append(field_table.value == json.dumps(value_part)) + else: + conditions.append(field_table.value.ilike(f"%{value_part}%")) + elif field_type == ExtraFieldType.integer: + try: + conditions.append(field_table.value == json.dumps(int(value_part))) + except ValueError: + pass + elif field_type == ExtraFieldType.float: + try: + conditions.append(field_table.value == json.dumps(float(value_part))) + except ValueError: + pass + elif field_type == ExtraFieldType.boolean: + bool_value = value_part.lower() in ("true", "1", "yes") + conditions.append(field_table.value == json.dumps(bool_value)) + elif field_type == ExtraFieldType.choice: + if multi_choice: + conditions.append(field_table.value.like(f'%"{value_part}"%')) + else: + conditions.append(field_table.value == json.dumps(value_part)) + elif field_type == ExtraFieldType.datetime: + conditions.append(field_table.value == json.dumps(value_part)) + elif field_type in (ExtraFieldType.integer_range, ExtraFieldType.float_range): + if ":" in value_part: + min_val_str, max_val_str = value_part.split(":", 1) + converter = int if field_type == ExtraFieldType.integer_range else float + try: + if min_val_str: + conditions.append(func.json_extract(field_table.value, "$[0]") >= converter(min_val_str)) + if max_val_str: + conditions.append(func.json_extract(field_table.value, "$[1]") <= converter(max_val_str)) + except (ValueError, TypeError): + pass + + if not conditions: + return stmt + + subq = sqlalchemy.select(entity_id_column).where( + sqlalchemy.and_(field_table.key == field_key, sqlalchemy.or_(*conditions)) + ) + + return stmt.where(getattr(base_obj, "id").in_(subq)) + + +def add_order_by_extra_field( + stmt: Select, + base_obj: Type[models.Base], + entity_type: Any, + field_key: str, + field_type: Any, + order: SortOrder, +) -> Select: + """Add an order by clause to a select statement for an extra field. + + Args: + stmt: The select statement to add the order by clause to + base_obj: The base object type (Spool, Filament, Vendor) + entity_type: The entity type + field_key: The key of the extra field + field_type: The type of the extra field + order: The sort order + + Returns: + The modified select statement + """ + # Import here to avoid circular imports + from spoolman.extra_fields import EntityType, ExtraFieldType + + # Use a subquery approach instead of joins + field_table = get_field_table_for_entity(entity_type) + entity_id_column = get_entity_id_column(field_table) + + # Create a subquery that selects the value for each entity + value_subq = ( + sqlalchemy.select(field_table.value) + .where( + sqlalchemy.and_( + field_table.key == field_key, + entity_id_column == getattr(base_obj, "id") + ) + ) + .scalar_subquery() + .correlate(base_obj) + ) + + # Create a sort expression based on the field type + if field_type == ExtraFieldType.integer: + # Cast the JSON value to an integer for sorting + sort_expr = func.cast(func.json_extract(value_subq, '$'), sqlalchemy.Integer) + elif field_type == ExtraFieldType.float: + # Cast the JSON value to a float for sorting + sort_expr = func.cast(func.json_extract(value_subq, '$'), sqlalchemy.Float) + elif field_type == ExtraFieldType.datetime: + # For datetime fields, we can sort by the ISO string + sort_expr = value_subq + elif field_type == ExtraFieldType.boolean: + # For boolean fields, true comes after false + sort_expr = value_subq + elif field_type in (ExtraFieldType.integer_range, ExtraFieldType.float_range): + # For range fields, sort by the first value in the range + sort_expr = func.json_extract(value_subq, '$[0]') + else: + # For text and choice fields, sort by the string value + sort_expr = value_subq + + # Apply the sort order + if order == SortOrder.ASC: + stmt = stmt.order_by(sort_expr.asc()) + else: + stmt = stmt.order_by(sort_expr.desc()) + + return stmt diff --git a/spoolman/database/vendor.py b/spoolman/database/vendor.py index f2e83018e..51ed2e509 100644 --- a/spoolman/database/vendor.py +++ b/spoolman/database/vendor.py @@ -9,7 +9,7 @@ from spoolman.api.v1.models import EventType, Vendor, VendorEvent from spoolman.database import models -from spoolman.database.utils import SortOrder, add_where_clause_str, add_where_clause_str_opt +from spoolman.database.utils import SortOrder from spoolman.exceptions import ItemNotFoundError from spoolman.ws import websocket_manager @@ -53,6 +53,7 @@ async def find( db: AsyncSession, name: str | None = None, external_id: str | None = None, + extra_field_filters: dict[str, str] | None = None, sort_by: dict[str, SortOrder] | None = None, limit: int | None = None, offset: int = 0, @@ -61,6 +62,14 @@ async def find( Returns a tuple containing the list of items and the total count of matching items. """ + # Import here to avoid circular imports + from spoolman.database.utils import ( + add_where_clause_str, + add_where_clause_str_opt, + add_where_clause_extra_field, + add_order_by_extra_field + ) + stmt = select(models.Vendor) stmt = add_where_clause_str(stmt, models.Vendor.name, name) @@ -74,13 +83,54 @@ async def find( stmt = stmt.offset(offset).limit(limit) + # Apply extra field filters if provided + if extra_field_filters: + # Get all extra fields for vendors + from spoolman.extra_fields import EntityType, get_extra_fields + + extra_fields = await get_extra_fields(db, EntityType.vendor) + extra_fields_dict = {field.key: field for field in extra_fields} + + for field_key, value in extra_field_filters.items(): + if field_key in extra_fields_dict: + field = extra_fields_dict[field_key] + stmt = add_where_clause_extra_field( + stmt, + models.Vendor, + EntityType.vendor, + field_key, + field.field_type, + value, + field.multi_choice if field.field_type == "choice" else None + ) + if sort_by is not None: for fieldstr, order in sort_by.items(): - field = getattr(models.Vendor, fieldstr) - if order == SortOrder.ASC: - stmt = stmt.order_by(field.asc()) - elif order == SortOrder.DESC: - stmt = stmt.order_by(field.desc()) + # Check if this is a custom field sort + if fieldstr.startswith("extra."): + field_key = fieldstr[6:] # Remove "extra." prefix + + # Get the field definition + from spoolman.extra_fields import EntityType, get_extra_fields + + extra_fields = await get_extra_fields(db, EntityType.vendor) + extra_field = next((f for f in extra_fields if f.key == field_key), None) + + if extra_field: + stmt = add_order_by_extra_field( + stmt, + models.Vendor, + EntityType.vendor, + field_key, + extra_field.field_type, + order + ) + else: + field = getattr(models.Vendor, fieldstr) + if order == SortOrder.ASC: + stmt = stmt.order_by(field.asc()) + elif order == SortOrder.DESC: + stmt = stmt.order_by(field.desc()) rows = await db.execute( stmt, diff --git a/tests_integration/tests/fields/test_filter_sort.py b/tests_integration/tests/fields/test_filter_sort.py new file mode 100644 index 000000000..0b8077c55 --- /dev/null +++ b/tests_integration/tests/fields/test_filter_sort.py @@ -0,0 +1,176 @@ +"""Tests for filtering and sorting by custom fields.""" + +import json +import pytest +from httpx import AsyncClient + + +@pytest.mark.asyncio +async def test_filter_by_custom_field(client: AsyncClient, setup_extra_fields): + """Test filtering by custom field.""" + # Create a spool with a custom field + spool_data = { + "filament_id": 1, + "extra": { + "test_field": json.dumps("test_value") + } + } + response = await client.post("/api/v1/spool", json=spool_data) + assert response.status_code == 200 + spool_id = response.json()["id"] + + # Create another spool with a different custom field value + spool_data2 = { + "filament_id": 1, + "extra": { + "test_field": json.dumps("other_value") + } + } + response = await client.post("/api/v1/spool", json=spool_data2) + assert response.status_code == 200 + spool_id2 = response.json()["id"] + + # Filter by custom field + response = await client.get("/api/v1/spool", params={"extra.test_field": "test_value"}) + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["id"] == spool_id + + # Filter by custom field with exact match + response = await client.get("/api/v1/spool", params={"extra.test_field": '"test_value"'}) + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["id"] == spool_id + + # Filter by custom field with multiple values + response = await client.get("/api/v1/spool", params={"extra.test_field": "test_value,other_value"}) + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + assert {item["id"] for item in data} == {spool_id, spool_id2} + + +@pytest.mark.asyncio +async def test_sort_by_custom_field(client: AsyncClient, setup_extra_fields): + """Test sorting by custom field.""" + # Create spools with custom fields of different types + # Text field + spool_data1 = { + "filament_id": 1, + "extra": { + "text_field": json.dumps("B value") + } + } + response = await client.post("/api/v1/spool", json=spool_data1) + assert response.status_code == 200 + spool_id1 = response.json()["id"] + + spool_data2 = { + "filament_id": 1, + "extra": { + "text_field": json.dumps("A value") + } + } + response = await client.post("/api/v1/spool", json=spool_data2) + assert response.status_code == 200 + spool_id2 = response.json()["id"] + + # Sort by custom field ascending + response = await client.get("/api/v1/spool", params={"sort": "extra.text_field:asc"}) + assert response.status_code == 200 + data = response.json() + assert len(data) >= 2 + # Find our test spools in the results + test_spools = [item for item in data if item["id"] in (spool_id1, spool_id2)] + assert len(test_spools) == 2 + assert test_spools[0]["id"] == spool_id2 # A value should come first + assert test_spools[1]["id"] == spool_id1 # B value should come second + + # Sort by custom field descending + response = await client.get("/api/v1/spool", params={"sort": "extra.text_field:desc"}) + assert response.status_code == 200 + data = response.json() + assert len(data) >= 2 + # Find our test spools in the results + test_spools = [item for item in data if item["id"] in (spool_id1, spool_id2)] + assert len(test_spools) == 2 + assert test_spools[0]["id"] == spool_id1 # B value should come first + assert test_spools[1]["id"] == spool_id2 # A value should come second + + +@pytest.mark.asyncio +async def test_filter_by_numeric_custom_field(client: AsyncClient, setup_extra_fields): + """Test filtering by numeric custom field.""" + # Create a spool with a numeric custom field + spool_data = { + "filament_id": 1, + "extra": { + "numeric_field": json.dumps(100) + } + } + response = await client.post("/api/v1/spool", json=spool_data) + assert response.status_code == 200 + spool_id = response.json()["id"] + + # Create another spool with a different numeric value + spool_data2 = { + "filament_id": 1, + "extra": { + "numeric_field": json.dumps(200) + } + } + response = await client.post("/api/v1/spool", json=spool_data2) + assert response.status_code == 200 + spool_id2 = response.json()["id"] + + # Filter by numeric custom field + response = await client.get("/api/v1/spool", params={"extra.numeric_field": "100"}) + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["id"] == spool_id + + # Sort by numeric custom field ascending + response = await client.get("/api/v1/spool", params={"sort": "extra.numeric_field:asc"}) + assert response.status_code == 200 + data = response.json() + # Find our test spools in the results + test_spools = [item for item in data if item["id"] in (spool_id, spool_id2)] + assert len(test_spools) == 2 + assert test_spools[0]["id"] == spool_id # 100 should come first + assert test_spools[1]["id"] == spool_id2 # 200 should come second + + +@pytest.mark.asyncio +async def test_filter_by_boolean_custom_field(client: AsyncClient, setup_extra_fields): + """Test filtering by boolean custom field.""" + # Create a spool with a boolean custom field + spool_data = { + "filament_id": 1, + "extra": { + "bool_field": json.dumps(True) + } + } + response = await client.post("/api/v1/spool", json=spool_data) + assert response.status_code == 200 + spool_id = response.json()["id"] + + # Create another spool with a different boolean value + spool_data2 = { + "filament_id": 1, + "extra": { + "bool_field": json.dumps(False) + } + } + response = await client.post("/api/v1/spool", json=spool_data2) + assert response.status_code == 200 + spool_id2 = response.json()["id"] + + # Filter by boolean custom field + response = await client.get("/api/v1/spool", params={"extra.bool_field": "true"}) + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["id"] == spool_id \ No newline at end of file From 33c669e9e3657d543e8ea24302b34377df756960 Mon Sep 17 00:00:00 2001 From: Dieter Blomme Date: Sun, 8 Feb 2026 23:34:39 +0100 Subject: [PATCH 2/5] Fix tests and make code postgres compatible --- spoolman/database/utils.py | 10 +- .../tests/fields/test_filter_sort.py | 302 +++++++++++------- 2 files changed, 198 insertions(+), 114 deletions(-) diff --git a/spoolman/database/utils.py b/spoolman/database/utils.py index 888d9deaf..41744cb79 100644 --- a/spoolman/database/utils.py +++ b/spoolman/database/utils.py @@ -261,9 +261,9 @@ def add_where_clause_extra_field( converter = int if field_type == ExtraFieldType.integer_range else float try: if min_val_str: - conditions.append(func.json_extract(field_table.value, "$[0]") >= converter(min_val_str)) + conditions.append(field_table.value[0].as_integer() >= converter(min_val_str)) if max_val_str: - conditions.append(func.json_extract(field_table.value, "$[1]") <= converter(max_val_str)) + conditions.append(field_table.value[1].as_integer() <= converter(min_val_str)) except (ValueError, TypeError): pass @@ -321,10 +321,10 @@ def add_order_by_extra_field( # Create a sort expression based on the field type if field_type == ExtraFieldType.integer: # Cast the JSON value to an integer for sorting - sort_expr = func.cast(func.json_extract(value_subq, '$'), sqlalchemy.Integer) + sort_expr = func.cast(value_subq, sqlalchemy.Integer) elif field_type == ExtraFieldType.float: # Cast the JSON value to a float for sorting - sort_expr = func.cast(func.json_extract(value_subq, '$'), sqlalchemy.Float) + sort_expr = func.cast(value_subq, sqlalchemy.Float) elif field_type == ExtraFieldType.datetime: # For datetime fields, we can sort by the ISO string sort_expr = value_subq @@ -333,7 +333,7 @@ def add_order_by_extra_field( sort_expr = value_subq elif field_type in (ExtraFieldType.integer_range, ExtraFieldType.float_range): # For range fields, sort by the first value in the range - sort_expr = func.json_extract(value_subq, '$[0]') + sort_expr = value_subq[0] else: # For text and choice fields, sort by the string value sort_expr = value_subq diff --git a/tests_integration/tests/fields/test_filter_sort.py b/tests_integration/tests/fields/test_filter_sort.py index 0b8077c55..141260582 100644 --- a/tests_integration/tests/fields/test_filter_sort.py +++ b/tests_integration/tests/fields/test_filter_sort.py @@ -1,86 +1,124 @@ """Tests for filtering and sorting by custom fields.""" +import httpx import json import pytest -from httpx import AsyncClient +from typing import Any + +from ..conftest import URL, assert_httpx_success, assert_lists_compatible @pytest.mark.asyncio -async def test_filter_by_custom_field(client: AsyncClient, setup_extra_fields): +async def test_filter_by_custom_field(random_filament: dict[str, Any]): + """Add a custom text field""" + result = httpx.post( + f"{URL}/api/v1/field/spool/test_field", + json={ + "name": "Test field", + "field_type": "text", + "default_value": json.dumps("Hello World"), + }, + ) + assert_httpx_success(result) + """Test filtering by custom field.""" # Create a spool with a custom field - spool_data = { - "filament_id": 1, - "extra": { - "test_field": json.dumps("test_value") - } - } - response = await client.post("/api/v1/spool", json=spool_data) - assert response.status_code == 200 - spool_id = response.json()["id"] + result = httpx.post( + f"{URL}/api/v1/spool", + json={ + "filament_id": random_filament["id"], + "extra": { + "test_field": json.dumps("test_value") + } + }, + ) + assert_httpx_success(result) + spool_id1 = result.json()["id"] # Create another spool with a different custom field value - spool_data2 = { - "filament_id": 1, - "extra": { - "test_field": json.dumps("other_value") - } - } - response = await client.post("/api/v1/spool", json=spool_data2) - assert response.status_code == 200 - spool_id2 = response.json()["id"] + result = httpx.post( + f"{URL}/api/v1/spool", + json={ + "filament_id": random_filament["id"], + "extra": { + "test_field": json.dumps("other_value") + } + }, + ) + assert_httpx_success(result) + spool_id2 = result.json()["id"] # Filter by custom field - response = await client.get("/api/v1/spool", params={"extra.test_field": "test_value"}) - assert response.status_code == 200 - data = response.json() + result = httpx.get(f"{URL}/api/v1/spool", params={"extra.test_field": "test_value"}) + assert_httpx_success(result) + data = result.json() assert len(data) == 1 - assert data[0]["id"] == spool_id + assert data[0]["id"] == spool_id1 # Filter by custom field with exact match - response = await client.get("/api/v1/spool", params={"extra.test_field": '"test_value"'}) - assert response.status_code == 200 - data = response.json() + result = httpx.get(f"{URL}/api/v1/spool", params={"extra.test_field": '"test_value"'}) + assert_httpx_success(result) + data = result.json() assert len(data) == 1 - assert data[0]["id"] == spool_id + assert data[0]["id"] == spool_id1 # Filter by custom field with multiple values - response = await client.get("/api/v1/spool", params={"extra.test_field": "test_value,other_value"}) - assert response.status_code == 200 - data = response.json() + result = httpx.get(f"{URL}/api/v1/spool", params={"extra.test_field": "test_value,other_value"}) + assert_httpx_success(result) + data = result.json() assert len(data) == 2 - assert {item["id"] for item in data} == {spool_id, spool_id2} + assert {item["id"] for item in data} == {spool_id1, spool_id2} + + # Clean up + result = httpx.delete(f"{URL}/api/v1/field/spool/test_field") + assert_httpx_success(result) + httpx.delete(f"{URL}/api/v1/spool/{spool_id1}").raise_for_status() + httpx.delete(f"{URL}/api/v1/spool/{spool_id2}").raise_for_status() @pytest.mark.asyncio -async def test_sort_by_custom_field(client: AsyncClient, setup_extra_fields): +async def test_sort_by_custom_field(random_filament: dict[str, Any]): + """Add a custom text field""" + result = httpx.post( + f"{URL}/api/v1/field/spool/text_field", + json={ + "name": "Text field", + "field_type": "text", + }, + ) + assert_httpx_success(result) + """Test sorting by custom field.""" # Create spools with custom fields of different types # Text field - spool_data1 = { - "filament_id": 1, - "extra": { - "text_field": json.dumps("B value") - } - } - response = await client.post("/api/v1/spool", json=spool_data1) - assert response.status_code == 200 - spool_id1 = response.json()["id"] - - spool_data2 = { - "filament_id": 1, - "extra": { - "text_field": json.dumps("A value") - } - } - response = await client.post("/api/v1/spool", json=spool_data2) - assert response.status_code == 200 - spool_id2 = response.json()["id"] + result = httpx.post( + f"{URL}/api/v1/spool", + json={ + "filament_id": random_filament["id"], + "extra": { + "text_field": json.dumps("B value") + } + }, + ) + assert_httpx_success(result) + spool_id1 = result.json()["id"] + + result = httpx.post( + f"{URL}/api/v1/spool", + json={ + "filament_id": random_filament["id"], + "extra": { + "text_field": json.dumps("A value") + } + }, + ) + assert_httpx_success(result) + spool_id2 = result.json()["id"] # Sort by custom field ascending - response = await client.get("/api/v1/spool", params={"sort": "extra.text_field:asc"}) - assert response.status_code == 200 - data = response.json() + result = httpx.get(f"{URL}/api/v1/spool", params={"sort": "extra.text_field:asc"}) + assert_httpx_success(result) + data = result.json() assert len(data) >= 2 # Find our test spools in the results test_spools = [item for item in data if item["id"] in (spool_id1, spool_id2)] @@ -89,9 +127,9 @@ async def test_sort_by_custom_field(client: AsyncClient, setup_extra_fields): assert test_spools[1]["id"] == spool_id1 # B value should come second # Sort by custom field descending - response = await client.get("/api/v1/spool", params={"sort": "extra.text_field:desc"}) - assert response.status_code == 200 - data = response.json() + result = httpx.get(f"{URL}/api/v1/spool", params={"sort": "extra.text_field:desc"}) + assert_httpx_success(result) + data = result.json() assert len(data) >= 2 # Find our test spools in the results test_spools = [item for item in data if item["id"] in (spool_id1, spool_id2)] @@ -99,78 +137,124 @@ async def test_sort_by_custom_field(client: AsyncClient, setup_extra_fields): assert test_spools[0]["id"] == spool_id1 # B value should come first assert test_spools[1]["id"] == spool_id2 # A value should come second + # Clean up + result = httpx.delete(f"{URL}/api/v1/field/spool/text_field") + assert_httpx_success(result) + httpx.delete(f"{URL}/api/v1/spool/{spool_id1}").raise_for_status() + httpx.delete(f"{URL}/api/v1/spool/{spool_id2}").raise_for_status() + @pytest.mark.asyncio -async def test_filter_by_numeric_custom_field(client: AsyncClient, setup_extra_fields): +async def test_filter_by_numeric_custom_field(random_filament: dict[str, Any]): + """Add a custom numeric field""" + result = httpx.post( + f"{URL}/api/v1/field/spool/numeric_field", + json={ + "name": "Numeric field", + "field_type": "integer", + }, + ) + assert_httpx_success(result) + """Test filtering by numeric custom field.""" # Create a spool with a numeric custom field - spool_data = { - "filament_id": 1, - "extra": { - "numeric_field": json.dumps(100) - } - } - response = await client.post("/api/v1/spool", json=spool_data) - assert response.status_code == 200 - spool_id = response.json()["id"] + result = httpx.post( + f"{URL}/api/v1/spool", + json={ + "filament_id": random_filament["id"], + "extra": { + "numeric_field": json.dumps(100) + } + }, + ) + assert_httpx_success(result) + spool_id1 = result.json()["id"] # Create another spool with a different numeric value - spool_data2 = { - "filament_id": 1, - "extra": { - "numeric_field": json.dumps(200) - } - } - response = await client.post("/api/v1/spool", json=spool_data2) - assert response.status_code == 200 - spool_id2 = response.json()["id"] + result = httpx.post( + f"{URL}/api/v1/spool", + json={ + "filament_id": random_filament["id"], + "extra": { + "numeric_field": json.dumps(200) + } + }, + ) + assert_httpx_success(result) + spool_id2 = result.json()["id"] # Filter by numeric custom field - response = await client.get("/api/v1/spool", params={"extra.numeric_field": "100"}) - assert response.status_code == 200 - data = response.json() + result = httpx.get(f"{URL}/api/v1/spool", params={"extra.numeric_field": "100"}) + assert_httpx_success(result) + data = result.json() assert len(data) == 1 - assert data[0]["id"] == spool_id + assert data[0]["id"] == spool_id1 # Sort by numeric custom field ascending - response = await client.get("/api/v1/spool", params={"sort": "extra.numeric_field:asc"}) - assert response.status_code == 200 - data = response.json() + result = httpx.get(f"{URL}/api/v1/spool", params={"sort": "extra.numeric_field:asc"}) + assert_httpx_success(result) + data = result.json() # Find our test spools in the results - test_spools = [item for item in data if item["id"] in (spool_id, spool_id2)] + test_spools = [item for item in data if item["id"] in (spool_id1, spool_id2)] assert len(test_spools) == 2 - assert test_spools[0]["id"] == spool_id # 100 should come first + assert test_spools[0]["id"] == spool_id1 # 100 should come first assert test_spools[1]["id"] == spool_id2 # 200 should come second + # Clean up + result = httpx.delete(f"{URL}/api/v1/field/spool/numeric_field") + assert_httpx_success(result) + httpx.delete(f"{URL}/api/v1/spool/{spool_id1}").raise_for_status() + httpx.delete(f"{URL}/api/v1/spool/{spool_id2}").raise_for_status() + @pytest.mark.asyncio -async def test_filter_by_boolean_custom_field(client: AsyncClient, setup_extra_fields): +async def test_filter_by_boolean_custom_field(random_filament: dict[str, Any]): + """Add a custom boolean field""" + result = httpx.post( + f"{URL}/api/v1/field/spool/boolean_field", + json={ + "name": "Boolean field", + "field_type": "boolean", + }, + ) + assert_httpx_success(result) + """Test filtering by boolean custom field.""" # Create a spool with a boolean custom field - spool_data = { - "filament_id": 1, - "extra": { - "bool_field": json.dumps(True) - } - } - response = await client.post("/api/v1/spool", json=spool_data) - assert response.status_code == 200 - spool_id = response.json()["id"] + result = httpx.post( + f"{URL}/api/v1/spool", + json={ + "filament_id": random_filament["id"], + "extra": { + "boolean_field": json.dumps(True) + } + }, + ) + assert_httpx_success(result) + spool_id1 = result.json()["id"] # Create another spool with a different boolean value - spool_data2 = { - "filament_id": 1, - "extra": { - "bool_field": json.dumps(False) - } - } - response = await client.post("/api/v1/spool", json=spool_data2) - assert response.status_code == 200 - spool_id2 = response.json()["id"] + result = httpx.post( + f"{URL}/api/v1/spool", + json={ + "filament_id": random_filament["id"], + "extra": { + "boolean_field": json.dumps(False) + } + }, + ) + assert_httpx_success(result) + spool_id2 = result.json()["id"] # Filter by boolean custom field - response = await client.get("/api/v1/spool", params={"extra.bool_field": "true"}) - assert response.status_code == 200 - data = response.json() + result = httpx.get(f"{URL}/api/v1/spool", params={"extra.boolean_field": "true"}) + assert_httpx_success(result) + data = result.json() assert len(data) == 1 - assert data[0]["id"] == spool_id \ No newline at end of file + assert data[0]["id"] == spool_id1 + + # Clean up + result = httpx.delete(f"{URL}/api/v1/field/spool/boolean_field") + assert_httpx_success(result) + httpx.delete(f"{URL}/api/v1/spool/{spool_id1}").raise_for_status() + httpx.delete(f"{URL}/api/v1/spool/{spool_id2}").raise_for_status() \ No newline at end of file From 4f24594569fad45160729ca97f33c0157843cb12 Mon Sep 17 00:00:00 2001 From: Dieter Blomme Date: Sun, 22 Feb 2026 22:35:24 +0100 Subject: [PATCH 3/5] Fix bugs and expand test coverage for custom field filter/sort --- spoolman/database/filament.py | 8 +- spoolman/database/spool.py | 8 +- spoolman/database/utils.py | 17 +- spoolman/database/vendor.py | 8 +- .../tests/fields/test_filter_sort.py | 310 +++++++++++++++++- 5 files changed, 327 insertions(+), 24 deletions(-) diff --git a/spoolman/database/filament.py b/spoolman/database/filament.py index 8818cec70..130981ae8 100644 --- a/spoolman/database/filament.py +++ b/spoolman/database/filament.py @@ -143,11 +143,11 @@ async def find( # Apply extra field filters if provided if extra_field_filters: # Get all extra fields for filaments - from spoolman.extra_fields import EntityType, get_extra_fields - + from spoolman.extra_fields import EntityType, ExtraFieldType, get_extra_fields + extra_fields = await get_extra_fields(db, EntityType.filament) extra_fields_dict = {field.key: field for field in extra_fields} - + for field_key, value in extra_field_filters.items(): if field_key in extra_fields_dict: field = extra_fields_dict[field_key] @@ -158,7 +158,7 @@ async def find( field_key, field.field_type, value, - field.multi_choice if field.field_type == "choice" else None + field.multi_choice if field.field_type == ExtraFieldType.choice else None, ) if sort_by is not None: diff --git a/spoolman/database/spool.py b/spoolman/database/spool.py index a4541b72e..89c58b291 100644 --- a/spoolman/database/spool.py +++ b/spoolman/database/spool.py @@ -173,11 +173,11 @@ async def find( # noqa: C901, PLR0912 # Apply extra field filters if provided if extra_field_filters: # Get all extra fields for spools - from spoolman.extra_fields import EntityType, get_extra_fields - + from spoolman.extra_fields import EntityType, ExtraFieldType, get_extra_fields + extra_fields = await get_extra_fields(db, EntityType.spool) extra_fields_dict = {field.key: field for field in extra_fields} - + for field_key, value in extra_field_filters.items(): if field_key in extra_fields_dict: field = extra_fields_dict[field_key] @@ -188,7 +188,7 @@ async def find( # noqa: C901, PLR0912 field_key, field.field_type, value, - field.multi_choice if field.field_type == "choice" else None + field.multi_choice if field.field_type == ExtraFieldType.choice else None, ) if sort_by is not None: diff --git a/spoolman/database/utils.py b/spoolman/database/utils.py index 41744cb79..54b96c739 100644 --- a/spoolman/database/utils.py +++ b/spoolman/database/utils.py @@ -3,18 +3,13 @@ from collections.abc import Sequence import json from enum import Enum -from typing import Any, Dict, Tuple, Type, TypeVar +from typing import Any, Type, TypeVar import sqlalchemy -from sqlalchemy import Select, and_, cast, func, or_, text -from sqlalchemy.orm import attributes, aliased -from sqlalchemy.sql import expression +from sqlalchemy import Select +from sqlalchemy.orm import attributes from spoolman.database import models -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from spoolman.extra_fields import EntityType, ExtraField, ExtraFieldType class SortOrder(Enum): @@ -263,7 +258,7 @@ def add_where_clause_extra_field( if min_val_str: conditions.append(field_table.value[0].as_integer() >= converter(min_val_str)) if max_val_str: - conditions.append(field_table.value[1].as_integer() <= converter(min_val_str)) + conditions.append(field_table.value[1].as_integer() <= converter(max_val_str)) except (ValueError, TypeError): pass @@ -321,10 +316,10 @@ def add_order_by_extra_field( # Create a sort expression based on the field type if field_type == ExtraFieldType.integer: # Cast the JSON value to an integer for sorting - sort_expr = func.cast(value_subq, sqlalchemy.Integer) + sort_expr = sqlalchemy.cast(value_subq, sqlalchemy.Integer) elif field_type == ExtraFieldType.float: # Cast the JSON value to a float for sorting - sort_expr = func.cast(value_subq, sqlalchemy.Float) + sort_expr = sqlalchemy.cast(value_subq, sqlalchemy.Float) elif field_type == ExtraFieldType.datetime: # For datetime fields, we can sort by the ISO string sort_expr = value_subq diff --git a/spoolman/database/vendor.py b/spoolman/database/vendor.py index 51ed2e509..aa84a7c3d 100644 --- a/spoolman/database/vendor.py +++ b/spoolman/database/vendor.py @@ -86,11 +86,11 @@ async def find( # Apply extra field filters if provided if extra_field_filters: # Get all extra fields for vendors - from spoolman.extra_fields import EntityType, get_extra_fields - + from spoolman.extra_fields import EntityType, ExtraFieldType, get_extra_fields + extra_fields = await get_extra_fields(db, EntityType.vendor) extra_fields_dict = {field.key: field for field in extra_fields} - + for field_key, value in extra_field_filters.items(): if field_key in extra_fields_dict: field = extra_fields_dict[field_key] @@ -101,7 +101,7 @@ async def find( field_key, field.field_type, value, - field.multi_choice if field.field_type == "choice" else None + field.multi_choice if field.field_type == ExtraFieldType.choice else None, ) if sort_by is not None: diff --git a/tests_integration/tests/fields/test_filter_sort.py b/tests_integration/tests/fields/test_filter_sort.py index 141260582..80d6cc605 100644 --- a/tests_integration/tests/fields/test_filter_sort.py +++ b/tests_integration/tests/fields/test_filter_sort.py @@ -257,4 +257,312 @@ async def test_filter_by_boolean_custom_field(random_filament: dict[str, Any]): result = httpx.delete(f"{URL}/api/v1/field/spool/boolean_field") assert_httpx_success(result) httpx.delete(f"{URL}/api/v1/spool/{spool_id1}").raise_for_status() - httpx.delete(f"{URL}/api/v1/spool/{spool_id2}").raise_for_status() \ No newline at end of file + httpx.delete(f"{URL}/api/v1/spool/{spool_id2}").raise_for_status() + + +@pytest.mark.asyncio +async def test_filter_and_sort_float_custom_field(random_filament: dict[str, Any]): + """Test filtering and sorting by a float custom field.""" + result = httpx.post( + f"{URL}/api/v1/field/spool/float_field", + json={"name": "Float field", "field_type": "float"}, + ) + assert_httpx_success(result) + + result = httpx.post( + f"{URL}/api/v1/spool", + json={"filament_id": random_filament["id"], "extra": {"float_field": json.dumps(1.5)}}, + ) + assert_httpx_success(result) + spool_id1 = result.json()["id"] + + result = httpx.post( + f"{URL}/api/v1/spool", + json={"filament_id": random_filament["id"], "extra": {"float_field": json.dumps(2.5)}}, + ) + assert_httpx_success(result) + spool_id2 = result.json()["id"] + + # Filter by exact float value + result = httpx.get(f"{URL}/api/v1/spool", params={"extra.float_field": "1.5"}) + assert_httpx_success(result) + ids = {item["id"] for item in result.json()} + assert spool_id1 in ids + assert spool_id2 not in ids + + # Sort ascending: 1.5 before 2.5 + result = httpx.get(f"{URL}/api/v1/spool", params={"sort": "extra.float_field:asc"}) + assert_httpx_success(result) + test_spools = [item for item in result.json() if item["id"] in (spool_id1, spool_id2)] + assert len(test_spools) == 2 + assert test_spools[0]["id"] == spool_id1 + assert test_spools[1]["id"] == spool_id2 + + # Sort descending: 2.5 before 1.5 + result = httpx.get(f"{URL}/api/v1/spool", params={"sort": "extra.float_field:desc"}) + assert_httpx_success(result) + test_spools = [item for item in result.json() if item["id"] in (spool_id1, spool_id2)] + assert len(test_spools) == 2 + assert test_spools[0]["id"] == spool_id2 + assert test_spools[1]["id"] == spool_id1 + + # Clean up + httpx.delete(f"{URL}/api/v1/field/spool/float_field").raise_for_status() + httpx.delete(f"{URL}/api/v1/spool/{spool_id1}").raise_for_status() + httpx.delete(f"{URL}/api/v1/spool/{spool_id2}").raise_for_status() + + +@pytest.mark.asyncio +async def test_filter_single_choice_custom_field(random_filament: dict[str, Any]): + """Test filtering by a single-choice custom field.""" + result = httpx.post( + f"{URL}/api/v1/field/spool/choice_field", + json={ + "name": "Choice field", + "field_type": "choice", + "choices": ["OptionA", "OptionB", "OptionC"], + "multi_choice": False, + }, + ) + assert_httpx_success(result) + + result = httpx.post( + f"{URL}/api/v1/spool", + json={"filament_id": random_filament["id"], "extra": {"choice_field": json.dumps("OptionA")}}, + ) + assert_httpx_success(result) + spool_id1 = result.json()["id"] + + result = httpx.post( + f"{URL}/api/v1/spool", + json={"filament_id": random_filament["id"], "extra": {"choice_field": json.dumps("OptionB")}}, + ) + assert_httpx_success(result) + spool_id2 = result.json()["id"] + + # Filter by a single choice value + result = httpx.get(f"{URL}/api/v1/spool", params={"extra.choice_field": "OptionA"}) + assert_httpx_success(result) + ids = {item["id"] for item in result.json()} + assert spool_id1 in ids + assert spool_id2 not in ids + + # Filter by multiple choices (OR) — both should be returned + result = httpx.get(f"{URL}/api/v1/spool", params={"extra.choice_field": "OptionA,OptionB"}) + assert_httpx_success(result) + ids = {item["id"] for item in result.json()} + assert spool_id1 in ids + assert spool_id2 in ids + + # Clean up + httpx.delete(f"{URL}/api/v1/field/spool/choice_field").raise_for_status() + httpx.delete(f"{URL}/api/v1/spool/{spool_id1}").raise_for_status() + httpx.delete(f"{URL}/api/v1/spool/{spool_id2}").raise_for_status() + + +@pytest.mark.asyncio +async def test_filter_multi_choice_custom_field(random_filament: dict[str, Any]): + """Test filtering by a multi-choice custom field.""" + result = httpx.post( + f"{URL}/api/v1/field/spool/multi_choice_field", + json={ + "name": "Multi-choice field", + "field_type": "choice", + "choices": ["A", "B", "C"], + "multi_choice": True, + }, + ) + assert_httpx_success(result) + + # Spool 1 has choices A and B + result = httpx.post( + f"{URL}/api/v1/spool", + json={"filament_id": random_filament["id"], "extra": {"multi_choice_field": json.dumps(["A", "B"])}}, + ) + assert_httpx_success(result) + spool_id1 = result.json()["id"] + + # Spool 2 has only choice C + result = httpx.post( + f"{URL}/api/v1/spool", + json={"filament_id": random_filament["id"], "extra": {"multi_choice_field": json.dumps(["C"])}}, + ) + assert_httpx_success(result) + spool_id2 = result.json()["id"] + + # Filter by A — only spool 1 has A + result = httpx.get(f"{URL}/api/v1/spool", params={"extra.multi_choice_field": "A"}) + assert_httpx_success(result) + ids = {item["id"] for item in result.json()} + assert spool_id1 in ids + assert spool_id2 not in ids + + # Filter by C — only spool 2 has C + result = httpx.get(f"{URL}/api/v1/spool", params={"extra.multi_choice_field": "C"}) + assert_httpx_success(result) + ids = {item["id"] for item in result.json()} + assert spool_id2 in ids + assert spool_id1 not in ids + + # Filter by A,C (OR) — both should be returned + result = httpx.get(f"{URL}/api/v1/spool", params={"extra.multi_choice_field": "A,C"}) + assert_httpx_success(result) + ids = {item["id"] for item in result.json()} + assert spool_id1 in ids + assert spool_id2 in ids + + # Clean up + httpx.delete(f"{URL}/api/v1/field/spool/multi_choice_field").raise_for_status() + httpx.delete(f"{URL}/api/v1/spool/{spool_id1}").raise_for_status() + httpx.delete(f"{URL}/api/v1/spool/{spool_id2}").raise_for_status() + + +@pytest.mark.asyncio +async def test_filter_empty_custom_field(random_filament: dict[str, Any]): + """Test the filter returns items that have no value set for a custom field.""" + result = httpx.post( + f"{URL}/api/v1/field/spool/optional_field", + json={"name": "Optional field", "field_type": "text"}, + ) + assert_httpx_success(result) + + # Spool 1 has the field set + result = httpx.post( + f"{URL}/api/v1/spool", + json={"filament_id": random_filament["id"], "extra": {"optional_field": json.dumps("has_value")}}, + ) + assert_httpx_success(result) + spool_id1 = result.json()["id"] + + # Spool 2 does NOT have the field set + result = httpx.post( + f"{URL}/api/v1/spool", + json={"filament_id": random_filament["id"]}, + ) + assert_httpx_success(result) + spool_id2 = result.json()["id"] + + # Filter by — spool 2 (no field row) should appear, spool 1 should not + result = httpx.get(f"{URL}/api/v1/spool", params={"extra.optional_field": ""}) + assert_httpx_success(result) + ids = {item["id"] for item in result.json()} + assert spool_id2 in ids + assert spool_id1 not in ids + + # Filter by the value — spool 1 should appear, spool 2 should not + result = httpx.get(f"{URL}/api/v1/spool", params={"extra.optional_field": "has_value"}) + assert_httpx_success(result) + ids = {item["id"] for item in result.json()} + assert spool_id1 in ids + assert spool_id2 not in ids + + # Clean up + httpx.delete(f"{URL}/api/v1/field/spool/optional_field").raise_for_status() + httpx.delete(f"{URL}/api/v1/spool/{spool_id1}").raise_for_status() + httpx.delete(f"{URL}/api/v1/spool/{spool_id2}").raise_for_status() + + +@pytest.mark.asyncio +async def test_filter_sort_filament_custom_field(random_filament: dict[str, Any]): + """Test filtering and sorting filaments by a custom field.""" + vendor_id = random_filament["vendor"]["id"] + + result = httpx.post( + f"{URL}/api/v1/field/filament/filament_tag", + json={"name": "Filament tag", "field_type": "text"}, + ) + assert_httpx_success(result) + + result = httpx.post( + f"{URL}/api/v1/filament", + json={"vendor_id": vendor_id, "density": 1.24, "diameter": 1.75, "extra": {"filament_tag": json.dumps("beta")}}, + ) + assert_httpx_success(result) + filament_id1 = result.json()["id"] + + result = httpx.post( + f"{URL}/api/v1/filament", + json={"vendor_id": vendor_id, "density": 1.24, "diameter": 1.75, "extra": {"filament_tag": json.dumps("alpha")}}, + ) + assert_httpx_success(result) + filament_id2 = result.json()["id"] + + # Filter by custom field — only filament with "beta" should appear + result = httpx.get(f"{URL}/api/v1/filament", params={"extra.filament_tag": "beta"}) + assert_httpx_success(result) + ids = {item["id"] for item in result.json()} + assert filament_id1 in ids + assert filament_id2 not in ids + + # Sort ascending: alpha before beta + result = httpx.get(f"{URL}/api/v1/filament", params={"sort": "extra.filament_tag:asc"}) + assert_httpx_success(result) + test_filaments = [item for item in result.json() if item["id"] in (filament_id1, filament_id2)] + assert len(test_filaments) == 2 + assert test_filaments[0]["id"] == filament_id2 # alpha first + assert test_filaments[1]["id"] == filament_id1 # beta second + + # Sort descending: beta before alpha + result = httpx.get(f"{URL}/api/v1/filament", params={"sort": "extra.filament_tag:desc"}) + assert_httpx_success(result) + test_filaments = [item for item in result.json() if item["id"] in (filament_id1, filament_id2)] + assert len(test_filaments) == 2 + assert test_filaments[0]["id"] == filament_id1 # beta first + assert test_filaments[1]["id"] == filament_id2 # alpha second + + # Clean up + httpx.delete(f"{URL}/api/v1/field/filament/filament_tag").raise_for_status() + httpx.delete(f"{URL}/api/v1/filament/{filament_id1}").raise_for_status() + httpx.delete(f"{URL}/api/v1/filament/{filament_id2}").raise_for_status() + + +@pytest.mark.asyncio +async def test_filter_sort_vendor_custom_field(): + """Test filtering and sorting vendors by a custom field.""" + result = httpx.post( + f"{URL}/api/v1/field/vendor/vendor_tier", + json={"name": "Vendor tier", "field_type": "text"}, + ) + assert_httpx_success(result) + + result = httpx.post( + f"{URL}/api/v1/vendor", + json={"name": "Vendor Gold", "extra": {"vendor_tier": json.dumps("gold")}}, + ) + assert_httpx_success(result) + vendor_id1 = result.json()["id"] + + result = httpx.post( + f"{URL}/api/v1/vendor", + json={"name": "Vendor Silver", "extra": {"vendor_tier": json.dumps("silver")}}, + ) + assert_httpx_success(result) + vendor_id2 = result.json()["id"] + + # Filter by vendor custom field — only gold vendor should appear + result = httpx.get(f"{URL}/api/v1/vendor", params={"extra.vendor_tier": "gold"}) + assert_httpx_success(result) + ids = {item["id"] for item in result.json()} + assert vendor_id1 in ids + assert vendor_id2 not in ids + + # Sort ascending: gold before silver + result = httpx.get(f"{URL}/api/v1/vendor", params={"sort": "extra.vendor_tier:asc"}) + assert_httpx_success(result) + test_vendors = [item for item in result.json() if item["id"] in (vendor_id1, vendor_id2)] + assert len(test_vendors) == 2 + assert test_vendors[0]["id"] == vendor_id1 # gold first + assert test_vendors[1]["id"] == vendor_id2 # silver second + + # Sort descending: silver before gold + result = httpx.get(f"{URL}/api/v1/vendor", params={"sort": "extra.vendor_tier:desc"}) + assert_httpx_success(result) + test_vendors = [item for item in result.json() if item["id"] in (vendor_id1, vendor_id2)] + assert len(test_vendors) == 2 + assert test_vendors[0]["id"] == vendor_id2 # silver first + assert test_vendors[1]["id"] == vendor_id1 # gold second + + # Clean up + httpx.delete(f"{URL}/api/v1/field/vendor/vendor_tier").raise_for_status() + httpx.delete(f"{URL}/api/v1/vendor/{vendor_id1}").raise_for_status() + httpx.delete(f"{URL}/api/v1/vendor/{vendor_id2}").raise_for_status() From f70d801cbd98036328a9f11806573c376634a8eb Mon Sep 17 00:00:00 2001 From: akira69 Date: Tue, 24 Feb 2026 01:15:28 -0600 Subject: [PATCH 4/5] fix(ts): align custom-field filter/sort typings with current master --- client/src/components/column.tsx | 9 +- client/src/utils/sorting.ts | 2 +- spoolman/database/extra_field_query.py | 170 ++++++++++++++++++++ spoolman/database/filament.py | 21 ++- spoolman/database/spool.py | 21 ++- spoolman/database/utils.py | 213 +------------------------ spoolman/database/vendor.py | 8 +- 7 files changed, 198 insertions(+), 246 deletions(-) create mode 100644 spoolman/database/extra_field_query.py diff --git a/client/src/components/column.tsx b/client/src/components/column.tsx index 7f5103048..6e2a7b9e8 100644 --- a/client/src/components/column.tsx +++ b/client/src/components/column.tsx @@ -98,9 +98,10 @@ function Column( // Sorting if (props.sorter) { columnProps.sorter = true; + const sortField = props.dataId ?? (Array.isArray(props.id) ? props.id.join(".") : props.id); columnProps.sortOrder = getSortOrderForField( typeSorters(props.tableState.sorters), - props.dataId ?? (props.id as keyof Obj), + sortField, ); } @@ -211,7 +212,8 @@ export function FilteredQueryColumn(props: FilteredQueryColu }); const typedFilters = typeFilters(props.tableState.filters); - const filteredValue = getFiltersForField(typedFilters, props.dataId ?? (props.id as keyof Obj)); + const filterField = props.dataId ?? (Array.isArray(props.id) ? props.id.join(".") : props.id); + const filteredValue = getFiltersForField(typedFilters, filterField); const onFilterDropdownOpen = () => { query.refetch(); @@ -325,7 +327,8 @@ export function SpoolIconColumn(props: SpoolIconColumnProps< }); const typedFilters = typeFilters(props.tableState.filters); - const filteredValue = getFiltersForField(typedFilters, props.dataId ?? (props.id as keyof Obj)); + const filterField = props.dataId ?? (Array.isArray(props.id) ? props.id.join(".") : props.id); + const filteredValue = getFiltersForField(typedFilters, filterField); const onFilterDropdownOpen = () => { query.refetch(); diff --git a/client/src/utils/sorting.ts b/client/src/utils/sorting.ts index c75603d2a..d58546ab1 100644 --- a/client/src/utils/sorting.ts +++ b/client/src/utils/sorting.ts @@ -1,6 +1,6 @@ import { CrudSort } from "@refinedev/core"; import { SortOrder } from "antd/es/table/interface"; -import { getCustomFieldKey, isCustomField } from "./queryFields"; +import { Field, getCustomFieldKey, isCustomField } from "./queryFields"; interface TypedCrudSort { field: keyof Obj | string; diff --git a/spoolman/database/extra_field_query.py b/spoolman/database/extra_field_query.py new file mode 100644 index 000000000..32b87c59a --- /dev/null +++ b/spoolman/database/extra_field_query.py @@ -0,0 +1,170 @@ +"""Helpers for filtering and sorting extra fields.""" + +import json +from typing import Any + +import sqlalchemy +from sqlalchemy import Select +from sqlalchemy.orm import attributes + +from spoolman.database import models +from spoolman.database.utils import SortOrder + + +def _get_field_table_for_entity(entity_type: Any) -> type[models.Base]: + """Get the field table class for a given entity type.""" + # Import here to avoid circular imports. + from spoolman.extra_fields import EntityType + + if entity_type == EntityType.spool: + return models.SpoolField + if entity_type == EntityType.filament: + return models.FilamentField + if entity_type == EntityType.vendor: + return models.VendorField + raise ValueError(f"Unknown entity type: {entity_type}") + + +def _get_entity_id_column(field_table: type[models.Base]) -> attributes.InstrumentedAttribute[int]: + """Get the entity ID column for a given field table.""" + if field_table == models.SpoolField: + return models.SpoolField.spool_id + if field_table == models.FilamentField: + return models.FilamentField.filament_id + if field_table == models.VendorField: + return models.VendorField.vendor_id + raise ValueError(f"Unknown field table: {field_table}") + + +def add_where_clause_extra_field( + stmt: Select, + base_obj: type[models.Base], + entity_type: Any, + field_key: str, + field_type: Any, + value: str, + multi_choice: bool | None = None, +) -> Select: + """Add a where clause to a select statement for an extra field.""" + # Import here to avoid circular imports. + from spoolman.extra_fields import ExtraFieldType + + field_table = _get_field_table_for_entity(entity_type) + entity_id_column = _get_entity_id_column(field_table) + value_parts = value.split(",") + + # An item is considered "empty" if: + # A) the row exists and value is null/'null' (or false for bool), or + # B) no row exists for this item + key. + if any(part == "" or len(part) == 0 for part in value_parts): + empty_conditions = [ + field_table.value.is_(None), + field_table.value == "null", + ] + if field_type == ExtraFieldType.boolean: + empty_conditions.append(field_table.value == json.dumps(False)) + + subq_a = sqlalchemy.select(entity_id_column).where( + sqlalchemy.and_(field_table.key == field_key, sqlalchemy.or_(*empty_conditions)) + ) + subq_b = sqlalchemy.select(base_obj.id).where( + getattr(base_obj, "id").not_in(sqlalchemy.select(entity_id_column).where(field_table.key == field_key)) + ) + return stmt.where( + sqlalchemy.or_( + getattr(base_obj, "id").in_(subq_a), + getattr(base_obj, "id").in_(subq_b), + ) + ) + + conditions = [] + for value_part in value_parts: + exact_match = value_part.startswith('"') and value_part.endswith('"') + if exact_match: + value_part = value_part[1:-1] + + if field_type == ExtraFieldType.text: + if exact_match: + conditions.append(field_table.value == json.dumps(value_part)) + else: + conditions.append(field_table.value.ilike(f"%{value_part}%")) + elif field_type == ExtraFieldType.integer: + try: + conditions.append(field_table.value == json.dumps(int(value_part))) + except ValueError: + pass + elif field_type == ExtraFieldType.float: + try: + conditions.append(field_table.value == json.dumps(float(value_part))) + except ValueError: + pass + elif field_type == ExtraFieldType.boolean: + bool_value = value_part.lower() in ("true", "1", "yes") + conditions.append(field_table.value == json.dumps(bool_value)) + elif field_type == ExtraFieldType.choice: + if multi_choice: + conditions.append(field_table.value.like(f'%"{value_part}"%')) + else: + conditions.append(field_table.value == json.dumps(value_part)) + elif field_type == ExtraFieldType.datetime: + conditions.append(field_table.value == json.dumps(value_part)) + elif field_type in (ExtraFieldType.integer_range, ExtraFieldType.float_range) and ":" in value_part: + min_val_str, max_val_str = value_part.split(":", 1) + converter = int if field_type == ExtraFieldType.integer_range else float + try: + if min_val_str: + conditions.append(field_table.value[0].as_integer() >= converter(min_val_str)) + if max_val_str: + conditions.append(field_table.value[1].as_integer() <= converter(max_val_str)) + except (ValueError, TypeError): + pass + + if not conditions: + return stmt + + subq = sqlalchemy.select(entity_id_column).where( + sqlalchemy.and_(field_table.key == field_key, sqlalchemy.or_(*conditions)) + ) + return stmt.where(getattr(base_obj, "id").in_(subq)) + + +def add_order_by_extra_field( + stmt: Select, + base_obj: type[models.Base], + entity_type: Any, + field_key: str, + field_type: Any, + order: SortOrder, +) -> Select: + """Add an order-by clause to a select statement for an extra field.""" + # Import here to avoid circular imports. + from spoolman.extra_fields import ExtraFieldType + + field_table = _get_field_table_for_entity(entity_type) + entity_id_column = _get_entity_id_column(field_table) + + value_subq = ( + sqlalchemy.select(field_table.value) + .where( + sqlalchemy.and_( + field_table.key == field_key, + entity_id_column == getattr(base_obj, "id"), + ) + ) + .scalar_subquery() + .correlate(base_obj) + ) + + if field_type == ExtraFieldType.integer: + sort_expr = sqlalchemy.cast(value_subq, sqlalchemy.Integer) + elif field_type == ExtraFieldType.float: + sort_expr = sqlalchemy.cast(value_subq, sqlalchemy.Float) + elif field_type in (ExtraFieldType.integer_range, ExtraFieldType.float_range): + # Sort ranges by low-end value. + sort_expr = value_subq[0] + else: + sort_expr = value_subq + + if order == SortOrder.ASC: + return stmt.order_by(sort_expr.asc()) + return stmt.order_by(sort_expr.desc()) diff --git a/spoolman/database/filament.py b/spoolman/database/filament.py index 130981ae8..320e030c3 100644 --- a/spoolman/database/filament.py +++ b/spoolman/database/filament.py @@ -12,7 +12,15 @@ from spoolman.api.v1.models import EventType, Filament, FilamentEvent, MultiColorDirection from spoolman.database import models, vendor -from spoolman.database.utils import SortOrder +from spoolman.database.extra_field_query import add_order_by_extra_field, add_where_clause_extra_field +from spoolman.database.utils import ( + SortOrder, + add_where_clause_int_in, + add_where_clause_int_opt, + add_where_clause_str, + add_where_clause_str_opt, + parse_nested_field, +) from spoolman.exceptions import ItemDeleteError, ItemNotFoundError from spoolman.math import delta_e, hex_to_rgb, rgb_to_lab from spoolman.ws import websocket_manager @@ -107,17 +115,6 @@ async def find( Returns a tuple containing the list of items and the total count of matching items. """ - # Import here to avoid circular imports - from spoolman.database.utils import ( - add_where_clause_int_in, - add_where_clause_int_opt, - add_where_clause_str, - add_where_clause_str_opt, - add_where_clause_extra_field, - add_order_by_extra_field, - parse_nested_field, - ) - stmt = ( select(models.Filament) .options(contains_eager(models.Filament.vendor)) diff --git a/spoolman/database/spool.py b/spoolman/database/spool.py index 89c58b291..f24e6a0fe 100644 --- a/spoolman/database/spool.py +++ b/spoolman/database/spool.py @@ -13,7 +13,15 @@ from spoolman.api.v1.models import EventType, Spool, SpoolEvent from spoolman.database import filament, models -from spoolman.database.utils import SortOrder +from spoolman.database.extra_field_query import add_order_by_extra_field, add_where_clause_extra_field +from spoolman.database.utils import ( + SortOrder, + add_where_clause_int, + add_where_clause_int_opt, + add_where_clause_str, + add_where_clause_str_opt, + parse_nested_field, +) from spoolman.exceptions import ItemCreateError, ItemNotFoundError, SpoolMeasureError from spoolman.math import weight_from_length from spoolman.ws import websocket_manager @@ -127,17 +135,6 @@ async def find( # noqa: C901, PLR0912 Returns a tuple containing the list of items and the total count of matching items. """ - # Import here to avoid circular imports - from spoolman.database.utils import ( - add_where_clause_int, - add_where_clause_int_opt, - add_where_clause_str, - add_where_clause_str_opt, - add_where_clause_extra_field, - add_order_by_extra_field, - parse_nested_field, - ) - stmt = ( sqlalchemy.select(models.Spool) .join(models.Spool.filament, isouter=True) diff --git a/spoolman/database/utils.py b/spoolman/database/utils.py index 54b96c739..2d8776c00 100644 --- a/spoolman/database/utils.py +++ b/spoolman/database/utils.py @@ -1,9 +1,8 @@ """Utility functions for the database module.""" from collections.abc import Sequence -import json from enum import Enum -from typing import Any, Type, TypeVar +from typing import Any, TypeVar import sqlalchemy from sqlalchemy import Select @@ -130,213 +129,3 @@ def add_where_clause_int_in( if value is not None: stmt = stmt.where(field.in_(value)) return stmt - - -def get_field_table_for_entity(entity_type: Any) -> Type[models.Base]: - """Get the field table class for a given entity type.""" - # Import here to avoid circular imports - from spoolman.extra_fields import EntityType - - if entity_type == EntityType.spool: - return models.SpoolField - elif entity_type == EntityType.filament: - return models.FilamentField - elif entity_type == EntityType.vendor: - return models.VendorField - else: - raise ValueError(f"Unknown entity type: {entity_type}") - - -def get_entity_id_column(field_table: Type[models.Base]) -> attributes.InstrumentedAttribute[int]: - """Get the entity ID column for a given field table.""" - if field_table == models.SpoolField: - return models.SpoolField.spool_id - elif field_table == models.FilamentField: - return models.FilamentField.filament_id - elif field_table == models.VendorField: - return models.VendorField.vendor_id - else: - raise ValueError(f"Unknown field table: {field_table}") - - -def add_where_clause_extra_field( - stmt: Select, - base_obj: Type[models.Base], - entity_type: Any, - field_key: str, - field_type: Any, - value: str, - multi_choice: bool | None = None, -) -> Select: - """Add a where clause to a select statement for an extra field. - Args: - stmt: The select statement to add the where clause to - base_obj: The base object type (Spool, Filament, Vendor) - entity_type: The entity type - field_key: The key of the extra field - field_type: The type of the extra field - value: The value to filter by - multi_choice: Whether the field is a multi-choice field (only for choice fields) - Returns: - The modified select statement - """ - # Import here to avoid circular imports - from spoolman.extra_fields import ExtraFieldType - - field_table = get_field_table_for_entity(entity_type) - entity_id_column = get_entity_id_column(field_table) - - value_parts = value.split(",") - - # Handle filtering for empty values - if any(p == "" or len(p) == 0 for p in value_parts): - # An item is considered "empty" if: - # A) A row exists in the field table, and its value is null, 'null', or 'false' for booleans. - # B) No row exists in the field table for this item and field_key. - - # Condition A subquery - empty_conditions = [ - field_table.value.is_(None), - field_table.value == "null", - ] - if field_type == ExtraFieldType.boolean: - empty_conditions.append(field_table.value == json.dumps(False)) - - subq_a = sqlalchemy.select(entity_id_column).where( - sqlalchemy.and_(field_table.key == field_key, sqlalchemy.or_(*empty_conditions)) - ) - - # Condition B subquery - subq_b = sqlalchemy.select(base_obj.id).where( - getattr(base_obj, "id").not_in(sqlalchemy.select(entity_id_column).where(field_table.key == field_key)) - ) - - return stmt.where( - sqlalchemy.or_( - getattr(base_obj, "id").in_(subq_a), - getattr(base_obj, "id").in_(subq_b), - ) - ) - - # Handle filtering for specific values - conditions = [] - for value_part in value_parts: - exact_match = value_part.startswith('"') and value_part.endswith('"') - if exact_match: - value_part = value_part[1:-1] - - if field_type == ExtraFieldType.text: - if exact_match: - conditions.append(field_table.value == json.dumps(value_part)) - else: - conditions.append(field_table.value.ilike(f"%{value_part}%")) - elif field_type == ExtraFieldType.integer: - try: - conditions.append(field_table.value == json.dumps(int(value_part))) - except ValueError: - pass - elif field_type == ExtraFieldType.float: - try: - conditions.append(field_table.value == json.dumps(float(value_part))) - except ValueError: - pass - elif field_type == ExtraFieldType.boolean: - bool_value = value_part.lower() in ("true", "1", "yes") - conditions.append(field_table.value == json.dumps(bool_value)) - elif field_type == ExtraFieldType.choice: - if multi_choice: - conditions.append(field_table.value.like(f'%"{value_part}"%')) - else: - conditions.append(field_table.value == json.dumps(value_part)) - elif field_type == ExtraFieldType.datetime: - conditions.append(field_table.value == json.dumps(value_part)) - elif field_type in (ExtraFieldType.integer_range, ExtraFieldType.float_range): - if ":" in value_part: - min_val_str, max_val_str = value_part.split(":", 1) - converter = int if field_type == ExtraFieldType.integer_range else float - try: - if min_val_str: - conditions.append(field_table.value[0].as_integer() >= converter(min_val_str)) - if max_val_str: - conditions.append(field_table.value[1].as_integer() <= converter(max_val_str)) - except (ValueError, TypeError): - pass - - if not conditions: - return stmt - - subq = sqlalchemy.select(entity_id_column).where( - sqlalchemy.and_(field_table.key == field_key, sqlalchemy.or_(*conditions)) - ) - - return stmt.where(getattr(base_obj, "id").in_(subq)) - - -def add_order_by_extra_field( - stmt: Select, - base_obj: Type[models.Base], - entity_type: Any, - field_key: str, - field_type: Any, - order: SortOrder, -) -> Select: - """Add an order by clause to a select statement for an extra field. - - Args: - stmt: The select statement to add the order by clause to - base_obj: The base object type (Spool, Filament, Vendor) - entity_type: The entity type - field_key: The key of the extra field - field_type: The type of the extra field - order: The sort order - - Returns: - The modified select statement - """ - # Import here to avoid circular imports - from spoolman.extra_fields import EntityType, ExtraFieldType - - # Use a subquery approach instead of joins - field_table = get_field_table_for_entity(entity_type) - entity_id_column = get_entity_id_column(field_table) - - # Create a subquery that selects the value for each entity - value_subq = ( - sqlalchemy.select(field_table.value) - .where( - sqlalchemy.and_( - field_table.key == field_key, - entity_id_column == getattr(base_obj, "id") - ) - ) - .scalar_subquery() - .correlate(base_obj) - ) - - # Create a sort expression based on the field type - if field_type == ExtraFieldType.integer: - # Cast the JSON value to an integer for sorting - sort_expr = sqlalchemy.cast(value_subq, sqlalchemy.Integer) - elif field_type == ExtraFieldType.float: - # Cast the JSON value to a float for sorting - sort_expr = sqlalchemy.cast(value_subq, sqlalchemy.Float) - elif field_type == ExtraFieldType.datetime: - # For datetime fields, we can sort by the ISO string - sort_expr = value_subq - elif field_type == ExtraFieldType.boolean: - # For boolean fields, true comes after false - sort_expr = value_subq - elif field_type in (ExtraFieldType.integer_range, ExtraFieldType.float_range): - # For range fields, sort by the first value in the range - sort_expr = value_subq[0] - else: - # For text and choice fields, sort by the string value - sort_expr = value_subq - - # Apply the sort order - if order == SortOrder.ASC: - stmt = stmt.order_by(sort_expr.asc()) - else: - stmt = stmt.order_by(sort_expr.desc()) - - return stmt diff --git a/spoolman/database/vendor.py b/spoolman/database/vendor.py index aa84a7c3d..810aca028 100644 --- a/spoolman/database/vendor.py +++ b/spoolman/database/vendor.py @@ -9,6 +9,7 @@ from spoolman.api.v1.models import EventType, Vendor, VendorEvent from spoolman.database import models +from spoolman.database.extra_field_query import add_order_by_extra_field, add_where_clause_extra_field from spoolman.database.utils import SortOrder from spoolman.exceptions import ItemNotFoundError from spoolman.ws import websocket_manager @@ -63,12 +64,7 @@ async def find( Returns a tuple containing the list of items and the total count of matching items. """ # Import here to avoid circular imports - from spoolman.database.utils import ( - add_where_clause_str, - add_where_clause_str_opt, - add_where_clause_extra_field, - add_order_by_extra_field - ) + from spoolman.database.utils import add_where_clause_str, add_where_clause_str_opt stmt = select(models.Vendor) From 035e61a6743b271e518aa4dc261731286ae18b12 Mon Sep 17 00:00:00 2001 From: akira69 Date: Tue, 3 Mar 2026 10:15:54 -0600 Subject: [PATCH 5/5] docs(custom-fields): clarify filter and sort semantics --- client/src/components/column.tsx | 4 ++++ spoolman/api/v1/filament.py | 3 ++- spoolman/api/v1/spool.py | 3 ++- spoolman/api/v1/vendor.py | 3 ++- spoolman/database/extra_field_query.py | 3 ++- 5 files changed, 12 insertions(+), 4 deletions(-) diff --git a/client/src/components/column.tsx b/client/src/components/column.tsx index 6e2a7b9e8..cb339df67 100644 --- a/client/src/components/column.tsx +++ b/client/src/components/column.tsx @@ -199,6 +199,8 @@ export function FilteredQueryColumn(props: FilteredQueryColu filters = query.data.map((item) => { if (typeof item === "string") { return { + // Wrap plain strings so the backend can distinguish exact-value picks from the + // loose text-match syntax used by custom text fields. text: item, value: '"' + item + '"', }; @@ -212,6 +214,8 @@ export function FilteredQueryColumn(props: FilteredQueryColu }); const typedFilters = typeFilters(props.tableState.filters); + // Custom columns often render through a synthetic id array, so prefer dataId when present + // to keep the table state key aligned with the API field name (for example "extra.foo"). const filterField = props.dataId ?? (Array.isArray(props.id) ? props.id.join(".") : props.id); const filteredValue = getFiltersForField(typedFilters, filterField); diff --git a/spoolman/api/v1/filament.py b/spoolman/api/v1/filament.py index e10963fb7..1e49af277 100644 --- a/spoolman/api/v1/filament.py +++ b/spoolman/api/v1/filament.py @@ -343,7 +343,8 @@ async def find( else: filter_by_ids = None - # Extract custom field filters from query parameters + # Custom field filters arrive as dynamic "extra." params, so they have to be + # collected from the raw query string instead of the static endpoint signature. extra_field_filters = {} query_params = request.query_params for key, value in query_params.items(): diff --git a/spoolman/api/v1/spool.py b/spoolman/api/v1/spool.py index ef659a327..450af0766 100644 --- a/spoolman/api/v1/spool.py +++ b/spoolman/api/v1/spool.py @@ -286,7 +286,8 @@ async def find( else: filament_vendor_ids = None - # Extract custom field filters from query parameters + # Custom field filters arrive as dynamic "extra." params, so they have to be + # collected from the raw query string instead of the static endpoint signature. extra_field_filters = {} query_params = request.query_params for key, value in query_params.items(): diff --git a/spoolman/api/v1/vendor.py b/spoolman/api/v1/vendor.py index 54601228a..72a5b210c 100644 --- a/spoolman/api/v1/vendor.py +++ b/spoolman/api/v1/vendor.py @@ -125,7 +125,8 @@ async def find( field, direction = sort_item.split(":") sort_by[field] = SortOrder[direction.upper()] - # Extract custom field filters from query parameters + # Custom field filters arrive as dynamic "extra." params, so they have to be + # collected from the raw query string instead of the static endpoint signature. extra_field_filters = {} query_params = request.query_params for key, value in query_params.items(): diff --git a/spoolman/database/extra_field_query.py b/spoolman/database/extra_field_query.py index 32b87c59a..1cb1b8733 100644 --- a/spoolman/database/extra_field_query.py +++ b/spoolman/database/extra_field_query.py @@ -160,7 +160,8 @@ def add_order_by_extra_field( elif field_type == ExtraFieldType.float: sort_expr = sqlalchemy.cast(value_subq, sqlalchemy.Float) elif field_type in (ExtraFieldType.integer_range, ExtraFieldType.float_range): - # Sort ranges by low-end value. + # Range columns need a stable scalar sort key; using the low-end keeps similar + # ranges grouped predictably without inventing a second synthetic value. sort_expr = value_subq[0] else: sort_expr = value_subq