Skip to content

Commit

Permalink
Merge pull request #23 from alan-turing-institute/22-make-redis-optional
Browse files Browse the repository at this point in the history
Make Redis optional
  • Loading branch information
jemrobinson authored Feb 27, 2024
2 parents c3b4066 + 62c0e58 commit 626674e
Show file tree
Hide file tree
Showing 9 changed files with 126 additions and 48 deletions.
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ The name is a slightly tortured acronym for: LD**A**P **pr**oxy for Open**I**D *

## Usage

**N.B.** As Apricot uses a Redis server to store generated `uidNumber` and `gidNumber` values.

Start the `Apricot` server on port 1389 by running:

```bash
Expand All @@ -21,6 +19,11 @@ docker compose up

from the `docker` directory.

### Using Redis [Optional]

You can use a Redis server to store generated `uidNumber` and `gidNumber` values in a more persistent way.
To do this, you will need to provide the `--redis-host` and `--redis-port` arguments to `run.py`.

## Outputs

This will create an LDAP tree that looks like this:
Expand Down
19 changes: 15 additions & 4 deletions apricot/apricot_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from twisted.internet.interfaces import IReactorCore, IStreamServerEndpoint
from twisted.python import log

from apricot.cache import LocalCache, RedisCache, UidCache
from apricot.ldap import OAuthLDAPServerFactory
from apricot.oauth import OAuthBackend, OAuthClientMap

Expand All @@ -18,21 +19,31 @@ def __init__(
client_secret: str,
domain: str,
port: int,
redis_host: str,
redis_port: int,
redis_host: str | None = None,
redis_port: int | None = None,
**kwargs: Any,
) -> None:
# Log to stdout
log.startLogging(sys.stdout)

# Initialise the UID cache
uid_cache: UidCache
if redis_host and redis_port:
log.msg(
f"Using a Redis user-id cache at host '{redis_host}' on port '{redis_port}'."
)
uid_cache = RedisCache(redis_host=redis_host, redis_port=redis_port)
else:
log.msg("Using a local user-id cache.")
uid_cache = LocalCache()

# Initialize the appropriate OAuth client
try:
oauth_client = OAuthClientMap[backend](
client_id=client_id,
client_secret=client_secret,
domain=domain,
redis_host=redis_host,
redis_port=redis_port,
uid_cache=uid_cache,
**kwargs,
)
except Exception as exc:
Expand Down
4 changes: 4 additions & 0 deletions apricot/cache/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from .local_cache import LocalCache
from .redis_cache import RedisCache
from .uid_cache import UidCache

__all__ = [
"LocalCache",
"RedisCache",
"UidCache",
]
18 changes: 18 additions & 0 deletions apricot/cache/local_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from .uid_cache import UidCache


class LocalCache(UidCache):
def __init__(self) -> None:
self.cache: dict[str, int] = {}

def get(self, identifier: str) -> int | None:
return self.cache.get(identifier, None)

def keys(self) -> list[str]:
return [str(k) for k in self.cache.keys()]

def set(self, identifier: str, uid_value: int) -> None:
self.cache[identifier] = uid_value

def values(self, keys: list[str]) -> list[int]:
return [v for k, v in self.cache.items() if k in keys]
36 changes: 36 additions & 0 deletions apricot/cache/redis_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import cast

import redis

from .uid_cache import UidCache


class RedisCache(UidCache):
def __init__(self, redis_host: str, redis_port: int) -> None:
self.redis_host = redis_host
self.redis_port = redis_port
self.cache_: "redis.Redis[str]" | None = None

@property
def cache(self) -> "redis.Redis[str]":
"""
Lazy-load the cache on request
"""
if not self.cache_:
self.cache_ = redis.Redis(
host=self.redis_host, port=self.redis_port, decode_responses=True
)
return self.cache_

def get(self, identifier: str) -> int | None:
value = self.cache.get(identifier)
return None if value is None else int(value)

def keys(self) -> list[str]:
return [str(k) for k in self.cache.keys()]

def set(self, identifier: str, uid_value: int) -> None:
self.cache.set(identifier, uid_value)

def values(self, keys: list[str]) -> list[int]:
return [int(cast(str, v)) for v in self.cache.mget(keys)]
52 changes: 28 additions & 24 deletions apricot/cache/uid_cache.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,35 @@
from abc import ABC, abstractmethod
from typing import cast

import redis

class UidCache(ABC):
@abstractmethod
def get(self, identifier: str) -> int | None:
"""
Get the UID for a given identifier, returning None if it does not exist
"""
pass

class UidCache:
def __init__(self, redis_host: str, redis_port: str) -> None:
self.redis_host = redis_host
self.redis_port = redis_port
self.cache_ = None
@abstractmethod
def keys(self) -> list[str]:
"""
Get list of cached keys
"""
pass

@property
def cache(self) -> redis.Redis: # type: ignore[type-arg]
@abstractmethod
def set(self, identifier: str, uid_value: int) -> None:
"""
Lazy-load the cache on request
Set the UID for a given identifier
"""
if not self.cache_:
self.cache_ = redis.Redis( # type: ignore[call-overload]
host=self.redis_host, port=self.redis_port, decode_responses=True
)
return self.cache_ # type: ignore[return-value]
pass

@property
def keys(self) -> list[str]:
@abstractmethod
def values(self, keys: list[str]) -> list[int]:
"""
Get list of keys from the cache
Get list of cached values corresponding to requested keys
"""
return [str(k) for k in self.cache.keys()]
pass

def get_group_uid(self, identifier: str) -> int:
"""
Expand Down Expand Up @@ -54,12 +58,12 @@ def get_uid(
@param min_value: Minimum allowed value for the UID
"""
identifier_ = f"{category}-{identifier}"
uid = self.cache.get(identifier_)
uid = self.get(identifier_)
if not uid:
min_value = min_value if min_value else 0
next_uid = max(self._get_max_uid(category) + 1, min_value)
self.cache.set(identifier_, next_uid)
return cast(int, self.cache.get(identifier_))
self.set(identifier_, next_uid)
return cast(int, self.get(identifier_))

def _get_max_uid(self, category: str | None) -> int:
"""
Expand All @@ -68,8 +72,8 @@ def _get_max_uid(self, category: str | None) -> int:
@param category: Category to check UIDs for
"""
if category:
keys = [k for k in self.keys if k.startswith(category)]
keys = [k for k in self.keys() if k.startswith(category)]
else:
keys = self.keys
values = [int(cast(str, v)) for v in self.cache.mget(keys)] + [-999]
keys = self.keys()
values = [*self.values(keys), -999]
return max(values)
8 changes: 5 additions & 3 deletions apricot/oauth/oauth_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,16 @@ def __init__(
client_secret: str,
domain: str,
redirect_uri: str,
redis_host: str,
redis_port: str,
scopes: list[str],
token_url: str,
uid_cache: UidCache,
) -> None:
# Set attributes
self.bearer_token_: str | None = None
self.client_secret = client_secret
self.domain = domain
self.token_url = token_url
self.uid_cache = UidCache(redis_host=redis_host, redis_port=redis_port)
self.uid_cache = uid_cache
# Allow token scope to not match requested scope. (Other auth libraries allow
# this, but Requests-OAuthlib raises exception on scope mismatch by default.)
os.environ["OAUTHLIB_RELAX_TOKEN_SCOPE"] = "1" # noqa: S105
Expand Down Expand Up @@ -92,6 +91,9 @@ def bearer_token(self) -> str:

@abstractmethod
def extract_token(self, json_response: JSONDict) -> str:
"""
Extract the bearer token from an OAuth2Session JSON response
"""
pass

@abstractmethod
Expand Down
20 changes: 9 additions & 11 deletions docker/entrypoint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,35 +23,33 @@ if [ -z "${DOMAIN}" ]; then
exit 1
fi

if [ -z "${REDIS_HOST}" ]; then
echo "$(date +'%Y-%m-%d %H:%M:%S+0000') [-] REDIS_HOST environment variable is not set"
exit 1
fi

# Arguments with defaults
if [ -z "${PORT}" ]; then
PORT="1389"
echo "$(date +'%Y-%m-%d %H:%M:%S+0000') [-] PORT environment variable is not set: using default of '${PORT}'"
fi

if [ -z "${REDIS_PORT}" ]; then
REDIS_PORT="6379"
echo "$(date +'%Y-%m-%d %H:%M:%S+0000') [-] REDIS_PORT environment variable is not set: using default of '${REDIS_PORT}'"
fi


# Optional arguments
EXTRA_OPTS=""
if [ -n "${ENTRA_TENANT_ID}" ]; then
EXTRA_OPTS="${EXTRA_OPTS} --entra-tenant-id $ENTRA_TENANT_ID"
fi

if [ -n "${REDIS_HOST}" ]; then
if [ -z "${REDIS_PORT}" ]; then
REDIS_PORT="6379"
echo "$(date +'%Y-%m-%d %H:%M:%S+0000') [-] REDIS_PORT environment variable is not set: using default of '${REDIS_PORT}'"
fi
EXTRA_OPTS="${EXTRA_OPTS} --redis-host $REDIS_HOST --redis-port $REDIS_PORT"
fi

# Run the server
hatch run python run.py \
--backend "${BACKEND}" \
--client-id "${CLIENT_ID}" \
--client-secret "${CLIENT_SECRET}" \
--domain "${DOMAIN}" \
--port "${PORT}" \
--redis-host "${REDIS_HOST}" \
--redis-port "${REDIS_PORT}" \
$EXTRA_OPTS
10 changes: 6 additions & 4 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
parser.add_argument("-s", "--client-secret", type=str, help="OAuth client secret.")
parser.add_argument("-d", "--domain", type=str, help="Which domain users belong to.")
parser.add_argument("-p", "--port", type=int, default=1389, help="Port to run on.")
parser.add_argument("--redis-host", type=str, help="Host for Redis server.")
parser.add_argument("--redis-port", type=int, help="Port for Redis server.")
# Options for Microsoft Entra backend
group = parser.add_argument_group("Microsoft Entra")
group.add_argument("-t", "--entra-tenant-id", type=str, help="Microsoft Entra tenant ID.", required=False)
entra_group = parser.add_argument_group("Microsoft Entra")
entra_group.add_argument("-t", "--entra-tenant-id", type=str, help="Microsoft Entra tenant ID.", required=False)
# Options for Redis cache
redis_group = parser.add_argument_group("Redis")
redis_group.add_argument("--redis-host", type=str, help="Host for Redis server.")
redis_group.add_argument("--redis-port", type=int, help="Port for Redis server.")
# Parse arguments
args = parser.parse_args()

Expand Down

0 comments on commit 626674e

Please sign in to comment.