3030
3131# stdlib
3232import ast
33- from typing import Iterator , Union
33+ from typing import Iterator , List , Optional , Union
3434
3535# 3rd party
3636import flake8_helper
3737
38- __all__ = ("Plugin" , "Visitor" , "get_decorator_names" )
38+ __all__ = ("Plugin" , "Visitor" , "get_decorator_names" , "check_params" )
3939
4040__author__ = "Dominic Davis-Foster"
4141__copyright__ = "2025 Dominic Davis-Foster"
4646PRM001 = "PRM001 Docstring parameters in wrong order."
4747PRM002 = "PRM002 Missing parameters in docstring."
4848PRM003 = "PRM003 Extra parameters in docstring."
49+ # TODO: class-specific codes?
4950
5051deco_allowed_attr_names = {
5152 ".setter" , # Property setter
@@ -73,7 +74,7 @@ def _get_deco_name(decorator: ast.expr) -> Iterator[str]:
7374 raise NotImplementedError (decorator )
7475
7576
76- def get_decorator_names (function : Union [ast .AsyncFunctionDef , ast .FunctionDef ]) -> Iterator [str ]:
77+ def get_decorator_names (function : Union [ast .AsyncFunctionDef , ast .FunctionDef , ast . ClassDef ]) -> Iterator [str ]:
7778 """
7879 Returns an iterator of the dotted names of decorators for the given function.
7980
@@ -84,81 +85,136 @@ def get_decorator_names(function: Union[ast.AsyncFunctionDef, ast.FunctionDef])
8485 yield from _get_deco_name (decorator )
8586
8687
88+ def check_params (
89+ signature_args : List [str ],
90+ docstring_args : List [str ],
91+ decorators : List [str ],
92+ ) -> Optional [str ]:
93+ """
94+ Check if signature and docstring parameters match, and return the flake8 error code if not.
95+
96+ :param signature_args:
97+ :param docstring_args:
98+ :param decorators: List of dotted names (e.g. ``foo.bar``, for ``@foo.bar()``) of decorators for the function or class.
99+
100+ :returns: Either a flake8 error code and description, or :py:obj:`None` if no errors were detected.
101+ """
102+
103+ if "self" in signature_args :
104+ signature_args .remove ("self" )
105+
106+ if "classmethod" in decorators and signature_args :
107+ signature_args .pop (0 )
108+ for deco in decorators :
109+ if any (deco .endswith (name ) for name in deco_allowed_attr_names ):
110+ signature_args = []
111+ break
112+
113+ if not signature_args and not docstring_args :
114+ # No args either way
115+ return None
116+
117+ if signature_args == docstring_args :
118+ # All match
119+ return None
120+
121+ # Either wrong order, extra in signature, extra in doc
122+ signature_set = set (signature_args )
123+ docstring_set = set (docstring_args )
124+ if signature_set == docstring_set :
125+ # Wrong order
126+ return PRM001
127+ elif signature_set - docstring_set :
128+ # Extras in signature
129+ return PRM002
130+ elif docstring_set - signature_set :
131+ # Extras in docstrings
132+ return PRM003
133+
134+ return None # pragma: no cover
135+
136+
87137class Visitor (flake8_helper .Visitor ):
88138 """
89139 AST node visitor for identifying mismatches between function signatures and docstring params.
90140 """
91141
142+ # TODO: async functions
143+
92144 def visit_FunctionDef (self , node : ast .FunctionDef ) -> None : # noqa: D102
93- docstring = ast .get_docstring (node , clean = False )
94145 if node .name == "__init__" :
95- # TODO: special case; parameters go on class
96146 self .generic_visit (node )
97147 return
98148
149+ docstring = ast .get_docstring (node , clean = False )
150+
99151 if not docstring :
100152 self .generic_visit (node )
101153 return
102154
103- seen_args = []
155+ docstring_args = []
104156 for line in docstring .split ('\n ' ):
105157 line = line .strip ()
106158 if line .startswith (":param" ):
107- seen_args .append (line [6 :].split (':' , 1 )[0 ].strip ())
159+ docstring_args .append (line [6 :].split (':' , 1 )[0 ].strip ())
108160
109161 signature_args = [a .arg for a in node .args .args ]
110- if "self" in signature_args :
111- signature_args .remove ("self" )
112162
113- # decorators = [n.id for n in node.decorator_list if isinstance(n, ast.Name)]
114163 decorators = list (get_decorator_names (node ))
115- if "classmethod" in decorators and signature_args :
116- signature_args .pop (0 )
117- for deco in decorators :
118- if any (deco .endswith (name ) for name in deco_allowed_attr_names ):
119- signature_args = []
120- break
121164
122- if not signature_args and not seen_args :
123- # No args either way
165+ error = check_params (signature_args , docstring_args , decorators )
166+ if not error :
167+ self .generic_visit (node )
168+ return
169+
170+ self .errors .append ((
171+ node .lineno ,
172+ node .col_offset ,
173+ error ,
174+ ))
175+
176+ self .generic_visit (node )
177+
178+ def visit_ClassDef (self , node : ast .ClassDef ) -> None : # noqa: D102
179+ docstring = ast .get_docstring (node , clean = False )
180+
181+ if not docstring :
124182 self .generic_visit (node )
125183 return
126184
127- if signature_args == seen_args :
128- # All match
185+ docstring_args = []
186+ for line in docstring .split ('\n ' ):
187+ line = line .strip ()
188+ if line .startswith (":param" ):
189+ docstring_args .append (line [6 :].split (':' , 1 )[0 ].strip ())
190+
191+ decorators = list (get_decorator_names (node ))
192+
193+ signature_args = []
194+ functions_in_body : List [ast .FunctionDef ] = [n for n in node .body if isinstance (n , ast .FunctionDef )]
195+
196+ for function in functions_in_body :
197+ if function .name == "__init__" :
198+ signature_args = [a .arg for a in function .args .args ]
199+ break
200+ else :
201+ # No __init__; maybe it comes from a base class.
202+ # TODO: check for base classes and still error if non exist
203+ return None
204+
205+ error = check_params (signature_args , docstring_args , decorators )
206+ if not error :
129207 self .generic_visit (node )
130208 return
131209
132- # Either wrong order, extra in signature, extra in doc
133- signature_set = set (signature_args )
134- seen_set = set (seen_args )
135- if signature_set == seen_set :
136- # Wrong order
137- self .errors .append ((
138- node .lineno ,
139- node .col_offset ,
140- PRM001 ,
141- ))
142- elif signature_set - seen_set :
143- # Extras in signature
144- self .errors .append ((
145- node .lineno ,
146- node .col_offset ,
147- PRM002 ,
148- ))
149- elif seen_set - signature_set :
150- # Extras in docstrings
151- self .errors .append ((
152- node .lineno ,
153- node .col_offset ,
154- PRM003 ,
155- ))
210+ self .errors .append ((
211+ node .lineno ,
212+ node .col_offset ,
213+ error ,
214+ ))
156215
157216 self .generic_visit (node )
158217
159- # def visit_ClassDef(self, node: ast.ClassDef):
160- # breakpoint()
161-
162218
163219class Plugin (flake8_helper .Plugin [Visitor ]):
164220 """
0 commit comments