diff --git a/apricot/__init__.py b/apricot/__init__.py index 0a47852..67a4c7d 100644 --- a/apricot/__init__.py +++ b/apricot/__init__.py @@ -3,7 +3,7 @@ from .patches import LDAPString # noqa: F401 __all__ = [ + "ApricotServer", "__version__", "__version_info__", - "ApricotServer", ] diff --git a/apricot/apricot_server.py b/apricot/apricot_server.py index 8776335..577ff88 100644 --- a/apricot/apricot_server.py +++ b/apricot/apricot_server.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import inspect import sys -from typing import Any, cast +from typing import Any, Self, cast from twisted.internet import reactor, task from twisted.internet.endpoints import quoteStringArgument, serverFromString @@ -13,8 +15,10 @@ class ApricotServer: + """The Apricot server running via Twisted.""" + def __init__( - self, + self: Self, backend: OAuthBackend, client_id: str, client_secret: str, @@ -32,6 +36,23 @@ def __init__( tls_private_key: str | None = None, **kwargs: Any, ) -> None: + """Initialise an ApricotServer. + + @param backend: An OAuth backend, + @param client_id: An OAuth client ID + @param client_secret: An OAuth client secret + @param domain: The OAuth domain + @param port: Port to expose LDAP on + @param background_refresh: Whether to refresh the LDAP tree in the background + @param debug: Enable debug output + @param enable_mirrored_groups: Create a mirrored LDAP group-of-groups for each group-of-users + @param redis_host: Host for a Redis cache (if used) + @param redis_port: Port for a Redis cache (if used) + @param refresh_interval: Interval after which the LDAP information is stale + @param tls_port: Port to expose LDAPS on + @param tls_certificate: TLS certificate for LDAPS + @param tls_private_key: TLS private key for LDAPS + """ self.debug = debug # Log to stdout @@ -41,7 +62,7 @@ def __init__( uid_cache: UidCache if redis_host and redis_port: log.msg( - f"Using a Redis user-id cache at host '{redis_host}' on port '{redis_port}'." + f"Using a Redis user-id cache at host '{redis_host}' on port '{redis_port}'.", ) uid_cache = RedisCache(redis_host=redis_host, redis_port=redis_port) else: @@ -54,7 +75,7 @@ def __init__( log.msg(f"Creating an OAuthClient for {backend}.") oauth_backend = OAuthClientMap[backend] oauth_backend_args = inspect.getfullargspec( - oauth_backend.__init__ # type: ignore + oauth_backend.__init__, # type: ignore[misc] ).args oauth_client = oauth_backend( client_id=client_id, @@ -81,7 +102,7 @@ def __init__( if background_refresh: if self.debug: log.msg( - f"Starting background refresh (interval={factory.adaptor.refresh_interval})" + f"Starting background refresh (interval={factory.adaptor.refresh_interval})", ) loop = task.LoopingCall(factory.adaptor.refresh) loop.start(factory.adaptor.refresh_interval) @@ -111,8 +132,8 @@ def __init__( # Load the Twisted reactor self.reactor = cast(IReactorCore, reactor) - def run(self) -> None: - """Start the Twisted reactor""" + def run(self: Self) -> None: + """Start the Twisted reactor.""" if self.debug: log.msg("Starting the Twisted reactor.") self.reactor.run() diff --git a/apricot/cache/local_cache.py b/apricot/cache/local_cache.py index 958217b..65b4cc5 100644 --- a/apricot/cache/local_cache.py +++ b/apricot/cache/local_cache.py @@ -1,18 +1,25 @@ +from __future__ import annotations + +from typing import Self + from .uid_cache import UidCache class LocalCache(UidCache): - def __init__(self) -> None: + """Implementation of UidCache using an in-memory dictionary.""" + + def __init__(self: Self) -> None: + """Initialise a RedisCache.""" self.cache: dict[str, int] = {} - def get(self, identifier: str) -> int | None: + def get(self: Self, identifier: str) -> int | None: return self.cache.get(identifier, None) - def keys(self) -> list[str]: - return [str(k) for k in self.cache.keys()] + def keys(self: Self) -> list[str]: + return [str(k) for k in self.cache] - def set(self, identifier: str, uid_value: int) -> None: + def set(self: Self, identifier: str, uid_value: int) -> None: self.cache[identifier] = uid_value - def values(self, keys: list[str]) -> list[int]: + def values(self: Self, keys: list[str]) -> list[int]: return [v for k, v in self.cache.items() if k in keys] diff --git a/apricot/cache/redis_cache.py b/apricot/cache/redis_cache.py index 24ac506..47ccc97 100644 --- a/apricot/cache/redis_cache.py +++ b/apricot/cache/redis_cache.py @@ -1,4 +1,6 @@ -from typing import cast +from __future__ import annotations + +from typing import Self, cast import redis @@ -6,31 +8,38 @@ class RedisCache(UidCache): - def __init__(self, redis_host: str, redis_port: int) -> None: + """Implementation of UidCache using a Redis backend.""" + + def __init__(self: Self, redis_host: str, redis_port: int) -> None: + """Initialise a RedisCache. + + @param redis_host: Host for the Redis cache + @param redis_port: Port for the Redis cache + """ self.redis_host = redis_host self.redis_port = redis_port - self.cache_: "redis.Redis[str]" | None = None # noqa: UP037 + self.cache_: redis.Redis[str] | None = None @property - def cache(self) -> "redis.Redis[str]": - """ - Lazy-load the cache on request - """ + def cache(self: Self) -> redis.Redis[str]: + """Lazy-load the cache on request.""" if not self.cache_: self.cache_ = redis.Redis( - host=self.redis_host, port=self.redis_port, decode_responses=True + host=self.redis_host, + port=self.redis_port, + decode_responses=True, ) return self.cache_ - def get(self, identifier: str) -> int | None: + def get(self: Self, identifier: str) -> int | None: value = self.cache.get(identifier) return None if value is None else int(value) - def keys(self) -> list[str]: - return [str(k) for k in self.cache.keys()] + def keys(self: Self) -> list[str]: + return [str(k) for k in self.cache.keys()] # noqa: SIM118 - def set(self, identifier: str, uid_value: int) -> None: + def set(self: Self, identifier: str, uid_value: int) -> None: self.cache.set(identifier, uid_value) - def values(self, keys: list[str]) -> list[int]: + def values(self: Self, keys: list[str]) -> list[int]: return [int(cast(str, v)) for v in self.cache.mget(keys)] diff --git a/apricot/cache/uid_cache.py b/apricot/cache/uid_cache.py index eb9c729..82621f3 100644 --- a/apricot/cache/uid_cache.py +++ b/apricot/cache/uid_cache.py @@ -1,57 +1,49 @@ +from __future__ import annotations + from abc import ABC, abstractmethod -from typing import cast +from typing import Self, cast class UidCache(ABC): + """Abstract cache for storing UIDs.""" + @abstractmethod - def get(self, identifier: str) -> int | None: - """ - Get the UID for a given identifier, returning None if it does not exist - """ - pass + def get(self: Self, identifier: str) -> int | None: + """Get the UID for a given identifier, returning None if it does not exist.""" @abstractmethod - def keys(self) -> list[str]: - """ - Get list of cached keys - """ - pass + def keys(self: Self) -> list[str]: + """Get list of cached keys.""" @abstractmethod - def set(self, identifier: str, uid_value: int) -> None: - """ - Set the UID for a given identifier - """ - pass + def set(self: Self, identifier: str, uid_value: int) -> None: + """Set the UID for a given identifier.""" @abstractmethod - def values(self, keys: list[str]) -> list[int]: - """ - Get list of cached values corresponding to requested keys - """ - pass + def values(self: Self, keys: list[str]) -> list[int]: + """Get list of cached values corresponding to requested keys.""" - def get_group_uid(self, identifier: str) -> int: - """ - Get UID for a group, constructing one if necessary + def get_group_uid(self: Self, identifier: str) -> int: + """Get UID for a group, constructing one if necessary. @param identifier: Identifier for group needing a UID """ return self.get_uid(identifier, category="group", min_value=3000) - def get_user_uid(self, identifier: str) -> int: - """ - Get UID for a user, constructing one if necessary + def get_user_uid(self: Self, identifier: str) -> int: + """Get UID for a user, constructing one if necessary. @param identifier: Identifier for user needing a UID """ return self.get_uid(identifier, category="user", min_value=2000) def get_uid( - self, identifier: str, category: str, min_value: int | None = None + self: Self, + identifier: str, + category: str, + min_value: int | None = None, ) -> int: - """ - Get UID, constructing one if necessary. + """Get UID, constructing one if necessary. @param identifier: Identifier for object needing a UID @param category: Category the object belongs to @@ -60,14 +52,13 @@ def get_uid( identifier_ = f"{category}-{identifier}" uid = self.get(identifier_) if not uid: - min_value = min_value if min_value else 0 + min_value = min_value or 0 next_uid = max(self._get_max_uid(category) + 1, min_value) self.set(identifier_, next_uid) return cast(int, self.get(identifier_)) - def _get_max_uid(self, category: str | None) -> int: - """ - Get maximum UID for a given category + def _get_max_uid(self: Self, category: str | None) -> int: + """Get maximum UID for a given category. @param category: Category to check UIDs for """ @@ -78,27 +69,24 @@ def _get_max_uid(self, category: str | None) -> int: values = [*self.values(keys), -999] return max(values) - def overwrite_group_uid(self, identifier: str, uid: int) -> None: - """ - Set UID for a group, overwriting the existing value if there is one + def overwrite_group_uid(self: Self, identifier: str, uid: int) -> None: + """Set UID for a group, overwriting the existing value if there is one. @param identifier: Identifier for group @param uid: Desired UID """ return self.overwrite_uid(identifier, category="group", uid=uid) - def overwrite_user_uid(self, identifier: str, uid: int) -> None: - """ - Get UID for a user, constructing one if necessary + def overwrite_user_uid(self: Self, identifier: str, uid: int) -> None: + """Get UID for a user, constructing one if necessary. @param identifier: Identifier for user @param uid: Desired UID """ return self.overwrite_uid(identifier, category="user", uid=uid) - def overwrite_uid(self, identifier: str, category: str, uid: int) -> None: - """ - Set UID, overwriting the existing one if necessary. + def overwrite_uid(self: Self, identifier: str, category: str, uid: int) -> None: + """Set UID, overwriting the existing one if necessary. @param identifier: Identifier for object @param category: Category the object belongs to diff --git a/apricot/ldap/oauth_ldap_entry.py b/apricot/ldap/oauth_ldap_entry.py index 6845a33..0058bf7 100644 --- a/apricot/ldap/oauth_ldap_entry.py +++ b/apricot/ldap/oauth_ldap_entry.py @@ -1,4 +1,6 @@ -from typing import cast +from __future__ import annotations + +from typing import Self, cast from ldaptor.inmemory import ReadOnlyInMemoryLDAPEntry from ldaptor.protocols.ldap.distinguishedname import ( @@ -16,17 +18,18 @@ class OAuthLDAPEntry(ReadOnlyInMemoryLDAPEntry): + """An LDAP entry that represents a view of an OAuth object.""" + dn: DistinguishedName attributes: LDAPAttributeDict def __init__( - self, + self: Self, dn: DistinguishedName | str, attributes: LDAPAttributeDict, oauth_client: OAuthClient | None = None, ) -> None: - """ - Initialize the object. + """Initialize the object. @param dn: Distinguished Name of the object @param attributes: Attributes of the object. @@ -37,7 +40,7 @@ def __init__( dn = DistinguishedName(stringValue=dn) super().__init__(dn, attributes) - def __str__(self) -> str: + def __str__(self: Self) -> str: output = bytes(self.toWire()).decode("utf-8") for child in self._children.values(): try: @@ -52,18 +55,19 @@ def __str__(self) -> str: return output @property - def oauth_client(self) -> OAuthClient: - if not self.oauth_client_: - if hasattr(self._parent, "oauth_client"): - self.oauth_client_ = self._parent.oauth_client + def oauth_client(self: Self) -> OAuthClient: + if not self.oauth_client_ and hasattr(self._parent, "oauth_client"): + self.oauth_client_ = self._parent.oauth_client if not isinstance(self.oauth_client_, OAuthClient): msg = f"OAuthClient is of incorrect type {type(self.oauth_client_)}" raise TypeError(msg) return self.oauth_client_ def add_child( - self, rdn: RelativeDistinguishedName | str, attributes: LDAPAttributeDict - ) -> "OAuthLDAPEntry": + self: Self, + rdn: RelativeDistinguishedName | str, + attributes: LDAPAttributeDict, + ) -> OAuthLDAPEntry: if isinstance(rdn, str): rdn = RelativeDistinguishedName(stringValue=rdn) try: @@ -73,8 +77,8 @@ def add_child( output = self._children[rdn.getText()] return cast(OAuthLDAPEntry, output) - def bind(self, password: bytes) -> defer.Deferred["OAuthLDAPEntry"]: - def _bind(password: bytes) -> "OAuthLDAPEntry": + def bind(self: Self, password: bytes) -> defer.Deferred[OAuthLDAPEntry]: + def _bind(password: bytes) -> OAuthLDAPEntry: oauth_username = next(iter(self.get("oauth_username", "unknown"))) s_password = password.decode("utf-8") if self.oauth_client.verify(username=oauth_username, password=s_password): @@ -84,5 +88,5 @@ def _bind(password: bytes) -> "OAuthLDAPEntry": return defer.maybeDeferred(_bind, password) - def list_children(self) -> "list[OAuthLDAPEntry]": + def list_children(self: Self) -> list[OAuthLDAPEntry]: return [cast(OAuthLDAPEntry, entry) for entry in self._children.values()] diff --git a/apricot/ldap/oauth_ldap_server_factory.py b/apricot/ldap/oauth_ldap_server_factory.py index 303d9e4..0744c22 100644 --- a/apricot/ldap/oauth_ldap_server_factory.py +++ b/apricot/ldap/oauth_ldap_server_factory.py @@ -1,3 +1,5 @@ +from typing import Self + from twisted.internet.interfaces import IAddress from twisted.internet.protocol import Protocol, ServerFactory @@ -8,17 +10,18 @@ class OAuthLDAPServerFactory(ServerFactory): + """A Twisted ServerFactory that provides an LDAP tree.""" + def __init__( - self, + self: Self, domain: str, oauth_client: OAuthClient, *, background_refresh: bool, enable_mirrored_groups: bool, refresh_interval: int, - ): - """ - Initialise an OAuthLDAPServerFactory + ) -> None: + """Initialise an OAuthLDAPServerFactory. @param background_refresh: Whether to refresh the LDAP tree in the background rather than on access @param domain: The root domain of the LDAP tree @@ -35,12 +38,11 @@ def __init__( refresh_interval=refresh_interval, ) - def __repr__(self) -> str: + def __repr__(self: Self) -> str: return f"{self.__class__.__name__} using adaptor {self.adaptor}" - def buildProtocol(self, addr: IAddress) -> Protocol: # noqa: N802 - """ - Create an LDAPServer instance. + def buildProtocol(self: Self, addr: IAddress) -> Protocol: # noqa: N802 + """Create an LDAPServer instance. This instance will use self.adaptor to produce LDAP entries. diff --git a/apricot/ldap/oauth_ldap_tree.py b/apricot/ldap/oauth_ldap_tree.py index 66e649f..f27cb53 100644 --- a/apricot/ldap/oauth_ldap_tree.py +++ b/apricot/ldap/oauth_ldap_tree.py @@ -1,20 +1,26 @@ +from __future__ import annotations + import time +from typing import TYPE_CHECKING, Self from ldaptor.interfaces import IConnectedLDAPEntry, ILDAPEntry from ldaptor.protocols.ldap.distinguishedname import DistinguishedName -from twisted.internet import defer from twisted.python import log from zope.interface import implementer from apricot.ldap.oauth_ldap_entry import OAuthLDAPEntry from apricot.oauth import OAuthClient, OAuthDataAdaptor +if TYPE_CHECKING: + from twisted.internet import defer + @implementer(IConnectedLDAPEntry) class OAuthLDAPTree: + """An LDAP tree that represents a view of an OAuth directory.""" def __init__( - self, + self: Self, domain: str, oauth_client: OAuthClient, *, @@ -22,8 +28,7 @@ def __init__( enable_mirrored_groups: bool, refresh_interval: int, ) -> None: - """ - Initialise an OAuthLDAPTree + """Initialise an OAuthLDAPTree. @param background_refresh: Whether to refresh the LDAP tree in the background rather than on access @param domain: The root domain of the LDAP tree @@ -41,13 +46,12 @@ def __init__( self.root_: OAuthLDAPEntry | None = None @property - def dn(self) -> DistinguishedName: + def dn(self: Self) -> DistinguishedName: return self.root.dn @property - def root(self) -> OAuthLDAPEntry: - """ - Lazy-load the LDAP tree on request + def root(self: Self) -> OAuthLDAPEntry: + """Lazy-load the LDAP tree on request. @return: An OAuthLDAPEntry for the tree @@ -60,7 +64,8 @@ def root(self) -> OAuthLDAPEntry: raise ValueError(msg) return self.root_ - def refresh(self) -> None: + def refresh(self: Self) -> None: + """Refresh the LDAP tree.""" if ( not self.root_ or (time.monotonic() - self.last_update) > self.refresh_interval @@ -83,16 +88,18 @@ def refresh(self) -> None: # Add OUs for users and groups groups_ou = self.root_.add_child( - "OU=groups", {"ou": ["groups"], "objectClass": ["organizationalUnit"]} + "OU=groups", + {"ou": ["groups"], "objectClass": ["organizationalUnit"]}, ) users_ou = self.root_.add_child( - "OU=users", {"ou": ["users"], "objectClass": ["organizationalUnit"]} + "OU=users", + {"ou": ["users"], "objectClass": ["organizationalUnit"]}, ) # Add groups to the groups OU if self.debug: log.msg( - f"Attempting to add {len(oauth_adaptor.groups)} groups to the LDAP tree." + f"Attempting to add {len(oauth_adaptor.groups)} groups to the LDAP tree.", ) for group_attrs in oauth_adaptor.groups: groups_ou.add_child(f"CN={group_attrs.cn}", group_attrs.to_dict()) @@ -105,7 +112,7 @@ def refresh(self) -> None: # Add users to the users OU if self.debug: log.msg( - f"Attempting to add {len(oauth_adaptor.users)} users to the LDAP tree." + f"Attempting to add {len(oauth_adaptor.users)} users to the LDAP tree.", ) for user_attrs in oauth_adaptor.users: users_ou.add_child(f"CN={user_attrs.cn}", user_attrs.to_dict()) @@ -119,12 +126,11 @@ def refresh(self) -> None: log.msg("Finished building LDAP tree.") self.last_update = time.monotonic() - def __repr__(self) -> str: + def __repr__(self: Self) -> str: return f"{self.__class__.__name__} with backend {self.oauth_client.__class__.__name__}" - def lookup(self, dn: DistinguishedName | str) -> defer.Deferred[ILDAPEntry]: - """ - Lookup the referred to by dn. + def lookup(self: Self, dn: DistinguishedName | str) -> defer.Deferred[ILDAPEntry]: + """Lookup the referred to by dn. @return: A Deferred returning an ILDAPEntry. diff --git a/apricot/ldap/read_only_ldap_server.py b/apricot/ldap/read_only_ldap_server.py index b23ba88..48b965c 100644 --- a/apricot/ldap/read_only_ldap_server.py +++ b/apricot/ldap/read_only_ldap_server.py @@ -1,41 +1,49 @@ -from typing import Callable +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, Self -from ldaptor.interfaces import ILDAPEntry from ldaptor.protocols.ldap.ldaperrors import LDAPProtocolError from ldaptor.protocols.ldap.ldapserver import LDAPServer -from ldaptor.protocols.pureldap import ( - LDAPAddRequest, - LDAPBindRequest, - LDAPCompareRequest, - LDAPDelRequest, - LDAPExtendedRequest, - LDAPModifyDNRequest, - LDAPModifyRequest, - LDAPProtocolRequest, - LDAPSearchRequest, - LDAPSearchResultDone, - LDAPSearchResultEntry, - LDAPUnbindRequest, -) -from twisted.internet import defer from twisted.python import log -from apricot.oauth import LDAPControlTuple +if TYPE_CHECKING: + from ldaptor.interfaces import ILDAPEntry + from ldaptor.protocols.pureldap import ( + LDAPAddRequest, + LDAPBindRequest, + LDAPCompareRequest, + LDAPDelRequest, + LDAPExtendedRequest, + LDAPModifyDNRequest, + LDAPModifyRequest, + LDAPProtocolRequest, + LDAPSearchRequest, + LDAPSearchResultDone, + LDAPSearchResultEntry, + LDAPUnbindRequest, + ) + from twisted.internet import defer + + from apricot.oauth import LDAPControlTuple class ReadOnlyLDAPServer(LDAPServer): - def __init__(self, *, debug: bool = False) -> None: + """A read-only LDAP server.""" + + def __init__(self: Self, *, debug: bool = False) -> None: + """Initialise a ReadOnlyLDAPServer. + + @param debug: Enable debug output + """ super().__init__() self.debug = debug def getRootDSE( # noqa: N802 - self, + self: Self, request: LDAPProtocolRequest, reply: Callable[[LDAPSearchResultEntry], None] | None, ) -> LDAPSearchResultDone: - """ - Handle an LDAP Root DSE request - """ + """Handle an LDAP Root DSE request.""" if self.debug: log.msg("Handling an LDAP Root DSE request.") try: @@ -45,14 +53,12 @@ def getRootDSE( # noqa: N802 raise LDAPProtocolError(msg) from exc def handle_LDAPAddRequest( # noqa: N802 - self, + self: Self, request: LDAPAddRequest, controls: list[LDAPControlTuple] | None, reply: Callable[..., None] | None, ) -> defer.Deferred[ILDAPEntry]: - """ - Refuse to handle an LDAP add request - """ + """Refuse to handle an LDAP add request.""" if self.debug: log.msg("Handling an LDAP add request.") id((request, controls, reply)) # ignore unused arguments @@ -60,14 +66,12 @@ def handle_LDAPAddRequest( # noqa: N802 raise LDAPProtocolError(msg) def handle_LDAPBindRequest( # noqa: N802 - self, + self: Self, request: LDAPBindRequest, controls: list[LDAPControlTuple] | None, reply: Callable[..., None] | None, ) -> defer.Deferred[ILDAPEntry]: - """ - Handle an LDAP bind request - """ + """Handle an LDAP bind request.""" if self.debug: log.msg("Handling an LDAP bind request.") try: @@ -77,14 +81,12 @@ def handle_LDAPBindRequest( # noqa: N802 raise LDAPProtocolError(msg) from exc def handle_LDAPCompareRequest( # noqa: N802 - self, + self: Self, request: LDAPCompareRequest, controls: list[LDAPControlTuple] | None, reply: Callable[..., None] | None, ) -> defer.Deferred[ILDAPEntry]: - """ - Handle an LDAP compare request - """ + """Handle an LDAP compare request.""" if self.debug: log.msg("Handling an LDAP compare request.") try: @@ -94,14 +96,12 @@ def handle_LDAPCompareRequest( # noqa: N802 raise LDAPProtocolError(msg) from exc def handle_LDAPDelRequest( # noqa: N802 - self, + self: Self, request: LDAPDelRequest, controls: list[LDAPControlTuple] | None, reply: Callable[..., None] | None, ) -> defer.Deferred[ILDAPEntry]: - """ - Refuse to handle an LDAP delete request - """ + """Refuse to handle an LDAP delete request.""" if self.debug: log.msg("Handling an LDAP delete request.") id((request, controls, reply)) # ignore unused arguments @@ -109,14 +109,12 @@ def handle_LDAPDelRequest( # noqa: N802 raise LDAPProtocolError(msg) def handle_LDAPExtendedRequest( # noqa: N802 - self, + self: Self, request: LDAPExtendedRequest, controls: list[LDAPControlTuple] | None, reply: Callable[..., None] | None, ) -> defer.Deferred[ILDAPEntry]: - """ - Handle an LDAP extended request - """ + """Handle an LDAP extended request.""" if self.debug: log.msg("Handling an LDAP extended request.") try: @@ -126,14 +124,12 @@ def handle_LDAPExtendedRequest( # noqa: N802 raise LDAPProtocolError(msg) from exc def handle_LDAPModifyDNRequest( # noqa: N802 - self, + self: Self, request: LDAPModifyDNRequest, controls: list[LDAPControlTuple] | None, reply: Callable[..., None] | None, ) -> defer.Deferred[ILDAPEntry]: - """ - Refuse to handle an LDAP modify DN request - """ + """Refuse to handle an LDAP modify DN request.""" if self.debug: log.msg("Handling an LDAP modify DN request.") id((request, controls, reply)) # ignore unused arguments @@ -141,14 +137,12 @@ def handle_LDAPModifyDNRequest( # noqa: N802 raise LDAPProtocolError(msg) def handle_LDAPModifyRequest( # noqa: N802 - self, + self: Self, request: LDAPModifyRequest, controls: list[LDAPControlTuple] | None, reply: Callable[..., None] | None, ) -> defer.Deferred[ILDAPEntry]: - """ - Refuse to handle an LDAP modify request - """ + """Refuse to handle an LDAP modify request.""" if self.debug: log.msg("Handling an LDAP modify request.") id((request, controls, reply)) # ignore unused arguments @@ -156,14 +150,12 @@ def handle_LDAPModifyRequest( # noqa: N802 raise LDAPProtocolError(msg) def handle_LDAPSearchRequest( # noqa: N802 - self, + self: Self, request: LDAPSearchRequest, controls: list[LDAPControlTuple] | None, reply: Callable[[LDAPSearchResultEntry], None] | None, ) -> defer.Deferred[ILDAPEntry]: - """ - Handle an LDAP search request - """ + """Handle an LDAP search request.""" if self.debug: log.msg("Handling an LDAP search request.") try: @@ -173,14 +165,12 @@ def handle_LDAPSearchRequest( # noqa: N802 raise LDAPProtocolError(msg) from exc def handle_LDAPUnbindRequest( # noqa: N802 - self, + self: Self, request: LDAPUnbindRequest, controls: list[LDAPControlTuple] | None, reply: Callable[..., None] | None, ) -> None: - """ - Handle an LDAP unbind request - """ + """Handle an LDAP unbind request.""" if self.debug: log.msg("Handling an LDAP unbind request.") try: diff --git a/apricot/models/__init__.py b/apricot/models/__init__.py index 615f99b..fd84e3c 100644 --- a/apricot/models/__init__.py +++ b/apricot/models/__init__.py @@ -1,9 +1,9 @@ from .ldap_attribute_adaptor import LDAPAttributeAdaptor from .ldap_group_of_names import LDAPGroupOfNames from .ldap_inetorgperson import LDAPInetOrgPerson +from .ldap_object_class import LDAPObjectClass from .ldap_posix_account import LDAPPosixAccount from .ldap_posix_group import LDAPPosixGroup -from .named_ldap_class import NamedLDAPClass from .overlay_memberof import OverlayMemberOf from .overlay_oauthentry import OverlayOAuthEntry @@ -11,9 +11,9 @@ "LDAPAttributeAdaptor", "LDAPGroupOfNames", "LDAPInetOrgPerson", + "LDAPObjectClass", "LDAPPosixAccount", "LDAPPosixGroup", - "NamedLDAPClass", "OverlayMemberOf", "OverlayOAuthEntry", ] diff --git a/apricot/models/ldap_attribute_adaptor.py b/apricot/models/ldap_attribute_adaptor.py index 40f986d..a559f9d 100644 --- a/apricot/models/ldap_attribute_adaptor.py +++ b/apricot/models/ldap_attribute_adaptor.py @@ -1,10 +1,20 @@ -from typing import Any +from __future__ import annotations -from apricot.types import LDAPAttributeDict +from typing import TYPE_CHECKING, Any, Self, Sequence + +if TYPE_CHECKING: + from apricot.models import LDAPObjectClass + from apricot.types import JSONDict, LDAPAttributeDict class LDAPAttributeAdaptor: - def __init__(self, attributes: dict[Any, Any]) -> None: + """A class to convert attributes into LDAP format.""" + + def __init__(self: Self, attributes: dict[Any, Any]) -> None: + """Initialise an LDAPAttributeAdaptor. + + @param attributes: A dictionary of attributes to be converted into str: list[str] + """ self.attributes = { str(k): list(map(str, v)) if isinstance(v, list) else [str(v)] for k, v in attributes.items() @@ -12,10 +22,25 @@ def __init__(self, attributes: dict[Any, Any]) -> None: } @property - def cn(self) -> str: - """Return CN for this set of LDAP attributes""" + def cn(self: Self) -> str: + """Return CN for this set of LDAP attributes.""" return self.attributes["cn"][0] - def to_dict(self) -> LDAPAttributeDict: - """Convert the attributes to an LDAPAttributeDict""" + @classmethod + def from_attributes( + cls: type[Self], + input_dict: JSONDict, + *, + required_classes: Sequence[type[LDAPObjectClass]], + ) -> LDAPAttributeAdaptor: + """Construct an LDAPAttributeAdaptor from attributes and required classes.""" + attributes = {"objectclass": ["top"]} + for ldap_class in required_classes: + model = ldap_class(**input_dict) + attributes.update(model.model_dump()) + attributes["objectclass"] += model.names() + return cls(attributes) + + def to_dict(self: Self) -> LDAPAttributeDict: + """Convert the attributes to an LDAPAttributeDict.""" return self.attributes diff --git a/apricot/models/ldap_group_of_names.py b/apricot/models/ldap_group_of_names.py index b1e077b..b5e35db 100644 --- a/apricot/models/ldap_group_of_names.py +++ b/apricot/models/ldap_group_of_names.py @@ -1,9 +1,10 @@ -from .named_ldap_class import NamedLDAPClass +from __future__ import annotations +from .ldap_object_class import LDAPObjectClass -class LDAPGroupOfNames(NamedLDAPClass): - """ - A group with named members + +class LDAPGroupOfNames(LDAPObjectClass): + """A group with named members. OID: 2.5.6.9 Object class: Structural @@ -11,9 +12,8 @@ class LDAPGroupOfNames(NamedLDAPClass): Schema: rfc4519 """ + _ldap_object_class_name: str = "groupOfNames" + cn: str description: str member: list[str] - - def names(self) -> list[str]: - return ["groupOfNames"] diff --git a/apricot/models/ldap_inetorgperson.py b/apricot/models/ldap_inetorgperson.py index 51e5cb5..e66cc5d 100644 --- a/apricot/models/ldap_inetorgperson.py +++ b/apricot/models/ldap_inetorgperson.py @@ -1,9 +1,10 @@ +from __future__ import annotations + from .ldap_organizational_person import LDAPOrganizationalPerson class LDAPInetOrgPerson(LDAPOrganizationalPerson): - """ - A person belonging to an internet/intranet directory service + """A person belonging to an internet/intranet directory service. OID: 2.16.840.1.113730.3.2.2 Object class: Structural @@ -11,6 +12,8 @@ class LDAPInetOrgPerson(LDAPOrganizationalPerson): Schema: rfc2798 """ + _ldap_object_class_name: str = "inetOrgPerson" + cn: str displayName: str | None = None # noqa: N815 employeeNumber: str | None = None # noqa: N815 @@ -18,6 +21,3 @@ class LDAPInetOrgPerson(LDAPOrganizationalPerson): sn: str mail: str | None = None telephoneNumber: str | None = None # noqa: N815 - - def names(self) -> list[str]: - return [*super().names(), "inetOrgPerson"] diff --git a/apricot/models/ldap_object_class.py b/apricot/models/ldap_object_class.py new file mode 100644 index 0000000..6b9eeed --- /dev/null +++ b/apricot/models/ldap_object_class.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from typing import Self + +from pydantic import BaseModel + + +class LDAPObjectClass(BaseModel): + """An LDAP object-class that may have a name.""" + + @classmethod + def names(cls: type[Self]) -> list[str]: + """List of object-class names for this LDAP object-class. + + We iterate through the parent classes in MRO order, getting an + `_ldap_object_class_name` from each class that has one. We then sort these + before returning a list of names. + """ + return sorted( + [ + cls_._ldap_object_class_name.default + for cls_ in cls.__mro__ + if hasattr(cls_, "_ldap_object_class_name") + ], + ) diff --git a/apricot/models/ldap_organizational_person.py b/apricot/models/ldap_organizational_person.py index 064ba5a..1d05382 100644 --- a/apricot/models/ldap_organizational_person.py +++ b/apricot/models/ldap_organizational_person.py @@ -1,9 +1,10 @@ +from __future__ import annotations + from .ldap_person import LDAPPerson class LDAPOrganizationalPerson(LDAPPerson): - """ - A person belonging to an organisation + """A person belonging to an organisation. OID: 2.5.6.7 Object class: Structural @@ -11,7 +12,6 @@ class LDAPOrganizationalPerson(LDAPPerson): Schema: rfc4519 """ - description: str + _ldap_object_class_name: str = "organizationalPerson" - def names(self) -> list[str]: - return [*super().names(), "organizationalPerson"] + description: str diff --git a/apricot/models/ldap_person.py b/apricot/models/ldap_person.py index 0656897..4affe81 100644 --- a/apricot/models/ldap_person.py +++ b/apricot/models/ldap_person.py @@ -1,9 +1,10 @@ -from .named_ldap_class import NamedLDAPClass +from __future__ import annotations +from .ldap_object_class import LDAPObjectClass -class LDAPPerson(NamedLDAPClass): - """ - A named person + +class LDAPPerson(LDAPObjectClass): + """A named person. OID: 2.5.6.6 Object class: Structural @@ -11,8 +12,7 @@ class LDAPPerson(NamedLDAPClass): Schema: rfc4519 """ + _ldap_object_class_name: str = "person" + cn: str sn: str - - def names(self) -> list[str]: - return ["person"] diff --git a/apricot/models/ldap_posix_account.py b/apricot/models/ldap_posix_account.py index 5bdd738..0062468 100644 --- a/apricot/models/ldap_posix_account.py +++ b/apricot/models/ldap_posix_account.py @@ -1,17 +1,19 @@ +from __future__ import annotations + import re +from typing import Self from pydantic import StringConstraints, validator from typing_extensions import Annotated -from .named_ldap_class import NamedLDAPClass +from .ldap_object_class import LDAPObjectClass ID_MIN = 2000 ID_MAX = 60000 -class LDAPPosixAccount(NamedLDAPClass): - """ - Abstraction of an account with POSIX attributes +class LDAPPosixAccount(LDAPObjectClass): + """Abstraction of an account with POSIX attributes. OID: 1.3.6.1.1.1.2.0 Object class: Auxiliary @@ -19,18 +21,21 @@ class LDAPPosixAccount(NamedLDAPClass): Schema: rfc2307bis """ + _ldap_object_class_name: str = "posixAccount" + cn: str gidNumber: int # noqa: N815 homeDirectory: Annotated[ # noqa: N815 - str, StringConstraints(strip_whitespace=True, to_lower=True) + str, + StringConstraints(strip_whitespace=True, to_lower=True), ] uid: str uidNumber: int # noqa: N815 @validator("gidNumber") # type: ignore[misc] @classmethod - def validate_gid_number(cls, gid_number: int) -> int: - """Avoid conflicts with existing users""" + def validate_gid_number(cls: type[Self], gid_number: int) -> int: + """Avoid conflicts with existing users.""" if not ID_MIN <= gid_number <= ID_MAX: msg = f"Must be in range {ID_MIN} to {ID_MAX}." raise ValueError(msg) @@ -38,17 +43,14 @@ def validate_gid_number(cls, gid_number: int) -> int: @validator("homeDirectory") # type: ignore[misc] @classmethod - def validate_home_directory(cls, home_directory: str) -> str: + def validate_home_directory(cls: type[Self], home_directory: str) -> str: return re.sub(r"\s+", "-", home_directory) @validator("uidNumber") # type: ignore[misc] @classmethod - def validate_uid_number(cls, uid_number: int) -> int: - """Avoid conflicts with existing users""" + def validate_uid_number(cls: type[Self], uid_number: int) -> int: + """Avoid conflicts with existing users.""" if not ID_MIN <= uid_number <= ID_MAX: msg = f"Must be in range {ID_MIN} to {ID_MAX}." raise ValueError(msg) return uid_number - - def names(self) -> list[str]: - return ["posixAccount"] diff --git a/apricot/models/ldap_posix_group.py b/apricot/models/ldap_posix_group.py index e926b49..7d3db08 100644 --- a/apricot/models/ldap_posix_group.py +++ b/apricot/models/ldap_posix_group.py @@ -1,14 +1,17 @@ +from __future__ import annotations + +from typing import Self + from pydantic import validator -from .named_ldap_class import NamedLDAPClass +from .ldap_object_class import LDAPObjectClass ID_MIN = 2000 ID_MAX = 4294967295 -class LDAPPosixGroup(NamedLDAPClass): - """ - Abstraction of a group of accounts +class LDAPPosixGroup(LDAPObjectClass): + """Abstraction of a group of accounts. OID: 1.3.6.1.1.1.2.2 Object class: Auxiliary @@ -16,18 +19,17 @@ class LDAPPosixGroup(NamedLDAPClass): Schema: rfc2307bis """ + _ldap_object_class_name: str = "posixGroup" + description: str gidNumber: int # noqa: N815 memberUid: list[str] # noqa: N815 @validator("gidNumber") # type: ignore[misc] @classmethod - def validate_gid_number(cls, gid_number: int) -> int: - """Avoid conflicts with existing groups""" + def validate_gid_number(cls: type[Self], gid_number: int) -> int: + """Avoid conflicts with existing groups.""" if not ID_MIN <= gid_number <= ID_MAX: msg = f"Must be in range {ID_MIN} to {ID_MAX}." raise ValueError(msg) return gid_number - - def names(self) -> list[str]: - return ["posixGroup"] diff --git a/apricot/models/named_ldap_class.py b/apricot/models/named_ldap_class.py deleted file mode 100644 index 329e771..0000000 --- a/apricot/models/named_ldap_class.py +++ /dev/null @@ -1,7 +0,0 @@ -from pydantic import BaseModel - - -class NamedLDAPClass(BaseModel): - def names(self) -> list[str]: - """List of names for this LDAP object class""" - return [] diff --git a/apricot/models/overlay_memberof.py b/apricot/models/overlay_memberof.py index 3e78f71..8731a06 100644 --- a/apricot/models/overlay_memberof.py +++ b/apricot/models/overlay_memberof.py @@ -1,9 +1,10 @@ -from .named_ldap_class import NamedLDAPClass +from __future__ import annotations +from .ldap_object_class import LDAPObjectClass -class OverlayMemberOf(NamedLDAPClass): - """ - Abstraction for tracking the groups that an individual belongs to + +class OverlayMemberOf(LDAPObjectClass): + """Abstraction for tracking the groups that an individual belongs to. OID: n/a Object class: Auxiliary diff --git a/apricot/models/overlay_oauthentry.py b/apricot/models/overlay_oauthentry.py index 3eabc37..fa24e9d 100644 --- a/apricot/models/overlay_oauthentry.py +++ b/apricot/models/overlay_oauthentry.py @@ -1,9 +1,10 @@ -from .named_ldap_class import NamedLDAPClass +from __future__ import annotations +from .ldap_object_class import LDAPObjectClass -class OverlayOAuthEntry(NamedLDAPClass): - """ - Abstraction for tracking an OAuth entry + +class OverlayOAuthEntry(LDAPObjectClass): + """Abstraction for tracking an OAuth entry. OID: n/a Object class: Auxiliary diff --git a/apricot/oauth/keycloak_client.py b/apricot/oauth/keycloak_client.py index 5b584c7..d55fad4 100644 --- a/apricot/oauth/keycloak_client.py +++ b/apricot/oauth/keycloak_client.py @@ -1,4 +1,7 @@ -from typing import Any, cast +from __future__ import annotations + +import operator +from typing import Any, Self, cast from apricot.types import JSONDict @@ -11,11 +14,16 @@ class KeycloakClient(OAuthClient): max_rows = 100 def __init__( - self, + self: Self, keycloak_base_url: str, keycloak_realm: str, **kwargs: Any, - ): + ) -> None: + """Initialise a KeycloakClient. + + @param keycloak_base_url: Base URL for Keycloak server + @param keycloak_realm: Realm for Keycloak server + """ self.base_url = keycloak_base_url self.realm = keycloak_realm @@ -30,10 +38,11 @@ def __init__( **kwargs, ) - def extract_token(self, json_response: JSONDict) -> str: + @staticmethod + def extract_token(json_response: JSONDict) -> str: return str(json_response["access_token"]) - def groups(self) -> list[JSONDict]: + def groups(self: Self) -> list[JSONDict]: output = [] try: group_data: list[JSONDict] = [] @@ -54,17 +63,18 @@ def groups(self) -> list[JSONDict]: # This ensures that any groups without a `gid` attribute will receive a # UID that does not overlap with existing groups if (group_gid := group_dict["attributes"]["gid"]) and len( - group_dict["attributes"]["gid"] + group_dict["attributes"]["gid"], ) == 1: self.uid_cache.overwrite_group_uid( - group_dict["id"], int(group_gid[0], 10) + group_dict["id"], + int(group_gid[0], 10), ) # Read group attributes for group_dict in group_data: if not group_dict["attributes"]["gid"]: group_dict["attributes"]["gid"] = [ - str(self.uid_cache.get_group_uid(group_dict["id"])) + str(self.uid_cache.get_group_uid(group_dict["id"])), ] self.request( f"{self.base_url}/admin/realms/{self.realm}/groups/{group_dict['id']}", @@ -89,7 +99,7 @@ def groups(self) -> list[JSONDict]: pass return output - def users(self) -> list[JSONDict]: + def users(self: Self) -> list[JSONDict]: output = [] try: user_data: list[JSONDict] = [] @@ -110,19 +120,21 @@ def users(self) -> list[JSONDict]: # This ensures that any groups without a `gid` attribute will receive a # UID that does not overlap with existing groups if (user_uid := user_dict["attributes"]["uid"]) and len( - user_dict["attributes"]["uid"] + user_dict["attributes"]["uid"], ) == 1: self.uid_cache.overwrite_user_uid( - user_dict["id"], int(user_uid[0], 10) + user_dict["id"], + int(user_uid[0], 10), ) # Read user attributes for user_dict in sorted( - user_data, key=lambda user: user["createdTimestamp"] + user_data, + key=operator.itemgetter("createdTimestamp"), ): if not user_dict["attributes"]["uid"]: user_dict["attributes"]["uid"] = [ - str(self.uid_cache.get_user_uid(user_dict["id"])) + str(self.uid_cache.get_user_uid(user_dict["id"])), ] self.request( f"{self.base_url}/admin/realms/{self.realm}/users/{user_dict['id']}", @@ -144,10 +156,10 @@ def users(self) -> list[JSONDict]: attributes["mail"] = user_dict.get("email") attributes["description"] = "" attributes["gidNumber"] = user_dict["attributes"]["uid"][0] - attributes["givenName"] = first_name if first_name else "" + attributes["givenName"] = first_name or "" attributes["homeDirectory"] = f"/home/{username}" if username else None attributes["oauth_id"] = user_dict.get("id", None) - attributes["sn"] = last_name if last_name else "" + attributes["sn"] = last_name or "" attributes["uidNumber"] = user_dict["attributes"]["uid"][0] output.append(attributes) except KeyError: diff --git a/apricot/oauth/microsoft_entra_client.py b/apricot/oauth/microsoft_entra_client.py index eecfa41..681abb0 100644 --- a/apricot/oauth/microsoft_entra_client.py +++ b/apricot/oauth/microsoft_entra_client.py @@ -1,4 +1,7 @@ -from typing import Any, cast +from __future__ import annotations + +import operator +from typing import Any, Self, cast from twisted.python import log @@ -11,10 +14,14 @@ class MicrosoftEntraClient(OAuthClient): """OAuth client for the Microsoft Entra backend.""" def __init__( - self, + self: Self, entra_tenant_id: str, **kwargs: Any, - ): + ) -> None: + """Initialise a MicrosoftEntraClient. + + @param entra_tenant_id: Tenant ID for the Entra ID + """ redirect_uri = "urn:ietf:wg:oauth:2.0:oob" # this is the "no redirect" URL scopes = ["https://graph.microsoft.com/.default"] # this is the default scope token_url = ( @@ -22,13 +29,17 @@ def __init__( ) self.tenant_id = entra_tenant_id super().__init__( - redirect_uri=redirect_uri, scopes=scopes, token_url=token_url, **kwargs + redirect_uri=redirect_uri, + scopes=scopes, + token_url=token_url, + **kwargs, ) - def extract_token(self, json_response: JSONDict) -> str: + @staticmethod + def extract_token(json_response: JSONDict) -> str: return str(json_response["access_token"]) - def groups(self) -> list[JSONDict]: + def groups(self: Self) -> list[JSONDict]: output = [] queries = [ "createdDateTime", @@ -36,11 +47,11 @@ def groups(self) -> list[JSONDict]: "id", ] group_data = self.query( - f"https://graph.microsoft.com/v1.0/groups?$select={','.join(queries)}" + f"https://graph.microsoft.com/v1.0/groups?$select={','.join(queries)}", ) for group_dict in cast( list[JSONDict], - sorted(group_data["value"], key=lambda group: group["createdDateTime"]), + sorted(group_data["value"], key=operator.itemgetter("createdDateTime")), ): try: group_uid = self.uid_cache.get_group_uid(group_dict["id"]) @@ -51,7 +62,7 @@ def groups(self) -> list[JSONDict]: attributes["oauth_id"] = group_dict.get("id", None) # Add membership attributes members = self.query( - f"https://graph.microsoft.com/v1.0/groups/{group_dict['id']}/members" + f"https://graph.microsoft.com/v1.0/groups/{group_dict['id']}/members", ) attributes["memberUid"] = [ str(user["userPrincipalName"]).split("@")[0] @@ -66,7 +77,7 @@ def groups(self) -> list[JSONDict]: log.msg(msg) return output - def users(self) -> list[JSONDict]: + def users(self: Self) -> list[JSONDict]: output = [] try: queries = [ @@ -78,11 +89,11 @@ def users(self) -> list[JSONDict]: "userPrincipalName", ] user_data = self.query( - f"https://graph.microsoft.com/v1.0/users?$select={','.join(queries)}" + f"https://graph.microsoft.com/v1.0/users?$select={','.join(queries)}", ) for user_dict in cast( list[JSONDict], - sorted(user_data["value"], key=lambda user: user["createdDateTime"]), + sorted(user_data["value"], key=operator.itemgetter("createdDateTime")), ): # Get user attributes given_name = user_dict.get("givenName", None) @@ -90,17 +101,17 @@ def users(self) -> list[JSONDict]: uid, domain = str(user_dict.get("userPrincipalName", "@")).split("@") user_uid = self.uid_cache.get_user_uid(user_dict["id"]) attributes: JSONDict = {} - attributes["cn"] = uid if uid else None + attributes["cn"] = uid or None attributes["description"] = user_dict.get("displayName", None) attributes["displayName"] = user_dict.get("displayName", None) attributes["domain"] = domain attributes["gidNumber"] = user_uid - attributes["givenName"] = given_name if given_name else "" + attributes["givenName"] = given_name or "" attributes["homeDirectory"] = f"/home/{uid}" if uid else None attributes["oauth_id"] = user_dict.get("id", None) attributes["oauth_username"] = user_dict.get("userPrincipalName", None) - attributes["sn"] = surname if surname else "" - attributes["uid"] = uid if uid else None + attributes["sn"] = surname or "" + attributes["uid"] = uid or None attributes["uidNumber"] = user_uid output.append(attributes) except KeyError: diff --git a/apricot/oauth/oauth_client.py b/apricot/oauth/oauth_client.py index b47f98c..49797b3 100644 --- a/apricot/oauth/oauth_client.py +++ b/apricot/oauth/oauth_client.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import os from abc import ABC, abstractmethod from http import HTTPStatus -from typing import Any +from typing import TYPE_CHECKING, Any, Self import requests from oauthlib.oauth2 import ( @@ -13,15 +15,16 @@ from requests_oauthlib import OAuth2Session from twisted.python import log -from apricot.cache import UidCache -from apricot.types import JSONDict +if TYPE_CHECKING: + from apricot.cache import UidCache + from apricot.types import JSONDict class OAuthClient(ABC): """Base class for OAuth client talking to a generic backend.""" def __init__( - self, + self: Self, client_id: str, client_secret: str, debug: bool, # noqa: FBT001 @@ -30,6 +33,16 @@ def __init__( token_url: str, uid_cache: UidCache, ) -> None: + """Initialise an OAuthClient. + + @param client_id: OAuth client ID + @param client_secret: OAuth client secret + @param debug: Enable debug output + @param redirect_uri: OAuth redirect URI + @param scopes: OAuth scopes + @param token_url: OAuth token URL + @param uid_cache: Cache for UIDs + """ # Set attributes self.bearer_token_: str | None = None self.client_secret = client_secret @@ -47,8 +60,10 @@ def __init__( log.msg("Initialising application credential client.") self.session_application = OAuth2Session( client=BackendApplicationClient( - client_id=client_id, scope=scopes, redirect_uri=redirect_uri - ) + client_id=client_id, + scope=scopes, + redirect_uri=redirect_uri, + ), ) except Exception as exc: msg = f"Failed to initialise application credential client.\n{exc!s}" @@ -60,18 +75,18 @@ def __init__( log.msg("Initialising delegated credential client.") self.session_interactive = OAuth2Session( client=LegacyApplicationClient( - client_id=client_id, scope=scopes, redirect_uri=redirect_uri - ) + client_id=client_id, + scope=scopes, + redirect_uri=redirect_uri, + ), ) except Exception as exc: msg = f"Failed to initialise delegated credential client.\n{exc!s}" raise RuntimeError(msg) from exc @property - def bearer_token(self) -> str: - """ - Return a bearer token, requesting a new one if necessary - """ + def bearer_token(self: Self) -> str: + """Return a bearer token, requesting a new one if necessary.""" try: if not self.bearer_token_: log.msg("Requesting a new authentication token from the OAuth backend.") @@ -81,38 +96,38 @@ def bearer_token(self) -> str: client_secret=self.client_secret, ) self.bearer_token_ = self.extract_token(json_response) - return self.bearer_token_ except Exception as exc: msg = f"Failed to fetch bearer token from OAuth endpoint.\n{exc!s}" raise RuntimeError(msg) from exc + else: + return self.bearer_token_ + @staticmethod @abstractmethod - def extract_token(self, json_response: JSONDict) -> str: - """ - Extract the bearer token from an OAuth2Session JSON response - """ - pass + def extract_token(json_response: JSONDict) -> str: + """Extract the bearer token from an OAuth2Session JSON response.""" @abstractmethod - def groups(self) -> list[JSONDict]: - """ - Return JSON data about groups from the OAuth backend. + def groups(self: Self) -> list[JSONDict]: + """Return JSON data about groups from the OAuth backend. + This should be a list of JSON dictionaries where 'None' is used to signify missing values. """ - pass @abstractmethod - def users(self) -> list[JSONDict]: - """ - Return JSON data about users from the OAuth backend. + def users(self: Self) -> list[JSONDict]: + """Return JSON data about users from the OAuth backend. + This should be a list of JSON dictionaries where 'None' is used to signify missing values. """ - pass - def query(self, url: str, *, use_client_secret: bool = True) -> dict[str, Any]: - """ - Make a query against the OAuth backend - """ + def query( + self: Self, + url: str, + *, + use_client_secret: bool = True, + ) -> dict[str, Any]: + """Make a query against the OAuth backend.""" kwargs = ( { "client_id": self.session_application._client.client_id, @@ -127,10 +142,13 @@ def query(self, url: str, *, use_client_secret: bool = True) -> dict[str, Any]: **kwargs, ) - def request(self, *args: Any, method: str = "GET", **kwargs: Any) -> dict[str, Any]: - """ - Make a request to the OAuth backend - """ + def request( + self: Self, + *args: Any, + method: str = "GET", + **kwargs: Any, + ) -> dict[str, Any]: + """Make a request to the OAuth backend.""" def query_(*args: Any, **kwargs: Any) -> requests.Response: return self.session_application.request( # type: ignore[no-any-return] @@ -149,12 +167,10 @@ def query_(*args: Any, **kwargs: Any) -> requests.Response: result = query_(*args, **kwargs) if result.status_code == HTTPStatus.NO_CONTENT: return {} - return result.json() # type: ignore + return result.json() # type: ignore[no-any-return] - def verify(self, username: str, password: str) -> bool: - """ - Verify username and password by attempting to authenticate against the OAuth backend. - """ + def verify(self: Self, username: str, password: str) -> bool: + """Verify username and password by attempting to authenticate against the OAuth backend.""" try: self.session_interactive.fetch_token( token_url=self.token_url, @@ -163,7 +179,8 @@ def verify(self, username: str, password: str) -> bool: client_id=self.session_interactive._client.client_id, client_secret=self.client_secret, ) - return True except InvalidGrantError as exc: log.msg(f"Authentication failed.\n{exc}") - return False + return False + else: + return True diff --git a/apricot/oauth/oauth_data_adaptor.py b/apricot/oauth/oauth_data_adaptor.py index 58aaf8d..31aecb5 100644 --- a/apricot/oauth/oauth_data_adaptor.py +++ b/apricot/oauth/oauth_data_adaptor.py @@ -1,4 +1,6 @@ -from collections.abc import Sequence +from __future__ import annotations + +from typing import TYPE_CHECKING, Self from pydantic import ValidationError from twisted.python import log @@ -7,25 +9,31 @@ LDAPAttributeAdaptor, LDAPGroupOfNames, LDAPInetOrgPerson, + LDAPObjectClass, LDAPPosixAccount, LDAPPosixGroup, - NamedLDAPClass, OverlayMemberOf, OverlayOAuthEntry, ) -from apricot.types import JSONDict -from .oauth_client import OAuthClient +if TYPE_CHECKING: + + from apricot.types import JSONDict + + from .oauth_client import OAuthClient class OAuthDataAdaptor: """Adaptor for converting raw user and group data into LDAP format.""" def __init__( - self, domain: str, oauth_client: OAuthClient, *, enable_mirrored_groups: bool - ): - """ - Initialise an OAuthDataAdaptor + self: Self, + domain: str, + oauth_client: OAuthClient, + *, + enable_mirrored_groups: bool, + ) -> None: + """Initialise an OAuthDataAdaptor. @param domain: The root domain of the LDAP tree @param enable_mirrored_groups: Create a mirrored LDAP group-of-groups for each group-of-users @@ -42,57 +50,38 @@ def __init__( self.validated_users = self._validate_users(annotated_users) if self.debug: log.msg( - f"Validated {len(self.validated_groups)} groups and {len(self.validated_users)} users." + f"Validated {len(self.validated_groups)} groups and {len(self.validated_users)} users.", ) @property - def groups(self) -> list[LDAPAttributeAdaptor]: - """ - Return a list of LDAPAttributeAdaptors representing validated group data. - """ + def groups(self: Self) -> list[LDAPAttributeAdaptor]: + """Return a list of LDAPAttributeAdaptors representing validated group data.""" return self.validated_groups @property - def users(self) -> list[LDAPAttributeAdaptor]: - """ - Return a list of LDAPAttributeAdaptors representing validated user data. - """ + def users(self: Self) -> list[LDAPAttributeAdaptor]: + """Return a list of LDAPAttributeAdaptors representing validated user data.""" return self.validated_users - def _dn_from_group_cn(self, group_cn: str) -> str: + def _dn_from_group_cn(self: Self, group_cn: str) -> str: return f"CN={group_cn},OU=groups,{self.root_dn}" - def _dn_from_user_cn(self, user_cn: str) -> str: + def _dn_from_user_cn(self: Self, user_cn: str) -> str: return f"CN={user_cn},OU=users,{self.root_dn}" - def _extract_attributes( - self, - input_dict: JSONDict, - required_classes: Sequence[type[NamedLDAPClass]], - ) -> LDAPAttributeAdaptor: - """Add appropriate LDAP class attributes""" - attributes = {"objectclass": ["top"]} - for ldap_class in required_classes: - model = ldap_class(**input_dict) - attributes.update(model.model_dump()) - attributes["objectclass"] += model.names() - return LDAPAttributeAdaptor(attributes) - def _retrieve_entries( - self, + self: Self, ) -> tuple[ - list[tuple[JSONDict, list[type[NamedLDAPClass]]]], - list[tuple[JSONDict, list[type[NamedLDAPClass]]]], + list[tuple[JSONDict, list[type[LDAPObjectClass]]]], + list[tuple[JSONDict, list[type[LDAPObjectClass]]]], ]: - """ - Obtain lists of users and groups, and construct necessary meta-entries. - """ + """Obtain lists of users and groups, and construct necessary meta-entries.""" # Get the initial set of users and groups oauth_groups = self.oauth_client.groups() oauth_users = self.oauth_client.users() if self.debug: log.msg( - f"Loaded {len(oauth_groups)} groups and {len(oauth_users)} users from OAuth client." + f"Loaded {len(oauth_groups)} groups and {len(oauth_users)} users from OAuth client.", ) # Ensure member is set for groups @@ -142,7 +131,7 @@ def _retrieve_entries( if self.debug: for group_name in child_dict["memberOf"]: log.msg( - f"... user '{child_dict['cn']}' is a member of '{group_name}'" + f"... user '{child_dict['cn']}' is a member of '{group_name}'", ) # Ensure memberOf is set correctly for groups @@ -156,7 +145,7 @@ def _retrieve_entries( if self.debug: for group_name in child_dict["memberOf"]: log.msg( - f"... group '{child_dict['cn']}' is a member of '{group_name}'" + f"... group '{child_dict['cn']}' is a member of '{group_name}'", ) # Annotate group and user dicts with the appropriate LDAP classes @@ -189,52 +178,51 @@ def _retrieve_entries( return (annotated_groups, annotated_users) def _validate_groups( - self, annotated_groups: list[tuple[JSONDict, list[type[NamedLDAPClass]]]] + self: Self, + annotated_groups: list[tuple[JSONDict, list[type[LDAPObjectClass]]]], ) -> list[LDAPAttributeAdaptor]: - """ - Return a list of LDAPAttributeAdaptors representing validated group data. - """ + """Return a list of LDAPAttributeAdaptors representing validated group data.""" if self.debug: log.msg(f"Attempting to validate {len(annotated_groups)} groups.") output = [] for group_dict, required_classes in annotated_groups: try: output.append( - self._extract_attributes( + LDAPAttributeAdaptor.from_attributes( group_dict, required_classes=required_classes, - ) + ), ) except ValidationError as exc: - name = group_dict["cn"] if "cn" in group_dict else "unknown" + name = group_dict.get("cn", "unknown") log.msg(f"Validation failed for group '{name}'.") for error in exc.errors(): log.msg( - f"... '{error['loc'][0]}': {error['msg']} but '{error['input']}' was provided." + f"... '{error['loc'][0]}': {error['msg']} but '{error['input']}' was provided.", ) return output def _validate_users( - self, annotated_users: list[tuple[JSONDict, list[type[NamedLDAPClass]]]] + self: Self, + annotated_users: list[tuple[JSONDict, list[type[LDAPObjectClass]]]], ) -> list[LDAPAttributeAdaptor]: - """ - Return a list of LDAPAttributeAdaptors representing validated user data. - """ + """Return a list of LDAPAttributeAdaptors representing validated user data.""" if self.debug: log.msg(f"Attempting to validate {len(annotated_users)} users.") output = [] for user_dict, required_classes in annotated_users: try: output.append( - self._extract_attributes( - user_dict, required_classes=required_classes - ) + LDAPAttributeAdaptor.from_attributes( + user_dict, + required_classes=required_classes, + ), ) except ValidationError as exc: - name = user_dict["cn"] if "cn" in user_dict else "unknown" + name = user_dict.get("cn", "unknown") log.msg(f"Validation failed for user '{name}'.") for error in exc.errors(): log.msg( - f"... '{error['loc'][0]}': {error['msg']} but '{error['input']}' was provided." + f"... '{error['loc'][0]}': {error['msg']} but '{error['input']}' was provided.", ) return output diff --git a/apricot/patches/ldap_string.py b/apricot/patches/ldap_string.py index 41bfc45..8d0f204 100644 --- a/apricot/patches/ldap_string.py +++ b/apricot/patches/ldap_string.py @@ -1,14 +1,14 @@ -"""Patch LDAPString to avoid TypeError when parsing LDAP filter strings""" +"""Patch LDAPString to avoid TypeError when parsing LDAP filter strings.""" -from typing import Any +from typing import Any, Self from ldaptor.protocols.pureldap import LDAPString old_init = LDAPString.__init__ -def patched_init(self, *args: Any, **kwargs: Any) -> None: # type: ignore[no-untyped-def] - """Patch LDAPString init to store its value as 'str' not 'bytes'""" +def patched_init(self: Self, *args: Any, **kwargs: Any) -> None: # type: ignore[misc] + """Patch LDAPString init to store its value as 'str' not 'bytes'.""" old_init(self, *args, **kwargs) if isinstance(self.value, bytes): self.value = self.value.decode() diff --git a/pyproject.toml b/pyproject.toml index 0ff02da..39e69ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,15 +58,15 @@ dependencies = [ ] [tool.hatch.envs.lint.scripts] -typing = "mypy {args:apricot}" +typing = "mypy {args:apricot} run.py" style = [ - "ruff check {args:apricot}", - "black --check --diff {args:apricot}", + "black --check --diff {args:apricot} run.py", + "ruff check --preview {args:apricot} run.py", ] fmt = [ - "black {args:apricot}", - "ruff check --fix {args:apricot}", + "black {args:apricot} run.py", + "ruff check --preview --fix {args:apricot} run.py", "style", ] all = [ @@ -80,43 +80,80 @@ target-version = ["py310", "py311"] [tool.ruff.lint] select = [ # See https://beta.ruff.rs/docs/rules/ - "A", # flake8-builtins - "ARG", # flake8-unused-arguments - "B", # flake8-bugbear - "C", # complexity, mcabe and flake8-comprehensions - "DTZ", # flake8-datetimez - "E", # pycodestyle errors - "EM", # flake8-errmsg - "F", # pyflakes - "FBT", # flake8-boolean-trap - "I", # isort - "ICN", # flake8-import-conventions - "ISC", # flake8-implicit-str-concat - "N", # pep8-naming - "PLC", # pylint convention - "PLE", # pylint error - "PLR", # pylint refactor - "PLW", # pylint warning - "Q", # flake8-quotes - "RUF", # ruff rules - "S", # flake8-bandit - "T", # flake8-debugger and flake8-print - "TID", # flake8-tidy-imports - "UP", # pyupgrade - "W", # pycodestyle warnings - "YTT", # flake8-2020 + "A", # flake8-builtins + "AIR", # Airflow + "ANN", # flake8-annotations + "ARG", # flake8-unused-arguments + "ASYNC", # flake8-async + "B", # flake8-bugbear + "BLE", # flake8-blind-except + "C", # complexity, mcabe and flake8-comprehensions + "COM", # flake8-commas + "D", # pydocstyle + "DTZ", # flake8-datetimez + "E", # pycodestyle errors + "EM", # flake8-errmsg + "ERA", # eradicate + "EXE", # flake8-executable + "F", # pyflakes + "FA", # flake8-future-annotations + "FBT", # flake8-boolean-trap + "FIX", # flake8-fixme + "FLY", # flynt + "FURB", # refurb + "G", # flake8-logging-format + "I", # isort + "ICN", # flake8-import-conventions + "INP", # flake8-no-pep420 + "INT", # flake8-gettext + "ISC", # flake8-implicit-str-concat + "LOG", # flake8-logging + "N", # pep8-naming + "NPY", # numpy-specific-rules + "PD", # pandas-vet + "PGH", # pygrep-hooks + "PIE", # flake8-pie + "PLC", # pylint convention + "PLE", # pylint error + "PLR", # pylint refactor + "PLW", # pylint warning + "PT", # flake8-pytest-style + "PTH", # flake8-use-pathlib + "PYI", # flake8-pyi + "Q", # flake8-quotes + "RET", # flake8-return + "RSE", # flake8-raise + "RUF", # ruff rules + "S", # flake8-bandit + "SIM", # flake8-simplify + "SLOT", # flake8-slot + "T", # flake8-debugger and flake8-print + "TCH", # flake8-type-checking + "TD", # flake8-todos + "TID", # flake8-tidy-imports + "TRIO", # flake8-trio + "TRY", # tryceratops + "UP", # pyupgrade + "W", # pycodestyle warnings + "YTT", # flake8-2020 ] ignore = [ - "E501", # ignore line length - "S106", # ignore check for possible passwords - "S603", # allow subprocess without shell=True - "S607", # allow subprocess without absolute path - "C901", # ignore complex-structure - "PLR0912", # ignore too-many-branches - "PLR0913", # ignore too-many-arguments - "PLR0915", # ignore too-many-statements + "D100", # missing-docstring-in-module + "D102", # missing-docstring-in-public-method + "D104", # missing-docstring-in-package + "D105", # missing-docstring-in-magic-method + "D203", # one-blank-line-before-class due to conflict with D211 + "D213", # multi-line-summary-second-line due to conflict with D212 + "E501", # line length + "C901", # complex-structure + "PLR0912", # too-many-branches + "PLR0913", # too-many-arguments + "PLR0917", # too-many-positional-arguments ] +[tool.ruff.lint.flake8-annotations] +allow-star-arg-any = true + [tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "parents" diff --git a/run.py b/run.py index 98f4831..cd1b1a5 100644 --- a/run.py +++ b/run.py @@ -12,17 +12,30 @@ ) # Common options needed for all backends parser.add_argument( - "-b", "--backend", type=OAuthBackend, help="Which OAuth backend to use." + "-b", + "--backend", + type=OAuthBackend, + help="Which OAuth backend to use.", ) parser.add_argument( - "-d", "--domain", type=str, help="Which domain users belong to." + "-d", + "--domain", + type=str, + help="Which domain users belong to.", ) parser.add_argument("-i", "--client-id", type=str, help="OAuth client ID.") parser.add_argument( - "-p", "--port", type=int, default=1389, help="Port to run on." + "-p", + "--port", + type=int, + default=1389, + help="Port to run on.", ) parser.add_argument( - "-s", "--client-secret", type=str, help="OAuth client secret." + "-s", + "--client-secret", + type=str, + help="OAuth client secret.", ) parser.add_argument( "--background-refresh", @@ -31,7 +44,9 @@ help="Refresh in the background instead of as needed per request", ) parser.add_argument( - "--debug", action="store_true", help="Enable debug logging." + "--debug", + action="store_true", + help="Enable debug logging.", ) parser.add_argument( "--disable-mirrored-groups", @@ -50,50 +65,67 @@ # Options for Microsoft Entra backend entra_group = parser.add_argument_group("Microsoft Entra") entra_group.add_argument( - "--entra-tenant-id", type=str, help="Microsoft Entra tenant ID." + "--entra-tenant-id", + type=str, + help="Microsoft Entra tenant ID.", ) # Options for Keycloak backend keycloak_group = parser.add_argument_group("Keycloak") keycloak_group.add_argument( - "--keycloak-base-url", type=str, help="Keycloak base URL." + "--keycloak-base-url", + type=str, + help="Keycloak base URL.", ) keycloak_group.add_argument( - "--keycloak-realm", type=str, help="Keycloak Realm." + "--keycloak-realm", + type=str, + help="Keycloak Realm.", ) # Options for Redis cache redis_group = parser.add_argument_group("Redis") redis_group.add_argument( - "--redis-host", type=str, help="Host for Redis server." + "--redis-host", + type=str, + help="Host for Redis server.", ) redis_group.add_argument( - "--redis-port", type=int, help="Port for Redis server." + "--redis-port", + type=int, + help="Port for Redis server.", ) # Options for TLS tls_group = parser.add_argument_group("TLS") tls_group.add_argument( - "--tls-certificate", type=str, help="Location of TLS certificate (pem)." + "--tls-certificate", + type=str, + help="Location of TLS certificate (pem).", ) tls_group.add_argument( - "--tls-port", type=int, default=1636, help="Port to run on with encryption." + "--tls-port", + type=int, + default=1636, + help="Port to run on with encryption.", ) tls_group.add_argument( - "--tls-private-key", type=str, help="Location of TLS private key (pem)." + "--tls-private-key", + type=str, + help="Location of TLS private key (pem).", ) # Parse arguments args = parser.parse_args() # Create the Apricot server reactor = ApricotServer(**vars(args)) - except Exception as exc: + except Exception as exc: # noqa: BLE001 msg = f"Unable to initialise Apricot server.\n{exc}" - print(msg) + print(msg) # noqa: T201 sys.exit(1) # Run the Apricot server try: reactor.run() - except Exception as exc: + except Exception as exc: # noqa: BLE001 msg = f"Apricot server encountered a runtime problem.\n{exc}" - print(msg) + print(msg) # noqa: T201 sys.exit(1)