From a5a95056e69c88065ea279acacb43e6bfe0e6047 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Thu, 30 May 2024 14:13:49 +0100 Subject: [PATCH 01/15] :wrench: Drop unused ruff ignores --- pyproject.toml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0ff02da..2708039 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -108,13 +108,9 @@ select = [ ] ignore = [ "E501", # ignore line length - "S106", # ignore check for possible passwords - "S603", # allow subprocess without shell=True - "S607", # allow subprocess without absolute path "C901", # ignore complex-structure "PLR0912", # ignore too-many-branches "PLR0913", # ignore too-many-arguments - "PLR0915", # ignore too-many-statements ] [tool.ruff.lint.flake8-tidy-imports] From 7171621088629cf714beff5ec7ffce404854d51f Mon Sep 17 00:00:00 2001 From: James Robinson Date: Thu, 30 May 2024 14:30:44 +0100 Subject: [PATCH 02/15] :rotating_light: Add new Ruff rules and lint accordingly --- apricot/apricot_server.py | 6 +- apricot/cache/uid_cache.py | 9 ++- apricot/ldap/oauth_ldap_entry.py | 4 +- apricot/ldap/oauth_ldap_tree.py | 10 ++-- apricot/models/ldap_posix_account.py | 3 +- apricot/oauth/keycloak_client.py | 17 +++--- apricot/oauth/microsoft_entra_client.py | 11 ++-- apricot/oauth/oauth_client.py | 15 ++--- apricot/oauth/oauth_data_adaptor.py | 31 ++++++---- pyproject.toml | 76 +++++++++++++++++-------- 10 files changed, 113 insertions(+), 69 deletions(-) diff --git a/apricot/apricot_server.py b/apricot/apricot_server.py index 8776335..152009a 100644 --- a/apricot/apricot_server.py +++ b/apricot/apricot_server.py @@ -41,7 +41,7 @@ def __init__( uid_cache: UidCache if redis_host and redis_port: log.msg( - f"Using a Redis user-id cache at host '{redis_host}' on port '{redis_port}'." + f"Using a Redis user-id cache at host '{redis_host}' on port '{redis_port}'.", ) uid_cache = RedisCache(redis_host=redis_host, redis_port=redis_port) else: @@ -54,7 +54,7 @@ def __init__( log.msg(f"Creating an OAuthClient for {backend}.") oauth_backend = OAuthClientMap[backend] oauth_backend_args = inspect.getfullargspec( - oauth_backend.__init__ # type: ignore + oauth_backend.__init__, # type: ignore ).args oauth_client = oauth_backend( client_id=client_id, @@ -81,7 +81,7 @@ def __init__( if background_refresh: if self.debug: log.msg( - f"Starting background refresh (interval={factory.adaptor.refresh_interval})" + f"Starting background refresh (interval={factory.adaptor.refresh_interval})", ) loop = task.LoopingCall(factory.adaptor.refresh) loop.start(factory.adaptor.refresh_interval) diff --git a/apricot/cache/uid_cache.py b/apricot/cache/uid_cache.py index eb9c729..355465f 100644 --- a/apricot/cache/uid_cache.py +++ b/apricot/cache/uid_cache.py @@ -8,28 +8,24 @@ def get(self, identifier: str) -> int | None: """ Get the UID for a given identifier, returning None if it does not exist """ - pass @abstractmethod def keys(self) -> list[str]: """ Get list of cached keys """ - pass @abstractmethod def set(self, identifier: str, uid_value: int) -> None: """ Set the UID for a given identifier """ - pass @abstractmethod def values(self, keys: list[str]) -> list[int]: """ Get list of cached values corresponding to requested keys """ - pass def get_group_uid(self, identifier: str) -> int: """ @@ -48,7 +44,10 @@ def get_user_uid(self, identifier: str) -> int: return self.get_uid(identifier, category="user", min_value=2000) def get_uid( - self, identifier: str, category: str, min_value: int | None = None + self, + identifier: str, + category: str, + min_value: int | None = None, ) -> int: """ Get UID, constructing one if necessary. diff --git a/apricot/ldap/oauth_ldap_entry.py b/apricot/ldap/oauth_ldap_entry.py index 6845a33..a4b07a1 100644 --- a/apricot/ldap/oauth_ldap_entry.py +++ b/apricot/ldap/oauth_ldap_entry.py @@ -62,7 +62,9 @@ def oauth_client(self) -> OAuthClient: return self.oauth_client_ def add_child( - self, rdn: RelativeDistinguishedName | str, attributes: LDAPAttributeDict + self, + rdn: RelativeDistinguishedName | str, + attributes: LDAPAttributeDict, ) -> "OAuthLDAPEntry": if isinstance(rdn, str): rdn = RelativeDistinguishedName(stringValue=rdn) diff --git a/apricot/ldap/oauth_ldap_tree.py b/apricot/ldap/oauth_ldap_tree.py index 66e649f..24b02ac 100644 --- a/apricot/ldap/oauth_ldap_tree.py +++ b/apricot/ldap/oauth_ldap_tree.py @@ -83,16 +83,18 @@ def refresh(self) -> None: # Add OUs for users and groups groups_ou = self.root_.add_child( - "OU=groups", {"ou": ["groups"], "objectClass": ["organizationalUnit"]} + "OU=groups", + {"ou": ["groups"], "objectClass": ["organizationalUnit"]}, ) users_ou = self.root_.add_child( - "OU=users", {"ou": ["users"], "objectClass": ["organizationalUnit"]} + "OU=users", + {"ou": ["users"], "objectClass": ["organizationalUnit"]}, ) # Add groups to the groups OU if self.debug: log.msg( - f"Attempting to add {len(oauth_adaptor.groups)} groups to the LDAP tree." + f"Attempting to add {len(oauth_adaptor.groups)} groups to the LDAP tree.", ) for group_attrs in oauth_adaptor.groups: groups_ou.add_child(f"CN={group_attrs.cn}", group_attrs.to_dict()) @@ -105,7 +107,7 @@ def refresh(self) -> None: # Add users to the users OU if self.debug: log.msg( - f"Attempting to add {len(oauth_adaptor.users)} users to the LDAP tree." + f"Attempting to add {len(oauth_adaptor.users)} users to the LDAP tree.", ) for user_attrs in oauth_adaptor.users: users_ou.add_child(f"CN={user_attrs.cn}", user_attrs.to_dict()) diff --git a/apricot/models/ldap_posix_account.py b/apricot/models/ldap_posix_account.py index 5bdd738..84344c4 100644 --- a/apricot/models/ldap_posix_account.py +++ b/apricot/models/ldap_posix_account.py @@ -22,7 +22,8 @@ class LDAPPosixAccount(NamedLDAPClass): cn: str gidNumber: int # noqa: N815 homeDirectory: Annotated[ # noqa: N815 - str, StringConstraints(strip_whitespace=True, to_lower=True) + str, + StringConstraints(strip_whitespace=True, to_lower=True), ] uid: str uidNumber: int # noqa: N815 diff --git a/apricot/oauth/keycloak_client.py b/apricot/oauth/keycloak_client.py index 5b584c7..9e96750 100644 --- a/apricot/oauth/keycloak_client.py +++ b/apricot/oauth/keycloak_client.py @@ -54,17 +54,18 @@ def groups(self) -> list[JSONDict]: # This ensures that any groups without a `gid` attribute will receive a # UID that does not overlap with existing groups if (group_gid := group_dict["attributes"]["gid"]) and len( - group_dict["attributes"]["gid"] + group_dict["attributes"]["gid"], ) == 1: self.uid_cache.overwrite_group_uid( - group_dict["id"], int(group_gid[0], 10) + group_dict["id"], + int(group_gid[0], 10), ) # Read group attributes for group_dict in group_data: if not group_dict["attributes"]["gid"]: group_dict["attributes"]["gid"] = [ - str(self.uid_cache.get_group_uid(group_dict["id"])) + str(self.uid_cache.get_group_uid(group_dict["id"])), ] self.request( f"{self.base_url}/admin/realms/{self.realm}/groups/{group_dict['id']}", @@ -110,19 +111,21 @@ def users(self) -> list[JSONDict]: # This ensures that any groups without a `gid` attribute will receive a # UID that does not overlap with existing groups if (user_uid := user_dict["attributes"]["uid"]) and len( - user_dict["attributes"]["uid"] + user_dict["attributes"]["uid"], ) == 1: self.uid_cache.overwrite_user_uid( - user_dict["id"], int(user_uid[0], 10) + user_dict["id"], + int(user_uid[0], 10), ) # Read user attributes for user_dict in sorted( - user_data, key=lambda user: user["createdTimestamp"] + user_data, + key=lambda user: user["createdTimestamp"], ): if not user_dict["attributes"]["uid"]: user_dict["attributes"]["uid"] = [ - str(self.uid_cache.get_user_uid(user_dict["id"])) + str(self.uid_cache.get_user_uid(user_dict["id"])), ] self.request( f"{self.base_url}/admin/realms/{self.realm}/users/{user_dict['id']}", diff --git a/apricot/oauth/microsoft_entra_client.py b/apricot/oauth/microsoft_entra_client.py index eecfa41..847925e 100644 --- a/apricot/oauth/microsoft_entra_client.py +++ b/apricot/oauth/microsoft_entra_client.py @@ -22,7 +22,10 @@ def __init__( ) self.tenant_id = entra_tenant_id super().__init__( - redirect_uri=redirect_uri, scopes=scopes, token_url=token_url, **kwargs + redirect_uri=redirect_uri, + scopes=scopes, + token_url=token_url, + **kwargs, ) def extract_token(self, json_response: JSONDict) -> str: @@ -36,7 +39,7 @@ def groups(self) -> list[JSONDict]: "id", ] group_data = self.query( - f"https://graph.microsoft.com/v1.0/groups?$select={','.join(queries)}" + f"https://graph.microsoft.com/v1.0/groups?$select={','.join(queries)}", ) for group_dict in cast( list[JSONDict], @@ -51,7 +54,7 @@ def groups(self) -> list[JSONDict]: attributes["oauth_id"] = group_dict.get("id", None) # Add membership attributes members = self.query( - f"https://graph.microsoft.com/v1.0/groups/{group_dict['id']}/members" + f"https://graph.microsoft.com/v1.0/groups/{group_dict['id']}/members", ) attributes["memberUid"] = [ str(user["userPrincipalName"]).split("@")[0] @@ -78,7 +81,7 @@ def users(self) -> list[JSONDict]: "userPrincipalName", ] user_data = self.query( - f"https://graph.microsoft.com/v1.0/users?$select={','.join(queries)}" + f"https://graph.microsoft.com/v1.0/users?$select={','.join(queries)}", ) for user_dict in cast( list[JSONDict], diff --git a/apricot/oauth/oauth_client.py b/apricot/oauth/oauth_client.py index b47f98c..9a712c6 100644 --- a/apricot/oauth/oauth_client.py +++ b/apricot/oauth/oauth_client.py @@ -47,8 +47,10 @@ def __init__( log.msg("Initialising application credential client.") self.session_application = OAuth2Session( client=BackendApplicationClient( - client_id=client_id, scope=scopes, redirect_uri=redirect_uri - ) + client_id=client_id, + scope=scopes, + redirect_uri=redirect_uri, + ), ) except Exception as exc: msg = f"Failed to initialise application credential client.\n{exc!s}" @@ -60,8 +62,10 @@ def __init__( log.msg("Initialising delegated credential client.") self.session_interactive = OAuth2Session( client=LegacyApplicationClient( - client_id=client_id, scope=scopes, redirect_uri=redirect_uri - ) + client_id=client_id, + scope=scopes, + redirect_uri=redirect_uri, + ), ) except Exception as exc: msg = f"Failed to initialise delegated credential client.\n{exc!s}" @@ -91,7 +95,6 @@ def extract_token(self, json_response: JSONDict) -> str: """ Extract the bearer token from an OAuth2Session JSON response """ - pass @abstractmethod def groups(self) -> list[JSONDict]: @@ -99,7 +102,6 @@ def groups(self) -> list[JSONDict]: Return JSON data about groups from the OAuth backend. This should be a list of JSON dictionaries where 'None' is used to signify missing values. """ - pass @abstractmethod def users(self) -> list[JSONDict]: @@ -107,7 +109,6 @@ def users(self) -> list[JSONDict]: Return JSON data about users from the OAuth backend. This should be a list of JSON dictionaries where 'None' is used to signify missing values. """ - pass def query(self, url: str, *, use_client_secret: bool = True) -> dict[str, Any]: """ diff --git a/apricot/oauth/oauth_data_adaptor.py b/apricot/oauth/oauth_data_adaptor.py index 58aaf8d..f0dec89 100644 --- a/apricot/oauth/oauth_data_adaptor.py +++ b/apricot/oauth/oauth_data_adaptor.py @@ -22,7 +22,11 @@ class OAuthDataAdaptor: """Adaptor for converting raw user and group data into LDAP format.""" def __init__( - self, domain: str, oauth_client: OAuthClient, *, enable_mirrored_groups: bool + self, + domain: str, + oauth_client: OAuthClient, + *, + enable_mirrored_groups: bool, ): """ Initialise an OAuthDataAdaptor @@ -42,7 +46,7 @@ def __init__( self.validated_users = self._validate_users(annotated_users) if self.debug: log.msg( - f"Validated {len(self.validated_groups)} groups and {len(self.validated_users)} users." + f"Validated {len(self.validated_groups)} groups and {len(self.validated_users)} users.", ) @property @@ -92,7 +96,7 @@ def _retrieve_entries( oauth_users = self.oauth_client.users() if self.debug: log.msg( - f"Loaded {len(oauth_groups)} groups and {len(oauth_users)} users from OAuth client." + f"Loaded {len(oauth_groups)} groups and {len(oauth_users)} users from OAuth client.", ) # Ensure member is set for groups @@ -142,7 +146,7 @@ def _retrieve_entries( if self.debug: for group_name in child_dict["memberOf"]: log.msg( - f"... user '{child_dict['cn']}' is a member of '{group_name}'" + f"... user '{child_dict['cn']}' is a member of '{group_name}'", ) # Ensure memberOf is set correctly for groups @@ -156,7 +160,7 @@ def _retrieve_entries( if self.debug: for group_name in child_dict["memberOf"]: log.msg( - f"... group '{child_dict['cn']}' is a member of '{group_name}'" + f"... group '{child_dict['cn']}' is a member of '{group_name}'", ) # Annotate group and user dicts with the appropriate LDAP classes @@ -189,7 +193,8 @@ def _retrieve_entries( return (annotated_groups, annotated_users) def _validate_groups( - self, annotated_groups: list[tuple[JSONDict, list[type[NamedLDAPClass]]]] + self, + annotated_groups: list[tuple[JSONDict, list[type[NamedLDAPClass]]]], ) -> list[LDAPAttributeAdaptor]: """ Return a list of LDAPAttributeAdaptors representing validated group data. @@ -203,19 +208,20 @@ def _validate_groups( self._extract_attributes( group_dict, required_classes=required_classes, - ) + ), ) except ValidationError as exc: name = group_dict["cn"] if "cn" in group_dict else "unknown" log.msg(f"Validation failed for group '{name}'.") for error in exc.errors(): log.msg( - f"... '{error['loc'][0]}': {error['msg']} but '{error['input']}' was provided." + f"... '{error['loc'][0]}': {error['msg']} but '{error['input']}' was provided.", ) return output def _validate_users( - self, annotated_users: list[tuple[JSONDict, list[type[NamedLDAPClass]]]] + self, + annotated_users: list[tuple[JSONDict, list[type[NamedLDAPClass]]]], ) -> list[LDAPAttributeAdaptor]: """ Return a list of LDAPAttributeAdaptors representing validated user data. @@ -227,14 +233,15 @@ def _validate_users( try: output.append( self._extract_attributes( - user_dict, required_classes=required_classes - ) + user_dict, + required_classes=required_classes, + ), ) except ValidationError as exc: name = user_dict["cn"] if "cn" in user_dict else "unknown" log.msg(f"Validation failed for user '{name}'.") for error in exc.errors(): log.msg( - f"... '{error['loc'][0]}': {error['msg']} but '{error['input']}' was provided." + f"... '{error['loc'][0]}': {error['msg']} but '{error['input']}' was provided.", ) return output diff --git a/pyproject.toml b/pyproject.toml index 2708039..cd03575 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,31 +80,57 @@ target-version = ["py310", "py311"] [tool.ruff.lint] select = [ # See https://beta.ruff.rs/docs/rules/ - "A", # flake8-builtins - "ARG", # flake8-unused-arguments - "B", # flake8-bugbear - "C", # complexity, mcabe and flake8-comprehensions - "DTZ", # flake8-datetimez - "E", # pycodestyle errors - "EM", # flake8-errmsg - "F", # pyflakes - "FBT", # flake8-boolean-trap - "I", # isort - "ICN", # flake8-import-conventions - "ISC", # flake8-implicit-str-concat - "N", # pep8-naming - "PLC", # pylint convention - "PLE", # pylint error - "PLR", # pylint refactor - "PLW", # pylint warning - "Q", # flake8-quotes - "RUF", # ruff rules - "S", # flake8-bandit - "T", # flake8-debugger and flake8-print - "TID", # flake8-tidy-imports - "UP", # pyupgrade - "W", # pycodestyle warnings - "YTT", # flake8-2020 + "A", # flake8-builtins + "AIR", # Airflow + "ARG", # flake8-unused-arguments + "ASYNC", # flake8-async + "B", # flake8-bugbear + "BLE", # flake8-blind-except + "C", # complexity, mcabe and flake8-comprehensions + "COM", # flake8-commas + "CPY", # flake8-copyright + "DTZ", # flake8-datetimez + "E", # pycodestyle errors + "EM", # flake8-errmsg + "ERA", # eradicate + "EXE", # flake8-executable + "F", # pyflakes + "FBT", # flake8-boolean-trap + "FIX", # flake8-fixme + "FLY", # flynt + "FURB", # refurb + "G", # flake8-logging-format + "I", # isort + "ICN", # flake8-import-conventions + "INP", # flake8-no-pep420 + "INT", # flake8-gettext + "ISC", # flake8-implicit-str-concat + "LOG", # flake8-logging + "N", # pep8-naming + "NPY", # numpy-specific-rules + "PD", # pandas-vet + "PIE", # flake8-pie + "PLC", # pylint convention + "PLE", # pylint error + "PLR", # pylint refactor + "PLW", # pylint warning + "PT", # flake8-pytest-style + "PTH", # flake8-use-pathlib + "PYI", # flake8-pyi + "Q", # flake8-quotes + "RET", # flake8-return + "RSE", # flake8-raise + "RUF", # ruff rules + "S", # flake8-bandit + "SLOT", # flake8-slot + "T", # flake8-debugger and flake8-print + "TCH", # flake8-type-checking + "TD", # flake8-todos + "TID", # flake8-tidy-imports + "TRIO", # flake8-trio + "UP", # pyupgrade + "W", # pycodestyle warnings + "YTT", # flake8-2020 ] ignore = [ "E501", # ignore line length From 8a92003a41c1395c21f14c14587510c39b7567c0 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Thu, 30 May 2024 15:44:18 +0100 Subject: [PATCH 03/15] :rotating_light: Add flake8-annotations --- apricot/apricot_server.py | 6 ++--- apricot/cache/local_cache.py | 12 +++++---- apricot/cache/redis_cache.py | 20 +++++++------- apricot/cache/uid_cache.py | 24 ++++++++--------- apricot/ldap/oauth_ldap_entry.py | 14 +++++----- apricot/ldap/oauth_ldap_server_factory.py | 10 ++++--- apricot/ldap/oauth_ldap_tree.py | 13 ++++----- apricot/ldap/read_only_ldap_server.py | 24 ++++++++--------- apricot/models/ldap_attribute_adaptor.py | 8 +++--- apricot/models/ldap_group_of_names.py | 4 ++- apricot/models/ldap_inetorgperson.py | 4 ++- apricot/models/ldap_organizational_person.py | 4 ++- apricot/models/ldap_person.py | 4 ++- apricot/models/ldap_posix_account.py | 9 ++++--- apricot/models/ldap_posix_group.py | 6 +++-- apricot/models/named_ldap_class.py | 4 ++- apricot/oauth/keycloak_client.py | 12 ++++----- apricot/oauth/microsoft_entra_client.py | 12 ++++----- apricot/oauth/oauth_client.py | 28 +++++++++++++------- apricot/oauth/oauth_data_adaptor.py | 21 ++++++++------- apricot/patches/ldap_string.py | 4 +-- pyproject.toml | 4 +++ 22 files changed, 141 insertions(+), 106 deletions(-) diff --git a/apricot/apricot_server.py b/apricot/apricot_server.py index 152009a..248fab7 100644 --- a/apricot/apricot_server.py +++ b/apricot/apricot_server.py @@ -1,6 +1,6 @@ import inspect import sys -from typing import Any, cast +from typing import Any, Self, cast from twisted.internet import reactor, task from twisted.internet.endpoints import quoteStringArgument, serverFromString @@ -14,7 +14,7 @@ class ApricotServer: def __init__( - self, + self: Self, backend: OAuthBackend, client_id: str, client_secret: str, @@ -111,7 +111,7 @@ def __init__( # Load the Twisted reactor self.reactor = cast(IReactorCore, reactor) - def run(self) -> None: + def run(self: Self) -> None: """Start the Twisted reactor""" if self.debug: log.msg("Starting the Twisted reactor.") diff --git a/apricot/cache/local_cache.py b/apricot/cache/local_cache.py index 958217b..0bc05e7 100644 --- a/apricot/cache/local_cache.py +++ b/apricot/cache/local_cache.py @@ -1,18 +1,20 @@ +from typing import Self + from .uid_cache import UidCache class LocalCache(UidCache): - def __init__(self) -> None: + def __init__(self: Self) -> None: self.cache: dict[str, int] = {} - def get(self, identifier: str) -> int | None: + def get(self: Self, identifier: str) -> int | None: return self.cache.get(identifier, None) - def keys(self) -> list[str]: + def keys(self: Self) -> list[str]: return [str(k) for k in self.cache.keys()] - def set(self, identifier: str, uid_value: int) -> None: + def set(self: Self, identifier: str, uid_value: int) -> None: self.cache[identifier] = uid_value - def values(self, keys: list[str]) -> list[int]: + def values(self: Self, keys: list[str]) -> list[int]: return [v for k, v in self.cache.items() if k in keys] diff --git a/apricot/cache/redis_cache.py b/apricot/cache/redis_cache.py index 24ac506..6bf78c5 100644 --- a/apricot/cache/redis_cache.py +++ b/apricot/cache/redis_cache.py @@ -1,4 +1,4 @@ -from typing import cast +from typing import Self, cast import redis @@ -6,31 +6,33 @@ class RedisCache(UidCache): - def __init__(self, redis_host: str, redis_port: int) -> None: + def __init__(self: Self, redis_host: str, redis_port: int) -> None: self.redis_host = redis_host self.redis_port = redis_port - self.cache_: "redis.Redis[str]" | None = None # noqa: UP037 + self.cache_: "redis.Redis[str]" | None = None @property - def cache(self) -> "redis.Redis[str]": + def cache(self: Self) -> "redis.Redis[str]": """ Lazy-load the cache on request """ if not self.cache_: self.cache_ = redis.Redis( - host=self.redis_host, port=self.redis_port, decode_responses=True + host=self.redis_host, + port=self.redis_port, + decode_responses=True, ) return self.cache_ - def get(self, identifier: str) -> int | None: + def get(self: Self, identifier: str) -> int | None: value = self.cache.get(identifier) return None if value is None else int(value) - def keys(self) -> list[str]: + def keys(self: Self) -> list[str]: return [str(k) for k in self.cache.keys()] - def set(self, identifier: str, uid_value: int) -> None: + def set(self: Self, identifier: str, uid_value: int) -> None: self.cache.set(identifier, uid_value) - def values(self, keys: list[str]) -> list[int]: + def values(self: Self, keys: list[str]) -> list[int]: return [int(cast(str, v)) for v in self.cache.mget(keys)] diff --git a/apricot/cache/uid_cache.py b/apricot/cache/uid_cache.py index 355465f..52abd54 100644 --- a/apricot/cache/uid_cache.py +++ b/apricot/cache/uid_cache.py @@ -1,33 +1,33 @@ from abc import ABC, abstractmethod -from typing import cast +from typing import Self, cast class UidCache(ABC): @abstractmethod - def get(self, identifier: str) -> int | None: + def get(self: Self, identifier: str) -> int | None: """ Get the UID for a given identifier, returning None if it does not exist """ @abstractmethod - def keys(self) -> list[str]: + def keys(self: Self) -> list[str]: """ Get list of cached keys """ @abstractmethod - def set(self, identifier: str, uid_value: int) -> None: + def set(self: Self, identifier: str, uid_value: int) -> None: """ Set the UID for a given identifier """ @abstractmethod - def values(self, keys: list[str]) -> list[int]: + def values(self: Self, keys: list[str]) -> list[int]: """ Get list of cached values corresponding to requested keys """ - def get_group_uid(self, identifier: str) -> int: + def get_group_uid(self: Self, identifier: str) -> int: """ Get UID for a group, constructing one if necessary @@ -35,7 +35,7 @@ def get_group_uid(self, identifier: str) -> int: """ return self.get_uid(identifier, category="group", min_value=3000) - def get_user_uid(self, identifier: str) -> int: + def get_user_uid(self: Self, identifier: str) -> int: """ Get UID for a user, constructing one if necessary @@ -44,7 +44,7 @@ def get_user_uid(self, identifier: str) -> int: return self.get_uid(identifier, category="user", min_value=2000) def get_uid( - self, + self: Self, identifier: str, category: str, min_value: int | None = None, @@ -64,7 +64,7 @@ def get_uid( self.set(identifier_, next_uid) return cast(int, self.get(identifier_)) - def _get_max_uid(self, category: str | None) -> int: + def _get_max_uid(self: Self, category: str | None) -> int: """ Get maximum UID for a given category @@ -77,7 +77,7 @@ def _get_max_uid(self, category: str | None) -> int: values = [*self.values(keys), -999] return max(values) - def overwrite_group_uid(self, identifier: str, uid: int) -> None: + def overwrite_group_uid(self: Self, identifier: str, uid: int) -> None: """ Set UID for a group, overwriting the existing value if there is one @@ -86,7 +86,7 @@ def overwrite_group_uid(self, identifier: str, uid: int) -> None: """ return self.overwrite_uid(identifier, category="group", uid=uid) - def overwrite_user_uid(self, identifier: str, uid: int) -> None: + def overwrite_user_uid(self: Self, identifier: str, uid: int) -> None: """ Get UID for a user, constructing one if necessary @@ -95,7 +95,7 @@ def overwrite_user_uid(self, identifier: str, uid: int) -> None: """ return self.overwrite_uid(identifier, category="user", uid=uid) - def overwrite_uid(self, identifier: str, category: str, uid: int) -> None: + def overwrite_uid(self: Self, identifier: str, category: str, uid: int) -> None: """ Set UID, overwriting the existing one if necessary. diff --git a/apricot/ldap/oauth_ldap_entry.py b/apricot/ldap/oauth_ldap_entry.py index a4b07a1..1dd4554 100644 --- a/apricot/ldap/oauth_ldap_entry.py +++ b/apricot/ldap/oauth_ldap_entry.py @@ -1,4 +1,4 @@ -from typing import cast +from typing import Self, cast from ldaptor.inmemory import ReadOnlyInMemoryLDAPEntry from ldaptor.protocols.ldap.distinguishedname import ( @@ -20,7 +20,7 @@ class OAuthLDAPEntry(ReadOnlyInMemoryLDAPEntry): attributes: LDAPAttributeDict def __init__( - self, + self: Self, dn: DistinguishedName | str, attributes: LDAPAttributeDict, oauth_client: OAuthClient | None = None, @@ -37,7 +37,7 @@ def __init__( dn = DistinguishedName(stringValue=dn) super().__init__(dn, attributes) - def __str__(self) -> str: + def __str__(self: Self) -> str: output = bytes(self.toWire()).decode("utf-8") for child in self._children.values(): try: @@ -52,7 +52,7 @@ def __str__(self) -> str: return output @property - def oauth_client(self) -> OAuthClient: + def oauth_client(self: Self) -> OAuthClient: if not self.oauth_client_: if hasattr(self._parent, "oauth_client"): self.oauth_client_ = self._parent.oauth_client @@ -62,7 +62,7 @@ def oauth_client(self) -> OAuthClient: return self.oauth_client_ def add_child( - self, + self: Self, rdn: RelativeDistinguishedName | str, attributes: LDAPAttributeDict, ) -> "OAuthLDAPEntry": @@ -75,7 +75,7 @@ def add_child( output = self._children[rdn.getText()] return cast(OAuthLDAPEntry, output) - def bind(self, password: bytes) -> defer.Deferred["OAuthLDAPEntry"]: + def bind(self: Self, password: bytes) -> defer.Deferred["OAuthLDAPEntry"]: def _bind(password: bytes) -> "OAuthLDAPEntry": oauth_username = next(iter(self.get("oauth_username", "unknown"))) s_password = password.decode("utf-8") @@ -86,5 +86,5 @@ def _bind(password: bytes) -> "OAuthLDAPEntry": return defer.maybeDeferred(_bind, password) - def list_children(self) -> "list[OAuthLDAPEntry]": + def list_children(self: Self) -> "list[OAuthLDAPEntry]": return [cast(OAuthLDAPEntry, entry) for entry in self._children.values()] diff --git a/apricot/ldap/oauth_ldap_server_factory.py b/apricot/ldap/oauth_ldap_server_factory.py index 303d9e4..f86f129 100644 --- a/apricot/ldap/oauth_ldap_server_factory.py +++ b/apricot/ldap/oauth_ldap_server_factory.py @@ -1,3 +1,5 @@ +from typing import Self + from twisted.internet.interfaces import IAddress from twisted.internet.protocol import Protocol, ServerFactory @@ -9,14 +11,14 @@ class OAuthLDAPServerFactory(ServerFactory): def __init__( - self, + self: Self, domain: str, oauth_client: OAuthClient, *, background_refresh: bool, enable_mirrored_groups: bool, refresh_interval: int, - ): + ) -> None: """ Initialise an OAuthLDAPServerFactory @@ -35,10 +37,10 @@ def __init__( refresh_interval=refresh_interval, ) - def __repr__(self) -> str: + def __repr__(self: Self) -> str: return f"{self.__class__.__name__} using adaptor {self.adaptor}" - def buildProtocol(self, addr: IAddress) -> Protocol: # noqa: N802 + def buildProtocol(self: Self, addr: IAddress) -> Protocol: # noqa: N802 """ Create an LDAPServer instance. diff --git a/apricot/ldap/oauth_ldap_tree.py b/apricot/ldap/oauth_ldap_tree.py index 24b02ac..3d31c10 100644 --- a/apricot/ldap/oauth_ldap_tree.py +++ b/apricot/ldap/oauth_ldap_tree.py @@ -1,4 +1,5 @@ import time +from typing import Self from ldaptor.interfaces import IConnectedLDAPEntry, ILDAPEntry from ldaptor.protocols.ldap.distinguishedname import DistinguishedName @@ -14,7 +15,7 @@ class OAuthLDAPTree: def __init__( - self, + self: Self, domain: str, oauth_client: OAuthClient, *, @@ -41,11 +42,11 @@ def __init__( self.root_: OAuthLDAPEntry | None = None @property - def dn(self) -> DistinguishedName: + def dn(self: Self) -> DistinguishedName: return self.root.dn @property - def root(self) -> OAuthLDAPEntry: + def root(self: Self) -> OAuthLDAPEntry: """ Lazy-load the LDAP tree on request @@ -60,7 +61,7 @@ def root(self) -> OAuthLDAPEntry: raise ValueError(msg) return self.root_ - def refresh(self) -> None: + def refresh(self: Self) -> None: if ( not self.root_ or (time.monotonic() - self.last_update) > self.refresh_interval @@ -121,10 +122,10 @@ def refresh(self) -> None: log.msg("Finished building LDAP tree.") self.last_update = time.monotonic() - def __repr__(self) -> str: + def __repr__(self: Self) -> str: return f"{self.__class__.__name__} with backend {self.oauth_client.__class__.__name__}" - def lookup(self, dn: DistinguishedName | str) -> defer.Deferred[ILDAPEntry]: + def lookup(self: Self, dn: DistinguishedName | str) -> defer.Deferred[ILDAPEntry]: """ Lookup the referred to by dn. diff --git a/apricot/ldap/read_only_ldap_server.py b/apricot/ldap/read_only_ldap_server.py index b23ba88..5a3b710 100644 --- a/apricot/ldap/read_only_ldap_server.py +++ b/apricot/ldap/read_only_ldap_server.py @@ -1,4 +1,4 @@ -from typing import Callable +from typing import Callable, Self from ldaptor.interfaces import ILDAPEntry from ldaptor.protocols.ldap.ldaperrors import LDAPProtocolError @@ -24,12 +24,12 @@ class ReadOnlyLDAPServer(LDAPServer): - def __init__(self, *, debug: bool = False) -> None: + def __init__(self: Self, *, debug: bool = False) -> None: super().__init__() self.debug = debug def getRootDSE( # noqa: N802 - self, + self: Self, request: LDAPProtocolRequest, reply: Callable[[LDAPSearchResultEntry], None] | None, ) -> LDAPSearchResultDone: @@ -45,7 +45,7 @@ def getRootDSE( # noqa: N802 raise LDAPProtocolError(msg) from exc def handle_LDAPAddRequest( # noqa: N802 - self, + self: Self, request: LDAPAddRequest, controls: list[LDAPControlTuple] | None, reply: Callable[..., None] | None, @@ -60,7 +60,7 @@ def handle_LDAPAddRequest( # noqa: N802 raise LDAPProtocolError(msg) def handle_LDAPBindRequest( # noqa: N802 - self, + self: Self, request: LDAPBindRequest, controls: list[LDAPControlTuple] | None, reply: Callable[..., None] | None, @@ -77,7 +77,7 @@ def handle_LDAPBindRequest( # noqa: N802 raise LDAPProtocolError(msg) from exc def handle_LDAPCompareRequest( # noqa: N802 - self, + self: Self, request: LDAPCompareRequest, controls: list[LDAPControlTuple] | None, reply: Callable[..., None] | None, @@ -94,7 +94,7 @@ def handle_LDAPCompareRequest( # noqa: N802 raise LDAPProtocolError(msg) from exc def handle_LDAPDelRequest( # noqa: N802 - self, + self: Self, request: LDAPDelRequest, controls: list[LDAPControlTuple] | None, reply: Callable[..., None] | None, @@ -109,7 +109,7 @@ def handle_LDAPDelRequest( # noqa: N802 raise LDAPProtocolError(msg) def handle_LDAPExtendedRequest( # noqa: N802 - self, + self: Self, request: LDAPExtendedRequest, controls: list[LDAPControlTuple] | None, reply: Callable[..., None] | None, @@ -126,7 +126,7 @@ def handle_LDAPExtendedRequest( # noqa: N802 raise LDAPProtocolError(msg) from exc def handle_LDAPModifyDNRequest( # noqa: N802 - self, + self: Self, request: LDAPModifyDNRequest, controls: list[LDAPControlTuple] | None, reply: Callable[..., None] | None, @@ -141,7 +141,7 @@ def handle_LDAPModifyDNRequest( # noqa: N802 raise LDAPProtocolError(msg) def handle_LDAPModifyRequest( # noqa: N802 - self, + self: Self, request: LDAPModifyRequest, controls: list[LDAPControlTuple] | None, reply: Callable[..., None] | None, @@ -156,7 +156,7 @@ def handle_LDAPModifyRequest( # noqa: N802 raise LDAPProtocolError(msg) def handle_LDAPSearchRequest( # noqa: N802 - self, + self: Self, request: LDAPSearchRequest, controls: list[LDAPControlTuple] | None, reply: Callable[[LDAPSearchResultEntry], None] | None, @@ -173,7 +173,7 @@ def handle_LDAPSearchRequest( # noqa: N802 raise LDAPProtocolError(msg) from exc def handle_LDAPUnbindRequest( # noqa: N802 - self, + self: Self, request: LDAPUnbindRequest, controls: list[LDAPControlTuple] | None, reply: Callable[..., None] | None, diff --git a/apricot/models/ldap_attribute_adaptor.py b/apricot/models/ldap_attribute_adaptor.py index 40f986d..83e4be8 100644 --- a/apricot/models/ldap_attribute_adaptor.py +++ b/apricot/models/ldap_attribute_adaptor.py @@ -1,10 +1,10 @@ -from typing import Any +from typing import Any, Self from apricot.types import LDAPAttributeDict class LDAPAttributeAdaptor: - def __init__(self, attributes: dict[Any, Any]) -> None: + def __init__(self: Self, attributes: dict[Any, Any]) -> None: self.attributes = { str(k): list(map(str, v)) if isinstance(v, list) else [str(v)] for k, v in attributes.items() @@ -12,10 +12,10 @@ def __init__(self, attributes: dict[Any, Any]) -> None: } @property - def cn(self) -> str: + def cn(self: Self) -> str: """Return CN for this set of LDAP attributes""" return self.attributes["cn"][0] - def to_dict(self) -> LDAPAttributeDict: + def to_dict(self: Self) -> LDAPAttributeDict: """Convert the attributes to an LDAPAttributeDict""" return self.attributes diff --git a/apricot/models/ldap_group_of_names.py b/apricot/models/ldap_group_of_names.py index b1e077b..03f8337 100644 --- a/apricot/models/ldap_group_of_names.py +++ b/apricot/models/ldap_group_of_names.py @@ -1,3 +1,5 @@ +from typing import Self + from .named_ldap_class import NamedLDAPClass @@ -15,5 +17,5 @@ class LDAPGroupOfNames(NamedLDAPClass): description: str member: list[str] - def names(self) -> list[str]: + def names(self: Self) -> list[str]: return ["groupOfNames"] diff --git a/apricot/models/ldap_inetorgperson.py b/apricot/models/ldap_inetorgperson.py index 51e5cb5..9eed838 100644 --- a/apricot/models/ldap_inetorgperson.py +++ b/apricot/models/ldap_inetorgperson.py @@ -1,3 +1,5 @@ +from typing import Self + from .ldap_organizational_person import LDAPOrganizationalPerson @@ -19,5 +21,5 @@ class LDAPInetOrgPerson(LDAPOrganizationalPerson): mail: str | None = None telephoneNumber: str | None = None # noqa: N815 - def names(self) -> list[str]: + def names(self: Self) -> list[str]: return [*super().names(), "inetOrgPerson"] diff --git a/apricot/models/ldap_organizational_person.py b/apricot/models/ldap_organizational_person.py index 064ba5a..64d102a 100644 --- a/apricot/models/ldap_organizational_person.py +++ b/apricot/models/ldap_organizational_person.py @@ -1,3 +1,5 @@ +from typing import Self + from .ldap_person import LDAPPerson @@ -13,5 +15,5 @@ class LDAPOrganizationalPerson(LDAPPerson): description: str - def names(self) -> list[str]: + def names(self: Self) -> list[str]: return [*super().names(), "organizationalPerson"] diff --git a/apricot/models/ldap_person.py b/apricot/models/ldap_person.py index 0656897..1263c85 100644 --- a/apricot/models/ldap_person.py +++ b/apricot/models/ldap_person.py @@ -1,3 +1,5 @@ +from typing import Self + from .named_ldap_class import NamedLDAPClass @@ -14,5 +16,5 @@ class LDAPPerson(NamedLDAPClass): cn: str sn: str - def names(self) -> list[str]: + def names(self: Self) -> list[str]: return ["person"] diff --git a/apricot/models/ldap_posix_account.py b/apricot/models/ldap_posix_account.py index 84344c4..9510f68 100644 --- a/apricot/models/ldap_posix_account.py +++ b/apricot/models/ldap_posix_account.py @@ -1,4 +1,5 @@ import re +from typing import Self, Type from pydantic import StringConstraints, validator from typing_extensions import Annotated @@ -30,7 +31,7 @@ class LDAPPosixAccount(NamedLDAPClass): @validator("gidNumber") # type: ignore[misc] @classmethod - def validate_gid_number(cls, gid_number: int) -> int: + def validate_gid_number(cls: Type[Self], gid_number: int) -> int: """Avoid conflicts with existing users""" if not ID_MIN <= gid_number <= ID_MAX: msg = f"Must be in range {ID_MIN} to {ID_MAX}." @@ -39,17 +40,17 @@ def validate_gid_number(cls, gid_number: int) -> int: @validator("homeDirectory") # type: ignore[misc] @classmethod - def validate_home_directory(cls, home_directory: str) -> str: + def validate_home_directory(cls: Type[Self], home_directory: str) -> str: return re.sub(r"\s+", "-", home_directory) @validator("uidNumber") # type: ignore[misc] @classmethod - def validate_uid_number(cls, uid_number: int) -> int: + def validate_uid_number(cls: Type[Self], uid_number: int) -> int: """Avoid conflicts with existing users""" if not ID_MIN <= uid_number <= ID_MAX: msg = f"Must be in range {ID_MIN} to {ID_MAX}." raise ValueError(msg) return uid_number - def names(self) -> list[str]: + def names(self: Self) -> list[str]: return ["posixAccount"] diff --git a/apricot/models/ldap_posix_group.py b/apricot/models/ldap_posix_group.py index e926b49..436f2b5 100644 --- a/apricot/models/ldap_posix_group.py +++ b/apricot/models/ldap_posix_group.py @@ -1,3 +1,5 @@ +from typing import Self, Type + from pydantic import validator from .named_ldap_class import NamedLDAPClass @@ -22,12 +24,12 @@ class LDAPPosixGroup(NamedLDAPClass): @validator("gidNumber") # type: ignore[misc] @classmethod - def validate_gid_number(cls, gid_number: int) -> int: + def validate_gid_number(cls: Type[Self], gid_number: int) -> int: """Avoid conflicts with existing groups""" if not ID_MIN <= gid_number <= ID_MAX: msg = f"Must be in range {ID_MIN} to {ID_MAX}." raise ValueError(msg) return gid_number - def names(self) -> list[str]: + def names(self: Self) -> list[str]: return ["posixGroup"] diff --git a/apricot/models/named_ldap_class.py b/apricot/models/named_ldap_class.py index 329e771..bf551c4 100644 --- a/apricot/models/named_ldap_class.py +++ b/apricot/models/named_ldap_class.py @@ -1,7 +1,9 @@ +from typing import Self + from pydantic import BaseModel class NamedLDAPClass(BaseModel): - def names(self) -> list[str]: + def names(self: Self) -> list[str]: """List of names for this LDAP object class""" return [] diff --git a/apricot/oauth/keycloak_client.py b/apricot/oauth/keycloak_client.py index 9e96750..837779f 100644 --- a/apricot/oauth/keycloak_client.py +++ b/apricot/oauth/keycloak_client.py @@ -1,4 +1,4 @@ -from typing import Any, cast +from typing import Any, Self, cast from apricot.types import JSONDict @@ -11,11 +11,11 @@ class KeycloakClient(OAuthClient): max_rows = 100 def __init__( - self, + self: Self, keycloak_base_url: str, keycloak_realm: str, **kwargs: Any, - ): + ) -> None: self.base_url = keycloak_base_url self.realm = keycloak_realm @@ -30,10 +30,10 @@ def __init__( **kwargs, ) - def extract_token(self, json_response: JSONDict) -> str: + def extract_token(self: Self, json_response: JSONDict) -> str: return str(json_response["access_token"]) - def groups(self) -> list[JSONDict]: + def groups(self: Self) -> list[JSONDict]: output = [] try: group_data: list[JSONDict] = [] @@ -90,7 +90,7 @@ def groups(self) -> list[JSONDict]: pass return output - def users(self) -> list[JSONDict]: + def users(self: Self) -> list[JSONDict]: output = [] try: user_data: list[JSONDict] = [] diff --git a/apricot/oauth/microsoft_entra_client.py b/apricot/oauth/microsoft_entra_client.py index 847925e..256be29 100644 --- a/apricot/oauth/microsoft_entra_client.py +++ b/apricot/oauth/microsoft_entra_client.py @@ -1,4 +1,4 @@ -from typing import Any, cast +from typing import Any, Self, cast from twisted.python import log @@ -11,10 +11,10 @@ class MicrosoftEntraClient(OAuthClient): """OAuth client for the Microsoft Entra backend.""" def __init__( - self, + self: Self, entra_tenant_id: str, **kwargs: Any, - ): + ) -> None: redirect_uri = "urn:ietf:wg:oauth:2.0:oob" # this is the "no redirect" URL scopes = ["https://graph.microsoft.com/.default"] # this is the default scope token_url = ( @@ -28,10 +28,10 @@ def __init__( **kwargs, ) - def extract_token(self, json_response: JSONDict) -> str: + def extract_token(self: Self, json_response: JSONDict) -> str: return str(json_response["access_token"]) - def groups(self) -> list[JSONDict]: + def groups(self: Self) -> list[JSONDict]: output = [] queries = [ "createdDateTime", @@ -69,7 +69,7 @@ def groups(self) -> list[JSONDict]: log.msg(msg) return output - def users(self) -> list[JSONDict]: + def users(self: Self) -> list[JSONDict]: output = [] try: queries = [ diff --git a/apricot/oauth/oauth_client.py b/apricot/oauth/oauth_client.py index 9a712c6..b269ae7 100644 --- a/apricot/oauth/oauth_client.py +++ b/apricot/oauth/oauth_client.py @@ -1,7 +1,7 @@ import os from abc import ABC, abstractmethod from http import HTTPStatus -from typing import Any +from typing import Any, Self import requests from oauthlib.oauth2 import ( @@ -21,7 +21,7 @@ class OAuthClient(ABC): """Base class for OAuth client talking to a generic backend.""" def __init__( - self, + self: Self, client_id: str, client_secret: str, debug: bool, # noqa: FBT001 @@ -72,7 +72,7 @@ def __init__( raise RuntimeError(msg) from exc @property - def bearer_token(self) -> str: + def bearer_token(self: Self) -> str: """ Return a bearer token, requesting a new one if necessary """ @@ -91,26 +91,31 @@ def bearer_token(self) -> str: raise RuntimeError(msg) from exc @abstractmethod - def extract_token(self, json_response: JSONDict) -> str: + def extract_token(self: Self, json_response: JSONDict) -> str: """ Extract the bearer token from an OAuth2Session JSON response """ @abstractmethod - def groups(self) -> list[JSONDict]: + def groups(self: Self) -> list[JSONDict]: """ Return JSON data about groups from the OAuth backend. This should be a list of JSON dictionaries where 'None' is used to signify missing values. """ @abstractmethod - def users(self) -> list[JSONDict]: + def users(self: Self) -> list[JSONDict]: """ Return JSON data about users from the OAuth backend. This should be a list of JSON dictionaries where 'None' is used to signify missing values. """ - def query(self, url: str, *, use_client_secret: bool = True) -> dict[str, Any]: + def query( + self: Self, + url: str, + *, + use_client_secret: bool = True, + ) -> dict[str, Any]: """ Make a query against the OAuth backend """ @@ -128,7 +133,12 @@ def query(self, url: str, *, use_client_secret: bool = True) -> dict[str, Any]: **kwargs, ) - def request(self, *args: Any, method: str = "GET", **kwargs: Any) -> dict[str, Any]: + def request( + self: Self, + *args: Any, + method: str = "GET", + **kwargs: Any, + ) -> dict[str, Any]: """ Make a request to the OAuth backend """ @@ -152,7 +162,7 @@ def query_(*args: Any, **kwargs: Any) -> requests.Response: return {} return result.json() # type: ignore - def verify(self, username: str, password: str) -> bool: + def verify(self: Self, username: str, password: str) -> bool: """ Verify username and password by attempting to authenticate against the OAuth backend. """ diff --git a/apricot/oauth/oauth_data_adaptor.py b/apricot/oauth/oauth_data_adaptor.py index f0dec89..ddc462b 100644 --- a/apricot/oauth/oauth_data_adaptor.py +++ b/apricot/oauth/oauth_data_adaptor.py @@ -1,4 +1,5 @@ from collections.abc import Sequence +from typing import Self from pydantic import ValidationError from twisted.python import log @@ -22,12 +23,12 @@ class OAuthDataAdaptor: """Adaptor for converting raw user and group data into LDAP format.""" def __init__( - self, + self: Self, domain: str, oauth_client: OAuthClient, *, enable_mirrored_groups: bool, - ): + ) -> None: """ Initialise an OAuthDataAdaptor @@ -50,27 +51,27 @@ def __init__( ) @property - def groups(self) -> list[LDAPAttributeAdaptor]: + def groups(self: Self) -> list[LDAPAttributeAdaptor]: """ Return a list of LDAPAttributeAdaptors representing validated group data. """ return self.validated_groups @property - def users(self) -> list[LDAPAttributeAdaptor]: + def users(self: Self) -> list[LDAPAttributeAdaptor]: """ Return a list of LDAPAttributeAdaptors representing validated user data. """ return self.validated_users - def _dn_from_group_cn(self, group_cn: str) -> str: + def _dn_from_group_cn(self: Self, group_cn: str) -> str: return f"CN={group_cn},OU=groups,{self.root_dn}" - def _dn_from_user_cn(self, user_cn: str) -> str: + def _dn_from_user_cn(self: Self, user_cn: str) -> str: return f"CN={user_cn},OU=users,{self.root_dn}" def _extract_attributes( - self, + self: Self, input_dict: JSONDict, required_classes: Sequence[type[NamedLDAPClass]], ) -> LDAPAttributeAdaptor: @@ -83,7 +84,7 @@ def _extract_attributes( return LDAPAttributeAdaptor(attributes) def _retrieve_entries( - self, + self: Self, ) -> tuple[ list[tuple[JSONDict, list[type[NamedLDAPClass]]]], list[tuple[JSONDict, list[type[NamedLDAPClass]]]], @@ -193,7 +194,7 @@ def _retrieve_entries( return (annotated_groups, annotated_users) def _validate_groups( - self, + self: Self, annotated_groups: list[tuple[JSONDict, list[type[NamedLDAPClass]]]], ) -> list[LDAPAttributeAdaptor]: """ @@ -220,7 +221,7 @@ def _validate_groups( return output def _validate_users( - self, + self: Self, annotated_users: list[tuple[JSONDict, list[type[NamedLDAPClass]]]], ) -> list[LDAPAttributeAdaptor]: """ diff --git a/apricot/patches/ldap_string.py b/apricot/patches/ldap_string.py index 41bfc45..8c1f90e 100644 --- a/apricot/patches/ldap_string.py +++ b/apricot/patches/ldap_string.py @@ -1,13 +1,13 @@ """Patch LDAPString to avoid TypeError when parsing LDAP filter strings""" -from typing import Any +from typing import Any, Self from ldaptor.protocols.pureldap import LDAPString old_init = LDAPString.__init__ -def patched_init(self, *args: Any, **kwargs: Any) -> None: # type: ignore[no-untyped-def] +def patched_init(self: Self, *args: Any, **kwargs: Any) -> None: # type: ignore """Patch LDAPString init to store its value as 'str' not 'bytes'""" old_init(self, *args, **kwargs) if isinstance(self.value, bytes): diff --git a/pyproject.toml b/pyproject.toml index cd03575..1ff7d8f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,6 +82,7 @@ select = [ # See https://beta.ruff.rs/docs/rules/ "A", # flake8-builtins "AIR", # Airflow + "ANN", # flake8-annotations "ARG", # flake8-unused-arguments "ASYNC", # flake8-async "B", # flake8-bugbear @@ -139,6 +140,9 @@ ignore = [ "PLR0913", # ignore too-many-arguments ] +[tool.ruff.lint.flake8-annotations] +allow-star-arg-any = true + [tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "parents" From c62fec9ef916278b524fa4f09fdef197f04a370c Mon Sep 17 00:00:00 2001 From: James Robinson Date: Thu, 30 May 2024 16:32:03 +0100 Subject: [PATCH 04/15] :memo: Add pydocstyle checks --- apricot/apricot_server.py | 21 ++++++++- apricot/cache/local_cache.py | 3 ++ apricot/cache/redis_cache.py | 11 +++-- apricot/cache/uid_cache.py | 39 ++++++----------- apricot/ldap/oauth_ldap_entry.py | 5 ++- apricot/ldap/oauth_ldap_server_factory.py | 8 ++-- apricot/ldap/oauth_ldap_tree.py | 11 +++-- apricot/ldap/read_only_ldap_server.py | 46 +++++++------------- apricot/models/ldap_attribute_adaptor.py | 10 ++++- apricot/models/ldap_group_of_names.py | 3 +- apricot/models/ldap_inetorgperson.py | 3 +- apricot/models/ldap_organizational_person.py | 3 +- apricot/models/ldap_person.py | 3 +- apricot/models/ldap_posix_account.py | 7 ++- apricot/models/ldap_posix_group.py | 5 +-- apricot/models/named_ldap_class.py | 4 +- apricot/models/overlay_memberof.py | 3 +- apricot/models/overlay_oauthentry.py | 3 +- apricot/oauth/keycloak_client.py | 5 +++ apricot/oauth/microsoft_entra_client.py | 4 ++ apricot/oauth/oauth_client.py | 38 ++++++++-------- apricot/oauth/oauth_data_adaptor.py | 25 +++-------- apricot/patches/ldap_string.py | 4 +- pyproject.toml | 15 +++++-- 24 files changed, 142 insertions(+), 137 deletions(-) diff --git a/apricot/apricot_server.py b/apricot/apricot_server.py index 248fab7..56fa9f9 100644 --- a/apricot/apricot_server.py +++ b/apricot/apricot_server.py @@ -13,6 +13,8 @@ class ApricotServer: + """The Apricot server running via Twisted.""" + def __init__( self: Self, backend: OAuthBackend, @@ -32,6 +34,23 @@ def __init__( tls_private_key: str | None = None, **kwargs: Any, ) -> None: + """Initialise an ApricotServer. + + @param backend: An OAuth backend, + @param client_id: An OAuth client ID + @param client_secret: An OAuth client secret + @param domain: The OAuth domain + @param port: Port to expose LDAP on + @param background_refresh: Whether to refresh the LDAP tree in the background + @param debug: Enable debug output + @param enable_mirrored_groups: Create a mirrored LDAP group-of-groups for each group-of-users + @param redis_host: Host for a Redis cache (if used) + @param redis_port: Port for a Redis cache (if used) + @param refresh_interval: Interval after which the LDAP information is stale + @param tls_port: Port to expose LDAPS on + @param tls_certificate: TLS certificate for LDAPS + @param tls_private_key: TLS private key for LDAPS + """ self.debug = debug # Log to stdout @@ -112,7 +131,7 @@ def __init__( self.reactor = cast(IReactorCore, reactor) def run(self: Self) -> None: - """Start the Twisted reactor""" + """Start the Twisted reactor.""" if self.debug: log.msg("Starting the Twisted reactor.") self.reactor.run() diff --git a/apricot/cache/local_cache.py b/apricot/cache/local_cache.py index 0bc05e7..b6835d9 100644 --- a/apricot/cache/local_cache.py +++ b/apricot/cache/local_cache.py @@ -4,7 +4,10 @@ class LocalCache(UidCache): + """Implementation of UidCache using an in-memory dictionary.""" + def __init__(self: Self) -> None: + """Initialise a RedisCache.""" self.cache: dict[str, int] = {} def get(self: Self, identifier: str) -> int | None: diff --git a/apricot/cache/redis_cache.py b/apricot/cache/redis_cache.py index 6bf78c5..016c2b7 100644 --- a/apricot/cache/redis_cache.py +++ b/apricot/cache/redis_cache.py @@ -6,16 +6,21 @@ class RedisCache(UidCache): + """Implementation of UidCache using a Redis backend.""" + def __init__(self: Self, redis_host: str, redis_port: int) -> None: + """Initialise a RedisCache. + + @param redis_host: Host for the Redis cache + @param redis_port: Port for the Redis cache + """ self.redis_host = redis_host self.redis_port = redis_port self.cache_: "redis.Redis[str]" | None = None @property def cache(self: Self) -> "redis.Redis[str]": - """ - Lazy-load the cache on request - """ + """Lazy-load the cache on request.""" if not self.cache_: self.cache_ = redis.Redis( host=self.redis_host, diff --git a/apricot/cache/uid_cache.py b/apricot/cache/uid_cache.py index 52abd54..8c3cbb2 100644 --- a/apricot/cache/uid_cache.py +++ b/apricot/cache/uid_cache.py @@ -3,41 +3,33 @@ class UidCache(ABC): + """Abstract cache for storing UIDs.""" + @abstractmethod def get(self: Self, identifier: str) -> int | None: - """ - Get the UID for a given identifier, returning None if it does not exist - """ + """Get the UID for a given identifier, returning None if it does not exist.""" @abstractmethod def keys(self: Self) -> list[str]: - """ - Get list of cached keys - """ + """Get list of cached keys.""" @abstractmethod def set(self: Self, identifier: str, uid_value: int) -> None: - """ - Set the UID for a given identifier - """ + """Set the UID for a given identifier.""" @abstractmethod def values(self: Self, keys: list[str]) -> list[int]: - """ - Get list of cached values corresponding to requested keys - """ + """Get list of cached values corresponding to requested keys.""" def get_group_uid(self: Self, identifier: str) -> int: - """ - Get UID for a group, constructing one if necessary + """Get UID for a group, constructing one if necessary. @param identifier: Identifier for group needing a UID """ return self.get_uid(identifier, category="group", min_value=3000) def get_user_uid(self: Self, identifier: str) -> int: - """ - Get UID for a user, constructing one if necessary + """Get UID for a user, constructing one if necessary. @param identifier: Identifier for user needing a UID """ @@ -49,8 +41,7 @@ def get_uid( category: str, min_value: int | None = None, ) -> int: - """ - Get UID, constructing one if necessary. + """Get UID, constructing one if necessary. @param identifier: Identifier for object needing a UID @param category: Category the object belongs to @@ -65,8 +56,7 @@ def get_uid( return cast(int, self.get(identifier_)) def _get_max_uid(self: Self, category: str | None) -> int: - """ - Get maximum UID for a given category + """Get maximum UID for a given category. @param category: Category to check UIDs for """ @@ -78,8 +68,7 @@ def _get_max_uid(self: Self, category: str | None) -> int: return max(values) def overwrite_group_uid(self: Self, identifier: str, uid: int) -> None: - """ - Set UID for a group, overwriting the existing value if there is one + """Set UID for a group, overwriting the existing value if there is one. @param identifier: Identifier for group @param uid: Desired UID @@ -87,8 +76,7 @@ def overwrite_group_uid(self: Self, identifier: str, uid: int) -> None: return self.overwrite_uid(identifier, category="group", uid=uid) def overwrite_user_uid(self: Self, identifier: str, uid: int) -> None: - """ - Get UID for a user, constructing one if necessary + """Get UID for a user, constructing one if necessary. @param identifier: Identifier for user @param uid: Desired UID @@ -96,8 +84,7 @@ def overwrite_user_uid(self: Self, identifier: str, uid: int) -> None: return self.overwrite_uid(identifier, category="user", uid=uid) def overwrite_uid(self: Self, identifier: str, category: str, uid: int) -> None: - """ - Set UID, overwriting the existing one if necessary. + """Set UID, overwriting the existing one if necessary. @param identifier: Identifier for object @param category: Category the object belongs to diff --git a/apricot/ldap/oauth_ldap_entry.py b/apricot/ldap/oauth_ldap_entry.py index 1dd4554..2d746fe 100644 --- a/apricot/ldap/oauth_ldap_entry.py +++ b/apricot/ldap/oauth_ldap_entry.py @@ -16,6 +16,8 @@ class OAuthLDAPEntry(ReadOnlyInMemoryLDAPEntry): + """An LDAP entry that represents a view of an OAuth object.""" + dn: DistinguishedName attributes: LDAPAttributeDict @@ -25,8 +27,7 @@ def __init__( attributes: LDAPAttributeDict, oauth_client: OAuthClient | None = None, ) -> None: - """ - Initialize the object. + """Initialize the object. @param dn: Distinguished Name of the object @param attributes: Attributes of the object. diff --git a/apricot/ldap/oauth_ldap_server_factory.py b/apricot/ldap/oauth_ldap_server_factory.py index f86f129..0744c22 100644 --- a/apricot/ldap/oauth_ldap_server_factory.py +++ b/apricot/ldap/oauth_ldap_server_factory.py @@ -10,6 +10,8 @@ class OAuthLDAPServerFactory(ServerFactory): + """A Twisted ServerFactory that provides an LDAP tree.""" + def __init__( self: Self, domain: str, @@ -19,8 +21,7 @@ def __init__( enable_mirrored_groups: bool, refresh_interval: int, ) -> None: - """ - Initialise an OAuthLDAPServerFactory + """Initialise an OAuthLDAPServerFactory. @param background_refresh: Whether to refresh the LDAP tree in the background rather than on access @param domain: The root domain of the LDAP tree @@ -41,8 +42,7 @@ def __repr__(self: Self) -> str: return f"{self.__class__.__name__} using adaptor {self.adaptor}" def buildProtocol(self: Self, addr: IAddress) -> Protocol: # noqa: N802 - """ - Create an LDAPServer instance. + """Create an LDAPServer instance. This instance will use self.adaptor to produce LDAP entries. diff --git a/apricot/ldap/oauth_ldap_tree.py b/apricot/ldap/oauth_ldap_tree.py index 3d31c10..6674019 100644 --- a/apricot/ldap/oauth_ldap_tree.py +++ b/apricot/ldap/oauth_ldap_tree.py @@ -13,6 +13,7 @@ @implementer(IConnectedLDAPEntry) class OAuthLDAPTree: + """An LDAP tree that represents a view of an OAuth directory.""" def __init__( self: Self, @@ -23,8 +24,7 @@ def __init__( enable_mirrored_groups: bool, refresh_interval: int, ) -> None: - """ - Initialise an OAuthLDAPTree + """Initialise an OAuthLDAPTree. @param background_refresh: Whether to refresh the LDAP tree in the background rather than on access @param domain: The root domain of the LDAP tree @@ -47,8 +47,7 @@ def dn(self: Self) -> DistinguishedName: @property def root(self: Self) -> OAuthLDAPEntry: - """ - Lazy-load the LDAP tree on request + """Lazy-load the LDAP tree on request. @return: An OAuthLDAPEntry for the tree @@ -62,6 +61,7 @@ def root(self: Self) -> OAuthLDAPEntry: return self.root_ def refresh(self: Self) -> None: + """Refresh the LDAP tree.""" if ( not self.root_ or (time.monotonic() - self.last_update) > self.refresh_interval @@ -126,8 +126,7 @@ def __repr__(self: Self) -> str: return f"{self.__class__.__name__} with backend {self.oauth_client.__class__.__name__}" def lookup(self: Self, dn: DistinguishedName | str) -> defer.Deferred[ILDAPEntry]: - """ - Lookup the referred to by dn. + """Lookup the referred to by dn. @return: A Deferred returning an ILDAPEntry. diff --git a/apricot/ldap/read_only_ldap_server.py b/apricot/ldap/read_only_ldap_server.py index 5a3b710..20bf665 100644 --- a/apricot/ldap/read_only_ldap_server.py +++ b/apricot/ldap/read_only_ldap_server.py @@ -24,7 +24,13 @@ class ReadOnlyLDAPServer(LDAPServer): + """A read-only LDAP server.""" + def __init__(self: Self, *, debug: bool = False) -> None: + """Initialise a ReadOnlyLDAPServer. + + @param debug: Enable debug output + """ super().__init__() self.debug = debug @@ -33,9 +39,7 @@ def getRootDSE( # noqa: N802 request: LDAPProtocolRequest, reply: Callable[[LDAPSearchResultEntry], None] | None, ) -> LDAPSearchResultDone: - """ - Handle an LDAP Root DSE request - """ + """Handle an LDAP Root DSE request.""" if self.debug: log.msg("Handling an LDAP Root DSE request.") try: @@ -50,9 +54,7 @@ def handle_LDAPAddRequest( # noqa: N802 controls: list[LDAPControlTuple] | None, reply: Callable[..., None] | None, ) -> defer.Deferred[ILDAPEntry]: - """ - Refuse to handle an LDAP add request - """ + """Refuse to handle an LDAP add request.""" if self.debug: log.msg("Handling an LDAP add request.") id((request, controls, reply)) # ignore unused arguments @@ -65,9 +67,7 @@ def handle_LDAPBindRequest( # noqa: N802 controls: list[LDAPControlTuple] | None, reply: Callable[..., None] | None, ) -> defer.Deferred[ILDAPEntry]: - """ - Handle an LDAP bind request - """ + """Handle an LDAP bind request.""" if self.debug: log.msg("Handling an LDAP bind request.") try: @@ -82,9 +82,7 @@ def handle_LDAPCompareRequest( # noqa: N802 controls: list[LDAPControlTuple] | None, reply: Callable[..., None] | None, ) -> defer.Deferred[ILDAPEntry]: - """ - Handle an LDAP compare request - """ + """Handle an LDAP compare request.""" if self.debug: log.msg("Handling an LDAP compare request.") try: @@ -99,9 +97,7 @@ def handle_LDAPDelRequest( # noqa: N802 controls: list[LDAPControlTuple] | None, reply: Callable[..., None] | None, ) -> defer.Deferred[ILDAPEntry]: - """ - Refuse to handle an LDAP delete request - """ + """Refuse to handle an LDAP delete request.""" if self.debug: log.msg("Handling an LDAP delete request.") id((request, controls, reply)) # ignore unused arguments @@ -114,9 +110,7 @@ def handle_LDAPExtendedRequest( # noqa: N802 controls: list[LDAPControlTuple] | None, reply: Callable[..., None] | None, ) -> defer.Deferred[ILDAPEntry]: - """ - Handle an LDAP extended request - """ + """Handle an LDAP extended request.""" if self.debug: log.msg("Handling an LDAP extended request.") try: @@ -131,9 +125,7 @@ def handle_LDAPModifyDNRequest( # noqa: N802 controls: list[LDAPControlTuple] | None, reply: Callable[..., None] | None, ) -> defer.Deferred[ILDAPEntry]: - """ - Refuse to handle an LDAP modify DN request - """ + """Refuse to handle an LDAP modify DN request.""" if self.debug: log.msg("Handling an LDAP modify DN request.") id((request, controls, reply)) # ignore unused arguments @@ -146,9 +138,7 @@ def handle_LDAPModifyRequest( # noqa: N802 controls: list[LDAPControlTuple] | None, reply: Callable[..., None] | None, ) -> defer.Deferred[ILDAPEntry]: - """ - Refuse to handle an LDAP modify request - """ + """Refuse to handle an LDAP modify request.""" if self.debug: log.msg("Handling an LDAP modify request.") id((request, controls, reply)) # ignore unused arguments @@ -161,9 +151,7 @@ def handle_LDAPSearchRequest( # noqa: N802 controls: list[LDAPControlTuple] | None, reply: Callable[[LDAPSearchResultEntry], None] | None, ) -> defer.Deferred[ILDAPEntry]: - """ - Handle an LDAP search request - """ + """Handle an LDAP search request.""" if self.debug: log.msg("Handling an LDAP search request.") try: @@ -178,9 +166,7 @@ def handle_LDAPUnbindRequest( # noqa: N802 controls: list[LDAPControlTuple] | None, reply: Callable[..., None] | None, ) -> None: - """ - Handle an LDAP unbind request - """ + """Handle an LDAP unbind request.""" if self.debug: log.msg("Handling an LDAP unbind request.") try: diff --git a/apricot/models/ldap_attribute_adaptor.py b/apricot/models/ldap_attribute_adaptor.py index 83e4be8..d6f8c00 100644 --- a/apricot/models/ldap_attribute_adaptor.py +++ b/apricot/models/ldap_attribute_adaptor.py @@ -4,7 +4,13 @@ class LDAPAttributeAdaptor: + """A class to convert attributes into LDAP format.""" + def __init__(self: Self, attributes: dict[Any, Any]) -> None: + """Initialise an LDAPAttributeAdaptor. + + @param attributes: A dictionary of attributes to be converted into str: list[str] + """ self.attributes = { str(k): list(map(str, v)) if isinstance(v, list) else [str(v)] for k, v in attributes.items() @@ -13,9 +19,9 @@ def __init__(self: Self, attributes: dict[Any, Any]) -> None: @property def cn(self: Self) -> str: - """Return CN for this set of LDAP attributes""" + """Return CN for this set of LDAP attributes.""" return self.attributes["cn"][0] def to_dict(self: Self) -> LDAPAttributeDict: - """Convert the attributes to an LDAPAttributeDict""" + """Convert the attributes to an LDAPAttributeDict.""" return self.attributes diff --git a/apricot/models/ldap_group_of_names.py b/apricot/models/ldap_group_of_names.py index 03f8337..b1374d6 100644 --- a/apricot/models/ldap_group_of_names.py +++ b/apricot/models/ldap_group_of_names.py @@ -4,8 +4,7 @@ class LDAPGroupOfNames(NamedLDAPClass): - """ - A group with named members + """A group with named members. OID: 2.5.6.9 Object class: Structural diff --git a/apricot/models/ldap_inetorgperson.py b/apricot/models/ldap_inetorgperson.py index 9eed838..80b82c8 100644 --- a/apricot/models/ldap_inetorgperson.py +++ b/apricot/models/ldap_inetorgperson.py @@ -4,8 +4,7 @@ class LDAPInetOrgPerson(LDAPOrganizationalPerson): - """ - A person belonging to an internet/intranet directory service + """A person belonging to an internet/intranet directory service. OID: 2.16.840.1.113730.3.2.2 Object class: Structural diff --git a/apricot/models/ldap_organizational_person.py b/apricot/models/ldap_organizational_person.py index 64d102a..0ea2a46 100644 --- a/apricot/models/ldap_organizational_person.py +++ b/apricot/models/ldap_organizational_person.py @@ -4,8 +4,7 @@ class LDAPOrganizationalPerson(LDAPPerson): - """ - A person belonging to an organisation + """A person belonging to an organisation. OID: 2.5.6.7 Object class: Structural diff --git a/apricot/models/ldap_person.py b/apricot/models/ldap_person.py index 1263c85..28e7539 100644 --- a/apricot/models/ldap_person.py +++ b/apricot/models/ldap_person.py @@ -4,8 +4,7 @@ class LDAPPerson(NamedLDAPClass): - """ - A named person + """A named person. OID: 2.5.6.6 Object class: Structural diff --git a/apricot/models/ldap_posix_account.py b/apricot/models/ldap_posix_account.py index 9510f68..d881179 100644 --- a/apricot/models/ldap_posix_account.py +++ b/apricot/models/ldap_posix_account.py @@ -11,8 +11,7 @@ class LDAPPosixAccount(NamedLDAPClass): - """ - Abstraction of an account with POSIX attributes + """Abstraction of an account with POSIX attributes. OID: 1.3.6.1.1.1.2.0 Object class: Auxiliary @@ -32,7 +31,7 @@ class LDAPPosixAccount(NamedLDAPClass): @validator("gidNumber") # type: ignore[misc] @classmethod def validate_gid_number(cls: Type[Self], gid_number: int) -> int: - """Avoid conflicts with existing users""" + """Avoid conflicts with existing users.""" if not ID_MIN <= gid_number <= ID_MAX: msg = f"Must be in range {ID_MIN} to {ID_MAX}." raise ValueError(msg) @@ -46,7 +45,7 @@ def validate_home_directory(cls: Type[Self], home_directory: str) -> str: @validator("uidNumber") # type: ignore[misc] @classmethod def validate_uid_number(cls: Type[Self], uid_number: int) -> int: - """Avoid conflicts with existing users""" + """Avoid conflicts with existing users.""" if not ID_MIN <= uid_number <= ID_MAX: msg = f"Must be in range {ID_MIN} to {ID_MAX}." raise ValueError(msg) diff --git a/apricot/models/ldap_posix_group.py b/apricot/models/ldap_posix_group.py index 436f2b5..459abf6 100644 --- a/apricot/models/ldap_posix_group.py +++ b/apricot/models/ldap_posix_group.py @@ -9,8 +9,7 @@ class LDAPPosixGroup(NamedLDAPClass): - """ - Abstraction of a group of accounts + """Abstraction of a group of accounts. OID: 1.3.6.1.1.1.2.2 Object class: Auxiliary @@ -25,7 +24,7 @@ class LDAPPosixGroup(NamedLDAPClass): @validator("gidNumber") # type: ignore[misc] @classmethod def validate_gid_number(cls: Type[Self], gid_number: int) -> int: - """Avoid conflicts with existing groups""" + """Avoid conflicts with existing groups.""" if not ID_MIN <= gid_number <= ID_MAX: msg = f"Must be in range {ID_MIN} to {ID_MAX}." raise ValueError(msg) diff --git a/apricot/models/named_ldap_class.py b/apricot/models/named_ldap_class.py index bf551c4..57fc56e 100644 --- a/apricot/models/named_ldap_class.py +++ b/apricot/models/named_ldap_class.py @@ -4,6 +4,8 @@ class NamedLDAPClass(BaseModel): + """An LDAP class that has a name.""" + def names(self: Self) -> list[str]: - """List of names for this LDAP object class""" + """List of names for this LDAP object class.""" return [] diff --git a/apricot/models/overlay_memberof.py b/apricot/models/overlay_memberof.py index 3e78f71..c3fc414 100644 --- a/apricot/models/overlay_memberof.py +++ b/apricot/models/overlay_memberof.py @@ -2,8 +2,7 @@ class OverlayMemberOf(NamedLDAPClass): - """ - Abstraction for tracking the groups that an individual belongs to + """Abstraction for tracking the groups that an individual belongs to. OID: n/a Object class: Auxiliary diff --git a/apricot/models/overlay_oauthentry.py b/apricot/models/overlay_oauthentry.py index 3eabc37..2c188d9 100644 --- a/apricot/models/overlay_oauthentry.py +++ b/apricot/models/overlay_oauthentry.py @@ -2,8 +2,7 @@ class OverlayOAuthEntry(NamedLDAPClass): - """ - Abstraction for tracking an OAuth entry + """Abstraction for tracking an OAuth entry. OID: n/a Object class: Auxiliary diff --git a/apricot/oauth/keycloak_client.py b/apricot/oauth/keycloak_client.py index 837779f..bb06ed1 100644 --- a/apricot/oauth/keycloak_client.py +++ b/apricot/oauth/keycloak_client.py @@ -16,6 +16,11 @@ def __init__( keycloak_realm: str, **kwargs: Any, ) -> None: + """Initialise a KeycloakClient. + + @param keycloak_base_url: Base URL for Keycloak server + @param keycloak_realm: Realm for Keycloak server + """ self.base_url = keycloak_base_url self.realm = keycloak_realm diff --git a/apricot/oauth/microsoft_entra_client.py b/apricot/oauth/microsoft_entra_client.py index 256be29..d389a23 100644 --- a/apricot/oauth/microsoft_entra_client.py +++ b/apricot/oauth/microsoft_entra_client.py @@ -15,6 +15,10 @@ def __init__( entra_tenant_id: str, **kwargs: Any, ) -> None: + """Initialise a MicrosoftEntraClient. + + @param entra_tenant_id: Tenant ID for the Entra ID + """ redirect_uri = "urn:ietf:wg:oauth:2.0:oob" # this is the "no redirect" URL scopes = ["https://graph.microsoft.com/.default"] # this is the default scope token_url = ( diff --git a/apricot/oauth/oauth_client.py b/apricot/oauth/oauth_client.py index b269ae7..7466309 100644 --- a/apricot/oauth/oauth_client.py +++ b/apricot/oauth/oauth_client.py @@ -30,6 +30,16 @@ def __init__( token_url: str, uid_cache: UidCache, ) -> None: + """Initialise an OAuthClient. + + @param client_id: OAuth client ID + @param client_secret: OAuth client secret + @param debug: Enable debug output + @param redirect_uri: OAuth redirect URI + @param scopes: OAuth scopes + @param token_url: OAuth token URL + @param uid_cache: Cache for UIDs + """ # Set attributes self.bearer_token_: str | None = None self.client_secret = client_secret @@ -73,9 +83,7 @@ def __init__( @property def bearer_token(self: Self) -> str: - """ - Return a bearer token, requesting a new one if necessary - """ + """Return a bearer token, requesting a new one if necessary.""" try: if not self.bearer_token_: log.msg("Requesting a new authentication token from the OAuth backend.") @@ -92,21 +100,19 @@ def bearer_token(self: Self) -> str: @abstractmethod def extract_token(self: Self, json_response: JSONDict) -> str: - """ - Extract the bearer token from an OAuth2Session JSON response - """ + """Extract the bearer token from an OAuth2Session JSON response.""" @abstractmethod def groups(self: Self) -> list[JSONDict]: - """ - Return JSON data about groups from the OAuth backend. + """Return JSON data about groups from the OAuth backend. + This should be a list of JSON dictionaries where 'None' is used to signify missing values. """ @abstractmethod def users(self: Self) -> list[JSONDict]: - """ - Return JSON data about users from the OAuth backend. + """Return JSON data about users from the OAuth backend. + This should be a list of JSON dictionaries where 'None' is used to signify missing values. """ @@ -116,9 +122,7 @@ def query( *, use_client_secret: bool = True, ) -> dict[str, Any]: - """ - Make a query against the OAuth backend - """ + """Make a query against the OAuth backend.""" kwargs = ( { "client_id": self.session_application._client.client_id, @@ -139,9 +143,7 @@ def request( method: str = "GET", **kwargs: Any, ) -> dict[str, Any]: - """ - Make a request to the OAuth backend - """ + """Make a request to the OAuth backend.""" def query_(*args: Any, **kwargs: Any) -> requests.Response: return self.session_application.request( # type: ignore[no-any-return] @@ -163,9 +165,7 @@ def query_(*args: Any, **kwargs: Any) -> requests.Response: return result.json() # type: ignore def verify(self: Self, username: str, password: str) -> bool: - """ - Verify username and password by attempting to authenticate against the OAuth backend. - """ + """Verify username and password by attempting to authenticate against the OAuth backend.""" try: self.session_interactive.fetch_token( token_url=self.token_url, diff --git a/apricot/oauth/oauth_data_adaptor.py b/apricot/oauth/oauth_data_adaptor.py index ddc462b..b387438 100644 --- a/apricot/oauth/oauth_data_adaptor.py +++ b/apricot/oauth/oauth_data_adaptor.py @@ -29,8 +29,7 @@ def __init__( *, enable_mirrored_groups: bool, ) -> None: - """ - Initialise an OAuthDataAdaptor + """Initialise an OAuthDataAdaptor. @param domain: The root domain of the LDAP tree @param enable_mirrored_groups: Create a mirrored LDAP group-of-groups for each group-of-users @@ -52,16 +51,12 @@ def __init__( @property def groups(self: Self) -> list[LDAPAttributeAdaptor]: - """ - Return a list of LDAPAttributeAdaptors representing validated group data. - """ + """Return a list of LDAPAttributeAdaptors representing validated group data.""" return self.validated_groups @property def users(self: Self) -> list[LDAPAttributeAdaptor]: - """ - Return a list of LDAPAttributeAdaptors representing validated user data. - """ + """Return a list of LDAPAttributeAdaptors representing validated user data.""" return self.validated_users def _dn_from_group_cn(self: Self, group_cn: str) -> str: @@ -75,7 +70,7 @@ def _extract_attributes( input_dict: JSONDict, required_classes: Sequence[type[NamedLDAPClass]], ) -> LDAPAttributeAdaptor: - """Add appropriate LDAP class attributes""" + """Add appropriate LDAP class attributes.""" attributes = {"objectclass": ["top"]} for ldap_class in required_classes: model = ldap_class(**input_dict) @@ -89,9 +84,7 @@ def _retrieve_entries( list[tuple[JSONDict, list[type[NamedLDAPClass]]]], list[tuple[JSONDict, list[type[NamedLDAPClass]]]], ]: - """ - Obtain lists of users and groups, and construct necessary meta-entries. - """ + """Obtain lists of users and groups, and construct necessary meta-entries.""" # Get the initial set of users and groups oauth_groups = self.oauth_client.groups() oauth_users = self.oauth_client.users() @@ -197,9 +190,7 @@ def _validate_groups( self: Self, annotated_groups: list[tuple[JSONDict, list[type[NamedLDAPClass]]]], ) -> list[LDAPAttributeAdaptor]: - """ - Return a list of LDAPAttributeAdaptors representing validated group data. - """ + """Return a list of LDAPAttributeAdaptors representing validated group data.""" if self.debug: log.msg(f"Attempting to validate {len(annotated_groups)} groups.") output = [] @@ -224,9 +215,7 @@ def _validate_users( self: Self, annotated_users: list[tuple[JSONDict, list[type[NamedLDAPClass]]]], ) -> list[LDAPAttributeAdaptor]: - """ - Return a list of LDAPAttributeAdaptors representing validated user data. - """ + """Return a list of LDAPAttributeAdaptors representing validated user data.""" if self.debug: log.msg(f"Attempting to validate {len(annotated_users)} users.") output = [] diff --git a/apricot/patches/ldap_string.py b/apricot/patches/ldap_string.py index 8c1f90e..5754255 100644 --- a/apricot/patches/ldap_string.py +++ b/apricot/patches/ldap_string.py @@ -1,4 +1,4 @@ -"""Patch LDAPString to avoid TypeError when parsing LDAP filter strings""" +"""Patch LDAPString to avoid TypeError when parsing LDAP filter strings.""" from typing import Any, Self @@ -8,7 +8,7 @@ def patched_init(self: Self, *args: Any, **kwargs: Any) -> None: # type: ignore - """Patch LDAPString init to store its value as 'str' not 'bytes'""" + """Patch LDAPString init to store its value as 'str' not 'bytes'.""" old_init(self, *args, **kwargs) if isinstance(self.value, bytes): self.value = self.value.decode() diff --git a/pyproject.toml b/pyproject.toml index 1ff7d8f..038ebba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,6 +90,7 @@ select = [ "C", # complexity, mcabe and flake8-comprehensions "COM", # flake8-commas "CPY", # flake8-copyright + "D", # pydocstyle "DTZ", # flake8-datetimez "E", # pycodestyle errors "EM", # flake8-errmsg @@ -134,10 +135,16 @@ select = [ "YTT", # flake8-2020 ] ignore = [ - "E501", # ignore line length - "C901", # ignore complex-structure - "PLR0912", # ignore too-many-branches - "PLR0913", # ignore too-many-arguments + "D100", # missing-docstring-in-module + "D102", # missing-docstring-in-public-method + "D104", # missing-docstring-in-package + "D105", # missing-docstring-in-magic-method + "D203", # one-blank-line-before-class due to conflict with D211 + "D213", # multi-line-summary-second-line due to conflict with D212 + "E501", # line length + "C901", # complex-structure + "PLR0912", # too-many-branches + "PLR0913", # too-many-arguments ] [tool.ruff.lint.flake8-annotations] From a12b74329a88f797895053a7d35740c115bb1d6f Mon Sep 17 00:00:00 2001 From: James Robinson Date: Thu, 30 May 2024 16:42:41 +0100 Subject: [PATCH 05/15] :rotating_light: Add annotations check --- apricot/apricot_server.py | 2 + apricot/cache/local_cache.py | 2 + apricot/cache/redis_cache.py | 6 ++- apricot/cache/uid_cache.py | 2 + apricot/ldap/oauth_ldap_entry.py | 10 +++-- apricot/ldap/oauth_ldap_tree.py | 8 +++- apricot/ldap/read_only_ldap_server.py | 40 +++++++++++--------- apricot/models/ldap_attribute_adaptor.py | 7 +++- apricot/models/ldap_group_of_names.py | 2 + apricot/models/ldap_inetorgperson.py | 2 + apricot/models/ldap_organizational_person.py | 2 + apricot/models/ldap_person.py | 2 + apricot/models/ldap_posix_account.py | 10 +++-- apricot/models/ldap_posix_group.py | 6 ++- apricot/models/named_ldap_class.py | 2 + apricot/models/overlay_memberof.py | 2 + apricot/models/overlay_oauthentry.py | 2 + apricot/oauth/keycloak_client.py | 2 + apricot/oauth/microsoft_entra_client.py | 2 + apricot/oauth/oauth_client.py | 9 +++-- apricot/oauth/oauth_data_adaptor.py | 13 +++++-- pyproject.toml | 1 + 22 files changed, 93 insertions(+), 41 deletions(-) diff --git a/apricot/apricot_server.py b/apricot/apricot_server.py index 56fa9f9..d369f9d 100644 --- a/apricot/apricot_server.py +++ b/apricot/apricot_server.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import inspect import sys from typing import Any, Self, cast diff --git a/apricot/cache/local_cache.py b/apricot/cache/local_cache.py index b6835d9..b148453 100644 --- a/apricot/cache/local_cache.py +++ b/apricot/cache/local_cache.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Self from .uid_cache import UidCache diff --git a/apricot/cache/redis_cache.py b/apricot/cache/redis_cache.py index 016c2b7..c337f8e 100644 --- a/apricot/cache/redis_cache.py +++ b/apricot/cache/redis_cache.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Self, cast import redis @@ -16,10 +18,10 @@ def __init__(self: Self, redis_host: str, redis_port: int) -> None: """ self.redis_host = redis_host self.redis_port = redis_port - self.cache_: "redis.Redis[str]" | None = None + self.cache_: redis.Redis[str] | None = None @property - def cache(self: Self) -> "redis.Redis[str]": + def cache(self: Self) -> redis.Redis[str]: """Lazy-load the cache on request.""" if not self.cache_: self.cache_ = redis.Redis( diff --git a/apricot/cache/uid_cache.py b/apricot/cache/uid_cache.py index 8c3cbb2..a96d951 100644 --- a/apricot/cache/uid_cache.py +++ b/apricot/cache/uid_cache.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from abc import ABC, abstractmethod from typing import Self, cast diff --git a/apricot/ldap/oauth_ldap_entry.py b/apricot/ldap/oauth_ldap_entry.py index 2d746fe..8e31b5a 100644 --- a/apricot/ldap/oauth_ldap_entry.py +++ b/apricot/ldap/oauth_ldap_entry.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Self, cast from ldaptor.inmemory import ReadOnlyInMemoryLDAPEntry @@ -66,7 +68,7 @@ def add_child( self: Self, rdn: RelativeDistinguishedName | str, attributes: LDAPAttributeDict, - ) -> "OAuthLDAPEntry": + ) -> OAuthLDAPEntry: if isinstance(rdn, str): rdn = RelativeDistinguishedName(stringValue=rdn) try: @@ -76,8 +78,8 @@ def add_child( output = self._children[rdn.getText()] return cast(OAuthLDAPEntry, output) - def bind(self: Self, password: bytes) -> defer.Deferred["OAuthLDAPEntry"]: - def _bind(password: bytes) -> "OAuthLDAPEntry": + def bind(self: Self, password: bytes) -> defer.Deferred[OAuthLDAPEntry]: + def _bind(password: bytes) -> OAuthLDAPEntry: oauth_username = next(iter(self.get("oauth_username", "unknown"))) s_password = password.decode("utf-8") if self.oauth_client.verify(username=oauth_username, password=s_password): @@ -87,5 +89,5 @@ def _bind(password: bytes) -> "OAuthLDAPEntry": return defer.maybeDeferred(_bind, password) - def list_children(self: Self) -> "list[OAuthLDAPEntry]": + def list_children(self: Self) -> list[OAuthLDAPEntry]: return [cast(OAuthLDAPEntry, entry) for entry in self._children.values()] diff --git a/apricot/ldap/oauth_ldap_tree.py b/apricot/ldap/oauth_ldap_tree.py index 6674019..f27cb53 100644 --- a/apricot/ldap/oauth_ldap_tree.py +++ b/apricot/ldap/oauth_ldap_tree.py @@ -1,15 +1,19 @@ +from __future__ import annotations + import time -from typing import Self +from typing import TYPE_CHECKING, Self from ldaptor.interfaces import IConnectedLDAPEntry, ILDAPEntry from ldaptor.protocols.ldap.distinguishedname import DistinguishedName -from twisted.internet import defer from twisted.python import log from zope.interface import implementer from apricot.ldap.oauth_ldap_entry import OAuthLDAPEntry from apricot.oauth import OAuthClient, OAuthDataAdaptor +if TYPE_CHECKING: + from twisted.internet import defer + @implementer(IConnectedLDAPEntry) class OAuthLDAPTree: diff --git a/apricot/ldap/read_only_ldap_server.py b/apricot/ldap/read_only_ldap_server.py index 20bf665..48b965c 100644 --- a/apricot/ldap/read_only_ldap_server.py +++ b/apricot/ldap/read_only_ldap_server.py @@ -1,26 +1,30 @@ -from typing import Callable, Self +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, Self -from ldaptor.interfaces import ILDAPEntry from ldaptor.protocols.ldap.ldaperrors import LDAPProtocolError from ldaptor.protocols.ldap.ldapserver import LDAPServer -from ldaptor.protocols.pureldap import ( - LDAPAddRequest, - LDAPBindRequest, - LDAPCompareRequest, - LDAPDelRequest, - LDAPExtendedRequest, - LDAPModifyDNRequest, - LDAPModifyRequest, - LDAPProtocolRequest, - LDAPSearchRequest, - LDAPSearchResultDone, - LDAPSearchResultEntry, - LDAPUnbindRequest, -) -from twisted.internet import defer from twisted.python import log -from apricot.oauth import LDAPControlTuple +if TYPE_CHECKING: + from ldaptor.interfaces import ILDAPEntry + from ldaptor.protocols.pureldap import ( + LDAPAddRequest, + LDAPBindRequest, + LDAPCompareRequest, + LDAPDelRequest, + LDAPExtendedRequest, + LDAPModifyDNRequest, + LDAPModifyRequest, + LDAPProtocolRequest, + LDAPSearchRequest, + LDAPSearchResultDone, + LDAPSearchResultEntry, + LDAPUnbindRequest, + ) + from twisted.internet import defer + + from apricot.oauth import LDAPControlTuple class ReadOnlyLDAPServer(LDAPServer): diff --git a/apricot/models/ldap_attribute_adaptor.py b/apricot/models/ldap_attribute_adaptor.py index d6f8c00..5d2412a 100644 --- a/apricot/models/ldap_attribute_adaptor.py +++ b/apricot/models/ldap_attribute_adaptor.py @@ -1,6 +1,9 @@ -from typing import Any, Self +from __future__ import annotations -from apricot.types import LDAPAttributeDict +from typing import TYPE_CHECKING, Any, Self + +if TYPE_CHECKING: + from apricot.types import LDAPAttributeDict class LDAPAttributeAdaptor: diff --git a/apricot/models/ldap_group_of_names.py b/apricot/models/ldap_group_of_names.py index b1374d6..85bc8b8 100644 --- a/apricot/models/ldap_group_of_names.py +++ b/apricot/models/ldap_group_of_names.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Self from .named_ldap_class import NamedLDAPClass diff --git a/apricot/models/ldap_inetorgperson.py b/apricot/models/ldap_inetorgperson.py index 80b82c8..461b218 100644 --- a/apricot/models/ldap_inetorgperson.py +++ b/apricot/models/ldap_inetorgperson.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Self from .ldap_organizational_person import LDAPOrganizationalPerson diff --git a/apricot/models/ldap_organizational_person.py b/apricot/models/ldap_organizational_person.py index 0ea2a46..8f1687b 100644 --- a/apricot/models/ldap_organizational_person.py +++ b/apricot/models/ldap_organizational_person.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Self from .ldap_person import LDAPPerson diff --git a/apricot/models/ldap_person.py b/apricot/models/ldap_person.py index 28e7539..b82ebd4 100644 --- a/apricot/models/ldap_person.py +++ b/apricot/models/ldap_person.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Self from .named_ldap_class import NamedLDAPClass diff --git a/apricot/models/ldap_posix_account.py b/apricot/models/ldap_posix_account.py index d881179..65bb994 100644 --- a/apricot/models/ldap_posix_account.py +++ b/apricot/models/ldap_posix_account.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import re -from typing import Self, Type +from typing import Self from pydantic import StringConstraints, validator from typing_extensions import Annotated @@ -30,7 +32,7 @@ class LDAPPosixAccount(NamedLDAPClass): @validator("gidNumber") # type: ignore[misc] @classmethod - def validate_gid_number(cls: Type[Self], gid_number: int) -> int: + def validate_gid_number(cls: type[Self], gid_number: int) -> int: """Avoid conflicts with existing users.""" if not ID_MIN <= gid_number <= ID_MAX: msg = f"Must be in range {ID_MIN} to {ID_MAX}." @@ -39,12 +41,12 @@ def validate_gid_number(cls: Type[Self], gid_number: int) -> int: @validator("homeDirectory") # type: ignore[misc] @classmethod - def validate_home_directory(cls: Type[Self], home_directory: str) -> str: + def validate_home_directory(cls: type[Self], home_directory: str) -> str: return re.sub(r"\s+", "-", home_directory) @validator("uidNumber") # type: ignore[misc] @classmethod - def validate_uid_number(cls: Type[Self], uid_number: int) -> int: + def validate_uid_number(cls: type[Self], uid_number: int) -> int: """Avoid conflicts with existing users.""" if not ID_MIN <= uid_number <= ID_MAX: msg = f"Must be in range {ID_MIN} to {ID_MAX}." diff --git a/apricot/models/ldap_posix_group.py b/apricot/models/ldap_posix_group.py index 459abf6..1c83344 100644 --- a/apricot/models/ldap_posix_group.py +++ b/apricot/models/ldap_posix_group.py @@ -1,4 +1,6 @@ -from typing import Self, Type +from __future__ import annotations + +from typing import Self from pydantic import validator @@ -23,7 +25,7 @@ class LDAPPosixGroup(NamedLDAPClass): @validator("gidNumber") # type: ignore[misc] @classmethod - def validate_gid_number(cls: Type[Self], gid_number: int) -> int: + def validate_gid_number(cls: type[Self], gid_number: int) -> int: """Avoid conflicts with existing groups.""" if not ID_MIN <= gid_number <= ID_MAX: msg = f"Must be in range {ID_MIN} to {ID_MAX}." diff --git a/apricot/models/named_ldap_class.py b/apricot/models/named_ldap_class.py index 57fc56e..89ff32a 100644 --- a/apricot/models/named_ldap_class.py +++ b/apricot/models/named_ldap_class.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Self from pydantic import BaseModel diff --git a/apricot/models/overlay_memberof.py b/apricot/models/overlay_memberof.py index c3fc414..4209afe 100644 --- a/apricot/models/overlay_memberof.py +++ b/apricot/models/overlay_memberof.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from .named_ldap_class import NamedLDAPClass diff --git a/apricot/models/overlay_oauthentry.py b/apricot/models/overlay_oauthentry.py index 2c188d9..ba8c879 100644 --- a/apricot/models/overlay_oauthentry.py +++ b/apricot/models/overlay_oauthentry.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from .named_ldap_class import NamedLDAPClass diff --git a/apricot/oauth/keycloak_client.py b/apricot/oauth/keycloak_client.py index bb06ed1..5405755 100644 --- a/apricot/oauth/keycloak_client.py +++ b/apricot/oauth/keycloak_client.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Any, Self, cast from apricot.types import JSONDict diff --git a/apricot/oauth/microsoft_entra_client.py b/apricot/oauth/microsoft_entra_client.py index d389a23..40e61e6 100644 --- a/apricot/oauth/microsoft_entra_client.py +++ b/apricot/oauth/microsoft_entra_client.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Any, Self, cast from twisted.python import log diff --git a/apricot/oauth/oauth_client.py b/apricot/oauth/oauth_client.py index 7466309..7817fde 100644 --- a/apricot/oauth/oauth_client.py +++ b/apricot/oauth/oauth_client.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import os from abc import ABC, abstractmethod from http import HTTPStatus -from typing import Any, Self +from typing import TYPE_CHECKING, Any, Self import requests from oauthlib.oauth2 import ( @@ -13,8 +15,9 @@ from requests_oauthlib import OAuth2Session from twisted.python import log -from apricot.cache import UidCache -from apricot.types import JSONDict +if TYPE_CHECKING: + from apricot.cache import UidCache + from apricot.types import JSONDict class OAuthClient(ABC): diff --git a/apricot/oauth/oauth_data_adaptor.py b/apricot/oauth/oauth_data_adaptor.py index b387438..ad153d4 100644 --- a/apricot/oauth/oauth_data_adaptor.py +++ b/apricot/oauth/oauth_data_adaptor.py @@ -1,5 +1,6 @@ -from collections.abc import Sequence -from typing import Self +from __future__ import annotations + +from typing import TYPE_CHECKING, Self from pydantic import ValidationError from twisted.python import log @@ -14,9 +15,13 @@ OverlayMemberOf, OverlayOAuthEntry, ) -from apricot.types import JSONDict -from .oauth_client import OAuthClient +if TYPE_CHECKING: + from collections.abc import Sequence + + from apricot.types import JSONDict + + from .oauth_client import OAuthClient class OAuthDataAdaptor: diff --git a/pyproject.toml b/pyproject.toml index 038ebba..8dc1ce3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -97,6 +97,7 @@ select = [ "ERA", # eradicate "EXE", # flake8-executable "F", # pyflakes + "FA", # flake8-future-annotations "FBT", # flake8-boolean-trap "FIX", # flake8-fixme "FLY", # flynt From 1cea9913b0a8673d07e53a508961b7b00d01fe5c Mon Sep 17 00:00:00 2001 From: James Robinson Date: Thu, 30 May 2024 16:46:03 +0100 Subject: [PATCH 06/15] :rotating_light: Added pygrep-hooks --- apricot/apricot_server.py | 2 +- apricot/oauth/oauth_client.py | 2 +- apricot/patches/ldap_string.py | 2 +- pyproject.toml | 1 + 4 files changed, 4 insertions(+), 3 deletions(-) diff --git a/apricot/apricot_server.py b/apricot/apricot_server.py index d369f9d..577ff88 100644 --- a/apricot/apricot_server.py +++ b/apricot/apricot_server.py @@ -75,7 +75,7 @@ def __init__( log.msg(f"Creating an OAuthClient for {backend}.") oauth_backend = OAuthClientMap[backend] oauth_backend_args = inspect.getfullargspec( - oauth_backend.__init__, # type: ignore + oauth_backend.__init__, # type: ignore[misc] ).args oauth_client = oauth_backend( client_id=client_id, diff --git a/apricot/oauth/oauth_client.py b/apricot/oauth/oauth_client.py index 7817fde..28b5508 100644 --- a/apricot/oauth/oauth_client.py +++ b/apricot/oauth/oauth_client.py @@ -165,7 +165,7 @@ def query_(*args: Any, **kwargs: Any) -> requests.Response: result = query_(*args, **kwargs) if result.status_code == HTTPStatus.NO_CONTENT: return {} - return result.json() # type: ignore + return result.json() # type: ignore[no-any-return] def verify(self: Self, username: str, password: str) -> bool: """Verify username and password by attempting to authenticate against the OAuth backend.""" diff --git a/apricot/patches/ldap_string.py b/apricot/patches/ldap_string.py index 5754255..8d0f204 100644 --- a/apricot/patches/ldap_string.py +++ b/apricot/patches/ldap_string.py @@ -7,7 +7,7 @@ old_init = LDAPString.__init__ -def patched_init(self: Self, *args: Any, **kwargs: Any) -> None: # type: ignore +def patched_init(self: Self, *args: Any, **kwargs: Any) -> None: # type: ignore[misc] """Patch LDAPString init to store its value as 'str' not 'bytes'.""" old_init(self, *args, **kwargs) if isinstance(self.value, bytes): diff --git a/pyproject.toml b/pyproject.toml index 8dc1ce3..32a14ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -112,6 +112,7 @@ select = [ "N", # pep8-naming "NPY", # numpy-specific-rules "PD", # pandas-vet + "PGH", # pygrep-hooks "PIE", # flake8-pie "PLC", # pylint convention "PLE", # pylint error From 34fcd76a8f0bfb676c7574741b12ee8ad4201e50 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Thu, 30 May 2024 16:48:49 +0100 Subject: [PATCH 07/15] :rotating_light: Add flake8-simplify --- apricot/cache/local_cache.py | 2 +- apricot/cache/redis_cache.py | 2 +- apricot/ldap/oauth_ldap_entry.py | 5 ++--- apricot/oauth/oauth_data_adaptor.py | 4 ++-- pyproject.toml | 1 + 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/apricot/cache/local_cache.py b/apricot/cache/local_cache.py index b148453..65b4cc5 100644 --- a/apricot/cache/local_cache.py +++ b/apricot/cache/local_cache.py @@ -16,7 +16,7 @@ def get(self: Self, identifier: str) -> int | None: return self.cache.get(identifier, None) def keys(self: Self) -> list[str]: - return [str(k) for k in self.cache.keys()] + return [str(k) for k in self.cache] def set(self: Self, identifier: str, uid_value: int) -> None: self.cache[identifier] = uid_value diff --git a/apricot/cache/redis_cache.py b/apricot/cache/redis_cache.py index c337f8e..47ccc97 100644 --- a/apricot/cache/redis_cache.py +++ b/apricot/cache/redis_cache.py @@ -36,7 +36,7 @@ def get(self: Self, identifier: str) -> int | None: return None if value is None else int(value) def keys(self: Self) -> list[str]: - return [str(k) for k in self.cache.keys()] + return [str(k) for k in self.cache.keys()] # noqa: SIM118 def set(self: Self, identifier: str, uid_value: int) -> None: self.cache.set(identifier, uid_value) diff --git a/apricot/ldap/oauth_ldap_entry.py b/apricot/ldap/oauth_ldap_entry.py index 8e31b5a..0058bf7 100644 --- a/apricot/ldap/oauth_ldap_entry.py +++ b/apricot/ldap/oauth_ldap_entry.py @@ -56,9 +56,8 @@ def __str__(self: Self) -> str: @property def oauth_client(self: Self) -> OAuthClient: - if not self.oauth_client_: - if hasattr(self._parent, "oauth_client"): - self.oauth_client_ = self._parent.oauth_client + if not self.oauth_client_ and hasattr(self._parent, "oauth_client"): + self.oauth_client_ = self._parent.oauth_client if not isinstance(self.oauth_client_, OAuthClient): msg = f"OAuthClient is of incorrect type {type(self.oauth_client_)}" raise TypeError(msg) diff --git a/apricot/oauth/oauth_data_adaptor.py b/apricot/oauth/oauth_data_adaptor.py index ad153d4..4b4a251 100644 --- a/apricot/oauth/oauth_data_adaptor.py +++ b/apricot/oauth/oauth_data_adaptor.py @@ -208,7 +208,7 @@ def _validate_groups( ), ) except ValidationError as exc: - name = group_dict["cn"] if "cn" in group_dict else "unknown" + name = group_dict.get("cn", "unknown") log.msg(f"Validation failed for group '{name}'.") for error in exc.errors(): log.msg( @@ -233,7 +233,7 @@ def _validate_users( ), ) except ValidationError as exc: - name = user_dict["cn"] if "cn" in user_dict else "unknown" + name = user_dict.get("cn", "unknown") log.msg(f"Validation failed for user '{name}'.") for error in exc.errors(): log.msg( diff --git a/pyproject.toml b/pyproject.toml index 32a14ad..a7296d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -126,6 +126,7 @@ select = [ "RSE", # flake8-raise "RUF", # ruff rules "S", # flake8-bandit + "SIM", # flake8-simplify "SLOT", # flake8-slot "T", # flake8-debugger and flake8-print "TCH", # flake8-type-checking From ee017fac98c9f2cdee64740e38203fab19adafa8 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Thu, 30 May 2024 16:51:50 +0100 Subject: [PATCH 08/15] :rotating_light: Added tryceratops --- apricot/oauth/oauth_client.py | 8 +++++--- pyproject.toml | 1 + 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/apricot/oauth/oauth_client.py b/apricot/oauth/oauth_client.py index 28b5508..f480c3d 100644 --- a/apricot/oauth/oauth_client.py +++ b/apricot/oauth/oauth_client.py @@ -96,10 +96,11 @@ def bearer_token(self: Self) -> str: client_secret=self.client_secret, ) self.bearer_token_ = self.extract_token(json_response) - return self.bearer_token_ except Exception as exc: msg = f"Failed to fetch bearer token from OAuth endpoint.\n{exc!s}" raise RuntimeError(msg) from exc + else: + return self.bearer_token_ @abstractmethod def extract_token(self: Self, json_response: JSONDict) -> str: @@ -177,7 +178,8 @@ def verify(self: Self, username: str, password: str) -> bool: client_id=self.session_interactive._client.client_id, client_secret=self.client_secret, ) - return True except InvalidGrantError as exc: log.msg(f"Authentication failed.\n{exc}") - return False + return False + else: + return True diff --git a/pyproject.toml b/pyproject.toml index a7296d3..be5bb71 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -133,6 +133,7 @@ select = [ "TD", # flake8-todos "TID", # flake8-tidy-imports "TRIO", # flake8-trio + "TRY", # tryceratops "UP", # pyupgrade "W", # pycodestyle warnings "YTT", # flake8-2020 From d81af3e1c7d33e0598d2ccf3bc118abb1db6ca5a Mon Sep 17 00:00:00 2001 From: James Robinson Date: Thu, 30 May 2024 16:54:58 +0100 Subject: [PATCH 09/15] :rotating_light: Extend linting to cover run.py --- pyproject.toml | 10 ++++---- run.py | 66 +++++++++++++++++++++++++++++++++++++------------- 2 files changed, 54 insertions(+), 22 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index be5bb71..f31883e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,15 +58,15 @@ dependencies = [ ] [tool.hatch.envs.lint.scripts] -typing = "mypy {args:apricot}" +typing = "mypy {args:apricot} run.py" style = [ - "ruff check {args:apricot}", - "black --check --diff {args:apricot}", + "ruff check {args:apricot} run.py", + "black --check --diff {args:apricot} run.py", ] fmt = [ - "black {args:apricot}", - "ruff check --fix {args:apricot}", + "black {args:apricot} run.py", + "ruff check --fix {args:apricot} run.py", "style", ] all = [ diff --git a/run.py b/run.py index 98f4831..cd1b1a5 100644 --- a/run.py +++ b/run.py @@ -12,17 +12,30 @@ ) # Common options needed for all backends parser.add_argument( - "-b", "--backend", type=OAuthBackend, help="Which OAuth backend to use." + "-b", + "--backend", + type=OAuthBackend, + help="Which OAuth backend to use.", ) parser.add_argument( - "-d", "--domain", type=str, help="Which domain users belong to." + "-d", + "--domain", + type=str, + help="Which domain users belong to.", ) parser.add_argument("-i", "--client-id", type=str, help="OAuth client ID.") parser.add_argument( - "-p", "--port", type=int, default=1389, help="Port to run on." + "-p", + "--port", + type=int, + default=1389, + help="Port to run on.", ) parser.add_argument( - "-s", "--client-secret", type=str, help="OAuth client secret." + "-s", + "--client-secret", + type=str, + help="OAuth client secret.", ) parser.add_argument( "--background-refresh", @@ -31,7 +44,9 @@ help="Refresh in the background instead of as needed per request", ) parser.add_argument( - "--debug", action="store_true", help="Enable debug logging." + "--debug", + action="store_true", + help="Enable debug logging.", ) parser.add_argument( "--disable-mirrored-groups", @@ -50,50 +65,67 @@ # Options for Microsoft Entra backend entra_group = parser.add_argument_group("Microsoft Entra") entra_group.add_argument( - "--entra-tenant-id", type=str, help="Microsoft Entra tenant ID." + "--entra-tenant-id", + type=str, + help="Microsoft Entra tenant ID.", ) # Options for Keycloak backend keycloak_group = parser.add_argument_group("Keycloak") keycloak_group.add_argument( - "--keycloak-base-url", type=str, help="Keycloak base URL." + "--keycloak-base-url", + type=str, + help="Keycloak base URL.", ) keycloak_group.add_argument( - "--keycloak-realm", type=str, help="Keycloak Realm." + "--keycloak-realm", + type=str, + help="Keycloak Realm.", ) # Options for Redis cache redis_group = parser.add_argument_group("Redis") redis_group.add_argument( - "--redis-host", type=str, help="Host for Redis server." + "--redis-host", + type=str, + help="Host for Redis server.", ) redis_group.add_argument( - "--redis-port", type=int, help="Port for Redis server." + "--redis-port", + type=int, + help="Port for Redis server.", ) # Options for TLS tls_group = parser.add_argument_group("TLS") tls_group.add_argument( - "--tls-certificate", type=str, help="Location of TLS certificate (pem)." + "--tls-certificate", + type=str, + help="Location of TLS certificate (pem).", ) tls_group.add_argument( - "--tls-port", type=int, default=1636, help="Port to run on with encryption." + "--tls-port", + type=int, + default=1636, + help="Port to run on with encryption.", ) tls_group.add_argument( - "--tls-private-key", type=str, help="Location of TLS private key (pem)." + "--tls-private-key", + type=str, + help="Location of TLS private key (pem).", ) # Parse arguments args = parser.parse_args() # Create the Apricot server reactor = ApricotServer(**vars(args)) - except Exception as exc: + except Exception as exc: # noqa: BLE001 msg = f"Unable to initialise Apricot server.\n{exc}" - print(msg) + print(msg) # noqa: T201 sys.exit(1) # Run the Apricot server try: reactor.run() - except Exception as exc: + except Exception as exc: # noqa: BLE001 msg = f"Apricot server encountered a runtime problem.\n{exc}" - print(msg) + print(msg) # noqa: T201 sys.exit(1) From e297e85fd9b6142eca8c4c86edb26efd6c155e38 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Fri, 31 May 2024 11:07:52 +0100 Subject: [PATCH 10/15] :rotating_light: Enable ruff preview features --- apricot/__init__.py | 2 +- apricot/cache/uid_cache.py | 2 +- apricot/oauth/keycloak_client.py | 7 ++++--- apricot/oauth/microsoft_entra_client.py | 13 +++++++------ pyproject.toml | 7 ++++--- 5 files changed, 17 insertions(+), 14 deletions(-) diff --git a/apricot/__init__.py b/apricot/__init__.py index 0a47852..67a4c7d 100644 --- a/apricot/__init__.py +++ b/apricot/__init__.py @@ -3,7 +3,7 @@ from .patches import LDAPString # noqa: F401 __all__ = [ + "ApricotServer", "__version__", "__version_info__", - "ApricotServer", ] diff --git a/apricot/cache/uid_cache.py b/apricot/cache/uid_cache.py index a96d951..82621f3 100644 --- a/apricot/cache/uid_cache.py +++ b/apricot/cache/uid_cache.py @@ -52,7 +52,7 @@ def get_uid( identifier_ = f"{category}-{identifier}" uid = self.get(identifier_) if not uid: - min_value = min_value if min_value else 0 + min_value = min_value or 0 next_uid = max(self._get_max_uid(category) + 1, min_value) self.set(identifier_, next_uid) return cast(int, self.get(identifier_)) diff --git a/apricot/oauth/keycloak_client.py b/apricot/oauth/keycloak_client.py index 5405755..d859231 100644 --- a/apricot/oauth/keycloak_client.py +++ b/apricot/oauth/keycloak_client.py @@ -1,5 +1,6 @@ from __future__ import annotations +import operator from typing import Any, Self, cast from apricot.types import JSONDict @@ -128,7 +129,7 @@ def users(self: Self) -> list[JSONDict]: # Read user attributes for user_dict in sorted( user_data, - key=lambda user: user["createdTimestamp"], + key=operator.itemgetter("createdTimestamp"), ): if not user_dict["attributes"]["uid"]: user_dict["attributes"]["uid"] = [ @@ -154,10 +155,10 @@ def users(self: Self) -> list[JSONDict]: attributes["mail"] = user_dict.get("email") attributes["description"] = "" attributes["gidNumber"] = user_dict["attributes"]["uid"][0] - attributes["givenName"] = first_name if first_name else "" + attributes["givenName"] = first_name or "" attributes["homeDirectory"] = f"/home/{username}" if username else None attributes["oauth_id"] = user_dict.get("id", None) - attributes["sn"] = last_name if last_name else "" + attributes["sn"] = last_name or "" attributes["uidNumber"] = user_dict["attributes"]["uid"][0] output.append(attributes) except KeyError: diff --git a/apricot/oauth/microsoft_entra_client.py b/apricot/oauth/microsoft_entra_client.py index 40e61e6..2c65375 100644 --- a/apricot/oauth/microsoft_entra_client.py +++ b/apricot/oauth/microsoft_entra_client.py @@ -1,5 +1,6 @@ from __future__ import annotations +import operator from typing import Any, Self, cast from twisted.python import log @@ -49,7 +50,7 @@ def groups(self: Self) -> list[JSONDict]: ) for group_dict in cast( list[JSONDict], - sorted(group_data["value"], key=lambda group: group["createdDateTime"]), + sorted(group_data["value"], key=operator.itemgetter("createdDateTime")), ): try: group_uid = self.uid_cache.get_group_uid(group_dict["id"]) @@ -91,7 +92,7 @@ def users(self: Self) -> list[JSONDict]: ) for user_dict in cast( list[JSONDict], - sorted(user_data["value"], key=lambda user: user["createdDateTime"]), + sorted(user_data["value"], key=operator.itemgetter("createdDateTime")), ): # Get user attributes given_name = user_dict.get("givenName", None) @@ -99,17 +100,17 @@ def users(self: Self) -> list[JSONDict]: uid, domain = str(user_dict.get("userPrincipalName", "@")).split("@") user_uid = self.uid_cache.get_user_uid(user_dict["id"]) attributes: JSONDict = {} - attributes["cn"] = uid if uid else None + attributes["cn"] = uid or None attributes["description"] = user_dict.get("displayName", None) attributes["displayName"] = user_dict.get("displayName", None) attributes["domain"] = domain attributes["gidNumber"] = user_uid - attributes["givenName"] = given_name if given_name else "" + attributes["givenName"] = given_name or "" attributes["homeDirectory"] = f"/home/{uid}" if uid else None attributes["oauth_id"] = user_dict.get("id", None) attributes["oauth_username"] = user_dict.get("userPrincipalName", None) - attributes["sn"] = surname if surname else "" - attributes["uid"] = uid if uid else None + attributes["sn"] = surname or "" + attributes["uid"] = uid or None attributes["uidNumber"] = user_uid output.append(attributes) except KeyError: diff --git a/pyproject.toml b/pyproject.toml index f31883e..0668fac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,12 +61,12 @@ dependencies = [ typing = "mypy {args:apricot} run.py" style = [ - "ruff check {args:apricot} run.py", "black --check --diff {args:apricot} run.py", + "ruff check --preview {args:apricot} run.py", ] fmt = [ "black {args:apricot} run.py", - "ruff check --fix {args:apricot} run.py", + "ruff check --preview --fix {args:apricot} run.py", "style", ] all = [ @@ -89,7 +89,6 @@ select = [ "BLE", # flake8-blind-except "C", # complexity, mcabe and flake8-comprehensions "COM", # flake8-commas - "CPY", # flake8-copyright "D", # pydocstyle "DTZ", # flake8-datetimez "E", # pycodestyle errors @@ -149,6 +148,8 @@ ignore = [ "C901", # complex-structure "PLR0912", # too-many-branches "PLR0913", # too-many-arguments + "PLR0917", # too-many-positional-arguments + "PLR6301", # method-could-be-function ] [tool.ruff.lint.flake8-annotations] From c18b40715021db3fe332358802ad5606f7b96397 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Fri, 31 May 2024 11:12:35 +0100 Subject: [PATCH 11/15] :recycle: Refactored _extract_attributes as a class-level constructor --- apricot/models/ldap_attribute_adaptor.py | 20 ++++++++++++++++++-- apricot/oauth/oauth_data_adaptor.py | 18 ++---------------- 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/apricot/models/ldap_attribute_adaptor.py b/apricot/models/ldap_attribute_adaptor.py index 5d2412a..af91d30 100644 --- a/apricot/models/ldap_attribute_adaptor.py +++ b/apricot/models/ldap_attribute_adaptor.py @@ -1,9 +1,10 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Self +from typing import TYPE_CHECKING, Any, Self, Sequence if TYPE_CHECKING: - from apricot.types import LDAPAttributeDict + from apricot.models import NamedLDAPClass + from apricot.types import JSONDict, LDAPAttributeDict class LDAPAttributeAdaptor: @@ -25,6 +26,21 @@ def cn(self: Self) -> str: """Return CN for this set of LDAP attributes.""" return self.attributes["cn"][0] + @classmethod + def from_attributes( + cls: type[Self], + input_dict: JSONDict, + *, + required_classes: Sequence[type[NamedLDAPClass]], + ) -> LDAPAttributeAdaptor: + """Construct an LDAPAttributeAdaptor from attributes and required classes.""" + attributes = {"objectclass": ["top"]} + for ldap_class in required_classes: + model = ldap_class(**input_dict) + attributes.update(model.model_dump()) + attributes["objectclass"] += model.names() + return cls(attributes) + def to_dict(self: Self) -> LDAPAttributeDict: """Convert the attributes to an LDAPAttributeDict.""" return self.attributes diff --git a/apricot/oauth/oauth_data_adaptor.py b/apricot/oauth/oauth_data_adaptor.py index 4b4a251..16cc663 100644 --- a/apricot/oauth/oauth_data_adaptor.py +++ b/apricot/oauth/oauth_data_adaptor.py @@ -17,7 +17,6 @@ ) if TYPE_CHECKING: - from collections.abc import Sequence from apricot.types import JSONDict @@ -70,19 +69,6 @@ def _dn_from_group_cn(self: Self, group_cn: str) -> str: def _dn_from_user_cn(self: Self, user_cn: str) -> str: return f"CN={user_cn},OU=users,{self.root_dn}" - def _extract_attributes( - self: Self, - input_dict: JSONDict, - required_classes: Sequence[type[NamedLDAPClass]], - ) -> LDAPAttributeAdaptor: - """Add appropriate LDAP class attributes.""" - attributes = {"objectclass": ["top"]} - for ldap_class in required_classes: - model = ldap_class(**input_dict) - attributes.update(model.model_dump()) - attributes["objectclass"] += model.names() - return LDAPAttributeAdaptor(attributes) - def _retrieve_entries( self: Self, ) -> tuple[ @@ -202,7 +188,7 @@ def _validate_groups( for group_dict, required_classes in annotated_groups: try: output.append( - self._extract_attributes( + LDAPAttributeAdaptor.from_attributes( group_dict, required_classes=required_classes, ), @@ -227,7 +213,7 @@ def _validate_users( for user_dict, required_classes in annotated_users: try: output.append( - self._extract_attributes( + LDAPAttributeAdaptor.from_attributes( user_dict, required_classes=required_classes, ), From 0553654e489afd23fb421607bc440bc6ae4e7e8f Mon Sep 17 00:00:00 2001 From: James Robinson Date: Fri, 31 May 2024 11:25:29 +0100 Subject: [PATCH 12/15] :recycle: Convert extract_token to staticmethod --- apricot/oauth/keycloak_client.py | 3 ++- apricot/oauth/microsoft_entra_client.py | 3 ++- apricot/oauth/oauth_client.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/apricot/oauth/keycloak_client.py b/apricot/oauth/keycloak_client.py index d859231..d55fad4 100644 --- a/apricot/oauth/keycloak_client.py +++ b/apricot/oauth/keycloak_client.py @@ -38,7 +38,8 @@ def __init__( **kwargs, ) - def extract_token(self: Self, json_response: JSONDict) -> str: + @staticmethod + def extract_token(json_response: JSONDict) -> str: return str(json_response["access_token"]) def groups(self: Self) -> list[JSONDict]: diff --git a/apricot/oauth/microsoft_entra_client.py b/apricot/oauth/microsoft_entra_client.py index 2c65375..681abb0 100644 --- a/apricot/oauth/microsoft_entra_client.py +++ b/apricot/oauth/microsoft_entra_client.py @@ -35,7 +35,8 @@ def __init__( **kwargs, ) - def extract_token(self: Self, json_response: JSONDict) -> str: + @staticmethod + def extract_token(json_response: JSONDict) -> str: return str(json_response["access_token"]) def groups(self: Self) -> list[JSONDict]: diff --git a/apricot/oauth/oauth_client.py b/apricot/oauth/oauth_client.py index f480c3d..49797b3 100644 --- a/apricot/oauth/oauth_client.py +++ b/apricot/oauth/oauth_client.py @@ -102,8 +102,9 @@ def bearer_token(self: Self) -> str: else: return self.bearer_token_ + @staticmethod @abstractmethod - def extract_token(self: Self, json_response: JSONDict) -> str: + def extract_token(json_response: JSONDict) -> str: """Extract the bearer token from an OAuth2Session JSON response.""" @abstractmethod From 8a858fd3650a060f7b189b87884b3b5bc5397a39 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Fri, 31 May 2024 14:41:29 +0100 Subject: [PATCH 13/15] :truck: Rename NamedLDAPClass to LDAPObjectClass --- apricot/models/__init__.py | 4 ++-- apricot/models/ldap_attribute_adaptor.py | 4 ++-- apricot/models/ldap_group_of_names.py | 4 ++-- .../{named_ldap_class.py => ldap_object_class.py} | 4 ++-- apricot/models/ldap_person.py | 4 ++-- apricot/models/ldap_posix_account.py | 4 ++-- apricot/models/ldap_posix_group.py | 4 ++-- apricot/models/overlay_memberof.py | 4 ++-- apricot/models/overlay_oauthentry.py | 4 ++-- apricot/oauth/oauth_data_adaptor.py | 10 +++++----- 10 files changed, 23 insertions(+), 23 deletions(-) rename apricot/models/{named_ldap_class.py => ldap_object_class.py} (70%) diff --git a/apricot/models/__init__.py b/apricot/models/__init__.py index 615f99b..fd84e3c 100644 --- a/apricot/models/__init__.py +++ b/apricot/models/__init__.py @@ -1,9 +1,9 @@ from .ldap_attribute_adaptor import LDAPAttributeAdaptor from .ldap_group_of_names import LDAPGroupOfNames from .ldap_inetorgperson import LDAPInetOrgPerson +from .ldap_object_class import LDAPObjectClass from .ldap_posix_account import LDAPPosixAccount from .ldap_posix_group import LDAPPosixGroup -from .named_ldap_class import NamedLDAPClass from .overlay_memberof import OverlayMemberOf from .overlay_oauthentry import OverlayOAuthEntry @@ -11,9 +11,9 @@ "LDAPAttributeAdaptor", "LDAPGroupOfNames", "LDAPInetOrgPerson", + "LDAPObjectClass", "LDAPPosixAccount", "LDAPPosixGroup", - "NamedLDAPClass", "OverlayMemberOf", "OverlayOAuthEntry", ] diff --git a/apricot/models/ldap_attribute_adaptor.py b/apricot/models/ldap_attribute_adaptor.py index af91d30..a559f9d 100644 --- a/apricot/models/ldap_attribute_adaptor.py +++ b/apricot/models/ldap_attribute_adaptor.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any, Self, Sequence if TYPE_CHECKING: - from apricot.models import NamedLDAPClass + from apricot.models import LDAPObjectClass from apricot.types import JSONDict, LDAPAttributeDict @@ -31,7 +31,7 @@ def from_attributes( cls: type[Self], input_dict: JSONDict, *, - required_classes: Sequence[type[NamedLDAPClass]], + required_classes: Sequence[type[LDAPObjectClass]], ) -> LDAPAttributeAdaptor: """Construct an LDAPAttributeAdaptor from attributes and required classes.""" attributes = {"objectclass": ["top"]} diff --git a/apricot/models/ldap_group_of_names.py b/apricot/models/ldap_group_of_names.py index 85bc8b8..92e56e8 100644 --- a/apricot/models/ldap_group_of_names.py +++ b/apricot/models/ldap_group_of_names.py @@ -2,10 +2,10 @@ from typing import Self -from .named_ldap_class import NamedLDAPClass +from .ldap_object_class import LDAPObjectClass -class LDAPGroupOfNames(NamedLDAPClass): +class LDAPGroupOfNames(LDAPObjectClass): """A group with named members. OID: 2.5.6.9 diff --git a/apricot/models/named_ldap_class.py b/apricot/models/ldap_object_class.py similarity index 70% rename from apricot/models/named_ldap_class.py rename to apricot/models/ldap_object_class.py index 89ff32a..2e08c88 100644 --- a/apricot/models/named_ldap_class.py +++ b/apricot/models/ldap_object_class.py @@ -5,8 +5,8 @@ from pydantic import BaseModel -class NamedLDAPClass(BaseModel): - """An LDAP class that has a name.""" +class LDAPObjectClass(BaseModel): + """An LDAP object-class that may have a name.""" def names(self: Self) -> list[str]: """List of names for this LDAP object class.""" diff --git a/apricot/models/ldap_person.py b/apricot/models/ldap_person.py index b82ebd4..4ae0ebd 100644 --- a/apricot/models/ldap_person.py +++ b/apricot/models/ldap_person.py @@ -2,10 +2,10 @@ from typing import Self -from .named_ldap_class import NamedLDAPClass +from .ldap_object_class import LDAPObjectClass -class LDAPPerson(NamedLDAPClass): +class LDAPPerson(LDAPObjectClass): """A named person. OID: 2.5.6.6 diff --git a/apricot/models/ldap_posix_account.py b/apricot/models/ldap_posix_account.py index 65bb994..0df7411 100644 --- a/apricot/models/ldap_posix_account.py +++ b/apricot/models/ldap_posix_account.py @@ -6,13 +6,13 @@ from pydantic import StringConstraints, validator from typing_extensions import Annotated -from .named_ldap_class import NamedLDAPClass +from .ldap_object_class import LDAPObjectClass ID_MIN = 2000 ID_MAX = 60000 -class LDAPPosixAccount(NamedLDAPClass): +class LDAPPosixAccount(LDAPObjectClass): """Abstraction of an account with POSIX attributes. OID: 1.3.6.1.1.1.2.0 diff --git a/apricot/models/ldap_posix_group.py b/apricot/models/ldap_posix_group.py index 1c83344..e7ae694 100644 --- a/apricot/models/ldap_posix_group.py +++ b/apricot/models/ldap_posix_group.py @@ -4,13 +4,13 @@ from pydantic import validator -from .named_ldap_class import NamedLDAPClass +from .ldap_object_class import LDAPObjectClass ID_MIN = 2000 ID_MAX = 4294967295 -class LDAPPosixGroup(NamedLDAPClass): +class LDAPPosixGroup(LDAPObjectClass): """Abstraction of a group of accounts. OID: 1.3.6.1.1.1.2.2 diff --git a/apricot/models/overlay_memberof.py b/apricot/models/overlay_memberof.py index 4209afe..8731a06 100644 --- a/apricot/models/overlay_memberof.py +++ b/apricot/models/overlay_memberof.py @@ -1,9 +1,9 @@ from __future__ import annotations -from .named_ldap_class import NamedLDAPClass +from .ldap_object_class import LDAPObjectClass -class OverlayMemberOf(NamedLDAPClass): +class OverlayMemberOf(LDAPObjectClass): """Abstraction for tracking the groups that an individual belongs to. OID: n/a diff --git a/apricot/models/overlay_oauthentry.py b/apricot/models/overlay_oauthentry.py index ba8c879..fa24e9d 100644 --- a/apricot/models/overlay_oauthentry.py +++ b/apricot/models/overlay_oauthentry.py @@ -1,9 +1,9 @@ from __future__ import annotations -from .named_ldap_class import NamedLDAPClass +from .ldap_object_class import LDAPObjectClass -class OverlayOAuthEntry(NamedLDAPClass): +class OverlayOAuthEntry(LDAPObjectClass): """Abstraction for tracking an OAuth entry. OID: n/a diff --git a/apricot/oauth/oauth_data_adaptor.py b/apricot/oauth/oauth_data_adaptor.py index 16cc663..078dbfe 100644 --- a/apricot/oauth/oauth_data_adaptor.py +++ b/apricot/oauth/oauth_data_adaptor.py @@ -11,7 +11,7 @@ LDAPInetOrgPerson, LDAPPosixAccount, LDAPPosixGroup, - NamedLDAPClass, + LDAPObjectClass, OverlayMemberOf, OverlayOAuthEntry, ) @@ -72,8 +72,8 @@ def _dn_from_user_cn(self: Self, user_cn: str) -> str: def _retrieve_entries( self: Self, ) -> tuple[ - list[tuple[JSONDict, list[type[NamedLDAPClass]]]], - list[tuple[JSONDict, list[type[NamedLDAPClass]]]], + list[tuple[JSONDict, list[type[LDAPObjectClass]]]], + list[tuple[JSONDict, list[type[LDAPObjectClass]]]], ]: """Obtain lists of users and groups, and construct necessary meta-entries.""" # Get the initial set of users and groups @@ -179,7 +179,7 @@ def _retrieve_entries( def _validate_groups( self: Self, - annotated_groups: list[tuple[JSONDict, list[type[NamedLDAPClass]]]], + annotated_groups: list[tuple[JSONDict, list[type[LDAPObjectClass]]]], ) -> list[LDAPAttributeAdaptor]: """Return a list of LDAPAttributeAdaptors representing validated group data.""" if self.debug: @@ -204,7 +204,7 @@ def _validate_groups( def _validate_users( self: Self, - annotated_users: list[tuple[JSONDict, list[type[NamedLDAPClass]]]], + annotated_users: list[tuple[JSONDict, list[type[LDAPObjectClass]]]], ) -> list[LDAPAttributeAdaptor]: """Return a list of LDAPAttributeAdaptors representing validated user data.""" if self.debug: From b9a181e51e4e5c36af7bd1a329d5b8bd707eaacd Mon Sep 17 00:00:00 2001 From: James Robinson Date: Fri, 31 May 2024 14:59:53 +0100 Subject: [PATCH 14/15] :recycle: Move LDAP object-class name construction into LDAPObjectClass base class --- apricot/models/ldap_group_of_names.py | 7 ++----- apricot/models/ldap_inetorgperson.py | 7 ++----- apricot/models/ldap_object_class.py | 18 +++++++++++++++--- apricot/models/ldap_organizational_person.py | 7 ++----- apricot/models/ldap_person.py | 7 ++----- apricot/models/ldap_posix_account.py | 5 ++--- apricot/models/ldap_posix_group.py | 5 ++--- apricot/oauth/oauth_data_adaptor.py | 2 +- 8 files changed, 28 insertions(+), 30 deletions(-) diff --git a/apricot/models/ldap_group_of_names.py b/apricot/models/ldap_group_of_names.py index 92e56e8..b5e35db 100644 --- a/apricot/models/ldap_group_of_names.py +++ b/apricot/models/ldap_group_of_names.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import Self - from .ldap_object_class import LDAPObjectClass @@ -14,9 +12,8 @@ class LDAPGroupOfNames(LDAPObjectClass): Schema: rfc4519 """ + _ldap_object_class_name: str = "groupOfNames" + cn: str description: str member: list[str] - - def names(self: Self) -> list[str]: - return ["groupOfNames"] diff --git a/apricot/models/ldap_inetorgperson.py b/apricot/models/ldap_inetorgperson.py index 461b218..e66cc5d 100644 --- a/apricot/models/ldap_inetorgperson.py +++ b/apricot/models/ldap_inetorgperson.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import Self - from .ldap_organizational_person import LDAPOrganizationalPerson @@ -14,6 +12,8 @@ class LDAPInetOrgPerson(LDAPOrganizationalPerson): Schema: rfc2798 """ + _ldap_object_class_name: str = "inetOrgPerson" + cn: str displayName: str | None = None # noqa: N815 employeeNumber: str | None = None # noqa: N815 @@ -21,6 +21,3 @@ class LDAPInetOrgPerson(LDAPOrganizationalPerson): sn: str mail: str | None = None telephoneNumber: str | None = None # noqa: N815 - - def names(self: Self) -> list[str]: - return [*super().names(), "inetOrgPerson"] diff --git a/apricot/models/ldap_object_class.py b/apricot/models/ldap_object_class.py index 2e08c88..6b9eeed 100644 --- a/apricot/models/ldap_object_class.py +++ b/apricot/models/ldap_object_class.py @@ -8,6 +8,18 @@ class LDAPObjectClass(BaseModel): """An LDAP object-class that may have a name.""" - def names(self: Self) -> list[str]: - """List of names for this LDAP object class.""" - return [] + @classmethod + def names(cls: type[Self]) -> list[str]: + """List of object-class names for this LDAP object-class. + + We iterate through the parent classes in MRO order, getting an + `_ldap_object_class_name` from each class that has one. We then sort these + before returning a list of names. + """ + return sorted( + [ + cls_._ldap_object_class_name.default + for cls_ in cls.__mro__ + if hasattr(cls_, "_ldap_object_class_name") + ], + ) diff --git a/apricot/models/ldap_organizational_person.py b/apricot/models/ldap_organizational_person.py index 8f1687b..1d05382 100644 --- a/apricot/models/ldap_organizational_person.py +++ b/apricot/models/ldap_organizational_person.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import Self - from .ldap_person import LDAPPerson @@ -14,7 +12,6 @@ class LDAPOrganizationalPerson(LDAPPerson): Schema: rfc4519 """ - description: str + _ldap_object_class_name: str = "organizationalPerson" - def names(self: Self) -> list[str]: - return [*super().names(), "organizationalPerson"] + description: str diff --git a/apricot/models/ldap_person.py b/apricot/models/ldap_person.py index 4ae0ebd..4affe81 100644 --- a/apricot/models/ldap_person.py +++ b/apricot/models/ldap_person.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import Self - from .ldap_object_class import LDAPObjectClass @@ -14,8 +12,7 @@ class LDAPPerson(LDAPObjectClass): Schema: rfc4519 """ + _ldap_object_class_name: str = "person" + cn: str sn: str - - def names(self: Self) -> list[str]: - return ["person"] diff --git a/apricot/models/ldap_posix_account.py b/apricot/models/ldap_posix_account.py index 0df7411..0062468 100644 --- a/apricot/models/ldap_posix_account.py +++ b/apricot/models/ldap_posix_account.py @@ -21,6 +21,8 @@ class LDAPPosixAccount(LDAPObjectClass): Schema: rfc2307bis """ + _ldap_object_class_name: str = "posixAccount" + cn: str gidNumber: int # noqa: N815 homeDirectory: Annotated[ # noqa: N815 @@ -52,6 +54,3 @@ def validate_uid_number(cls: type[Self], uid_number: int) -> int: msg = f"Must be in range {ID_MIN} to {ID_MAX}." raise ValueError(msg) return uid_number - - def names(self: Self) -> list[str]: - return ["posixAccount"] diff --git a/apricot/models/ldap_posix_group.py b/apricot/models/ldap_posix_group.py index e7ae694..7d3db08 100644 --- a/apricot/models/ldap_posix_group.py +++ b/apricot/models/ldap_posix_group.py @@ -19,6 +19,8 @@ class LDAPPosixGroup(LDAPObjectClass): Schema: rfc2307bis """ + _ldap_object_class_name: str = "posixGroup" + description: str gidNumber: int # noqa: N815 memberUid: list[str] # noqa: N815 @@ -31,6 +33,3 @@ def validate_gid_number(cls: type[Self], gid_number: int) -> int: msg = f"Must be in range {ID_MIN} to {ID_MAX}." raise ValueError(msg) return gid_number - - def names(self: Self) -> list[str]: - return ["posixGroup"] diff --git a/apricot/oauth/oauth_data_adaptor.py b/apricot/oauth/oauth_data_adaptor.py index 078dbfe..31aecb5 100644 --- a/apricot/oauth/oauth_data_adaptor.py +++ b/apricot/oauth/oauth_data_adaptor.py @@ -9,9 +9,9 @@ LDAPAttributeAdaptor, LDAPGroupOfNames, LDAPInetOrgPerson, + LDAPObjectClass, LDAPPosixAccount, LDAPPosixGroup, - LDAPObjectClass, OverlayMemberOf, OverlayOAuthEntry, ) From 2d3458945ec261b63b6deac52af0006f54031e17 Mon Sep 17 00:00:00 2001 From: James Robinson Date: Fri, 31 May 2024 15:00:56 +0100 Subject: [PATCH 15/15] :safety_vest: Include PLR6301 --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0668fac..39e69ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -149,7 +149,6 @@ ignore = [ "PLR0912", # too-many-branches "PLR0913", # too-many-arguments "PLR0917", # too-many-positional-arguments - "PLR6301", # method-could-be-function ] [tool.ruff.lint.flake8-annotations]