Skip to content
This repository has been archived by the owner on May 13, 2024. It is now read-only.

Type annotations #1071

Merged
merged 4 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 27 additions & 8 deletions timed/authentication.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from __future__ import annotations

import base64
import functools
import hashlib
from typing import TYPE_CHECKING

import requests
from django.conf import settings
Expand All @@ -10,9 +13,18 @@
from mozilla_django_oidc.auth import LOGGER, OIDCAuthenticationBackend
from rest_framework.exceptions import AuthenticationFailed

if TYPE_CHECKING:
from typing import Callable, Self

from django.db.models import QuerySet

from timed.employment.models import User


class TimedOIDCAuthenticationBackend(OIDCAuthenticationBackend):
def get_introspection(self, access_token, _id_token, _payload):
def get_introspection(
self, access_token: str, _id_token: str, _payload: dict
hairmare marked this conversation as resolved.
Show resolved Hide resolved
) -> dict:
"""Return user details dictionary."""
basic = base64.b64encode(
f"{settings.OIDC_RP_INTROSPECT_CLIENT_ID}:{settings.OIDC_RP_INTROSPECT_CLIENT_SECRET}".encode()
Expand All @@ -31,7 +43,7 @@ def get_introspection(self, access_token, _id_token, _payload):
response.raise_for_status()
return response.json()

def get_userinfo_or_introspection(self, access_token):
def get_userinfo_or_introspection(self, access_token: str) -> dict:
try:
return self.cached_request(self.get_userinfo, access_token, "auth.userinfo")
except requests.HTTPError as exc:
Expand All @@ -57,7 +69,9 @@ def get_userinfo_or_introspection(self, access_token):
return claims
raise AuthenticationFailed from exc

def get_or_create_user(self, access_token, _id_token, _payload):
def get_or_create_user(
self, access_token: str, _id_token: str, _payload: dict
) -> User | None:
"""Verify claims and return user, otherwise raise an Exception."""
claims = self.get_userinfo_or_introspection(access_token)

Expand All @@ -76,17 +90,22 @@ def get_or_create_user(self, access_token, _id_token, _payload):
)
return None

def update_user_from_claims(self, user, claims):
def update_user_from_claims(self, user: User, claims: dict[str, str]) -> None:
user.email = claims.get(settings.OIDC_EMAIL_CLAIM, "")
user.first_name = claims.get(settings.OIDC_FIRSTNAME_CLAIM, "")
user.last_name = claims.get(settings.OIDC_LASTNAME_CLAIM, "")
user.save()

def filter_users_by_claims(self, claims):
def filter_users_by_claims(self, claims: dict[str, str]) -> QuerySet[User]:
username = self.get_username(claims)
return self.UserModel.objects.filter(username__iexact=username)

def cached_request(self, method, token, cache_prefix):
def cached_request(
self,
method: Callable[[Self, str, None, None], dict],
token: str,
cache_prefix: str,
) -> dict:
token_hash = hashlib.sha256(force_bytes(token)).hexdigest()

func = functools.partial(method, token, None, None)
Expand All @@ -97,7 +116,7 @@ def cached_request(self, method, token, cache_prefix):
timeout=settings.OIDC_BEARER_TOKEN_REVALIDATION_TIME,
)

def create_user(self, claims):
def create_user(self, claims: dict[str, str]) -> User:
"""Return object for a newly created user account."""
username = self.get_username(claims)
email = claims.get(settings.OIDC_EMAIL_CLAIM, "")
Expand All @@ -108,7 +127,7 @@ def create_user(self, claims):
username=username, email=email, first_name=first_name, last_name=last_name
)

def get_username(self, claims):
def get_username(self, claims: dict[str, str]) -> str:
try:
return claims[settings.OIDC_USERNAME_CLAIM]
except KeyError as exc:
Expand Down
29 changes: 24 additions & 5 deletions timed/employment/filters.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from django.db.models import Q
from django_filters.constants import EMPTY_VALUES
from django_filters.rest_framework import DateFilter, Filter, FilterSet, NumberFilter

from timed.employment import models
from timed.employment.models import User

if TYPE_CHECKING:
from typing import TypeVar

from django.db.models import QuerySet

T = TypeVar("T", QuerySet)


class YearFilter(Filter):
"""Filter to filter a queryset by year."""

def filter(self, qs, value):
def filter(self, qs: T, value: int) -> T:
if value in EMPTY_VALUES:
return qs

Expand Down Expand Up @@ -54,15 +65,21 @@ class UserFilterSet(FilterSet):
is_accountant = NumberFilter(field_name="is_accountant")
is_external = NumberFilter(method="filter_is_external")

def filter_is_external(self, queryset, _name, value):
def filter_is_external(
self, queryset: QuerySet[models.User], _name: str, value: int
) -> QuerySet[models.User]:
return queryset.filter(employments__is_external=value)

def filter_is_reviewer(self, queryset, _name, value):
def filter_is_reviewer(
self, queryset: QuerySet[models.User], _name: str, value: int
) -> QuerySet[models.User]:
if value:
return queryset.filter(pk__in=User.objects.all_reviewers())
return queryset.exclude(pk__in=User.objects.all_reviewers())

def filter_is_supervisor(self, queryset, _name, value):
def filter_is_supervisor(
self, queryset: QuerySet[models.User], _name: str, value: int
) -> QuerySet[models.User]:
if value:
return queryset.filter(pk__in=User.objects.all_supervisors())
return queryset.exclude(pk__in=User.objects.all_supervisors())
Expand All @@ -81,7 +98,9 @@ class Meta:
class EmploymentFilterSet(FilterSet):
date = DateFilter(method="filter_date")

