Skip to content

Commit

Permalink
Use PyRight
Browse files Browse the repository at this point in the history
  • Loading branch information
dnknth committed Nov 23, 2024
1 parent 646e0b9 commit d3a536a
Show file tree
Hide file tree
Showing 11 changed files with 1,784 additions and 4,867 deletions.
7 changes: 4 additions & 3 deletions .vscode/extensions.json
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
{
"recommendations": [
"Vue.volar",
"bradlc.vscode-tailwindcss",
"dbaeumer.vscode-eslint",
"editorconfig.editorconfig",
"esbenp.prettier-vscode",
"bradlc.vscode-tailwindcss",
"jtavin.ldif"
"jtavin.ldif",
"ms-pyright.pyright",
"Vue.volar"
]
}
32 changes: 19 additions & 13 deletions backend/ldap_ui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,15 @@
import logging
import sys
from http import HTTPStatus
from typing import AsyncGenerator, Optional

import ldap
from typing import Optional

from ldap import (
INSUFFICIENT_ACCESS, # pyright: ignore[reportAttributeAccessIssue]
INVALID_CREDENTIALS, # pyright: ignore[reportAttributeAccessIssue]
SCOPE_SUBTREE, # pyright: ignore[reportAttributeAccessIssue]
UNWILLING_TO_PERFORM, # pyright: ignore[reportAttributeAccessIssue]
LDAPError, # pyright: ignore[reportAttributeAccessIssue]
)
from ldap.ldapobject import LDAPObject
from pydantic import ValidationError
from starlette.applications import Starlette
Expand All @@ -28,7 +34,7 @@
from starlette.exceptions import HTTPException
from starlette.middleware import Middleware
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.middleware.gzip import GZipMiddleware
from starlette.requests import HTTPConnection, Request
from starlette.responses import Response
Expand All @@ -55,7 +61,7 @@ async def anonymous_user_search(connection: LDAPObject, username: str) -> Option
connection,
connection.search(
settings.BASE_DN,
ldap.SCOPE_SUBTREE,
SCOPE_SUBTREE,
settings.GET_BIND_DN_FILTER(username),
),
)
Expand All @@ -67,7 +73,7 @@ async def anonymous_user_search(connection: LDAPObject, username: str) -> Option

