Skip to content

Commit

Permalink
refactor: improve typing
Browse files Browse the repository at this point in the history
Prefer usage of `object` over Any
Prefer usage of `Mapping` / `MutableMapping` for public data
  • Loading branch information
NiceAesth committed Sep 24, 2023
1 parent b6de876 commit 89fdadf
Show file tree
Hide file tree
Showing 19 changed files with 186 additions and 154 deletions.
30 changes: 17 additions & 13 deletions aiosu/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from typing import Callable
from typing import Optional
from typing import TypeVar
from collections.abc import Mapping
from collections.abc import MutableMapping

T = TypeVar("T")

Expand All @@ -19,13 +21,13 @@
)


def from_list(f: Callable[[Any], T], x: Any) -> list[T]:
def from_list(f: Callable[[Any], T], x: list[object]) -> list[T]:
r"""Applies a function to all elements in a list.
:param f: Function to apply on list elements
:type f: Callable[[Any], T]
:param x: List of objects
:type x: Any
:type x: list[object]
:raises TypeError: If x is not a list
:return: New list
:rtype: list[T]
Expand All @@ -36,18 +38,18 @@ def from_list(f: Callable[[Any], T], x: Any) -> list[T]:


def add_param(
params: dict[str, Any],
kwargs: dict[str, Any],
params: MutableMapping[str, Any],
kwargs: Mapping[str, Any],
key: str,
param_name: Optional[str] = None,
converter: Optional[Callable[[Any], T]] = None,
) -> bool:
r"""Adds a parameter to a dictionary if it exists in kwargs.
:param params: Dictionary to add parameter to
:type params: dict[str, Any]
:type params: Mapping[str, Any]
:param kwargs: Dictionary to get parameter from
:type kwargs: dict[str, Any]
:type kwargs: Mapping[str, Any]
:param key: Key to get parameter from
:type key: str
:param param_name: Name of parameter to add to dictionary, defaults to None
Expand All @@ -57,10 +59,12 @@ def add_param(
:return: True if parameter was added, False otherwise
:rtype: bool
"""
if key in kwargs:
value = kwargs[key]
if converter:
value = converter(value)
params[param_name or key] = value
return True
return False
if key not in kwargs:
return False

value = kwargs[key]
if converter:
value = converter(value)

params[param_name or key] = value
return True
21 changes: 21 additions & 0 deletions aiosu/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
"""
from __future__ import annotations

from typing import SupportsFloat
from typing import SupportsInt

import pydantic
from pydantic import ConfigDict

Expand Down Expand Up @@ -35,3 +38,21 @@ class FrozenModel(BaseModel):
populate_by_name=True,
frozen=True,
)


def cast_int(v: object) -> int:
if v is None:
return 0
if isinstance(v, (SupportsInt, str)):
return int(v)

raise ValueError(f"{v} is not a valid value.")


def cast_float(v: object) -> float:
if v is None:
return 0.0
if isinstance(v, (SupportsFloat, str)):
return float(v)