def filter_date(self, queryset, _name, value):
def filter_date(
self, queryset: QuerySet[models.Employment], _name: str, value: int
) -> QuerySet[models.Employment]:
return queryset.filter(
Q(start_date__lte=value)
& Q(Q(end_date__gte=value) | Q(end_date__isnull=True))
Expand Down
48 changes: 25 additions & 23 deletions timed/employment/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""Models for the employment app."""

from __future__ import annotations

from datetime import date, timedelta
from typing import TYPE_CHECKING

from dateutil import rrule
from django.conf import settings
Expand All @@ -12,7 +15,10 @@

from timed.models import WeekdaysField
from timed.projects.models import CustomerAssignee, ProjectAssignee, TaskAssignee
from timed.tracking.models import Absence
from timed.tracking.models import Absence, Report

if TYPE_CHECKING:
from django.db.models import QuerySet


class Location(models.Model):
Expand Down Expand Up @@ -76,7 +82,7 @@ def __str__(self) -> str:
"""Represent the model as a string."""
return self.name

def calculate_credit(self, user, start, end):
def calculate_credit(self, user: User, start: date, end: date) -> int | None:
"""Calculate approved days of type for user in given time frame.

For absence types which fill worktime this will be None.
Expand All @@ -90,7 +96,7 @@ def calculate_credit(self, user, start, end):
data = credits.aggregate(credit=Sum("days"))
return data["credit"] or 0

def calculate_used_days(self, user, start, end):
def calculate_used_days(self, user: User, start: date, end: date) -> int | None:
"""Calculate used days of type for user in given time frame.

For absence types which fill worktime this will be None.
Expand Down Expand Up @@ -157,7 +163,7 @@ def __str__(self) -> str:
class EmploymentManager(models.Manager):
"""Custom manager for employments."""

def get_at(self, user, date):
def get_at(self, user: User, date: date) -> Employment:
"""Get employment of user at given date.

:param User user: The user of the searched employments
Expand All @@ -170,7 +176,7 @@ def get_at(self, user, date):
user=user,
)

def for_user(self, user, start, end):
def for_user(self, user: User, start: date, end: date) -> QuerySet[Employment]:
"""Get employments in given time frame for current user.

This includes overlapping employments.
Expand Down Expand Up @@ -228,7 +234,9 @@ def __str__(self) -> str:
self.end_date.strftime("%d.%m.%Y") if self.end_date else "today",
)

def calculate_worktime(self, start, end):
def calculate_worktime(
self, start: date, end: date
) -> tuple[timedelta, timedelta, timedelta]:
"""Calculate reported, expected and balance for employment.

1. It shortens the time frame so it is within given employment
Expand All @@ -245,13 +253,8 @@ def calculate_worktime(self, start, end):
7. The balance is the reported time plus the absences plus the
overtime credit minus the expected worktime

hairmare marked this conversation as resolved.
Show resolved Hide resolved
:param start: calculate worktime starting on given day.
:param end: calculate worktime till given day
:returns: tuple of 3 values reported, expected and delta in given
time frame
Return a tuple with 3 timedeltas of reported, expected and the delta.
"""
from timed.tracking.models import Absence, Report

# shorten time frame to employment
start = max(start, self.start_date)
end = min(self.end_date or date.today(), end)
Expand Down Expand Up @@ -303,13 +306,13 @@ def calculate_worktime(self, start, end):


class UserManager(UserManager):
def all_supervisors(self):
def all_supervisors(self) -> QuerySet[User]:
objects = self.model.objects.annotate(
supervisees_count=models.Count("supervisees")
)
return objects.filter(supervisees_count__gt=0)

def all_reviewers(self):
def all_reviewers(self) -> QuerySet[User]:
return self.all().filter(
models.Q(
pk__in=TaskAssignee.objects.filter(is_reviewer=True).values("user")
Expand All @@ -322,7 +325,7 @@ def all_reviewers(self):
)
)

def all_supervisees(self):
def all_supervisees(self) -> QuerySet[User]:
objects = self.model.objects.annotate(
supervisors_count=models.Count("supervisors")
)
Expand Down Expand Up @@ -352,28 +355,27 @@ class User(AbstractUser):
objects = UserManager()

@property
def is_reviewer(self):
def is_reviewer(self) -> bool:
return (
TaskAssignee.objects.filter(user=self, is_reviewer=True).exists()
or ProjectAssignee.objects.filter(user=self, is_reviewer=True).exists()
or CustomerAssignee.objects.filter(user=self, is_reviewer=True).exists()
)

@property
def user_id(self):
def user_id(self) -> int:
"""Map to id to be able to use generic permissions."""
return self.id

def calculate_worktime(self, start, end):
def calculate_worktime(
self, start: date, end: date
) -> tuple[timedelta, timedelta, timedelta]:
"""Calculate reported, expected and balance for user.

This calculates summarizes worktime for all employments of users which
are in given time frame.

:param start: calculate worktime starting on given day.
:param end: calculate worktime till given day
:returns: tuple of 3 values reported, expected and delta in given
time frame
Return a tuple with 3 timedeltas of reported, expected and balance.
"""
employments = Employment.objects.for_user(self, start, end).select_related(
"location"
Expand All @@ -389,7 +391,7 @@ def calculate_worktime(self, start, end):

return (reported, expected, balance)

def get_active_employment(self):
def get_active_employment(self) -> Employment | None:
"""Get current employment of the user.

Get current active employment of the user.
Expand Down
Loading