diff --git a/examples/csp_nonce/main.py b/examples/csp_nonce/main.py index 0c948adcf..feb242cbb 100644 --- a/examples/csp_nonce/main.py +++ b/examples/csp_nonce/main.py @@ -1,4 +1,5 @@ from collections.abc import Callable +from typing import Any from flask import Flask from flask_admin import Admin @@ -19,7 +20,7 @@ content_security_policy_nonce_in=["script-src", "style-src"], ) # Get the CSP nonce generator from jinja environment globals which is added by Talisman -csp_nonce_generator: Callable = app.jinja_env.globals["csp_nonce"] # type: ignore[assignment] +csp_nonce_generator: Callable[[], Any] = app.jinja_env.globals["csp_nonce"] # type: ignore[assignment] @app.route("/") diff --git a/examples/pymongo_simple/main.py b/examples/pymongo_simple/main.py index a6e8ff078..b1cdf46c0 100644 --- a/examples/pymongo_simple/main.py +++ b/examples/pymongo_simple/main.py @@ -1,3 +1,5 @@ +from typing import Any + from bson.objectid import ObjectId from flask import Flask from flask import url_for @@ -107,7 +109,7 @@ def index(): if __name__ == "__main__": with MongoDbContainer("mongo:7.0.7") as mongo: - conn: MongoClient = MongoClient(mongo.get_connection_url()) + conn: MongoClient[Any] = MongoClient(mongo.get_connection_url()) db = conn.test admin.add_view(UserView(db.user, "User")) diff --git a/flask_admin/_compat.py b/flask_admin/_compat.py index bcb5c8993..a2460b840 100644 --- a/flask_admin/_compat.py +++ b/flask_admin/_compat.py @@ -13,27 +13,27 @@ import typing as t from types import MappingProxyType -from flask_admin._types import T_TRANSLATABLE, T_ITER_CHOICES +from flask_admin._types import T_TRANSLATABLE, T_ITER_CHOICES, T_ORM_MODEL text_type = str string_types = (str,) -def itervalues(d: dict) -> t.Iterator[t.Any]: +def itervalues(d: dict[t.Any, t.Any]) -> t.Iterator[t.Any]: return iter(d.values()) def iteritems( - d: dict | MappingProxyType[str, t.Any] | t.Mapping[str, t.Any], + d: dict[t.Any, t.Any] | MappingProxyType[str, t.Any] | t.Mapping[t.Any, t.Any], ) -> t.Iterator[tuple[t.Any, t.Any]]: return iter(d.items()) -def filter_list(f: t.Callable, l: list) -> list[t.Any]: +def filter_list(f: t.Callable[[t.Any], t.Any], l: list[t.Any]) -> list[t.Any]: return list(filter(f, l)) -def as_unicode(s: str | bytes | int) -> str: +def as_unicode(s: t.Any) -> str: if isinstance(s, bytes): return s.decode("utf-8") diff --git a/flask_admin/_types.py b/flask_admin/_types.py index 2939fce94..7dd8afdb3 100644 --- a/flask_admin/_types.py +++ b/flask_admin/_types.py @@ -5,7 +5,6 @@ import wtforms.widgets from flask import Response -from jinja2.runtime import Context from markupsafe import Markup from werkzeug.wrappers import Response as Wkzg_Response from wtforms import Field @@ -14,9 +13,9 @@ from wtforms.widgets import Input if sys.version_info >= (3, 11): - from typing import NotRequired + from typing import NotRequired # noqa else: - from typing_extensions import NotRequired + from typing_extensions import NotRequired # noqa if t.TYPE_CHECKING: from flask_admin.base import BaseView as T_VIEW # noqa @@ -56,18 +55,24 @@ from flask_sqlalchemy import Model as T_SQLALCHEMY_MODEL from peewee import Model as T_PEEWEE_MODEL from peewee import Field as T_PEEWEE_FIELD # noqa - from pymongo import MongoClient as T_MONGO_CLIENT + from pymongo import MongoClient from mongoengine import Document as T_MONGO_ENGINE_CLIENT - import sqlalchemy # noqa - from sqlalchemy import Column as T_SQLALCHEMY_COLUMN + from sqlalchemy import Column from sqlalchemy import Table as T_TABLE # noqa - from sqlalchemy.orm import InstrumentedAttribute as T_INSTRUMENTED_ATTRIBUTE # noqa - from sqlalchemy.orm import scoped_session as T_SQLALCHEMY_SESSION # noqa - from sqlalchemy.orm.query import Query + from sqlalchemy.orm import InstrumentedAttribute from sqlalchemy_utils import Choice as T_CHOICE # noqa from sqlalchemy_utils import ChoiceType as T_CHOICE_TYPE # noqa - T_SQLALCHEMY_QUERY = Query + try: + T_INSTRUMENTED_ATTRIBUTE = InstrumentedAttribute[t.Any] + except TypeError: # Fall back to non-generic types for older SQLAlchemy + T_INSTRUMENTED_ATTRIBUTE = InstrumentedAttribute # type: ignore[misc] + + try: + T_SQLALCHEMY_COLUMN = Column[t.Any] + except TypeError: # Fall back to non-generic types for older SQLAlchemy + T_SQLALCHEMY_COLUMN = Column # type: ignore[misc] + T_MONGO_CLIENT = MongoClient[t.Any] from redis import Redis as T_REDIS # noqa from flask_admin.contrib.peewee.ajax import ( QueryAjaxModelLoader as T_PEEWEE_QUERY_AJAX_MODEL_LOADER, @@ -101,19 +106,17 @@ # optional dependencies T_ARROW = "arrow.Arrow" T_LAZY_STRING = "flask_babel.LazyString" - T_SQLALCHEMY_COLUMN = "sqlalchemy.Column" - T_SQLALCHEMY_MODEL = "flask_sqlalchemy.Model" + T_SQLALCHEMY_COLUMN = "sqlalchemy.Column[t.Any]" + T_SQLALCHEMY_MODEL = t.TypeVar("T_SQLALCHEMY_MODEL", bound=t.Any) T_PEEWEE_FIELD = "peewee.Field" - T_PEEWEE_MODEL = "peewee.Model" - T_MONGO_CLIENT = "pymongo.MongoClient" + T_PEEWEE_MODEL = t.TypeVar("T_PEEWEE_MODEL", bound=t.Any) + T_MONGO_CLIENT = "pymongo.MongoClient[t.Any]" T_MONGO_ENGINE_CLIENT = "mongoengine.Document" T_TABLE = "sqlalchemy.Table" T_CHOICE_TYPE = "sqlalchemy_utils.ChoiceType" T_CHOICE = "sqlalchemy_utils.Choice" - T_SQLALCHEMY_QUERY = "sqlalchemy.orm.query.Query" - T_INSTRUMENTED_ATTRIBUTE = "sqlalchemy.orm.InstrumentedAttribute" - T_SQLALCHEMY_SESSION = "sqlalchemy.orm.scoped_session" + T_INSTRUMENTED_ATTRIBUTE = t.TypeVar("T_INSTRUMENTED_ATTRIBUTE", bound=t.Any) T_REDIS = "redis.Redis" T_PEEWEE_QUERY_AJAX_MODEL_LOADER = ( "flask_admin.contrib.peewee.ajax.QueryAjaxModelLoader" @@ -129,13 +132,6 @@ T_COLUMN_LIST = t.Sequence[ T_ORM_COLUMN | t.Iterable[T_ORM_COLUMN] | tuple[str, tuple[T_ORM_COLUMN, ...]] ] -T_CONTRAVARIANT_MODEL_VIEW = t.TypeVar( - "T_CONTRAVARIANT_MODEL_VIEW", bound=T_MODEL_VIEW, contravariant=True -) -T_FORMATTER = t.Callable[ - [T_CONTRAVARIANT_MODEL_VIEW, Context | None, t.Any, str], str | Markup -] -T_COLUMN_FORMATTERS = dict[str, T_FORMATTER] T_TYPE_FORMATTER = t.Callable[[T_MODEL_VIEW, t.Any, str], str | Markup] T_COLUMN_TYPE_FORMATTERS = dict[type, T_TYPE_FORMATTER] T_TRANSLATABLE = t.Union[str, T_LAZY_STRING] diff --git a/flask_admin/actions.py b/flask_admin/actions.py index 55f8b944f..fe65643b8 100644 --- a/flask_admin/actions.py +++ b/flask_admin/actions.py @@ -11,7 +11,9 @@ from flask_admin.helpers import get_redirect_target -def action(name: str, text: str, confirmation: str | None = None) -> t.Callable: +def action( + name: str, text: str, confirmation: str | None = None +) -> t.Callable[..., t.Any]: """ Use this decorator to expose actions that span more than one entity (model, file, etc) @@ -25,7 +27,7 @@ def action(name: str, text: str, confirmation: str | None = None) -> t.Callable: unconditionally. """ - def wrap(f: t.Callable) -> t.Callable: + def wrap(f: t.Callable[..., t.Any]) -> t.Callable[..., t.Any]: f._action = (name, text, confirmation) # type: ignore[attr-defined] return f diff --git a/flask_admin/babel.py b/flask_admin/babel.py index 4e1dabdfc..54c109828 100644 --- a/flask_admin/babel.py +++ b/flask_admin/babel.py @@ -26,7 +26,7 @@ def ngettext(self, singular: str, plural: str, n: int) -> str: else: from flask_admin import translations - class CustomDomain(Domain): + class CustomDomain(Domain): # type: ignore[misc] def __init__(self) -> None: super().__init__(translations.__path__[0], domain="admin") diff --git a/flask_admin/base.py b/flask_admin/base.py index 9d87d53c2..d6dc49abd 100644 --- a/flask_admin/base.py +++ b/flask_admin/base.py @@ -16,6 +16,7 @@ from flask_admin import babel from flask_admin import helpers as h from flask_admin._compat import as_unicode +from flask_admin._types import T_VIEW # For compatibility reasons import MenuLink from flask_admin.blueprints import _BlueprintWithHostSupport as Blueprint @@ -29,7 +30,9 @@ from flask_admin.theme import Theme -def expose(url: str = "/", methods: t.Iterable[str] | None = ("GET",)) -> t.Callable: +def expose( + url: str = "/", methods: t.Iterable[str] | None = ("GET",) +) -> t.Callable[[t.Any], t.Any]: """ Use this decorator to expose views in your view classes. @@ -48,7 +51,7 @@ def wrap(f: AdminViewMeta) -> AdminViewMeta: return wrap -def expose_plugview(url: str = "/") -> t.Callable: +def expose_plugview(url: str = "/") -> t.Callable[[t.Any], t.Any]: """ Decorator to expose Flask's pluggable view classes (``flask.views.View`` or ``flask.views.MethodView``). @@ -71,7 +74,7 @@ def wrap(v: View | MethodView) -> t.Any: # Base views -def _wrap_view(f: t.Callable) -> t.Callable: +def _wrap_view(f: t.Callable[..., t.Any]) -> t.Callable[..., t.Any]: # Avoid wrapping view method twice if hasattr(f, "_wrapped"): return f @@ -161,7 +164,7 @@ def index(self): """Extra JavaScript files to include in the page""" @property - def _template_args(self) -> dict: + def _template_args(self) -> dict[str, str]: """ Extra template arguments. @@ -418,7 +421,7 @@ def _handle_view(self, name: str, **kwargs: dict[str, t.Any]) -> t.Any: return self.inaccessible_callback(name, **kwargs) def _run_view( - self, fn: t.Callable, *args: tuple[t.Any], **kwargs: dict[str, t.Any] + self, fn: t.Callable[..., t.Any], *args: t.Any, **kwargs: t.Any ) -> t.Any: """ This method will run actual view function. @@ -547,7 +550,7 @@ def __init__( theme: Theme | None = None, category_icon_classes: dict[str, str] | None = None, host: str | None = None, - csp_nonce_generator: t.Callable | None = None, + csp_nonce_generator: t.Callable[[], t.Any] | None = None, ) -> None: """ Constructor. @@ -587,10 +590,10 @@ def __init__( self.translations_path = translations_path - self._views = [] # type: ignore[var-annotated] - self._menu = [] # type: ignore[var-annotated] + self._views: list[T_VIEW] = [] + self._menu: list[MenuView | MenuCategory | BaseMenu] = [] self._menu_categories: dict[str, MenuCategory] = dict() - self._menu_links = [] # type: ignore[var-annotated] + self._menu_links: list[MenuLink] = [] if name is None: name = "Admin" @@ -891,13 +894,13 @@ def _init_extension(self) -> None: admins.append(self) self.app.extensions["admin"] = admins # type: ignore[union-attr] - def menu(self) -> list: + def menu(self) -> list[MenuView | MenuCategory | BaseMenu]: """ Return the menu hierarchy. """ return self._menu - def menu_links(self) -> list: + def menu_links(self) -> list[MenuLink]: """ Return menu links. """ diff --git a/flask_admin/contrib/fileadmin/__init__.py b/flask_admin/contrib/fileadmin/__init__.py index 804d8ac87..77744afeb 100644 --- a/flask_admin/contrib/fileadmin/__init__.py +++ b/flask_admin/contrib/fileadmin/__init__.py @@ -405,7 +405,7 @@ def get_upload_form(self) -> type[form.BaseForm]: Override to implement customized behavior. """ - class UploadForm(self.form_base_class): # type: ignore[name-defined] + class UploadForm(self.form_base_class): # type: ignore[name-defined, misc] """ File upload form. Works with FileAdmin instance to check if it is allowed to upload file with given extension. @@ -435,7 +435,7 @@ def get_edit_form(self) -> type[form.BaseForm]: Override to implement customized behavior. """ - class EditForm(self.form_base_class): # type: ignore[name-defined] + class EditForm(self.form_base_class): # type: ignore[name-defined, misc] content = fields.TextAreaField( lazy_gettext("Content"), (validators.InputRequired(),) ) @@ -456,7 +456,7 @@ def validate_name(self: type[form.BaseForm], field: Field) -> None: if not regexp.match(field.data): raise validators.ValidationError(gettext("Invalid name")) - class NameForm(self.form_base_class): # type: ignore[name-defined] + class NameForm(self.form_base_class): # type: ignore[name-defined, misc] """ Form with a filename input field. @@ -478,7 +478,7 @@ def get_delete_form(self) -> type[form.BaseForm]: Override to implement customized behavior. """ - class DeleteForm(self.form_base_class): # type: ignore[name-defined] + class DeleteForm(self.form_base_class): # type: ignore[name-defined, misc] path = fields.HiddenField(validators=[validators.InputRequired()]) return DeleteForm @@ -490,7 +490,7 @@ def get_action_form(self) -> type[form.BaseForm]: Override to implement customized behavior. """ - class ActionForm(self.form_base_class): # type: ignore[name-defined] + class ActionForm(self.form_base_class): # type: ignore[name-defined, misc] action = fields.HiddenField() url = fields.HiddenField() # rowid is retrieved using getlist, for backward compatibility diff --git a/flask_admin/contrib/fileadmin/s3.py b/flask_admin/contrib/fileadmin/s3.py index 1cfed97a6..2bb89ad17 100644 --- a/flask_admin/contrib/fileadmin/s3.py +++ b/flask_admin/contrib/fileadmin/s3.py @@ -22,11 +22,11 @@ def _strip_leading_slash_from( """ def decorator( - func: t.Callable, + func: t.Callable[..., t.Any], ) -> t.Callable[[tuple[t.Any, ...], dict[str, t.Any]], t.Any]: @functools.wraps(func) def wrapper(*args: t.Any, **kwargs: t.Any) -> t.Any: - args: list = list(args) # type: ignore[no-redef] + args: list[t.Any] = list(args) # type: ignore[no-redef] arg_names = func.__code__.co_varnames[: func.__code__.co_argcount] if arg_name in arg_names: @@ -82,7 +82,7 @@ def __init__(self, s3_client: BaseClient, bucket_name: str) -> None: self.separator = "/" @_strip_leading_slash_from("path") - def get_files(self, path: str, directory: str) -> list: + def get_files(self, path: str, directory: str) -> list[t.Any]: def _strip_path(name: str, path: str) -> str: if name.startswith(path): return name.replace(path, "", 1) diff --git a/flask_admin/contrib/mongoengine/view.py b/flask_admin/contrib/mongoengine/view.py index 073e50e65..319d63569 100644 --- a/flask_admin/contrib/mongoengine/view.py +++ b/flask_admin/contrib/mongoengine/view.py @@ -21,7 +21,6 @@ from flask_admin.model import BaseModelView from flask_admin.model.form import create_editable_list_form -from ...model.filters import BaseFilter from .ajax import create_ajax_loader from .ajax import process_ajax_references from .filters import BaseMongoEngineFilter @@ -60,7 +59,7 @@ class ModelView(BaseModelView): MongoEngine model scaffolding. """ - column_filters: t.Collection[str | BaseFilter] | None = None + column_filters: t.Collection[str | BaseMongoEngineFilter] | None = None """ Collection of the column filters. diff --git a/flask_admin/contrib/peewee/ajax.py b/flask_admin/contrib/peewee/ajax.py index 2e971ecff..c2b83e8fe 100644 --- a/flask_admin/contrib/peewee/ajax.py +++ b/flask_admin/contrib/peewee/ajax.py @@ -20,7 +20,7 @@ def __init__(self, name: str, model: t.Any, **options: t.Any) -> None: super().__init__(name, options) self.model = model - self.fields = t.cast(t.Iterable, options.get("fields")) + self.fields = t.cast(t.Iterable[t.Any], options.get("fields")) if not self.fields: raise ValueError( @@ -48,7 +48,7 @@ def _process_fields(self) -> list[t.Any]: return remote_fields - def format(self, model: None | str | bytes) -> tuple[t.Any, str] | None: + def format(self, model: T_PEEWEE_MODEL | None) -> tuple[t.Any, str] | None: # type: ignore[override] if not model: return None @@ -84,7 +84,7 @@ def create_ajax_loader( model: type[T_PEEWEE_MODEL], name: str, field_name: str, - options: dict[str, t.Any] | list | tuple, + options: dict[str, t.Any], ) -> QueryAjaxModelLoader: prop = getattr(model, field_name, None) @@ -93,4 +93,4 @@ def create_ajax_loader( # TODO: Check for field remote_model = prop.rel_model - return QueryAjaxModelLoader(name, remote_model, **options) # type: ignore[arg-type] + return QueryAjaxModelLoader(name, remote_model, **options) diff --git a/flask_admin/contrib/peewee/form.py b/flask_admin/contrib/peewee/form.py index 782297188..f4aa5f246 100644 --- a/flask_admin/contrib/peewee/form.py +++ b/flask_admin/contrib/peewee/form.py @@ -124,7 +124,7 @@ def save_related(self, obj: t.Any) -> None: f.save_related(model) -class CustomModelConverter(ModelConverter): +class CustomModelConverter(ModelConverter): # type: ignore[misc] def __init__(self, view: t.Any, additional: t.Any = None) -> None: super().__init__(additional) self.view = view @@ -274,7 +274,7 @@ def process_ajax_refs(self, info: InlineFormAdmin) -> dict[str, t.Any]: info.model, # type: ignore[arg-type] new_name, name, - opts, + opts, # type: ignore[arg-type] ) else: loader = opts diff --git a/flask_admin/contrib/peewee/view.py b/flask_admin/contrib/peewee/view.py index cd82f6fae..e4349f53e 100644 --- a/flask_admin/contrib/peewee/view.py +++ b/flask_admin/contrib/peewee/view.py @@ -27,11 +27,11 @@ from ..._types import T_FIELD_ARGS_VALIDATORS_FILES from ..._types import T_FILTER -from ..._types import T_PEEWEE_FIELD from ..._types import T_PEEWEE_MODEL from ..._types import T_WIDGET from .ajax import create_ajax_loader from .ajax import QueryAjaxModelLoader +from .filters import BasePeeweeFilter from .form import CustomModelConverter from .form import get_form from .form import InlineModelConverter @@ -45,12 +45,15 @@ class ModelView(BaseModelView): - column_filters: t.Collection[t.Union[str, T_PEEWEE_FIELD]] | None = None # type: ignore[assignment] + column_filters: t.Collection[str | Field | BasePeeweeFilter] | None = None # type: ignore[assignment] """ Collection of the column filters. - Can contain either field names or instances of - :class:`flask_admin.contrib.peewee.filters.BasePeeweeFilter` classes. + Can contain either: + - Field names (str) or Fields (instances of :class:`peewee.Field`): allow any + filter operation available for the field’s data type. + - Instances of :class:`flask_admin.contrib.peewee.filters.BasePeeweeFilter` + classes: restrict or customize which filters are available for a specific field. Filters will be grouped by name when displayed in the drop-down. @@ -201,7 +204,7 @@ def __init__( menu_icon_type: str | None = None, menu_icon_value: str | None = None, ) -> None: - self._search_fields: list = [] + self._search_fields: list[t.Any] = [] super().__init__( model, name, @@ -373,7 +376,7 @@ def scaffold_inline_form_models(self, form_class: type[Form]) -> type[Form]: # AJAX foreignkey support def _create_ajax_loader( - self, name: str, options: dict[str, t.Any] | list | tuple + self, name: str, options: dict[str, t.Any] ) -> QueryAjaxModelLoader: return create_ajax_loader(self.model, name, name, options) # type: ignore[arg-type] @@ -434,7 +437,7 @@ def get_list( # type: ignore[override] filters: t.Sequence[T_FILTER] | None, execute: bool = True, page_size: int | None = None, - ) -> tuple[int | None, list | ModelSelect]: + ) -> tuple[int | None, list[ModelBase] | ModelSelect]: """ Return records from the database. diff --git a/flask_admin/contrib/pymongo/view.py b/flask_admin/contrib/pymongo/view.py index 4aa50d374..b468b169e 100644 --- a/flask_admin/contrib/pymongo/view.py +++ b/flask_admin/contrib/pymongo/view.py @@ -29,7 +29,7 @@ class ModelView(BaseModelView): MongoEngine model scaffolding. """ - column_filters: t.Collection[str | BaseFilter] | None = None + column_filters: t.Collection[str | BasePyMongoFilter] | None = None """ Collection of the column filters. @@ -195,7 +195,7 @@ def _get_field_value(self, model, name): def _search(self, query, search_term: str): values = search_term.split(" ") - queries: list[dict] = [] + queries: list[dict[str, t.Any]] = [] # Construct inner querie for value in values: @@ -265,7 +265,7 @@ def get_list( # type: ignore[override] # Filters if self._filters: - data: list = [] + data: list[str] | str = [] for flt, _flt_name, value in filters: # type: ignore[union-attr] f = self._filters[flt] @@ -273,7 +273,7 @@ def get_list( # type: ignore[override] if data: if len(data) == 1: - query = data[0] + query = data[0] # type: ignore[assignment] else: query["$and"] = data diff --git a/flask_admin/contrib/sqla/_types.py b/flask_admin/contrib/sqla/_types.py new file mode 100644 index 000000000..4fd39afaa --- /dev/null +++ b/flask_admin/contrib/sqla/_types.py @@ -0,0 +1,11 @@ +import typing as t + +from sqlalchemy.orm import scoped_session +from sqlalchemy.orm.query import Query + +if t.TYPE_CHECKING: # sqlalchemy 2.x types are subscriptable + T_SQLALCHEMY_QUERY = Query[t.Any] + T_SCOPED_SESSION = scoped_session[t.Any] +else: # sqlalchemy 1.x types are not subscriptable + T_SQLALCHEMY_QUERY = Query + T_SCOPED_SESSION = scoped_session diff --git a/flask_admin/contrib/sqla/ajax.py b/flask_admin/contrib/sqla/ajax.py index d2edb06b7..e6c4e3795 100644 --- a/flask_admin/contrib/sqla/ajax.py +++ b/flask_admin/contrib/sqla/ajax.py @@ -1,5 +1,6 @@ import typing as t +from flask_sqlalchemy.model import Model from sqlalchemy import and_ from sqlalchemy import cast from sqlalchemy import or_ @@ -11,9 +12,8 @@ from flask_admin.model.ajax import AjaxModelLoader from flask_admin.model.ajax import DEFAULT_PAGE_SIZE -from ..._types import T_SQLALCHEMY_MODEL -from ..._types import T_SQLALCHEMY_QUERY -from ..._types import T_SQLALCHEMY_SESSION +from ._types import T_SCOPED_SESSION +from ._types import T_SQLALCHEMY_QUERY from .tools import get_primary_key from .tools import has_multiple_pks from .tools import is_association_proxy @@ -24,8 +24,8 @@ class QueryAjaxModelLoader(AjaxModelLoader): def __init__( self, name: str, - session: T_SQLALCHEMY_SESSION, - model: type[T_SQLALCHEMY_MODEL], + session: T_SCOPED_SESSION, + model: type[Model], **options: t.Any, ) -> None: """ @@ -59,7 +59,7 @@ def __init__( self.pk: str = t.cast(str, get_primary_key(model)) - def _process_fields(self) -> list: + def _process_fields(self) -> list[t.Any]: remote_fields = [] for field in self.fields: # type: ignore[union-attr] @@ -76,7 +76,7 @@ def _process_fields(self) -> list: return remote_fields - def format(self, model: None | str | bytes) -> tuple[t.Any, str] | None: + def format(self, model: Model | None) -> tuple[t.Any, str] | None: # type: ignore[override] if not model: return None @@ -106,7 +106,7 @@ def get_list( if self.filters: filters = [ - text(f"{self.model.__tablename__.lower()}.{value}") + text(f"{self.model.__tablename__.lower()}.{value}") # type: ignore[attr-defined] for value in self.filters ] query = query.filter(and_(*filters)) @@ -119,7 +119,7 @@ def get_list( def create_ajax_loader( model: t.Any, - session: T_SQLALCHEMY_SESSION, + session: T_SCOPED_SESSION, name: str, field_name: str, options: dict[str, t.Any], diff --git a/flask_admin/contrib/sqla/fields.py b/flask_admin/contrib/sqla/fields.py index f92a9f2ca..6e635faee 100644 --- a/flask_admin/contrib/sqla/fields.py +++ b/flask_admin/contrib/sqla/fields.py @@ -28,9 +28,9 @@ from ..._types import T_ITER_CHOICES from ..._types import T_ORM_MODEL from ..._types import T_SQLALCHEMY_MODEL -from ..._types import T_SQLALCHEMY_SESSION from ..._types import T_VALIDATOR from ...model.form import InlineBaseFormAdmin +from ._types import T_SCOPED_SESSION from .tools import get_primary_key @@ -112,7 +112,7 @@ def _get_data(self) -> t.Any: def _set_data(self, data: t.Any) -> None: self._data = data - self._formdata: set | str | None = None + self._formdata: set[str] | str | None = None data = property(_get_data, _set_data) @@ -192,7 +192,7 @@ def _get_data(self) -> t.Any: def _set_data(self, data: list[t.Any]) -> None: self._data = data - self._formdata: set | None = None + self._formdata: set[str] | None = None data = property(_get_data, _set_data) @@ -202,7 +202,7 @@ def iter_choices(self) -> t.Iterator[T_ITER_CHOICES]: # type: ignore[override] pk, self.get_label(obj), obj in self.data ) - def process_formdata(self, valuelist: t.Iterable) -> None: + def process_formdata(self, valuelist: t.Iterable[str]) -> None: self._formdata = set(valuelist) def pre_validate(self, form: form.BaseForm) -> None: @@ -261,7 +261,7 @@ class InlineHstoreList(InlineFieldList): def process( self, - formdata: dict | None, # type: ignore[override] + formdata: dict[t.Any, t.Any] | None, # type: ignore[override] data: UnsetValue | list[KeyValue] = unset_value, extra_filters: t.Any = None, ) -> None: @@ -300,7 +300,7 @@ class InlineModelFormList(InlineFieldList): def __init__( self, form: type[form.BaseForm], - session: T_SQLALCHEMY_SESSION, + session: T_SCOPED_SESSION, model: type[T_SQLALCHEMY_MODEL], prop: str, inline_view: t.Any, @@ -378,7 +378,7 @@ class InlineModelOneToOneField(InlineModelFormField): def __init__( self, form: type[form.BaseForm], - session: T_SQLALCHEMY_SESSION, + session: T_SCOPED_SESSION, model: type[T_ORM_MODEL], prop: str, inline_view: InlineBaseFormAdmin, diff --git a/flask_admin/contrib/sqla/filters.py b/flask_admin/contrib/sqla/filters.py index c4f9add89..e9e35abed 100644 --- a/flask_admin/contrib/sqla/filters.py +++ b/flask_admin/contrib/sqla/filters.py @@ -3,14 +3,14 @@ from sqlalchemy.sql import not_ from sqlalchemy.sql import or_ -from sqlalchemy.sql.schema import Column from flask_admin._types import T_OPTIONS -from flask_admin._types import T_SQLALCHEMY_QUERY +from flask_admin._types import T_SQLALCHEMY_COLUMN from flask_admin._types import T_TRANSLATABLE from flask_admin._types import T_WIDGET_TYPE from flask_admin.babel import lazy_gettext from flask_admin.contrib.sqla import tools +from flask_admin.contrib.sqla._types import T_SQLALCHEMY_QUERY from flask_admin.model import filters @@ -21,7 +21,7 @@ class BaseSQLAFilter(filters.BaseFilter): def __init__( self, - column: Column, + column: T_SQLALCHEMY_COLUMN, name: str, options: T_OPTIONS = None, data_type: T_WIDGET_TYPE = None, @@ -42,7 +42,7 @@ def __init__( self.column = column - def get_column(self, alias: t.Any) -> Column: + def get_column(self, alias: t.Any) -> T_SQLALCHEMY_COLUMN: return self.column if alias is None else getattr(alias, self.column.key) def apply( @@ -130,7 +130,7 @@ def operation(self) -> T_TRANSLATABLE: class FilterInList(BaseSQLAFilter): def __init__( self, - column: Column, + column: T_SQLALCHEMY_COLUMN, name: str, options: T_OPTIONS = None, data_type: T_WIDGET_TYPE = None, @@ -247,7 +247,7 @@ class DateSmallerFilter(FilterSmaller, filters.BaseDateFilter): class DateBetweenFilter(BaseSQLAFilter, filters.BaseDateBetweenFilter): def __init__( self, - column: Column, + column: T_SQLALCHEMY_COLUMN, name: str, options: T_OPTIONS = None, data_type: T_WIDGET_TYPE = None, @@ -291,7 +291,7 @@ class DateTimeSmallerFilter(FilterSmaller, filters.BaseDateTimeFilter): class DateTimeBetweenFilter(BaseSQLAFilter, filters.BaseDateTimeBetweenFilter): def __init__( self, - column: Column, + column: T_SQLALCHEMY_COLUMN, name: str, options: T_OPTIONS = None, data_type: T_WIDGET_TYPE = None, @@ -335,7 +335,7 @@ class TimeSmallerFilter(FilterSmaller, filters.BaseTimeFilter): class TimeBetweenFilter(BaseSQLAFilter, filters.BaseTimeBetweenFilter): def __init__( self, - column: Column, + column: T_SQLALCHEMY_COLUMN, name: str, options: T_OPTIONS = None, data_type: T_WIDGET_TYPE = None, @@ -362,7 +362,11 @@ def operation(self) -> T_TRANSLATABLE: class EnumEqualFilter(FilterEqual): def __init__( - self, column: Column, name: str, options: T_OPTIONS = None, **kwargs: t.Any + self, + column: T_SQLALCHEMY_COLUMN, + name: str, + options: T_OPTIONS = None, + **kwargs: t.Any, ) -> None: self.enum_class = column.type.enum_class # type: ignore[attr-defined] super().__init__(column, name, options, **kwargs) @@ -375,7 +379,11 @@ def clean(self, value: t.Any) -> t.Any: class EnumFilterNotEqual(FilterNotEqual): def __init__( - self, column: Column, name: str, options: T_OPTIONS = None, **kwargs: t.Any + self, + column: T_SQLALCHEMY_COLUMN, + name: str, + options: T_OPTIONS = None, + **kwargs: t.Any, ) -> None: self.enum_class = column.type.enum_class # type: ignore[attr-defined] super().__init__(column, name, options, **kwargs) @@ -388,7 +396,11 @@ def clean(self, value: t.Any) -> t.Any: class EnumFilterEmpty(FilterEmpty): def __init__( - self, column: Column, name: str, options: T_OPTIONS = None, **kwargs: t.Any + self, + column: T_SQLALCHEMY_COLUMN, + name: str, + options: T_OPTIONS = None, + **kwargs: t.Any, ) -> None: self.enum_class = column.type.enum_class # type: ignore[attr-defined] super().__init__(column, name, options, **kwargs) @@ -396,7 +408,11 @@ def __init__( class EnumFilterInList(FilterInList): def __init__( - self, column: Column, name: str, options: T_OPTIONS = None, **kwargs: t.Any + self, + column: T_SQLALCHEMY_COLUMN, + name: str, + options: T_OPTIONS = None, + **kwargs: t.Any, ) -> None: self.enum_class = column.type.enum_class # type: ignore[attr-defined] super().__init__(column, name, options, **kwargs) @@ -413,7 +429,11 @@ def clean(self, value: t.Any) -> t.Any: class EnumFilterNotInList(FilterNotInList): def __init__( - self, column: Column, name: str, options: T_OPTIONS = None, **kwargs: t.Any + self, + column: T_SQLALCHEMY_COLUMN, + name: str, + options: T_OPTIONS = None, + **kwargs: t.Any, ) -> None: self.enum_class = column.type.enum_class # type: ignore[attr-defined] super().__init__(column, name, options, **kwargs) @@ -430,7 +450,11 @@ def clean(self, value: t.Any) -> t.Any: class ChoiceTypeEqualFilter(FilterEqual): def __init__( - self, column: Column, name: str, options: T_OPTIONS = None, **kwargs: t.Any + self, + column: T_SQLALCHEMY_COLUMN, + name: str, + options: T_OPTIONS = None, + **kwargs: t.Any, ) -> None: super().__init__(column, name, options, **kwargs) @@ -458,7 +482,11 @@ def apply( class ChoiceTypeNotEqualFilter(FilterNotEqual): def __init__( - self, column: Column, name: str, options: T_OPTIONS = None, **kwargs: t.Any + self, + column: T_SQLALCHEMY_COLUMN, + name: str, + options: T_OPTIONS = None, + **kwargs: t.Any, ) -> None: super().__init__(column, name, options, **kwargs) @@ -487,7 +515,11 @@ def apply( class ChoiceTypeLikeFilter(FilterLike): def __init__( - self, column: Column, name: str, options: T_OPTIONS = None, **kwargs: t.Any + self, + column: T_SQLALCHEMY_COLUMN, + name: str, + options: T_OPTIONS = None, + **kwargs: t.Any, ) -> None: super().__init__(column, name, options, **kwargs) @@ -514,7 +546,11 @@ def apply( class ChoiceTypeNotLikeFilter(FilterNotLike): def __init__( - self, column: Column, name: str, options: T_OPTIONS = None, **kwargs: t.Any + self, + column: T_SQLALCHEMY_COLUMN, + name: str, + options: T_OPTIONS = None, + **kwargs: t.Any, ) -> None: super().__init__(column, name, options, **kwargs) @@ -644,7 +680,7 @@ class FilterConverter(filters.BaseFilterConverter): arrow_type_filters = (DateTimeGreaterFilter, DateTimeSmallerFilter, FilterEmpty) def convert( - self, type_name: str, column: Column, name: str, **kwargs: t.Any + self, type_name: str, column: T_SQLALCHEMY_COLUMN, name: str, **kwargs: t.Any ) -> list[BaseSQLAFilter] | None: filter_name = type_name.lower() @@ -672,19 +708,19 @@ def convert( "IPAddressType", ) def conv_string( - self, column: Column, name: str, **kwargs: t.Any + self, column: T_SQLALCHEMY_COLUMN, name: str, **kwargs: t.Any ) -> list[BaseSQLAFilter]: return [f(column, name, **kwargs) for f in self.strings] @filters.convert("UUIDType", "ColorType", "TimezoneType", "CurrencyType") def conv_string_keys( - self, column: Column, name: str, **kwargs: t.Any + self, column: T_SQLALCHEMY_COLUMN, name: str, **kwargs: t.Any ) -> list[BaseSQLAFilter]: return [f(column, name, **kwargs) for f in self.string_key_filters] @filters.convert("boolean", "tinyint") def conv_bool( - self, column: Column, name: str, **kwargs: t.Any + self, column: T_SQLALCHEMY_COLUMN, name: str, **kwargs: t.Any ) -> list[BaseSQLAFilter]: return [f(column, name, **kwargs) for f in self.bool_filters] @@ -698,7 +734,7 @@ def conv_bool( "mediumint", ) def conv_int( - self, column: Column, name: str, **kwargs: t.Any + self, column: T_SQLALCHEMY_COLUMN, name: str, **kwargs: t.Any ) -> list[BaseSQLAFilter]: return [f(column, name, **kwargs) for f in self.int_filters] @@ -706,43 +742,47 @@ def conv_int( "float", "real", "decimal", "numeric", "double_precision", "double" ) def conv_float( - self, column: Column, name: str, **kwargs: t.Any + self, column: T_SQLALCHEMY_COLUMN, name: str, **kwargs: t.Any ) -> list[BaseSQLAFilter]: return [f(column, name, **kwargs) for f in self.float_filters] @filters.convert("date") def conv_date( - self, column: Column, name: str, **kwargs: t.Any + self, column: T_SQLALCHEMY_COLUMN, name: str, **kwargs: t.Any ) -> list[BaseSQLAFilter]: return [f(column, name, **kwargs) for f in self.date_filters] @filters.convert("datetime", "datetime2", "timestamp", "smalldatetime") def conv_datetime( - self, column: Column, name: str, **kwargs: t.Any + self, column: T_SQLALCHEMY_COLUMN, name: str, **kwargs: t.Any ) -> list[BaseSQLAFilter]: return [f(column, name, **kwargs) for f in self.datetime_filters] @filters.convert("time") def conv_time( - self, column: Column, name: str, **kwargs: t.Any + self, column: T_SQLALCHEMY_COLUMN, name: str, **kwargs: t.Any ) -> list[BaseSQLAFilter]: return [f(column, name, **kwargs) for f in self.time_filters] @filters.convert("ChoiceType") def conv_sqla_utils_choice( - self, column: Column, name: str, **kwargs: t.Any + self, column: T_SQLALCHEMY_COLUMN, name: str, **kwargs: t.Any ) -> list[BaseSQLAFilter]: return [f(column, name, **kwargs) for f in self.choice_type_filters] @filters.convert("ArrowType") def conv_sqla_utils_arrow( - self, column: Column, name: str, **kwargs: t.Any + self, column: T_SQLALCHEMY_COLUMN, name: str, **kwargs: t.Any ) -> list[BaseSQLAFilter]: return [f(column, name, **kwargs) for f in self.arrow_type_filters] @filters.convert("enum") def conv_enum( - self, column: Column, name: str, options: T_OPTIONS = None, **kwargs: t.Any + self, + column: T_SQLALCHEMY_COLUMN, + name: str, + options: T_OPTIONS = None, + **kwargs: t.Any, ) -> list[BaseSQLAFilter]: if not options: options = [(v, v) for v in column.type.enums] # type: ignore[attr-defined] @@ -751,6 +791,6 @@ def conv_enum( @filters.convert("uuid") def conv_uuid( - self, column: Column, name: str, **kwargs: t.Any + self, column: T_SQLALCHEMY_COLUMN, name: str, **kwargs: t.Any ) -> list[BaseSQLAFilter]: return [f(column, name, **kwargs) for f in self.uuid_filters] diff --git a/flask_admin/contrib/sqla/form.py b/flask_admin/contrib/sqla/form.py index daed509cf..15dc70344 100644 --- a/flask_admin/contrib/sqla/form.py +++ b/flask_admin/contrib/sqla/form.py @@ -1,12 +1,14 @@ +from __future__ import annotations + import typing as t import warnings +from collections.abc import Callable from enum import Enum from enum import EnumMeta from sqlalchemy import Boolean from sqlalchemy import Column from sqlalchemy.orm import ColumnProperty -from sqlalchemy.orm import InstrumentedAttribute from wtforms import Field from wtforms import fields from wtforms import Form @@ -34,12 +36,14 @@ from ..._types import T_FIELD_ARGS_VALIDATORS from ..._types import T_FIELD_ARGS_VALIDATORS_ALLOW_BLANK from ..._types import T_FIELD_ARGS_VALIDATORS_FILES +from ..._types import T_INSTRUMENTED_ATTRIBUTE from ..._types import T_MODEL_VIEW from ..._types import T_ORM_MODEL +from ..._types import T_SQLALCHEMY_COLUMN from ..._types import T_SQLALCHEMY_INLINE_MODELS from ..._types import T_SQLALCHEMY_MODEL -from ..._types import T_SQLALCHEMY_SESSION from ...form import Select2Field +from ._types import T_SCOPED_SESSION from .ajax import create_ajax_loader from .fields import HstoreForm from .fields import InlineHstoreList @@ -63,7 +67,7 @@ class AdminModelConverter(ModelConverterBase): SQLAlchemy model to form converter """ - def __init__(self, session: T_SQLALCHEMY_SESSION, view: T_MODEL_VIEW) -> None: + def __init__(self, session: T_SCOPED_SESSION, view: T_MODEL_VIEW) -> None: super().__init__() self.session = session @@ -103,7 +107,7 @@ def _get_description( return column_descriptions.get(name) return None - def _get_field_override(self, name: str) -> t.Callable | None: + def _get_field_override(self, name: str) -> Callable[..., t.Any] | None: form_overrides = getattr(self.view, "form_overrides", None) if form_overrides: @@ -113,7 +117,7 @@ def _get_field_override(self, name: str) -> t.Callable | None: def _model_select_field( self, - prop: ColumnProperty | InstrumentedAttribute, + prop: ColumnProperty[t.Any] | T_INSTRUMENTED_ATTRIBUTE, multiple: str | bool, remote_model: type[T_SQLALCHEMY_MODEL], **kwargs: t.Any, @@ -201,7 +205,7 @@ def convert( model: type[T_SQLALCHEMY_MODEL], mapper: t.Any, name: str, - prop: FieldPlaceholder | ColumnProperty | InstrumentedAttribute, + prop: FieldPlaceholder | ColumnProperty[t.Any] | T_INSTRUMENTED_ATTRIBUTE, field_args: T_FIELD_ARGS_VALIDATORS, hidden_pk: bool, ) -> ( @@ -262,13 +266,13 @@ def convert( column = columns[0] else: # Grab column - column = prop.columns[0] + column = prop.columns[0] # type: ignore[assignment] form_columns = getattr(self.view, "form_columns", None) or () # Do not display foreign keys - use relations, except when explicitly # instructed - if column.foreign_keys and prop.key not in form_columns: + if column.foreign_keys and prop.key not in form_columns: # type: ignore[union-attr] return None # Only display "real" columns @@ -364,7 +368,7 @@ def convert( @classmethod def _nullable_common( cls, - column: Column, + column: T_SQLALCHEMY_COLUMN, field_args: T_FIELD_ARGS_FILTERS | T_FIELD_ARGS_VALIDATORS, ) -> None: if column.nullable: @@ -375,7 +379,7 @@ def _nullable_common( @classmethod def _string_common( cls, - column: Column, + column: T_SQLALCHEMY_COLUMN, field_args: T_FIELD_ARGS_VALIDATORS, **extra: t.Any, ) -> None: @@ -390,7 +394,7 @@ def _string_common( @converts("String") # includes VARCHAR, CHAR, and Unicode def conv_String( self, - column: Column, + column: T_SQLALCHEMY_COLUMN, field_args: T_FIELD_ARGS_VALIDATORS, **extra: t.Any, ) -> fields.StringField: @@ -400,7 +404,7 @@ def conv_String( @converts("sqlalchemy.sql.sqltypes.Enum") def convert_enum( self, - column: Column, + column: T_SQLALCHEMY_COLUMN, field_args: T_FIELD_ARGS_FILTERS, **extra: t.Any, ) -> form.Select2Field: @@ -420,7 +424,10 @@ def convert_enum( @converts("sqlalchemy_utils.types.choice.ChoiceType") def convert_choice_type( - self, column: Column, field_args: T_FIELD_ARGS_FILTERS, **extra: t.Any + self, + column: T_SQLALCHEMY_COLUMN, + field_args: T_FIELD_ARGS_FILTERS, + **extra: t.Any, ) -> form.Select2Field: available_choices: list[Enum] | list[tuple[int, str]] = [] # choices can either be specified as an enum, or as a list of tuples @@ -491,7 +498,7 @@ def convert_arrow_time( def convert_email( self, field_args: T_FIELD_ARGS_VALIDATORS, - column: Column | None = None, + column: T_SQLALCHEMY_COLUMN | None = None, **extra: t.Any, ) -> fields.StringField: self._nullable_common(column, field_args) # type: ignore[arg-type] @@ -537,7 +544,10 @@ def convert_currency( @converts("sqlalchemy_utils.types.timezone.TimezoneType") def convert_timezone( - self, column: Column, field_args: T_FIELD_ARGS_VALIDATORS, **extra: t.Any + self, + column: T_SQLALCHEMY_COLUMN, + field_args: T_FIELD_ARGS_VALIDATORS, + **extra: t.Any, ) -> fields.StringField: field_args["validators"].append( TimeZoneValidator(coerce_function=column.type._coerce) # type: ignore[attr-defined] @@ -546,7 +556,10 @@ def convert_timezone( @converts("Integer") # includes BigInteger and SmallInteger def handle_integer_types( - self, column: Column, field_args: T_FIELD_ARGS_VALIDATORS, **extra: t.Any + self, + column: T_SQLALCHEMY_COLUMN, + field_args: T_FIELD_ARGS_VALIDATORS, + **extra: t.Any, ) -> fields.IntegerField: unsigned = getattr(column.type, "unsigned", False) if unsigned: @@ -649,7 +662,7 @@ def choice_coerce(value: t.Any) -> t.Any: return choice_coerce -def _resolve_prop(prop: ColumnProperty) -> ColumnProperty: +def _resolve_prop(prop: ColumnProperty[t.Any]) -> ColumnProperty[t.Any]: """ Resolve proxied property @@ -668,12 +681,12 @@ def get_form( model: type[T_SQLALCHEMY_MODEL], converter: AdminModelConverter, base_class: type[form.BaseForm] = form.BaseForm, - only: t.Collection[str | InstrumentedAttribute] | None = None, - exclude: t.Collection[str | InstrumentedAttribute] | None = None, + only: t.Collection[str | T_INSTRUMENTED_ATTRIBUTE] | None = None, + exclude: t.Collection[str | T_INSTRUMENTED_ATTRIBUTE] | None = None, field_args: dict[str, T_FIELD_ARGS_VALIDATORS_FILES] | None = None, hidden_pk: bool = False, ignore_hidden: bool = True, - extra_fields: dict[str | InstrumentedAttribute, Field] | None = None, + extra_fields: dict[str | T_INSTRUMENTED_ATTRIBUTE, Field] | None = None, ) -> type: """ Generate form from the model. @@ -707,7 +720,7 @@ def get_form( if only: def find( - name: str | InstrumentedAttribute, + name: str | T_INSTRUMENTED_ATTRIBUTE, ) -> ( tuple[str, FieldPlaceholder] | tuple[str, t.Any | None] @@ -720,7 +733,7 @@ def find( column, path = get_field_with_path( model, name, return_remote_proxy_attr=False ) - column = t.cast(InstrumentedAttribute, column) + column = t.cast(T_INSTRUMENTED_ATTRIBUTE, column) if path and not (is_relationship(column) or is_association_proxy(column)): raise Exception( "form column is located in another table and " @@ -784,9 +797,9 @@ class InlineModelConverter(InlineModelConverterBase): def __init__( self, - session: T_SQLALCHEMY_SESSION, + session: T_SCOPED_SESSION, view: T_MODEL_VIEW, - model_converter: t.Callable[[T_SQLALCHEMY_SESSION, t.Any], t.Any], + model_converter: t.Callable[[T_SCOPED_SESSION, t.Any], t.Any], ) -> None: """ Constructor. @@ -833,7 +846,7 @@ def get_info( return info - def process_ajax_refs(self, info: InlineFormAdmin) -> dict: + def process_ajax_refs(self, info: InlineFormAdmin) -> dict[t.Any, t.Any]: refs = getattr(info, "form_ajax_refs", None) result = {} @@ -951,7 +964,7 @@ def contribute( """ info = t.cast(T_MODEL_VIEW, self.get_info(inline_model)) # type: ignore[arg-type] - forward_reverse_props_keys: dict = self._calculate_mapping_key_pair( + forward_reverse_props_keys: dict[str, str] = self._calculate_mapping_key_pair( model, info, # type: ignore[arg-type] ) @@ -993,7 +1006,9 @@ def contribute( kwargs["label"] = label if self.view.form_args: - field_args = t.cast(dict, self.view.form_args.get(forward_prop_key, {})) + field_args = t.cast( + dict[t.Any, t.Any], self.view.form_args.get(forward_prop_key, {}) + ) kwargs.update(**field_args) # Contribute field @@ -1109,7 +1124,7 @@ def contribute( # Post-process form child_form = info.postprocess_form(child_form) # type: ignore[attr-defined] - kwargs: dict = dict() + kwargs: dict[t.Any, t.Any] = dict() # Contribute field for key in inline_relationships.keys(): diff --git a/flask_admin/contrib/sqla/tools.py b/flask_admin/contrib/sqla/tools.py index c2f5ad190..ea61cbfa1 100644 --- a/flask_admin/contrib/sqla/tools.py +++ b/flask_admin/contrib/sqla/tools.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import types import typing as t @@ -9,10 +11,12 @@ from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.orm.clsregistry import _class_resolver from sqlalchemy.orm.properties import ColumnProperty -from sqlalchemy.sql.schema import Column from sqlalchemy.sql.schema import Table +from flask_admin._types import T_COLUMN +from flask_admin._types import T_INSTRUMENTED_ATTRIBUTE from flask_admin._types import T_ORM_MODEL +from flask_admin._types import T_SQLALCHEMY_COLUMN from flask_admin._types import T_SQLALCHEMY_MODEL try: @@ -46,7 +50,9 @@ def parse_like_term(term: str) -> str: return stmt -def filter_foreign_columns(base_table: Table, columns: list) -> list: +def filter_foreign_columns( + base_table: Table, columns: list[T_COLUMN] +) -> list[T_COLUMN]: """ Return list of columns that belong to passed table. @@ -119,7 +125,7 @@ def tuple_operator_in( def get_query_for_ids( - modelquery: t.Any, model: type[T_SQLALCHEMY_MODEL], ids: tuple + modelquery: t.Any, model: type[T_SQLALCHEMY_MODEL], ids: tuple[str, ...] ) -> t.Any: """ Return a query object filtered by primary key values passed in `ids` argument. @@ -157,8 +163,8 @@ def get_query_for_ids( def get_columns_for_field( - field: InstrumentedAttribute | ColumnProperty, -) -> list[Column]: + field: T_INSTRUMENTED_ATTRIBUTE | ColumnProperty[t.Any], +) -> list[T_SQLALCHEMY_COLUMN]: if ( not field or not hasattr(field, "property") @@ -179,9 +185,9 @@ def need_join(model: type[T_SQLALCHEMY_MODEL], table: Table) -> bool: def get_field_with_path( model: type[T_SQLALCHEMY_MODEL], - name: str | InstrumentedAttribute | ColumnProperty, + name: str | T_INSTRUMENTED_ATTRIBUTE | ColumnProperty[t.Any], return_remote_proxy_attr: bool = True, -) -> tuple[t.Any | None, list]: +) -> tuple[t.Any, list[t.Any]]: """ Resolve property by name and figure out its join path. @@ -236,7 +242,7 @@ def get_field_with_path( # copied from sqlalchemy-utils def get_hybrid_properties( model: type[T_SQLALCHEMY_MODEL], -) -> dict[str, hybrid_property]: +) -> dict[str, hybrid_property[t.Any]]: return dict( (key, prop) for key, prop in inspect(model).all_orm_descriptors.items() @@ -265,11 +271,13 @@ def is_hybrid_property(model: type[T_SQLALCHEMY_MODEL], attr_name: str) -> bool: return attr_name.name in get_hybrid_properties(model) -def is_relationship(attr: InstrumentedAttribute) -> bool: +def is_relationship(attr: T_INSTRUMENTED_ATTRIBUTE) -> bool: return hasattr(attr, "property") and hasattr(attr.property, "direction") -def is_association_proxy(attr: ColumnProperty | InstrumentedAttribute) -> bool: +def is_association_proxy( + attr: ColumnProperty[t.Any] | T_INSTRUMENTED_ATTRIBUTE, +) -> bool: if hasattr(attr, "parent"): attr = attr.parent # type: ignore[assignment] return hasattr(attr, "extension_type") and attr.extension_type == ASSOCIATION_PROXY diff --git a/flask_admin/contrib/sqla/validators.py b/flask_admin/contrib/sqla/validators.py index 89dab53aa..fbb5eebdb 100644 --- a/flask_admin/contrib/sqla/validators.py +++ b/flask_admin/contrib/sqla/validators.py @@ -10,9 +10,9 @@ from flask_admin._compat import filter_list from flask_admin._types import T_COLUMN from flask_admin._types import T_SQLALCHEMY_MODEL -from flask_admin._types import T_SQLALCHEMY_SESSION from flask_admin._types import T_TRANSLATABLE from flask_admin.babel import lazy_gettext +from flask_admin.contrib.sqla._types import T_SCOPED_SESSION class Unique: @@ -32,7 +32,7 @@ class Unique: def __init__( self, - db_session: T_SQLALCHEMY_SESSION, + db_session: T_SCOPED_SESSION, model: type[T_SQLALCHEMY_MODEL], column: T_COLUMN, message: T_TRANSLATABLE | None = None, diff --git a/flask_admin/contrib/sqla/view.py b/flask_admin/contrib/sqla/view.py index d5e90f62b..5d488560f 100644 --- a/flask_admin/contrib/sqla/view.py +++ b/flask_admin/contrib/sqla/view.py @@ -7,7 +7,6 @@ from flask import current_app from flask import flash from sqlalchemy import Boolean -from sqlalchemy import Column from sqlalchemy import func from sqlalchemy import or_ from sqlalchemy import Table @@ -18,6 +17,7 @@ from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.orm.base import instance_state from sqlalchemy.orm.base import manager_of_class +from sqlalchemy.sql.elements import BinaryExpression from sqlalchemy.sql.expression import cast as sql_cast from sqlalchemy.sql.expression import desc from wtforms import Form @@ -40,12 +40,13 @@ from ..._types import T_COLUMN_LIST from ..._types import T_FIELD_ARGS_VALIDATORS_FILES from ..._types import T_FILTER +from ..._types import T_INSTRUMENTED_ATTRIBUTE +from ..._types import T_SQLALCHEMY_COLUMN from ..._types import T_SQLALCHEMY_INLINE_MODELS from ..._types import T_SQLALCHEMY_MODEL -from ..._types import T_SQLALCHEMY_QUERY -from ..._types import T_SQLALCHEMY_SESSION from ..._types import T_WIDGET -from ...model.filters import BaseFilter +from ._types import T_SCOPED_SESSION +from ._types import T_SQLALCHEMY_QUERY from .ajax import create_ajax_loader from .ajax import QueryAjaxModelLoader from .filters import BaseSQLAFilter @@ -145,7 +146,7 @@ class MyModelView(ModelView): used. """ - column_filters: t.Collection[str | BaseFilter] | None = None + column_filters: t.Collection[str | BaseSQLAFilter] | None = None """ Collection of the column filters. @@ -339,7 +340,7 @@ class MyModelView(BaseModelView): def __init__( self, model: type[T_SQLALCHEMY_MODEL], - session: T_SQLALCHEMY_SESSION, + session: T_SCOPED_SESSION, name: str | None = None, category: str | None = None, endpoint: str | None = None, @@ -379,11 +380,11 @@ def __init__( """ self.session = session - self._search_fields: list[tuple[Column, t.Any]] | None = None + self._search_fields: list[tuple[T_SQLALCHEMY_COLUMN, t.Any]] | None = None - self._filter_joins: dict = dict() + self._filter_joins: dict[tuple[bool, t.Any], t.Any] = dict() - self._sortable_joins: dict = dict() + self._sortable_joins: dict[T_COLUMN, list[T_INSTRUMENTED_ATTRIBUTE]] = dict() if self.form_choices is None: self.form_choices = {} @@ -408,7 +409,7 @@ def __init__( raise Exception(f"Model {self.model.__name__} does not have primary key.") # Configuration - self._auto_joins: t.Iterable + self._auto_joins: t.Iterable[t.Any] if not self.column_select_related_list: self._auto_joins = self.scaffold_auto_joins() else: @@ -417,7 +418,7 @@ def __init__( # Internal API def _get_model_iterator( self, model: type[T_SQLALCHEMY_MODEL] | None = None - ) -> t.Iterable: + ) -> t.Iterable[t.Any]: """ Return property iterator for the model """ @@ -429,10 +430,10 @@ def _get_model_iterator( def _apply_path_joins( self, query: T_SQLALCHEMY_QUERY, - joins: dict, - path: t.Iterable | None, + joins: dict[tuple[bool, t.Any], t.Any], + path: t.Iterable[t.Any] | None, inner_join: bool = True, - ) -> tuple[T_SQLALCHEMY_QUERY, dict, t.Any | None]: + ) -> tuple[T_SQLALCHEMY_QUERY, dict[tuple[bool, t.Any], t.Any], t.Any | None]: """ Apply join path to the query. @@ -492,7 +493,7 @@ def get_pk_value( else: return tools.escape(getattr(model, self._primary_key)) - def scaffold_list_columns(self) -> list: + def scaffold_list_columns(self) -> list[t.Any]: """ Return a list of columns from the model. """ @@ -525,10 +526,10 @@ def scaffold_list_columns(self) -> list: else: column = p.columns[0] - if column.foreign_keys: + if column.foreign_keys: # type: ignore[union-attr] continue - if not self.column_display_pk and column.primary_key: + if not self.column_display_pk and column.primary_key: # type: ignore[union-attr] continue columns.append(p.key) @@ -575,31 +576,32 @@ def get_sortable_columns(self) -> dict[T_COLUMN, T_COLUMN]: if self.column_sortable_list is None: return self.scaffold_sortable_columns() else: - result = dict() + result: dict[T_COLUMN, T_COLUMN] = dict() self.model = t.cast(type[T_SQLALCHEMY_MODEL], self.model) for c in self.column_sortable_list: if isinstance(c, tuple): if isinstance(c[1], tuple): - column, path = [], [] + column: list[T_COLUMN] = [] + path: list[T_COLUMN] = [] for item in c[1]: column_item, path_item = tools.get_field_with_path( self.model, item ) column.append(column_item) - path.append(path_item) + path.append(path_item) # type: ignore[arg-type] column_name = c[0] else: - column, path = tools.get_field_with_path(self.model, c[1]) # type: ignore[assignment] + column, path = tools.get_field_with_path(self.model, c[1]) column_name = c[0] else: - column, path = tools.get_field_with_path( # type: ignore[assignment] + column, path = tools.get_field_with_path( self.model, c, # type: ignore[arg-type] ) column_name = text_type(c) if path and (hasattr(path[0], "property") or isinstance(path[0], list)): - self._sortable_joins[column_name] = path + self._sortable_joins[column_name] = path # type: ignore[assignment] elif path: raise Exception( "For sorting columns in a related table, " @@ -613,9 +615,8 @@ def get_sortable_columns(self) -> dict[T_COLUMN, T_COLUMN]: column_name = column.key # type: ignore[attr-defined] # column_name must match column_name used in `get_list_columns` - result[column_name] = column - - return result # type: ignore[return-value] + result[column_name] = column # type: ignore[assignment] + return result def get_column_names( self, @@ -651,7 +652,7 @@ def get_column_names( else: # column is in same table, use only model attribute name if getattr(column, "key", None) is not None: - column_name = column.key # type: ignore[union-attr] + column_name = column.key else: column_name = text_type(c) except AttributeError: @@ -933,7 +934,7 @@ def scaffold_inline_form_models(self, form_class: type[Form]) -> type[Form]: form_class = custom_converter.contribute(self.model, form_class, m) return form_class - def scaffold_auto_joins(self) -> list: + def scaffold_auto_joins(self) -> list[t.Any]: """ Return a list of joined tables by going through the displayed columns. @@ -1008,11 +1009,11 @@ def get_count_query(self) -> T_SQLALCHEMY_QUERY: def _order_by( self, query: T_SQLALCHEMY_QUERY, - joins: dict, - sort_joins: dict, - sort_field: InstrumentedAttribute | None, + joins: dict[tuple[bool, t.Any], t.Any], + sort_joins: list[T_INSTRUMENTED_ATTRIBUTE], + sort_field: T_INSTRUMENTED_ATTRIBUTE | None, sort_desc: bool, - ) -> tuple[T_SQLALCHEMY_QUERY, dict]: + ) -> tuple[T_SQLALCHEMY_QUERY, dict[tuple[bool, t.Any], t.Any]]: """ Apply order_by to the query @@ -1044,7 +1045,7 @@ def _order_by( def _get_default_order( # type: ignore[override] self, - ) -> t.Generator[tuple[t.Any | None, list, bool], None, None]: + ) -> t.Generator[tuple[t.Any | None, list[t.Any], bool], None, None]: order = super()._get_default_order() for field, direction in order or []: attr, joins = tools.get_field_with_path( @@ -1056,16 +1057,19 @@ def _get_default_order( # type: ignore[override] def _apply_sorting( self, query: T_SQLALCHEMY_QUERY, - joins: dict, + joins: dict[tuple[bool, t.Any], t.Any], sort_column: T_COLUMN | None, sort_desc: bool, - ) -> tuple[T_SQLALCHEMY_QUERY, dict]: + ) -> tuple[T_SQLALCHEMY_QUERY, dict[tuple[bool, t.Any], t.Any]]: if sort_column is not None: if sort_column in self._sortable_columns: sort_field = t.cast( - InstrumentedAttribute, self._sortable_columns[sort_column] + T_INSTRUMENTED_ATTRIBUTE, self._sortable_columns[sort_column] + ) + sort_joins = t.cast( + list[T_INSTRUMENTED_ATTRIBUTE], + self._sortable_joins.get(sort_column), ) - sort_joins = t.cast(dict, self._sortable_joins.get(sort_column)) if isinstance(sort_field, list): for field_item, join_item in zip( @@ -1091,10 +1095,15 @@ def _apply_search( self, query: T_SQLALCHEMY_QUERY, count_query: t.Optional[T_SQLALCHEMY_QUERY], - joins: dict, - count_joins: dict, + joins: dict[tuple[bool, t.Any], t.Any], + count_joins: dict[tuple[bool, t.Any], t.Any], search: str, - ) -> tuple[T_SQLALCHEMY_QUERY, t.Optional[T_SQLALCHEMY_QUERY], dict, dict]: + ) -> tuple[ + T_SQLALCHEMY_QUERY, + t.Optional[T_SQLALCHEMY_QUERY], + dict[tuple[bool, t.Any], t.Any], + dict[tuple[bool, t.Any], t.Any], + ]: """ Apply search to a query. """ @@ -1107,7 +1116,7 @@ def _apply_search( stmt = tools.parse_like_term(term) filter_stmt = [] - count_filter_stmt: list = [] + count_filter_stmt: list[BinaryExpression[t.Any]] = [] for field, path in self._search_fields: # type: ignore[union-attr] query, joins, alias = self._apply_path_joins( @@ -1143,10 +1152,15 @@ def _apply_filters( self, query: T_SQLALCHEMY_QUERY, count_query: t.Optional[T_SQLALCHEMY_QUERY], - joins: dict, - count_joins: dict, + joins: dict[tuple[bool, t.Any], t.Any], + count_joins: dict[tuple[bool, t.Any], t.Any], filters: t.Sequence[T_FILTER], - ) -> tuple[T_SQLALCHEMY_QUERY, t.Optional[T_SQLALCHEMY_QUERY], dict, dict]: + ) -> tuple[ + T_SQLALCHEMY_QUERY, + t.Optional[T_SQLALCHEMY_QUERY], + dict[tuple[bool, t.Any], t.Any], + dict[tuple[bool, t.Any], t.Any], + ]: for idx, _flt_name, value in filters: flt = self._filters[idx] # type: ignore[index] @@ -1157,7 +1171,7 @@ def _apply_filters( if isinstance(flt, sqla_filters.BaseSQLAFilter): # If no key_name is specified, use filter column as filter key filter_key = flt.key_name or flt.column - path = self._filter_joins.get(filter_key, []) + path = self._filter_joins.get(filter_key, []) # type: ignore[call-overload] query, joins, alias = self._apply_path_joins( query, joins, path, inner_join=False @@ -1246,8 +1260,8 @@ def get_list( # type: ignore[override] """ # Will contain join paths with optional aliased object - joins: dict = {} - count_joins: dict = {} + joins: dict[tuple[bool, t.Any], t.Any] = {} + count_joins: dict[tuple[bool, t.Any], t.Any] = {} query = self.get_query() count_query = self.get_count_query() if not self.simple_list_pager else None @@ -1439,7 +1453,7 @@ def is_action_allowed(self, name: str) -> bool: lazy_gettext("Delete"), lazy_gettext("Are you sure you want to delete selected records?"), ) - def action_delete(self, ids: tuple) -> None: + def action_delete(self, ids: tuple[str, ...]) -> None: try: query = tools.get_query_for_ids( self.get_query(), diff --git a/flask_admin/form/__init__.py b/flask_admin/form/__init__.py index e700968b7..70784f443 100644 --- a/flask_admin/form/__init__.py +++ b/flask_admin/form/__init__.py @@ -25,7 +25,7 @@ def get_translations(self, form: "BaseForm") -> Translations: def __init__( self, - formdata: dict | None = None, + formdata: dict[str, t.Any] | None = None, obj: t.Any = None, prefix: str = "", **kwargs: t.Any, @@ -44,13 +44,13 @@ class FormOpts: __slots__ = ["widget_args", "form_rules"] def __init__( - self, widget_args: dict | None = None, form_rules: t.Any = None + self, widget_args: dict[t.Any, t.Any] | None = None, form_rules: t.Any = None ) -> None: self.widget_args = widget_args or {} self.form_rules = form_rules -def recreate_field(unbound: UnboundField | Field) -> t.Any: +def recreate_field(unbound: "UnboundField[t.Any] | Field") -> t.Any: """ Create new instance of the unbound field, resetting wtforms creation counter. diff --git a/flask_admin/form/fields.py b/flask_admin/form/fields.py index 849937f61..e7818a80d 100644 --- a/flask_admin/form/fields.py +++ b/flask_admin/form/fields.py @@ -73,7 +73,7 @@ def __init__( self, label: T_TRANSLATABLE | None = None, validators: list[T_VALIDATOR] | None = None, - formats: t.Iterable | None = None, + formats: t.Iterable[str] | None = None, default_format: str | None = None, widget_format: t.Any = None, **kwargs: t.Any, diff --git a/flask_admin/form/rules.py b/flask_admin/form/rules.py index 6fba4077b..0ca056415 100644 --- a/flask_admin/form/rules.py +++ b/flask_admin/form/rules.py @@ -595,7 +595,7 @@ class RuleSet: def __init__( self, view: t.Union[T_MODEL_VIEW, T_INLINE_BASE_FORM_ADMIN], - rules: t.Sequence[str | tuple | list | BaseRule | FieldSet], + rules: t.Sequence[str | tuple[BaseRule] | list[BaseRule] | BaseRule | FieldSet], ) -> None: """ Constructor. @@ -626,7 +626,7 @@ def convert_string(self, value: str) -> BaseRule: def configure_rules( self, - rules: t.Sequence[str | tuple | list | BaseRule | FieldSet], + rules: t.Sequence[str | tuple[BaseRule] | list[BaseRule] | BaseRule | FieldSet], parent: BaseRule | None = None, ) -> list[BaseRule]: """ diff --git a/flask_admin/form/upload.py b/flask_admin/form/upload.py index 42d1fc88f..e43a10b8b 100644 --- a/flask_admin/form/upload.py +++ b/flask_admin/form/upload.py @@ -257,9 +257,9 @@ def pre_validate(self, form: BaseForm) -> None: def process( self, - formdata: dict, # type:ignore[override] + formdata: dict[str, str], # type:ignore[override] data: UnsetValue = unset_value, - extra_filters: t.Sequence | None = None, + extra_filters: t.Sequence[t.Any] | None = None, ) -> None: if formdata: marker = f"_{self.name}-delete" diff --git a/flask_admin/form/validators.py b/flask_admin/form/validators.py index 33c3b118c..a227b9121 100644 --- a/flask_admin/form/validators.py +++ b/flask_admin/form/validators.py @@ -12,7 +12,7 @@ class FieldListInputRequired: field_flags = {"required": True} - def __call__(self, form: Form, field: FieldList) -> None: + def __call__(self, form: Form, field: FieldList) -> None: # type: ignore[type-arg] if len(field.entries) == 0: field.errors[:] = [] # type:ignore[index] raise StopValidation(gettext("This field requires at least one item.")) diff --git a/flask_admin/helpers.py b/flask_admin/helpers.py index 8601f71c8..4e5dd8a4a 100644 --- a/flask_admin/helpers.py +++ b/flask_admin/helpers.py @@ -87,7 +87,7 @@ def validate_form_on_submit(form: Form) -> bool: return is_form_submitted() and form.validate() -def get_form_data() -> ImmutableMultiDict | None: +def get_form_data() -> ImmutableMultiDict[str, str] | None: """ If current method is PUT or POST, return concatenated `request.form` with `request.files` or `None` otherwise. @@ -102,7 +102,7 @@ def get_form_data() -> ImmutableMultiDict | None: return None -def is_field_error(errors: list | tuple | None) -> bool: +def is_field_error(errors: list[t.Any] | tuple[t.Any, ...] | None) -> bool: """ Check if wtforms field has error without checking its children. diff --git a/flask_admin/model/ajax.py b/flask_admin/model/ajax.py index 25699114b..d943ffb20 100644 --- a/flask_admin/model/ajax.py +++ b/flask_admin/model/ajax.py @@ -1,5 +1,7 @@ import typing as t +from flask_admin._types import T_ORM_MODEL + DEFAULT_PAGE_SIZE = 10 @@ -8,7 +10,7 @@ class AjaxModelLoader: Ajax related model loader. Override this to implement custom loading behavior. """ - def __init__(self, name: str, options: dict) -> None: + def __init__(self, name: str, options: dict[t.Any, t.Any]) -> None: """ Constructor. @@ -18,7 +20,7 @@ def __init__(self, name: str, options: dict) -> None: self.name = name self.options = options - def format(self, model: None | str | bytes) -> tuple[t.Any, str] | None: + def format(self, model: T_ORM_MODEL | None) -> tuple[t.Any, str] | None: """ Return (id, name) tuple from the model. """ @@ -35,7 +37,7 @@ def get_one(self, pk: t.Any) -> t.Any: def get_list( self, query: str, offset: int = 0, limit: int = DEFAULT_PAGE_SIZE - ) -> list: + ) -> list[T_ORM_MODEL]: """ Return models that match `query`. diff --git a/flask_admin/model/base.py b/flask_admin/model/base.py index ad5f55651..74d2c166f 100644 --- a/flask_admin/model/base.py +++ b/flask_admin/model/base.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import csv import inspect import mimetypes @@ -19,11 +21,11 @@ from flask import stream_with_context from jinja2 import pass_context from jinja2.runtime import Context +from markupsafe import Markup from werkzeug import Response from werkzeug.utils import secure_filename from .._types import T_COLUMN -from .._types import T_COLUMN_FORMATTERS from .._types import T_COLUMN_LIST from .._types import T_COLUMN_TYPE_FORMATTERS from .._types import T_FIELD_ARGS_VALIDATORS_FILES @@ -111,7 +113,7 @@ def __init__( self.extra_args = extra_args or dict() - def clone(self, **kwargs: t.Any) -> "ViewArgs": + def clone(self, **kwargs: t.Any) -> ViewArgs: if self.filters: flt = list(self.filters) else: @@ -131,12 +133,12 @@ def clone(self, **kwargs: t.Any) -> "ViewArgs": class FilterGroup: def __init__(self, label: str) -> None: self.label = label - self.filters: list[dict] = [] + self.filters: list[dict[t.Any, t.Any]] = [] - def append(self, filter: dict) -> None: + def append(self, filter: dict[t.Any, t.Any]) -> None: self.filters.append(filter) - def non_lazy(self) -> tuple[str, list[dict]]: + def non_lazy(self) -> tuple[str, list[dict[t.Any, t.Any]]]: filters = [] for item in self.filters: copy = dict(item) @@ -148,10 +150,26 @@ def non_lazy(self) -> tuple[str, list[dict]]: filters.append(copy) return as_unicode(self.label), filters - def __iter__(self) -> t.Iterator[dict]: + def __iter__(self) -> t.Iterator[dict[t.Any, t.Any]]: return iter(self.filters) +T_COLUMN_FORMATTERS = dict[ + str, + t.Callable[ + [ + # First arg inherits from BaseModelView. + # Cannot type hint and allow users to override without type errors + t.Any, + Context | None, + t.Any, + str, + ], + str | Markup, + ], +] + + class BaseModelView(BaseView, ActionsMixin): """ Base model view. @@ -286,7 +304,8 @@ class MyModelView(BaseModelView): """ column_formatters: T_COLUMN_FORMATTERS = cast( - dict, ObsoleteAttr("column_formatters", "list_formatters", dict()) + T_COLUMN_FORMATTERS, + ObsoleteAttr("column_formatters", "list_formatters", dict()), ) """ Dictionary of list view column formatters. @@ -672,7 +691,7 @@ class MyModelView(BaseModelView): """ form_excluded_columns: t.Collection[str] = cast( - t.Collection, + t.Collection[str], ObsoleteAttr("form_excluded_columns", "excluded_form_columns", None), ) """ @@ -979,7 +998,7 @@ def _refresh_filters_cache(self) -> None: self._filters = self.get_filters() if self._filters: - self._filter_groups: t.OrderedDict | None = OrderedDict() + self._filter_groups: OrderedDict[str, FilterGroup] | None = OrderedDict() self._filter_args: dict[str, tuple[int, BaseFilter]] | None = {} for i, flt in enumerate(self._filters): @@ -1348,7 +1367,7 @@ def get_filter_arg(self, index: int, flt: BaseFilter) -> str: else: return str(index) - def _get_filter_groups(self) -> OrderedDict | None: + def _get_filter_groups(self) -> OrderedDict[str, FilterGroup] | None: """ Returns non-lazy version of filter strings """ @@ -1462,7 +1481,7 @@ def get_delete_form(self) -> type[BaseForm]: Override to implement customized behavior. """ - class DeleteForm(self.form_base_class): # type: ignore[name-defined] + class DeleteForm(self.form_base_class): # type: ignore[name-defined, misc] id = HiddenField(validators=[InputRequired()]) url = HiddenField() @@ -1475,7 +1494,7 @@ def get_action_form(self) -> type[BaseForm]: Override to implement customized behavior. """ - class ActionForm(self.form_base_class): # type: ignore[name-defined] + class ActionForm(self.form_base_class): # type: ignore[name-defined, misc] action = HiddenField() url = ( HiddenField() @@ -1725,7 +1744,7 @@ def get_one(self, id: t.Any) -> T_ORM_MODEL | None: # Exception handler def handle_view_exception(self, exc: Exception) -> bool: if isinstance(exc, ValidationError): - flash(as_unicode(exc), "error") # type: ignore[arg-type] + flash(as_unicode(exc), "error") return True if current_app.config.get("FLASK_ADMIN_RAISE_ON_VIEW_EXCEPTION"): @@ -2180,7 +2199,9 @@ def _process_ajax_references(self) -> dict[str, AjaxModelLoader]: return result - def _create_ajax_loader(self, name: str, options: dict) -> AjaxModelLoader: + def _create_ajax_loader( + self, name: str, options: dict[str, t.Any] + ) -> AjaxModelLoader: """ Model backend will override this to implement AJAX model loading. """ @@ -2211,7 +2232,7 @@ def index_view(self) -> str: page_size = self.get_safe_page_size(view_args.page_size) # Get count and data - data: list + data: list[T_ORM_MODEL] count, data = self.get_list( view_args.page, sort_column, @@ -2492,7 +2513,7 @@ def action_view(self) -> T_RESPONSE: """ return self.handle_action() - def _export_data(self) -> tuple[int, list]: + def _export_data(self) -> tuple[int, list[T_ORM_MODEL]]: # Macros in column_formatters are not supported. # Macros will have a function name 'inner' # This causes non-macro functions named 'inner' not work. @@ -2520,7 +2541,7 @@ def _export_data(self) -> tuple[int, list]: else: sort_column = None # Get count and data - data: list + data: list[T_ORM_MODEL] count, data = self.get_list( 0, sort_column, diff --git a/flask_admin/model/fields.py b/flask_admin/model/fields.py index 12f58e045..3d293d34b 100644 --- a/flask_admin/model/fields.py +++ b/flask_admin/model/fields.py @@ -20,7 +20,7 @@ from .widgets import InlineFormWidget -class InlineFieldList(FieldList): +class InlineFieldList(FieldList): # type: ignore[type-arg] widget: RenderTemplateWidget = InlineFieldListWidget() # type: ignore[assignment] def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: @@ -55,7 +55,7 @@ def display_row_controls(self, field: "InlineModelFormField") -> bool: def process( self, - formdata: dict | None, # type: ignore[override] + formdata: dict[str, str] | None, # type: ignore[override] data: UnsetValue | list[t.Any] = unset_value, extra_filters: t.Any = None, ) -> None: @@ -75,7 +75,7 @@ def process( def validate( self, form: BaseForm, - extra_validators: tuple = tuple(), # type: ignore[override] + extra_validators: tuple[t.Any] = tuple(), # type: ignore[override, assignment] ) -> bool: """ Validate this FieldList. @@ -84,7 +84,7 @@ def validate( that FieldList validates all its enclosed fields first before running any of its own validators. """ - self.errors: list = [] + self.errors: list[t.Any] = [] # Run validators on all entries within for subfield in self.entries: @@ -120,7 +120,7 @@ def populate_obj(self, obj: t.Any, name: str) -> None: setattr(obj, name, output) -class InlineFormField(FormField): +class InlineFormField(FormField): # type: ignore[type-arg] """ Inline version of the ``FormField`` widget. """ @@ -128,7 +128,7 @@ class InlineFormField(FormField): widget = InlineFormWidget() # type: ignore[assignment] -class InlineModelFormField(FormField): +class InlineModelFormField(FormField): # type: ignore[type-arg] """ Customized ``FormField``. @@ -201,7 +201,7 @@ def _get_data(self) -> t.Any: def _set_data(self, data: t.Any) -> None: self._data = data - self._formdata: str | None = None + self._formdata: str | set[str] | None = None data = property(_get_data, _set_data) @@ -263,7 +263,7 @@ def _get_data(self) -> t.Any: def _set_data(self, data: t.Any) -> None: self._data = data - self._formdata = None # type: ignore[assignment] + self._formdata = None data = property(_get_data, _set_data) @@ -271,7 +271,7 @@ def process_formdata( self, valuelist: t.Sequence[str], # type: ignore[override] ) -> None: - self._formdata: set = set() # type: ignore[assignment] + self._formdata = set() for field in valuelist: for n in field.split(self.separator): diff --git a/flask_admin/model/filters.py b/flask_admin/model/filters.py index f866984b3..cb9758f15 100644 --- a/flask_admin/model/filters.py +++ b/flask_admin/model/filters.py @@ -328,14 +328,14 @@ def clean(self, value: str) -> list[str]: return [str(uuid.UUID(v.strip())) for v in value.split(",") if v.strip()] -def convert(*args: t.Any) -> t.Callable: +def convert(*args: t.Any) -> t.Callable[..., t.Any]: """ Decorator for field to filter conversion routine. See :mod:`flask_admin.contrib.sqla.filters` for usage example. """ - def _inner(func: t.Callable) -> t.Callable: + def _inner(func: t.Callable[..., t.Any]) -> t.Callable[..., t.Any]: func._converter_for = list(map(lambda x: x.lower(), args)) # type: ignore[attr-defined] return func diff --git a/flask_admin/model/form.py b/flask_admin/model/form.py index 3e57a69ad..bc24e67ce 100644 --- a/flask_admin/model/form.py +++ b/flask_admin/model/form.py @@ -34,7 +34,7 @@ class BaseListForm(BaseForm): def create_editable_list_form( form_base_class: type[Form], form_class: type[Form], - widget: t.Callable | None = None, + widget: t.Callable[..., t.Any] | None = None, ) -> type[BaseListForm]: """ Create a form class with all the fields wrapped in a FieldList. @@ -189,7 +189,9 @@ def __init__( class ModelConverterBase: - def __init__(self, converters: dict | None = None, use_mro: bool = True) -> None: + def __init__( + self, converters: dict[t.Any, t.Any] | None = None, use_mro: bool = True + ) -> None: self.use_mro = use_mro if not converters: @@ -203,7 +205,9 @@ def __init__(self, converters: dict | None = None, use_mro: bool = True) -> None self.converters = converters - def get_converter(self, column: T_SQLALCHEMY_COLUMN) -> t.Callable | None: + def get_converter( + self, column: T_SQLALCHEMY_COLUMN + ) -> t.Callable[..., t.Any] | None: types: list[type] | tuple[type, ...] if self.use_mro: types = inspect.getmro(type(column.type)) diff --git a/flask_admin/model/helpers.py b/flask_admin/model/helpers.py index edf489af1..a4772a150 100644 --- a/flask_admin/model/helpers.py +++ b/flask_admin/model/helpers.py @@ -16,7 +16,7 @@ def prettify_name(name: str) -> str: def get_mdict_item_or_list( - mdict: werkzeug.datastructures.MultiDict, key: str + mdict: werkzeug.datastructures.MultiDict[str, str], key: str ) -> t.Any | None: """ Return the value for the given key of the multidict. @@ -33,7 +33,7 @@ def get_mdict_item_or_list( if hasattr(mdict, "getlist"): v = mdict.getlist(key) if len(v) == 1: - value = v[0] + value: str | None = v[0] # Special case for empty strings, treat them as "no-value" if value == "": diff --git a/flask_admin/tests/geoa/test_basic.py b/flask_admin/tests/geoa/test_basic.py index b2a57a22d..b4fc455de 100644 --- a/flask_admin/tests/geoa/test_basic.py +++ b/flask_admin/tests/geoa/test_basic.py @@ -9,7 +9,7 @@ def create_models(db): - class GeoModel(db.Model): # type: ignore[name-defined] + class GeoModel(db.Model): # type: ignore[name-defined, misc] id = db.Column(db.Integer, primary_key=True) name = db.Column(db.String(20)) point = db.Column(Geometry("POINT")) diff --git a/flask_admin/tests/mongoengine/test_basic.py b/flask_admin/tests/mongoengine/test_basic.py index 0aa6c2c16..fe3ab8763 100644 --- a/flask_admin/tests/mongoengine/test_basic.py +++ b/flask_admin/tests/mongoengine/test_basic.py @@ -8,7 +8,7 @@ from flask_admin.contrib.mongoengine import ModelView -class Test(Document): +class Test(Document): # type: ignore[misc] __test__ = False test1 = StringField() test2 = StringField() diff --git a/flask_admin/tests/peeweemodel/test_basic.py b/flask_admin/tests/peeweemodel/test_basic.py index fde473fb9..c2ca7cb0d 100644 --- a/flask_admin/tests/peeweemodel/test_basic.py +++ b/flask_admin/tests/peeweemodel/test_basic.py @@ -889,9 +889,9 @@ def test_default_sort(app, db, admin): _, data = view.get_list(0, None, None, None, None) - assert data[0].test1 == "a" - assert data[1].test1 == "b" - assert data[2].test1 == "c" + assert data[0].test1 == "a" # type: ignore[union-attr] + assert data[1].test1 == "b" # type: ignore[union-attr] + assert data[2].test1 == "c" # type: ignore[union-attr] # test default sort with multiple columns order = [("test2", False), ("test1", False)] @@ -901,9 +901,9 @@ def test_default_sort(app, db, admin): _, data = view2.get_list(0, None, None, None, None) assert len(data) == 3 - assert data[0].test1 == "b" - assert data[1].test1 == "c" - assert data[2].test1 == "a" + assert data[0].test1 == "b" # type: ignore[union-attr] + assert data[1].test1 == "c" # type: ignore[union-attr] + assert data[2].test1 == "a" # type: ignore[union-attr] def test_extra_fields(app, db, admin): @@ -1003,11 +1003,11 @@ class Model2(BaseModel): items = loader.get_list("fir") assert len(items) == 1 - assert items[0].id == model.id # type: ignore[attr-defined] + assert items[0].id == model.id # type: ignore[attr-defined, union-attr] items = loader.get_list("bar") assert len(items) == 1 - assert items[0].test1 == "foo" + assert items[0].test1 == "foo" # type: ignore[union-attr] # Check form generation form = view.create_form() diff --git a/flask_admin/tests/pymongo/conftest.py b/flask_admin/tests/pymongo/conftest.py index 966541dff..bc2a3db54 100644 --- a/flask_admin/tests/pymongo/conftest.py +++ b/flask_admin/tests/pymongo/conftest.py @@ -1,4 +1,5 @@ import os +import typing as t import pytest from pymongo import MongoClient @@ -8,7 +9,9 @@ @pytest.fixture def db(): - client: MongoClient = MongoClient(host=os.getenv("MONGOCLIENT_HOST", "localhost")) + client: MongoClient[t.Any] = MongoClient( + host=os.getenv("MONGOCLIENT_HOST", "localhost") + ) db = client.tests yield db client.close() diff --git a/flask_admin/tests/sqla/test_basic.py b/flask_admin/tests/sqla/test_basic.py index 1f167c49a..67148bef8 100644 --- a/flask_admin/tests/sqla/test_basic.py +++ b/flask_admin/tests/sqla/test_basic.py @@ -52,7 +52,7 @@ def __init__( def create_models(db): - class Model1(db.Model): # type: ignore[name-defined] + class Model1(db.Model): # type: ignore[name-defined, misc] def __init__( self, test1=None, @@ -115,7 +115,7 @@ def __unicode__(self): def __str__(self): return self.test1 - class Model2(db.Model): # type: ignore[name-defined] + class Model2(db.Model): # type: ignore[name-defined, misc] def __init__( self, string_field=None, @@ -382,7 +382,7 @@ def test_model(app, db, admin): @pytest.mark.xfail(raises=Exception) def test_no_pk(app, db, admin): - class Model(db.Model): # type: ignore[name-defined] + class Model(db.Model): # type: ignore[name-defined, misc] test = db.Column(db.Integer) view = CustomModelView(Model, db.session) @@ -784,7 +784,7 @@ def test_editable_list_special_pks(app, db, admin): """Tests editable list view + a primary key with special characters""" with app.app_context(): - class Model1(db.Model): # type: ignore[name-defined] + class Model1(db.Model): # type: ignore[name-defined, misc] def __init__(self, id=None, val1=None): self.id = id self.val1 = val1 @@ -1784,7 +1784,7 @@ def test_column_filters_sqla_obj(app, db, admin): def test_hybrid_property(app, db, admin): with app.app_context(): - class Model1(db.Model): # type: ignore[name-defined] + class Model1(db.Model): # type: ignore[name-defined, misc] id = db.Column(db.Integer, primary_key=True) name = db.Column(db.String) width = db.Column(db.Integer) @@ -1856,7 +1856,7 @@ def number_of_pixels_str(cls): def test_hybrid_property_nested(app, db, admin): with app.app_context(): - class Model1(db.Model): # type: ignore[name-defined] + class Model1(db.Model): # type: ignore[name-defined, misc] id = db.Column(db.Integer, primary_key=True) firstname = db.Column(db.String) lastname = db.Column(db.String) @@ -1865,7 +1865,7 @@ class Model1(db.Model): # type: ignore[name-defined] def fullname(self): return f"{self.firstname} {self.lastname}" - class Model2(db.Model): # type: ignore[name-defined] + class Model2(db.Model): # type: ignore[name-defined, misc] id = db.Column(db.Integer, primary_key=True) name = db.Column(db.String) owner_id = db.Column( @@ -1967,7 +1967,7 @@ def test_url_args(app, db, admin): def test_non_int_pk(app, db, admin): with app.app_context(): - class Model(db.Model): # type: ignore[name-defined] + class Model(db.Model): # type: ignore[name-defined, misc] id = db.Column(db.String, primary_key=True) test = db.Column(db.String) @@ -1998,14 +1998,14 @@ class Model(db.Model): # type: ignore[name-defined] def test_form_columns(app, db, admin): with app.app_context(): - class Model(db.Model): # type: ignore[name-defined] + class Model(db.Model): # type: ignore[name-defined, misc] id = db.Column(db.String, primary_key=True) int_field = db.Column(db.Integer) datetime_field = db.Column(db.DateTime) text_field = db.Column(db.UnicodeText) excluded_column = db.Column(db.String) - class ChildModel(db.Model): # type: ignore[name-defined] + class ChildModel(db.Model): # type: ignore[name-defined, misc] class EnumChoices(enum.Enum): first = 1 second = 2 @@ -2081,7 +2081,7 @@ def test_complex_form_columns(app, db, admin): def test_form_args(app, db, admin): with app.app_context(): - class Model(db.Model): # type: ignore[name-defined] + class Model(db.Model): # type: ignore[name-defined, misc] id = db.Column(db.String, primary_key=True) test = db.Column(db.String, nullable=False) @@ -2103,7 +2103,7 @@ class Model(db.Model): # type: ignore[name-defined] def test_form_override(app, db, admin): with app.app_context(): - class Model(db.Model): # type: ignore[name-defined] + class Model(db.Model): # type: ignore[name-defined, misc] id = db.Column(db.String, primary_key=True) test = db.Column(db.String) @@ -2126,11 +2126,11 @@ class Model(db.Model): # type: ignore[name-defined] def test_form_onetoone(app, db, admin): with app.app_context(): - class Model1(db.Model): # type: ignore[name-defined] + class Model1(db.Model): # type: ignore[name-defined, misc] id = db.Column(db.Integer, primary_key=True) test = db.Column(db.String) - class Model2(db.Model): # type: ignore[name-defined] + class Model2(db.Model): # type: ignore[name-defined, misc] id = db.Column(db.Integer, primary_key=True) model1_id = db.Column(db.Integer, db.ForeignKey(Model1.id)) @@ -2573,11 +2573,11 @@ def test_ajax_fk(app, db, admin): items = loader.get_list("fir") assert len(items) == 1 - assert items[0].id == model.id + assert items[0].id == model.id # type: ignore[union-attr] items = loader.get_list("bar") assert len(items) == 1 - assert items[0].test1 == "foo" + assert items[0].test1 == "foo" # type: ignore[union-attr] # Check form generation form = view.create_form() @@ -2612,7 +2612,7 @@ def test_ajax_fk(app, db, admin): def test_ajax_fk_multi(app, db, admin): with app.app_context(): - class Model1(db.Model): # type: ignore[name-defined] + class Model1(db.Model): # type: ignore[name-defined, misc] __tablename__ = "model1" id = db.Column(db.Integer, primary_key=True) @@ -2628,7 +2628,7 @@ def __str__(self): db.Column("model2_id", db.Integer, db.ForeignKey("model2.id")), ) - class Model2(db.Model): # type: ignore[name-defined] + class Model2(db.Model): # type: ignore[name-defined, misc] __tablename__ = "model2" id = db.Column(db.Integer, primary_key=True) @@ -2873,19 +2873,19 @@ def test_unlimited_page_size(app, db, admin): def test_advanced_joins(app, db, admin): with app.app_context(): - class Model1(db.Model): # type: ignore[name-defined] + class Model1(db.Model): # type: ignore[name-defined, misc] id = db.Column(db.Integer, primary_key=True) val1 = db.Column(db.String(20)) test = db.Column(db.String(20)) - class Model2(db.Model): # type: ignore[name-defined] + class Model2(db.Model): # type: ignore[name-defined, misc] id = db.Column(db.Integer, primary_key=True) val2 = db.Column(db.String(20)) model1_id = db.Column(db.Integer, db.ForeignKey(Model1.id)) model1 = db.relationship(Model1, backref="model2") - class Model3(db.Model): # type: ignore[name-defined] + class Model3(db.Model): # type: ignore[name-defined, misc] id = db.Column(db.Integer, primary_key=True) val2 = db.Column(db.String(20)) @@ -2948,12 +2948,12 @@ class Model3(db.Model): # type: ignore[name-defined] def test_multipath_joins(app, db, admin): with app.app_context(): - class Model1(db.Model): # type: ignore[name-defined] + class Model1(db.Model): # type: ignore[name-defined, misc] id = db.Column(db.Integer, primary_key=True) val1 = db.Column(db.String(20)) test = db.Column(db.String(20)) - class Model2(db.Model): # type: ignore[name-defined] + class Model2(db.Model): # type: ignore[name-defined, misc] id = db.Column(db.Integer, primary_key=True) val2 = db.Column(db.String(20)) @@ -2982,11 +2982,11 @@ def test_different_bind_joins(request, app): with app.app_context(): - class Model1(db.Model): # type: ignore[name-defined] + class Model1(db.Model): # type: ignore[name-defined, misc] id = db.Column(db.Integer, primary_key=True) val1 = db.Column(db.String(20)) - class Model2(db.Model): # type: ignore[name-defined] + class Model2(db.Model): # type: ignore[name-defined, misc] __bind_key__ = "other" id = db.Column(db.Integer, primary_key=True) val1 = db.Column(db.String(20)) @@ -3070,7 +3070,7 @@ def test_export_csv(app, db, admin): def test_string_null_behavior(app, db, admin): with app.app_context(): - class StringTestModel(db.Model): # type: ignore[name-defined] + class StringTestModel(db.Model): # type: ignore[name-defined, misc] id = db.Column(db.Integer, primary_key=True) test_no = db.Column(db.Integer, nullable=False) string_field = db.Column(db.String) @@ -3155,7 +3155,7 @@ class StringTestModel(db.Model): # type: ignore[name-defined] def test_form_overrides(app, db, admin): with app.app_context(): - class UserModel(db.Model): # type: ignore[name-defined] + class UserModel(db.Model): # type: ignore[name-defined, misc] id = db.Column(db.Integer, primary_key=True) text = db.Column(db.String) diff --git a/flask_admin/tests/sqla/test_inlineform.py b/flask_admin/tests/sqla/test_inlineform.py index b8a2808e5..eab8d386c 100644 --- a/flask_admin/tests/sqla/test_inlineform.py +++ b/flask_admin/tests/sqla/test_inlineform.py @@ -11,7 +11,7 @@ def test_inline_form(app, db, admin): client = app.test_client() # Set up models and database - class User(db.Model): # type: ignore[name-defined] + class User(db.Model): # type: ignore[name-defined, misc] __tablename__ = "users" id = db.Column(db.Integer, primary_key=True) name = db.Column(db.String, unique=True) @@ -19,7 +19,7 @@ class User(db.Model): # type: ignore[name-defined] def __init__(self, name=None): self.name = name - class UserInfo(db.Model): # type: ignore[name-defined] + class UserInfo(db.Model): # type: ignore[name-defined, misc] __tablename__ = "user_info" id = db.Column(db.Integer, primary_key=True) key = db.Column(db.String, nullable=False) @@ -119,7 +119,7 @@ def test_inline_form_required(app, db, admin): client = app.test_client() # Set up models and database - class User(db.Model): # type: ignore[name-defined] + class User(db.Model): # type: ignore[name-defined, misc] __tablename__ = "users" id = db.Column(db.Integer, primary_key=True) name = db.Column(db.String, unique=True) @@ -127,7 +127,7 @@ class User(db.Model): # type: ignore[name-defined] def __init__(self, name=None): self.name = name - class UserEmail(db.Model): # type: ignore[name-defined] + class UserEmail(db.Model): # type: ignore[name-defined, misc] __tablename__ = "user_info" id = db.Column(db.Integer, primary_key=True) email = db.Column(db.String, nullable=False, unique=True) @@ -179,7 +179,7 @@ class UserModelView(ModelView): def test_inline_form_ajax_fk(app, db, admin): with app.app_context(): # Set up models and database - class User(db.Model): # type: ignore[name-defined] + class User(db.Model): # type: ignore[name-defined, misc] __tablename__ = "users" id = db.Column(db.Integer, primary_key=True) name = db.Column(db.String, unique=True) @@ -187,13 +187,13 @@ class User(db.Model): # type: ignore[name-defined] def __init__(self, name=None): self.name = name - class Tag(db.Model): # type: ignore[name-defined] + class Tag(db.Model): # type: ignore[name-defined, misc] __tablename__ = "tags" id = db.Column(db.Integer, primary_key=True) name = db.Column(db.String, unique=True) - class UserInfo(db.Model): # type: ignore[name-defined] + class UserInfo(db.Model): # type: ignore[name-defined, misc] __tablename__ = "user_info" id = db.Column(db.Integer, primary_key=True) key = db.Column(db.String, nullable=False) @@ -233,7 +233,7 @@ class UserModelView(ModelView): def test_inline_form_self(app, db, admin): with app.app_context(): - class Tree(db.Model): # type: ignore[name-defined] + class Tree(db.Model): # type: ignore[name-defined, misc] id = db.Column(db.Integer, primary_key=True) parent_id = db.Column(db.Integer, db.ForeignKey("tree.id")) parent = db.relationship("Tree", remote_side=[id], backref="children") @@ -256,7 +256,7 @@ def test_inline_form_base_class(app, db, admin): with app.app_context(): # Set up models and database - class User(db.Model): # type: ignore[name-defined] + class User(db.Model): # type: ignore[name-defined, misc] __tablename__ = "users" id = db.Column(db.Integer, primary_key=True) name = db.Column(db.String, unique=True) @@ -264,7 +264,7 @@ class User(db.Model): # type: ignore[name-defined] def __init__(self, name=None): self.name = name - class UserEmail(db.Model): # type: ignore[name-defined] + class UserEmail(db.Model): # type: ignore[name-defined, misc] __tablename__ = "user_info" id = db.Column(db.Integer, primary_key=True) email = db.Column(db.String, nullable=False, unique=True) diff --git a/flask_admin/tests/sqla/test_multi_pk.py b/flask_admin/tests/sqla/test_multi_pk.py index 7895b9dd2..9cefc2634 100644 --- a/flask_admin/tests/sqla/test_multi_pk.py +++ b/flask_admin/tests/sqla/test_multi_pk.py @@ -8,7 +8,7 @@ def test_multiple_pk(app, db, admin): # Test multiple primary keys - mix int and string together with app.app_context(): - class Model(db.Model): # type: ignore[name-defined] + class Model(db.Model): # type: ignore[name-defined, misc] id = db.Column(db.Integer, primary_key=True) id2 = db.Column(db.String(20), primary_key=True) test = db.Column(db.String) @@ -45,7 +45,7 @@ def test_joined_inheritance(app, db, admin): # Test multiple primary keys - mix int and string together with app.app_context(): - class Parent(db.Model): # type: ignore[name-defined] + class Parent(db.Model): # type: ignore[name-defined, misc] id = db.Column(db.Integer, primary_key=True) test = db.Column(db.String) @@ -121,7 +121,7 @@ def test_concrete_table_inheritance(app, db, admin): # Test multiple primary keys - mix int and string together with app.app_context(): - class Parent(db.Model): # type: ignore[name-defined] + class Parent(db.Model): # type: ignore[name-defined, misc] id = db.Column(db.Integer, primary_key=True) test = db.Column(db.String) @@ -155,7 +155,7 @@ def test_concrete_multipk_inheritance(app, db, admin): # Test multiple primary keys - mix int and string together with app.app_context(): - class Parent(db.Model): # type: ignore[name-defined] + class Parent(db.Model): # type: ignore[name-defined, misc] id = db.Column(db.Integer, primary_key=True) test = db.Column(db.String) diff --git a/flask_admin/tests/sqla/test_postgres.py b/flask_admin/tests/sqla/test_postgres.py index c0483a338..97c948e4d 100644 --- a/flask_admin/tests/sqla/test_postgres.py +++ b/flask_admin/tests/sqla/test_postgres.py @@ -9,7 +9,7 @@ def test_hstore(app, postgres_db, postgres_admin): with app.app_context(): - class Model(postgres_db.Model): # type: ignore[name-defined] + class Model(postgres_db.Model): # type: ignore[name-defined, misc] id = postgres_db.Column( postgres_db.Integer, primary_key=True, autoincrement=True ) @@ -47,7 +47,7 @@ class Model(postgres_db.Model): # type: ignore[name-defined] def test_json(app, postgres_db, postgres_admin): with app.app_context(): - class JSONModel(postgres_db.Model): # type: ignore[name-defined] + class JSONModel(postgres_db.Model): # type: ignore[name-defined, misc] id = postgres_db.Column( postgres_db.Integer, primary_key=True, autoincrement=True ) @@ -90,7 +90,7 @@ class JSONModel(postgres_db.Model): # type: ignore[name-defined] def test_citext(app, postgres_db, postgres_admin): with app.app_context(): - class CITextModel(postgres_db.Model): # type: ignore[name-defined] + class CITextModel(postgres_db.Model): # type: ignore[name-defined, misc] id = postgres_db.Column( postgres_db.Integer, primary_key=True, autoincrement=True ) @@ -137,7 +137,7 @@ def test_boolean_filters(app, postgres_db, postgres_admin): """ with app.app_context(): - class BoolModel(postgres_db.Model): # type: ignore[name-defined] + class BoolModel(postgres_db.Model): # type: ignore[name-defined, misc] id = postgres_db.Column( postgres_db.Integer, primary_key=True, autoincrement=True ) diff --git a/flask_admin/tests/sqla/test_translation.py b/flask_admin/tests/sqla/test_translation.py index 332ac301c..acf7cb8f3 100644 --- a/flask_admin/tests/sqla/test_translation.py +++ b/flask_admin/tests/sqla/test_translation.py @@ -41,7 +41,7 @@ def test_column_label_translation(request, app): def test_unique_validator_translation_is_dynamic(app, db, admin): with app.app_context(): - class UniqueTable(db.Model): # type: ignore[name-defined] + class UniqueTable(db.Model): # type: ignore[name-defined, misc] id = db.Column(db.Integer, primary_key=True) value = db.Column(db.String, unique=True) diff --git a/pyproject.toml b/pyproject.toml index dd1382fa5..5d828f0a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -193,9 +193,9 @@ strict_concatenate = true # These shouldn't be too much additional work, but may be tricky to # get passing if you use a lot of untyped libraries -disallow_subclassing_any = false -disallow_untyped_decorators = false -disallow_any_generics = false +disallow_subclassing_any = true +disallow_untyped_decorators = true +disallow_any_generics = true # These next few are various gradations of forcing use of type annotations disallow_untyped_calls = false @@ -379,6 +379,7 @@ commands = [ "--ignore-missing-imports", "--disable-error-code", "name-defined", # allow flask_sqlalchemy's db.Model in examples "--no-warn-unused-ignores", # prevent conflict with mypy 1st run on src folder + "--allow-subclassing-any", # flask-slqalchemy db.Model is Any "examples", ] # TODO: reenable after pyright test pass # ["pyright"],