raise ValueError(f"{v} is not a valid value.")
29 changes: 16 additions & 13 deletions aiosu/models/beatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
"""
from __future__ import annotations

from collections.abc import Mapping
from datetime import datetime
from enum import Enum
from enum import unique
from functools import cached_property
from typing import Any
from typing import Literal
from typing import Optional

Expand All @@ -16,6 +16,7 @@
from pydantic import model_validator

from .base import BaseModel
from .base import cast_int
from .common import CurrentUserAttributes
from .common import CursorModel
from .gamemode import Gamemode
Expand Down Expand Up @@ -133,7 +134,7 @@ def __str__(self) -> str:
return self.name_api

@classmethod
def _missing_(cls, query: object) -> Any:
def _missing_(cls, query: object) -> BeatmapRankStatus:
if isinstance(query, int):
for status in list(BeatmapRankStatus):
if status.id == query:
Expand Down Expand Up @@ -165,7 +166,7 @@ class BeatmapAvailability(BaseModel):
download_disabled: Optional[bool] = None

@classmethod
def _from_api_v1(cls, data: Any) -> BeatmapAvailability:
def _from_api_v1(cls, data: Mapping[str, object]) -> BeatmapAvailability:
return cls.model_validate({"download_disabled": data["download_unavailable"]})


Expand Down Expand Up @@ -208,8 +209,8 @@ def from_beatmapset_id(cls, beatmapset_id: int) -> BeatmapCovers:
)

@classmethod
def _from_api_v1(cls, data: Any) -> BeatmapCovers:
return cls.from_beatmapset_id(data["beatmapset_id"])
def _from_api_v1(cls, data: Mapping[str, object]) -> BeatmapCovers:
return cls.from_beatmapset_id(cast_int(data["beatmapset_id"]))


class BeatmapHype(BaseModel):
Expand Down Expand Up @@ -294,7 +295,7 @@ def count_objects(self) -> Optional[int]:

@model_validator(mode="before")
@classmethod
def _set_url(cls, values: dict[str, Any]) -> dict[str, Any]:
def _set_url(cls, values: dict[str, object]) -> dict[str, object]:
if values.get("url") is None:
id = values["id"]
beatmapset_id = values["beatmapset_id"]
Expand All @@ -305,14 +306,14 @@ def _set_url(cls, values: dict[str, Any]) -> dict[str, Any]:
return values

@classmethod
def _from_api_v1(cls, data: Any) -> Beatmap:
def _from_api_v1(cls, data: Mapping[str, object]) -> Beatmap:
return cls.model_validate(
{
"beatmapset_id": data["beatmapset_id"],
"difficulty_rating": data["difficultyrating"],
"id": data["beatmap_id"],
"mode": int(data["mode"]),
"status": int(data["approved"]),
"mode": cast_int(data["mode"]),
"status": cast_int(data["approved"]),
"total_length": data["total_length"],
"hit_length": data["total_length"],
"user_id": data["creator_id"],
Expand Down Expand Up @@ -388,7 +389,7 @@ def discussion_url(self) -> str:
return f"https://osu.ppy.sh/beatmapsets/{self.id}/discussion"

@classmethod
def _from_api_v1(cls, data: Any) -> Beatmapset:
def _from_api_v1(cls, data: Mapping[str, object]) -> Beatmapset:
return cls.model_validate(
{
"id": data["beatmapset_id"],
Expand All @@ -400,7 +401,7 @@ def _from_api_v1(cls, data: Any) -> Beatmapset:
"play_count": data["playcount"],
"preview_url": f"//b.ppy.sh/preview/{data['beatmapset_id']}.mp3",
"source": data["source"],
"status": int(data["approved"]),
"status": cast_int(data["approved"]),
"title": data["title"],
"title_unicode": data["title"],
"user_id": data["creator_id"],
Expand Down Expand Up @@ -505,8 +506,10 @@ class BeatmapsetDiscussionResponse(CursorModel):

@model_validator(mode="before")
@classmethod
def _set_max_blocks(cls, values: dict[str, Any]) -> dict[str, Any]:
values["max_blocks"] = values["reviews_config"]["max_blocks"]
def _set_max_blocks(cls, values: dict[str, object]) -> dict[str, object]:
if isinstance(values["reviews_config"], Mapping):
values["max_blocks"] = values["reviews_config"]["max_blocks"]

return values


Expand Down
10 changes: 6 additions & 4 deletions aiosu/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from datetime import datetime
from functools import cached_property
from functools import partial
from typing import Any
from typing import Literal
from typing import Optional

Expand Down Expand Up @@ -53,10 +52,13 @@ class TimestampedCount(BaseModel):

@field_validator("start_date", mode="before")
@classmethod
def _date_validate(cls, v: Any) -> Any:
def _date_validate(cls, v: object) -> datetime:
if isinstance(v, str):
return datetime.strptime(v, "%Y-%m-%d")
return v
if isinstance(v, datetime):
return v

raise ValueError(f"{v} is not a valid value.")


class Achievement(BaseModel):
Expand Down Expand Up @@ -131,7 +133,7 @@ class CursorModel(BaseModel):
"""

