diff --git a/mcproto/packets/packet_map.py b/mcproto/packets/packet_map.py index 9ba600df..9dde1310 100644 --- a/mcproto/packets/packet_map.py +++ b/mcproto/packets/packet_map.py @@ -8,6 +8,7 @@ from typing import Literal, NamedTuple, NoReturn, overload from mcproto.packets.packet import ClientBoundPacket, GameState, Packet, PacketDirection, ServerBoundPacket +from mcproto.utils.decorators import copied_return __all__ = ["generate_packet_map"] @@ -94,8 +95,9 @@ def generate_packet_map( ... +@copied_return @lru_cache() -def generate_packet_map(direction: PacketDirection, state: GameState) -> Mapping[int, type[Packet]]: +def generate_packet_map(direction: PacketDirection, state: GameState) -> dict[int, type[Packet]]: """Dynamically generated a packet map for given ``direction`` and ``state``. This generation is done by dynamically importing all of the modules containing these packets, @@ -105,10 +107,6 @@ def generate_packet_map(direction: PacketDirection, state: GameState) -> Mapping As this fucntion is likely to be called quite often, and it uses dynamic importing to obtain the packet classes, this function is cached, which means the logic only actually runs once, after which, for the same arguments, the same dict will be returned. - - ..warning: As this function is cached, make sure to avoid modifying the returned dictionary, - as that will directly modify the stored version in the cache, leading to future calls - with the same arguments returning a wrong (modified) version of this dictionary. """ module = importlib.import_module(MODULE_PATHS[state]) diff --git a/mcproto/utils/decorators.py b/mcproto/utils/decorators.py new file mode 100644 index 00000000..e7c04909 --- /dev/null +++ b/mcproto/utils/decorators.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +from collections.abc import Callable +from functools import wraps +from typing import Protocol, TypeVar + +from typing_extensions import ParamSpec + +__all__ = ["copied_return"] + +T = TypeVar("T") + + +class SupportsCopy(Protocol): + def copy(self: T) -> T: + ... + + +P = ParamSpec("P") +R_Copyable = TypeVar("R_Copyable", bound=SupportsCopy) + + +def copied_return(func: Callable[P, R_Copyable]) -> Callable[P, R_Copyable]: + @wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R_Copyable: + ret = func(*args, **kwargs) + return ret.copy() + + return wrapper