class LdapConnectionMiddleware(BaseHTTPMiddleware):
async def dispatch(
self, request: Request, call_next: AsyncGenerator[Request, Response]
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
"Add an authenticated LDAP connection to the request"

Expand All @@ -93,23 +99,23 @@ async def dispatch(
request.state.ldap = connection
return await call_next(request)

except ldap.INVALID_CREDENTIALS:
except INVALID_CREDENTIALS:
pass

except ldap.INSUFFICIENT_ACCESS as err:
except INSUFFICIENT_ACCESS as err:
return Response(
ldap_exception_message(err),
status_code=HTTPStatus.FORBIDDEN.value,
)

except ldap.UNWILLING_TO_PERFORM:
except UNWILLING_TO_PERFORM:
LOG.warning("Need BIND_DN or BIND_PATTERN to authenticate")
return Response(
HTTPStatus.FORBIDDEN.phrase,
status_code=HTTPStatus.FORBIDDEN.value,
)

except ldap.LDAPError as err:
except LDAPError as err:
LOG.error(ldap_exception_message(err), exc_info=err)
return Response(
ldap_exception_message(err),
Expand All @@ -126,7 +132,7 @@ async def dispatch(
)


def ldap_exception_message(exc: ldap.LDAPError) -> str:
def ldap_exception_message(exc: LDAPError) -> str:
args = exc.args[0]
if "info" in args:
return args.get("info", "") + ": " + args.get("desc", "")
Expand Down Expand Up @@ -166,7 +172,7 @@ class CacheBustingMiddleware(BaseHTTPMiddleware):
"Forbid caching of API responses"

async def dispatch(
self, request: Request, call_next: AsyncGenerator[Request, Response]
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
response = await call_next(request)
if request.url.path.startswith("/api"):
Expand Down Expand Up @@ -195,7 +201,7 @@ async def http_422(_request: Request, e: ValidationError) -> Response:
# Main ASGI entry
app = Starlette(
debug=settings.DEBUG,
exception_handlers={
exception_handlers={ # pyright: ignore[reportArgumentType]
HTTPException: http_exception,
ValidationError: http_422,
},
Expand Down
81 changes: 46 additions & 35 deletions backend/ldap_ui/ldap_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,15 @@
import base64
import io
from http import HTTPStatus
from typing import Any, Optional, Tuple, Union
from typing import Any, Optional, Tuple, Union, cast

import ldap
import ldif
from ldap import (
INVALID_CREDENTIALS, # pyright: ignore[reportAttributeAccessIssue]
SCOPE_BASE, # pyright: ignore[reportAttributeAccessIssue]
SCOPE_ONELEVEL, # pyright: ignore[reportAttributeAccessIssue]
SCOPE_SUBTREE, # pyright: ignore[reportAttributeAccessIssue]
)
from ldap.ldapobject import LDAPObject
from ldap.modlist import addModlist, modifyModlist
from ldap.schema import SubSchema
Expand Down Expand Up @@ -65,12 +70,12 @@ async def tree(request: Request) -> JSONResponse:
"List directory entries"

basedn = request.path_params["basedn"]
scope = ldap.SCOPE_ONELEVEL
scope = SCOPE_ONELEVEL
if basedn == "base":
scope = ldap.SCOPE_BASE
scope = SCOPE_BASE
basedn = settings.BASE_DN

return JSONResponse(await _tree(request, basedn, scope))
return JSONResponse(await _tree(request, str(basedn), scope))


async def _tree(request: Request, basedn: str, scope: int) -> list[dict[str, Any]]:
Expand Down Expand Up @@ -110,12 +115,12 @@ def _entry(schema: SubSchema, res: Tuple[str, Any]) -> Entry:
ocs = set([oc.decode() for oc in attrs["objectClass"]])
must_attrs, _may_attrs = schema.attribute_types(ocs)
soc = [
oc.names[0]
oc.names[0] # pyright: ignore[reportOptionalMemberAccess]
for oc in map(lambda o: schema.get_obj(ObjectClass, o), ocs)
if oc.kind == OC.Kind.structural.value
if oc.kind == OC.Kind.structural.value # pyright: ignore[reportOptionalMemberAccess]
]
aux = set(
schema.get_obj(ObjectClass, a).names[0]
schema.get_obj(ObjectClass, a).names[0] # pyright: ignore[reportOptionalMemberAccess]
for a in schema.get_applicable_aux_classes(soc[0])
)

Expand All @@ -130,26 +135,28 @@ def _entry(schema: SubSchema, res: Tuple[str, Any]) -> Entry:

# Octet strings are not used consistently.
# Try to decode as text and treat as binary on failure
if not obj.syntax or obj.syntax == OCTET_STRING:
if not obj.syntax or obj.syntax == OCTET_STRING: # pyright: ignore[reportOptionalMemberAccess]
try:
for val in attrs[attr]:
assert val.decode().isprintable()
except: # noqa: E722
binary.add(attr)

else: # Check human-readable flag in schema
syntax = schema.get_obj(LDAPSyntax, obj.syntax)
if syntax.not_human_readable:
syntax = schema.get_obj(LDAPSyntax, obj.syntax) # pyright: ignore[reportOptionalMemberAccess]
if syntax.not_human_readable: # pyright: ignore[reportOptionalMemberAccess]
binary.add(attr)

return Entry(
attrs={
k: [base64.b64encode(val) if k in binary else val for val in values]
k: [
base64.b64encode(val).decode() if k in binary else val for val in values
]
for k, values in attrs.items()
},
meta=Meta(
dn=dn,
required=[schema.get_obj(AttributeType, a).names[0] for a in must_attrs],
required=[schema.get_obj(AttributeType, a).names[0] for a in must_attrs], # pyright: ignore[reportOptionalMemberAccess]
aux=sorted(aux - ocs),
binary=sorted(binary),
autoFilled=[],
Expand All @@ -160,7 +167,7 @@ def _entry(schema: SubSchema, res: Tuple[str, Any]) -> Entry:
Attributes = TypeAdapter(dict[str, list[bytes]])


@api.route("/entry/{dn}", methods=("GET", "POST", "DELETE", "PUT"))
@api.route("/entry/{dn}", methods=["GET", "POST", "DELETE", "PUT"])
async def entry(request: Request) -> Response:
"Edit directory entries"

Expand All @@ -176,7 +183,7 @@ async def entry(request: Request) -> Response:

if request.method == "DELETE":
for entry in reversed(
sorted(await _tree(request, dn, ldap.SCOPE_SUBTREE), key=_dn_order)
sorted(await _tree(request, dn, SCOPE_SUBTREE), key=_dn_order)
):
await empty(connection, connection.delete(entry["dn"]))
return NO_CONTENT
Expand Down Expand Up @@ -206,8 +213,10 @@ async def entry(request: Request) -> Response:
await empty(connection, connection.add(dn, modlist))
return JSONResponse({"changed": ["dn"]}) # Dummy

raise HTTPException(HTTPStatus.METHOD_NOT_ALLOWED)


@api.route("/blob/{attr}/{index:int}/{dn}", methods=("GET", "DELETE", "PUT"))
@api.route("/blob/{attr}/{index:int}/{dn}", methods=["GET", "DELETE", "PUT"])
async def blob(request: Request) -> Response:
"Handle binary attributes"

Expand Down Expand Up @@ -236,16 +245,16 @@ async def blob(request: Request) -> Response:
async with request.form() as form_data:
blob = form_data["blob"]
if type(blob) is UploadFile:
data = await blob.read(blob.size)
data = await blob.read(cast(int, blob.size))
if attr in attrs:
await empty(
connection,
connection.modify(
dn, [(1, attr, None), (0, attr, data + attrs[attr])]
dn, [(1, attr, None), (0, attr, attrs[attr] + [data])]
),
)
else:
await empty(connection, connection.modify(dn, [(0, attr, data)]))
await empty(connection, connection.modify(dn, [(0, attr, [data])]))
return NO_CONTENT

if request.method == "DELETE":
Expand All @@ -259,6 +268,8 @@ async def blob(request: Request) -> Response:
await empty(connection, connection.modify(dn, [(0, attr, data)]))
return NO_CONTENT

raise HTTPException(HTTPStatus.METHOD_NOT_ALLOWED)


@api.route("/ldif/{dn}")
async def ldifDump(request: Request) -> PlainTextResponse:
Expand All @@ -269,9 +280,7 @@ async def ldifDump(request: Request) -> PlainTextResponse:
writer = ldif.LDIFWriter(out)
connection = request.state.ldap

async for dn, attrs in result(
connection, connection.search(dn, ldap.SCOPE_SUBTREE)
):
async for dn, attrs in result(connection, connection.search(dn, SCOPE_SUBTREE)):
writer.unparse(dn, attrs)

file_name = dn.split(",")[0].split("=")[1]
Expand All @@ -282,7 +291,7 @@ async def ldifDump(request: Request) -> PlainTextResponse:


class LDIFReader(ldif.LDIFParser):
def __init__(self, input: str, con: LDAPObject):
def __init__(self, input: bytes, con: LDAPObject):
ldif.LDIFParser.__init__(self, io.BytesIO(input))
self.count = 0
self.con = con
Expand All @@ -292,7 +301,7 @@ def handle(self, dn: str, entry: dict[str, Any]):
self.count += 1


@api.route("/ldif", methods=("POST",))
@api.route("/ldif", methods=["POST"])
async def ldifUpload(
request: Request,
) -> Response:
Expand All @@ -309,8 +318,8 @@ async def ldifUpload(
Rdn = TypeAdapter(str)


@api.route("/rename/{dn}", methods=("POST",))
async def rename(request: Request) -> JSONResponse:
@api.route("/rename/{dn}", methods=["POST"])
async def rename(request: Request) -> Response:
"Rename an entry"

dn = request.path_params["dn"]
Expand All @@ -332,8 +341,8 @@ class CheckPasswordRequest(BaseModel):
PasswordRequest = TypeAdapter(Union[ChangePasswordRequest, CheckPasswordRequest])


@api.route("/entry/password/{dn}", methods=("POST",))
async def passwd(request: Request) -> JSONResponse:
@api.route("/entry/password/{dn}", methods=["POST"])
async def passwd(request: Request) -> Response:
"Update passwords"

dn = request.path_params["dn"]
Expand All @@ -344,10 +353,10 @@ async def passwd(request: Request) -> JSONResponse:
try:
con.simple_bind_s(dn, args.check)
return JSONResponse(True)
except ldap.INVALID_CREDENTIALS:
except INVALID_CREDENTIALS:
return JSONResponse(False)

else:
elif type(args) is ChangePasswordRequest:
connection = request.state.ldap
if args.new1:
await empty(
Expand All @@ -359,7 +368,9 @@ async def passwd(request: Request) -> JSONResponse:

else:
await empty(connection, connection.modify(dn, [(1, "userPassword", None)]))
return JSONResponse(None)
return NO_CONTENT

raise HTTPException(HTTPStatus.UNPROCESSABLE_ENTITY)


def _cn(entry: dict) -> Optional[str]:
Expand All @@ -386,7 +397,7 @@ async def search(request: Request) -> JSONResponse:
res = []
connection = request.state.ldap
async for dn, attrs in result(
connection, connection.search(settings.BASE_DN, ldap.SCOPE_SUBTREE, query)
connection, connection.search(settings.BASE_DN, SCOPE_SUBTREE, query)
):
res.append({"dn": dn, "name": _cn(attrs) or dn})
if len(res) >= settings.SEARCH_MAX:
Expand All @@ -405,7 +416,7 @@ async def subtree(request: Request) -> JSONResponse:

dn = request.path_params["dn"]
result, start = [], len(dn.split(","))
for node in sorted(await _tree(request, dn, ldap.SCOPE_SUBTREE), key=_dn_order):
for node in sorted(await _tree(request, dn, SCOPE_SUBTREE), key=_dn_order):
if node["dn"] == dn:
continue
node["level"] = len(node["dn"].split(",")) - start
Expand All @@ -428,7 +439,7 @@ async def attribute_range(request: Request) -> JSONResponse:
connection,
connection.search(
settings.BASE_DN,
ldap.SCOPE_SUBTREE,
SCOPE_SUBTREE,
f"({attribute}=*)",
attrlist=(attribute,),
),
Expand Down Expand Up @@ -462,7 +473,7 @@ async def json_schema(request: Request) -> JSONResponse:
connection,
connection.search(
settings.SCHEMA_DN,
ldap.SCOPE_BASE,
SCOPE_BASE,
attrlist=WITH_OPERATIONAL_ATTRS,
),
)
Expand Down
Loading

0 comments on commit d3a536a

Please sign in to comment.