cursor_string: Optional[str] = None
next: Optional[partial[Coroutine[Any, Any, CursorModel]]] = Field(
next: Optional[partial[Coroutine[object, object, CursorModel]]] = Field(
default=None,
exclude=True,
)
Expand Down
6 changes: 1 addition & 5 deletions aiosu/models/gamemode.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,6 @@

from enum import Enum
from enum import unique
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from typing import Any

__all__ = ("Gamemode",)

Expand Down Expand Up @@ -88,5 +84,5 @@ def from_type(cls, __o: object) -> Gamemode:
raise ValueError(f"Gamemode {__o} does not exist.")

@classmethod
def _missing_(cls, query: object) -> Any:
def _missing_(cls, query: object) -> Gamemode:
return cls.from_type(query)
5 changes: 2 additions & 3 deletions aiosu/models/lazer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from datetime import datetime
from functools import cached_property
from typing import Any
from typing import Optional

from pydantic import computed_field
Expand Down Expand Up @@ -63,7 +62,7 @@ class LazerMod(BaseModel):
"""Temporary model for lazer mods."""

acronym: str
settings: dict[str, Any] = Field(default_factory=dict)
settings: dict[str, object] = Field(default_factory=dict)

def __str__(self) -> str:
return self.acronym
Expand Down Expand Up @@ -216,7 +215,7 @@ def mods_str(self) -> str:

@model_validator(mode="before")
@classmethod
def _fail_rank(cls, values: dict[str, Any]) -> dict[str, Any]:
def _fail_rank(cls, values: dict[str, object]) -> dict[str, object]:
if not values["passed"]:
values["rank"] = "F"
return values
43 changes: 21 additions & 22 deletions aiosu/models/legacy/match.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
from __future__ import annotations

from collections.abc import Mapping
from datetime import datetime
from enum import IntEnum
from enum import unique
from typing import Optional
from typing import TYPE_CHECKING

from pydantic import Field
from pydantic import field_validator
from pydantic import model_validator

from ..base import BaseModel
from ..base import cast_int
from ..gamemode import Gamemode
from ..mods import Mods
from ..score import ScoreStatistics

if TYPE_CHECKING:
from typing import Any

__all__ = (
"MatchTeam",
Expand Down Expand Up @@ -74,21 +73,19 @@ def get_full_mods(self, game: MatchGame) -> Mods:

@model_validator(mode="before")
@classmethod
def _set_statistics(cls, values: dict[str, Any]) -> dict[str, Any]:
def _set_statistics(cls, values: dict[str, object]) -> dict[str, object]:
values["statistics"] = ScoreStatistics._from_api_v1(values)
return values

@field_validator("enabled_mods", mode="before")
@classmethod
def _set_enabled_mods(cls, v: Any) -> int:
if v is not None:
return int(v)
return 0
def _set_enabled_mods(cls, v: object) -> int:
return cast_int(v)

@field_validator("team", mode="before")
@classmethod
def _set_team(cls, v: Any) -> int:
return int(v)
def _set_team(cls, v: object) -> int:
return cast_int(v)


class MatchGame(BaseModel):
Expand All @@ -108,25 +105,23 @@ class MatchGame(BaseModel):

@field_validator("mode", mode="before")
@classmethod
def _set_mode(cls, v: Any) -> int:
return int(v)
def _set_mode(cls, v: object) -> int:
return cast_int(v)

@field_validator("mods", mode="before")
@classmethod
def _set_mods(cls, v: Any) -> int:
if v is not None:
return int(v)
return 0
def _set_mods(cls, v: object) -> int:
return cast_int(v)

@field_validator("scoring_type", mode="before")
@classmethod
def _set_scoring_type(cls, v: Any) -> int:
return int(v)
def _set_scoring_type(cls, v: object) -> int:
return cast_int(v)

@field_validator("team_type", mode="before")
@classmethod
def _set_team_type(cls, v: Any) -> int:
return int(v)
def _set_team_type(cls, v: object) -> int:
return cast_int(v)


class Match(BaseModel):
Expand All @@ -141,5 +136,9 @@ class Match(BaseModel):

@model_validator(mode="before")
@classmethod
def _format_values(cls, values: dict[str, Any]) -> dict[str, Any]:
return {**values["match"], "games": values["games"]}
def _format_values(cls, values: dict[str, object]) -> dict[str, object]:
match = values["match"]
if not isinstance(match, Mapping):
raise ValueError(f"Invalid match type: {type(match)}")

return {**match, "games": values["games"]}
Loading

0 comments on commit 89fdadf

Please sign in to comment.