diff --git a/beets/dbcore/db.py b/beets/dbcore/db.py index 64e77f8140..2aa0081d79 100755 --- a/beets/dbcore/db.py +++ b/beets/dbcore/db.py @@ -35,8 +35,6 @@ from ..util import cached_classproperty, functemplate from . import types from .query import ( - AndQuery, - FieldQuery, FieldQueryType, FieldSort, MatchQuery, @@ -718,33 +716,6 @@ def set_parse(self, key, string: str): """Set the object's key to a value represented by a string.""" self[key] = self._parse(key, string) - # Convenient queries. - - @classmethod - def field_query( - cls, - field, - pattern, - query_cls: FieldQueryType = MatchQuery, - ) -> FieldQuery: - """Get a `FieldQuery` for this model.""" - return query_cls(field, pattern, field in cls._fields) - - @classmethod - def all_fields_query( - cls: type[Model], - pats: Mapping[str, str], - query_cls: FieldQueryType = MatchQuery, - ): - """Get a query that matches many fields with different patterns. - - `pats` should be a mapping from field names to patterns. The - resulting query is a conjunction ("and") of per-field queries - for all of these field/pattern pairs. - """ - subqueries = [cls.field_query(k, v, query_cls) for k, v in pats.items()] - return AndQuery(subqueries) - # Database controller and supporting interfaces. diff --git a/beets/dbcore/query.py b/beets/dbcore/query.py index 866162c4ad..c7ca444524 100644 --- a/beets/dbcore/query.py +++ b/beets/dbcore/query.py @@ -97,6 +97,9 @@ def match(self, obj: Model): """ ... + def __and__(self, other: Query) -> AndQuery: + return AndQuery([self, other]) + def __repr__(self) -> str: return f"{self.__class__.__name__}()" @@ -505,50 +508,6 @@ def __hash__(self) -> int: return reduce(mul, map(hash, self.subqueries), 1) -class AnyFieldQuery(CollectionQuery): - """A query that matches if a given FieldQuery subclass matches in - any field. The individual field query class is provided to the - constructor. - """ - - @property - def field_names(self) -> set[str]: - """Return a set with field names that this query operates on.""" - return set(self.fields) - - def __init__(self, pattern, fields, cls: FieldQueryType): - self.pattern = pattern - self.fields = fields - self.query_class = cls - - subqueries = [] - for field in self.fields: - subqueries.append(cls(field, pattern, True)) - # TYPING ERROR - super().__init__(subqueries) - - def clause(self) -> tuple[str | None, Sequence[SQLiteType]]: - return self.clause_with_joiner("or") - - def match(self, obj: Model) -> bool: - for subq in self.subqueries: - if subq.match(obj): - return True - return False - - def __repr__(self) -> str: - return ( - f"{self.__class__.__name__}({self.pattern!r}, {self.fields!r}, " - f"{self.query_class.__name__})" - ) - - def __eq__(self, other) -> bool: - return super().__eq__(other) and self.query_class == other.query_class - - def __hash__(self) -> int: - return hash((self.pattern, tuple(self.fields), self.query_class)) - - class MutableCollectionQuery(CollectionQuery): """A collection query whose subqueries may be modified after the query is initialized. diff --git a/beets/dbcore/queryparse.py b/beets/dbcore/queryparse.py index 2896326680..f84ed74365 100644 --- a/beets/dbcore/queryparse.py +++ b/beets/dbcore/queryparse.py @@ -20,15 +20,17 @@ import re from typing import TYPE_CHECKING -from . import Model, query +from . import query if TYPE_CHECKING: from collections.abc import Collection, Sequence + from ..library import LibModel from .query import FieldQueryType, Sort Prefixes = dict[str, FieldQueryType] + PARSE_QUERY_PART_REGEX = re.compile( # Non-capturing optional segment for the keyword. r"(-|\^)?" # Negation prefixes. @@ -112,7 +114,7 @@ def parse_query_part( def construct_query_part( - model_cls: type[Model], + model_cls: type[LibModel], prefixes: Prefixes, query_part: str, ) -> query.Query: @@ -147,28 +149,14 @@ def construct_query_part( query_part, query_classes, prefixes ) - # If there's no key (field name) specified, this is a "match - # anything" query. if key is None: - # The query type matches a specific field, but none was - # specified. So we use a version of the query that matches - # any field. - out_query = query.AnyFieldQuery( - pattern, model_cls._search_fields, query_class - ) - - # Field queries get constructed according to the name of the field - # they are querying. + # If there's no key (field name) specified, this is a "match anything" + # query. + out_query = model_cls.any_field_query(pattern, query_class) else: - field = table = key.lower() - if field in model_cls.shared_db_fields: - # This field exists in both tables, so SQLite will encounter - # an OperationalError if we try to query it in a join. - # Using an explicit table name resolves this. - table = f"{model_cls._table}.{field}" - - field_in_db = field in model_cls.all_db_fields - out_query = query_class(table, pattern, field_in_db) + # Field queries get constructed according to the name of the field + # they are querying. + out_query = model_cls.field_query(key.lower(), pattern, query_class) # Apply negation. if negate: @@ -180,7 +168,7 @@ def construct_query_part( # TYPING ERROR def query_from_strings( query_cls: type[query.CollectionQuery], - model_cls: type[Model], + model_cls: type[LibModel], prefixes: Prefixes, query_parts: Collection[str], ) -> query.Query: @@ -197,7 +185,7 @@ def query_from_strings( def construct_sort_part( - model_cls: type[Model], + model_cls: type[LibModel], part: str, case_insensitive: bool = True, ) -> Sort: @@ -228,7 +216,7 @@ def construct_sort_part( def sort_from_strings( - model_cls: type[Model], + model_cls: type[LibModel], sort_parts: Sequence[str], case_insensitive: bool = True, ) -> Sort: @@ -247,7 +235,7 @@ def sort_from_strings( def parse_sorted_query( - model_cls: type[Model], + model_cls: type[LibModel], parts: list[str], prefixes: Prefixes = {}, case_insensitive: bool = True, diff --git a/beets/importer.py b/beets/importer.py index ab2382c9fd..b30e6399b3 100644 --- a/beets/importer.py +++ b/beets/importer.py @@ -707,9 +707,7 @@ def find_duplicates(self, lib): # use a temporary Album object to generate any computed fields. tmp_album = library.Album(lib, **info) keys = config["import"]["duplicate_keys"]["album"].as_str_seq() - dup_query = library.Album.all_fields_query( - {key: tmp_album.get(key) for key in keys} - ) + dup_query = tmp_album.duplicates_query(keys) # Don't count albums with the same files as duplicates. task_paths = {i.path for i in self.items if i} @@ -1025,9 +1023,7 @@ def find_duplicates(self, lib): # temporary `Item` object to generate any computed fields. tmp_item = library.Item(lib, **info) keys = config["import"]["duplicate_keys"]["item"].as_str_seq() - dup_query = library.Album.all_fields_query( - {key: tmp_item.get(key) for key in keys} - ) + dup_query = tmp_item.duplicates_query(keys) found_items = [] for other_item in lib.items(dup_query): diff --git a/beets/library.py b/beets/library.py index 2430f71258..d4ec63200d 100644 --- a/beets/library.py +++ b/beets/library.py @@ -25,6 +25,7 @@ import unicodedata from functools import cached_property from pathlib import Path +from typing import TYPE_CHECKING import platformdirs from mediafile import MediaFile, UnreadableFileError @@ -42,6 +43,9 @@ ) from beets.util.functemplate import Template, template +if TYPE_CHECKING: + from .dbcore.query import FieldQuery, FieldQueryType + # To use the SQLite "blob" type, it doesn't suffice to provide a byte # string; SQLite treats that as encoded text. Wrapping it in a # `memoryview` tells it that we actually mean non-text data. @@ -346,6 +350,10 @@ class LibModel(dbcore.Model["Library"]): # Config key that specifies how an instance should be formatted. _format_config_key: str + @cached_classproperty + def writable_media_fields(cls) -> set[str]: + return set(MediaFile.fields()) & cls._fields.keys() + def _template_funcs(self): funcs = DefaultTemplateFunctions(self, self._db).functions() funcs.update(plugins.template_funcs()) @@ -375,6 +383,44 @@ def __str__(self): def __bytes__(self): return self.__str__().encode("utf-8") + # Convenient queries. + + @classmethod + def field_query( + cls, field: str, pattern: str, query_cls: FieldQueryType + ) -> FieldQuery: + """Get a `FieldQuery` for the given field on this model.""" + fast = field in cls.all_db_fields + if field in cls.shared_db_fields: + # This field exists in both tables, so SQLite will encounter + # an OperationalError if we try to use it in a query. + # Using an explicit table name resolves this. + field = f"{cls._table}.{field}" + + return query_cls(field, pattern, fast) + + @classmethod + def any_field_query(cls, *args, **kwargs) -> dbcore.OrQuery: + return dbcore.OrQuery( + [cls.field_query(f, *args, **kwargs) for f in cls._search_fields] + ) + + @classmethod + def any_writable_media_field_query(cls, *args, **kwargs) -> dbcore.OrQuery: + fields = cls.writable_media_fields + return dbcore.OrQuery( + [cls.field_query(f, *args, **kwargs) for f in fields] + ) + + def duplicates_query(self, fields: list[str]) -> dbcore.AndQuery: + """Return a query for entities with same values in the given fields.""" + return dbcore.AndQuery( + [ + self.field_query(f, self.get(f), dbcore.MatchQuery) + for f in fields + ] + ) + class FormattedItemMapping(dbcore.db.FormattedMapping): """Add lookup for album-level fields. @@ -648,6 +694,12 @@ def _getters(cls): getters["filesize"] = Item.try_filesize # In bytes. return getters + def duplicates_query(self, fields: list[str]) -> dbcore.AndQuery: + """Return a query for entities with same values in the given fields.""" + return super().duplicates_query(fields) & dbcore.query.NoneQuery( + "album_id" + ) + @classmethod def from_path(cls, path): """Create a new item from the media file at the specified path.""" @@ -1866,7 +1918,6 @@ def tmpl_sunique(self, keys=None, disam=None, bracket=None): Item.all_keys(), # Do nothing for non singletons. lambda i: i.album_id is not None, - initial_subqueries=[dbcore.query.NoneQuery("album_id", True)], ) def _tmpl_unique_memokey(self, name, keys, disam, item_id): @@ -1885,7 +1936,6 @@ def _tmpl_unique( db_item, item_keys, skip_item, - initial_subqueries=None, ): """Generate a string that is guaranteed to be unique among all items of the same type as "db_item" who share the same set of keys. @@ -1932,15 +1982,7 @@ def _tmpl_unique( bracket_r = "" # Find matching items to disambiguate with. - subqueries = [] - if initial_subqueries is not None: - subqueries.extend(initial_subqueries) - for key in keys: - value = db_item.get(key, "") - # Use slow queries for flexible attributes. - fast = key in item_keys - subqueries.append(dbcore.MatchQuery(key, value, fast)) - query = dbcore.AndQuery(subqueries) + query = db_item.duplicates_query(keys) ambigous_items = ( self.lib.items(query) if isinstance(db_item, Item) diff --git a/beetsplug/aura.py b/beetsplug/aura.py index a9b270657f..e7034c1e9e 100644 --- a/beetsplug/aura.py +++ b/beetsplug/aura.py @@ -186,7 +186,9 @@ def translate_filters(self): value = converter(value) # Add exact match query to list # Use a slow query so it works with all fields - queries.append(MatchQuery(beets_attr, value, fast=False)) + queries.append( + self.model_cls.field_query(beets_attr, value, MatchQuery) + ) # NOTE: AURA doesn't officially support multiple queries return AndQuery(queries) @@ -318,13 +320,12 @@ def all_resources(self): sort = self.translate_sorts(sort_arg) # For each sort field add a query which ensures all results # have a non-empty, non-zero value for that field. - for s in sort.sorts: - query.subqueries.append( - NotQuery( - # Match empty fields (^$) or zero fields, (^0$) - RegexpQuery(s.field, "(^$|^0$)", fast=False) - ) + query.subqueries.extend( + NotQuery( + self.model_cls.field_query(s.field, "(^$|^0$)", RegexpQuery) ) + for s in sort.sorts + ) else: sort = None # Get information from the library diff --git a/beetsplug/bpd/__init__.py b/beetsplug/bpd/__init__.py index da6c2eb468..9d8b4142b3 100644 --- a/beetsplug/bpd/__init__.py +++ b/beetsplug/bpd/__init__.py @@ -26,8 +26,7 @@ import time import traceback from string import Template - -from mediafile import MediaFile +from typing import TYPE_CHECKING import beets import beets.ui @@ -36,6 +35,9 @@ from beets.plugins import BeetsPlugin from beets.util import bluelet +if TYPE_CHECKING: + from beets.dbcore.query import Query + PROTOCOL_VERSION = "0.16.0" BUFSIZE = 1024 @@ -91,8 +93,6 @@ "partition", ] -ITEM_KEYS_WRITABLE = set(MediaFile.fields()).intersection(Item._fields.keys()) - # Gstreamer import error. class NoGstreamerError(Exception): @@ -1399,29 +1399,29 @@ def _tagtype_lookup(self, tag): return test_tag, key raise BPDError(ERROR_UNKNOWN, "no such tagtype") - def _metadata_query(self, query_type, any_query_type, kv): + def _metadata_query(self, query_type, kv, allow_any_query: bool = False): """Helper function returns a query object that will find items according to the library query type provided and the key-value pairs specified. The any_query_type is used for queries of type "any"; if None, then an error is thrown. """ if kv: # At least one key-value pair. - queries = [] + queries: list[Query] = [] # Iterate pairwise over the arguments. it = iter(kv) for tag, value in zip(it, it): if tag.lower() == "any": - if any_query_type: + if allow_any_query: queries.append( - any_query_type( - value, ITEM_KEYS_WRITABLE, query_type + Item.any_writable_media_field_query( + query_type, value ) ) else: raise BPDError(ERROR_UNKNOWN, "no such tagtype") else: _, key = self._tagtype_lookup(tag) - queries.append(query_type(key, value)) + queries.append(Item.field_query(key, value, query_type)) return dbcore.query.AndQuery(queries) else: # No key-value pairs. return dbcore.query.TrueQuery() @@ -1429,14 +1429,14 @@ def _metadata_query(self, query_type, any_query_type, kv): def cmd_search(self, conn, *kv): """Perform a substring match for items.""" query = self._metadata_query( - dbcore.query.SubstringQuery, dbcore.query.AnyFieldQuery, kv + dbcore.query.SubstringQuery, kv, allow_any_query=True ) for item in self.lib.items(query): yield self._item_info(item) def cmd_find(self, conn, *kv): """Perform an exact match for items.""" - query = self._metadata_query(dbcore.query.MatchQuery, None, kv) + query = self._metadata_query(dbcore.query.MatchQuery, kv) for item in self.lib.items(query): yield self._item_info(item) @@ -1456,7 +1456,7 @@ def cmd_list(self, conn, show_tag, *kv): raise BPDError(ERROR_ARG, 'should be "Album" for 3 arguments') elif len(kv) % 2 != 0: raise BPDError(ERROR_ARG, "Incorrect number of filter arguments") - query = self._metadata_query(dbcore.query.MatchQuery, None, kv) + query = self._metadata_query(dbcore.query.MatchQuery, kv) clause, subvals = query.clause() statement = ( @@ -1484,7 +1484,9 @@ def cmd_count(self, conn, tag, value): _, key = self._tagtype_lookup(tag) songs = 0 playtime = 0.0 - for item in self.lib.items(dbcore.query.MatchQuery(key, value)): + for item in self.lib.items( + Item.field_query(key, value, dbcore.query.MatchQuery) + ): songs += 1 playtime += item.length yield "songs: " + str(songs) diff --git a/docs/changelog.rst b/docs/changelog.rst index 186749d461..f38a40a295 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -38,6 +38,9 @@ Bug fixes: request their own last.fm genre. Also log messages regarding what's been tagged are now more polished. :bug:`5582` +* Fix ambiguous column name ``sqlite3.OperationalError`` that occured in album + queries that filtered album track titles, for example ``beet list -a keyword + title:foo``. For packagers: diff --git a/test/test_dbcore.py b/test/test_dbcore.py index ba2b84ad2a..2ff20c3a33 100644 --- a/test/test_dbcore.py +++ b/test/test_dbcore.py @@ -23,6 +23,7 @@ import pytest from beets import dbcore +from beets.library import LibModel from beets.test import _common # Fixture: concrete database and model classes. For migration tests, we @@ -44,7 +45,7 @@ def match(self): return True -class ModelFixture1(dbcore.Model): +class ModelFixture1(LibModel): _table = "test" _flex_table = "testflex" _fields = { @@ -587,7 +588,7 @@ def test_two_parts(self): q = self.qfs(["foo", "bar:baz"]) assert isinstance(q, dbcore.query.AndQuery) assert len(q.subqueries) == 2 - assert isinstance(q.subqueries[0], dbcore.query.AnyFieldQuery) + assert isinstance(q.subqueries[0], dbcore.query.OrQuery) assert isinstance(q.subqueries[1], dbcore.query.SubstringQuery) def test_parse_fixed_type_query(self): diff --git a/test/test_query.py b/test/test_query.py index 6f7fe4da72..f85e5c6370 100644 --- a/test/test_query.py +++ b/test/test_query.py @@ -56,40 +56,6 @@ def assertNotInResult(self, item, results): assert item.id not in result_ids -class AnyFieldQueryTest(ItemInDBTestCase): - def test_no_restriction(self): - q = dbcore.query.AnyFieldQuery( - "title", - beets.library.Item._fields.keys(), - dbcore.query.SubstringQuery, - ) - assert self.lib.items(q).get().title == "the title" - - def test_restriction_completeness(self): - q = dbcore.query.AnyFieldQuery( - "title", ["title"], dbcore.query.SubstringQuery - ) - assert self.lib.items(q).get().title == "the title" - - def test_restriction_soundness(self): - q = dbcore.query.AnyFieldQuery( - "title", ["artist"], dbcore.query.SubstringQuery - ) - assert self.lib.items(q).get() is None - - def test_eq(self): - q1 = dbcore.query.AnyFieldQuery( - "foo", ["bar"], dbcore.query.SubstringQuery - ) - q2 = dbcore.query.AnyFieldQuery( - "foo", ["bar"], dbcore.query.SubstringQuery - ) - assert q1 == q2 - - q2.query_class = None - assert q1 != q2 - - # A test case class providing a library with some dummy data and some # assertions involving that data. class DummyDataTestCase(BeetsTestCase, AssertsMixin): @@ -954,14 +920,6 @@ def test_type_and(self): self.assert_items_matched(not_results, ["foo bar", "beets 4 eva"]) self.assertNegationProperties(q) - def test_type_anyfield(self): - q = dbcore.query.AnyFieldQuery( - "foo", ["title", "artist", "album"], dbcore.query.SubstringQuery - ) - not_results = self.lib.items(dbcore.query.NotQuery(q)) - self.assert_items_matched(not_results, ["baz qux"]) - self.assertNegationProperties(q) - def test_type_boolean(self): q = dbcore.query.BooleanQuery("comp", True) not_results = self.lib.items(dbcore.query.NotQuery(q)) @@ -1135,7 +1093,14 @@ def test_get_items_filter_by_album_field(self): results = self.lib.items(q) self.assert_items_matched(results, ["Album1 Item1", "Album1 Item2"]) - def test_filter_by_common_field(self): - q = "catalognum:ABC Album1" + def test_filter_albums_by_common_field(self): + # title:Album1 ensures that the items table is joined for the query + q = "title:Album1 Album1" results = self.lib.albums(q) self.assert_albums_matched(results, ["Album1"]) + + def test_filter_items_by_common_field(self): + # artpath::A ensures that the albums table is joined for the query + q = "artpath::A Album1" + results = self.lib.items(q) + self.assert_items_matched(results, ["Album1 Item1", "Album1 Item2"])