-
Notifications
You must be signed in to change notification settings - Fork 3
/
helpers.py
377 lines (300 loc) · 16.6 KB
/
helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
from __future__ import annotations
import asyncio
import inspect
import re
import unittest.mock
from collections.abc import Callable, Coroutine
from typing import Any, Generic, NamedTuple, TypeVar
import pytest
from typing_extensions import ParamSpec, TypeGuard, override
from mcproto.buffer import Buffer
from mcproto.utils.abc import Serializable
T = TypeVar("T")
P = ParamSpec("P")
T_Mock = TypeVar("T_Mock", bound=unittest.mock.Mock)
__all__ = [
"synchronize",
"SynchronizedMixin",
"UnpropagatingMockMixin",
"CustomMockMixin",
"gen_serializable_test",
"TestExc",
]
def synchronize(f: Callable[P, Coroutine[Any, Any, T]]) -> Callable[P, T]:
"""Take an asynchronous function, and return a synchronous alternative.
This is needed because we sometimes want to test asynchronous behavior in a synchronous test function,
where we can't simply await something. This function uses `asyncio.run` and generates a wrapper
around the original asynchronous function, that awaits the result in a blocking synchronous way,
returning the obtained value.
"""
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
return asyncio.run(f(*args, **kwargs))
return wrapper
class SynchronizedMixin:
"""Class acting as another wrapped object, with all async methods synchronized.
This class needs :attr:`._WRAPPED_ATTRIBUTE` class variable to be set as the name of the internally
held attribute, holding the object we'll be wrapping around.
Child classes of this mixin will have their lookup logic changed, to instead perform a lookup
on the wrapped attribute. Only if that lookup fails, we fallback to this class, meaning if both
the wrapped attribute and this class have some attribute defined, the attribute from the wrapped
object is returned. The only exceptions to this are lookup of the ``_WRAPPED_ATTRIBUTE`` variable,
and of the attribute name stored under the ``_WRAPPED_ATTRIBUTE`` (the wrapped object).
If the attribute held by the wrapped object is an asynchronous function, instead of returning it
directly, the :func:`.synchronize` function will be called, returning a wrapped synchronous
alternative for the requested async function.
This is useful when we need to quickly create a synchronous alternative to a class holding async methods.
However it isn't useful in production, since will cause typing issues (attributes will be accessible, but
type checkers won't know that they exist here, because of the dynamic nature of this implementation).
"""
_WRAPPED_ATTRIBUTE: str
@override
def __getattribute__(self, __name: str) -> Any:
"""Return attributes of the wrapped object, if the attribute is a coroutine function, synchronize it.
The only exception to this behavior is getting the :attr:`._WRAPPED_ATTRIBUTE` variable itself, or the
attribute named as the content of the ``_WRAPPED_ATTRIBUTE`` variable. All other attribute access will
be delegated to the wrapped attribute. If the wrapped object doesn't have given attribute, the lookup
will fallback to regular lookup for variables belonging to this class.
"""
if __name == "_WRAPPED_ATTRIBUTE" or __name == self._WRAPPED_ATTRIBUTE: # noqa: PLR1714 # Order is important
return super().__getattribute__(__name)
wrapped = getattr(self, self._WRAPPED_ATTRIBUTE)
if hasattr(wrapped, __name):
obj = getattr(wrapped, __name)
if inspect.iscoroutinefunction(obj):
return synchronize(obj)
return obj
return super().__getattribute__(__name)
@override
def __setattr__(self, __name: str, __value: object) -> None:
"""Allow for changing attributes of the wrapped object.
* If wrapped object isn't yet set, fall back to :meth:`~object.__setattr__` of this class.
* If wrapped object doesn't already contain the attribute we want to set, also fallback to this class.
* Otherwise, run ``__setattr__`` on it to update it.
"""
try:
wrapped = getattr(self, self._WRAPPED_ATTRIBUTE)
except AttributeError:
return super().__setattr__(__name, __value)
else:
if hasattr(wrapped, __name):
return setattr(wrapped, __name, __value)
return super().__setattr__(__name, __value)
class UnpropagatingMockMixin(Generic[T_Mock]):
"""Provides common functionality for our :class:`~unittest.mock.Mock` classes.
By default, mock objects propagate themselves by returning a new instance of the same mock
class, with same initialization attributes. This is done whenever we're accessing new
attributes that mock class.
This propagation makes sense for simple mocks without any additional restrictions, however when
dealing with limited mocks to some ``spec_set``, it doesn't usually make sense to propagate
those same ``spec_set`` restrictions, since we generally don't have attributes/methods of a
class be of/return the same class.
This mixin class stops this propagation, and instead returns instances of specified mock class,
defined in :attr:`.child_mock_type` class variable, which is by default set to
:class:`~unittest.mock.MagicMock`, as it can safely represent most objects.
.. note:
This propagation handling will only be done for the mock classes that inherited from this
mixin class. That means if the :attr:`.child_mock_type` is one of the regular mock classes,
and the mock is propagated, a regular mock class is returned as that new attribute. This
regular class then won't have the same overrides, and will therefore propagate itself, like
any other mock class would.
If you wish to counteract this, you can set the :attr:`.child_mock_type` to a mock class
that also inherits from this mixin class, perhaps to your class itself, overriding any
propagation recursively.
"""
child_mock_type: T_Mock = unittest.mock.MagicMock
# Since this is a mixin class, we can access some attributes defined in mock classes safely.
# Define the types of these variables here, for proper static type analysis.
_mock_sealed: bool
_extract_mock_name: Callable[[], str]
def _get_child_mock(self, **kwargs) -> T_Mock:
"""Make :attr:`.child_mock_type`` instances instead of instances of the same class.
By default, this method creates a new mock instance of the same original class, and passes
over the same initialization arguments. This overrides that behavior to instead create an
instance of :attr:`.child_mock_type` class.
"""
# Mocks can be sealed, in which case we wouldn't want to allow propagation of any kind
# and rather raise an AttributeError, informing that given attr isn't accessible
if self._mock_sealed:
mock_name = self._extract_mock_name()
obj_name = f"{mock_name}.{kwargs['name']}" if "name" in kwargs else f"{mock_name}()"
raise AttributeError(f"Can't access {obj_name}, mock is sealed.")
# Propagate any other children as simple `unittest.mock.Mock` instances
# rather than `self.__class__` instances
return self.child_mock_type(**kwargs)
class CustomMockMixin(UnpropagatingMockMixin):
"""Provides common functionality for our custom mock types.
* Stops propagation of same ``spec_set`` restricted mock in child mocks
(see :class:`.UnpropagatingMockMixin` for more info)
* Allows using the ``spec_set`` attribute as class attribute
"""
spec_set = None
def __init__(self, **kwargs):
if "spec_set" in kwargs:
self.spec_set = kwargs.pop("spec_set")
super().__init__(spec_set=self.spec_set, **kwargs) # type: ignore # Mixin class, this __init__ is valid
def isexception(obj: object) -> TypeGuard[type[Exception] | TestExc]:
"""Check if the object is an exception."""
return (isinstance(obj, type) and issubclass(obj, Exception)) or isinstance(obj, TestExc)
class TestExc(NamedTuple):
"""Named tuple to check if an exception is raised with a specific message.
:param exception: The exception type.
:param match: If specified, a string containing a regular expression, or a regular expression object, that is
tested against the string representation of the exception using :func:`re.search`.
:param kwargs: The keyword arguments passed to the exception.
If :attr:`kwargs` is not None, the exception instance will need to have the same attributes with the same values.
"""
exception: type[Exception] | tuple[type[Exception], ...]
match: str | re.Pattern[str] | None = None
kwargs: dict[str, Any] | None = None
@classmethod
def from_exception(cls, exception: type[Exception] | tuple[type[Exception], ...] | TestExc) -> TestExc:
"""Create a :class:`TestExc` from an exception, does nothing if the object is already a :class:`TestExc`."""
if isinstance(exception, TestExc):
return exception
return cls(exception)
def gen_serializable_test(
context: dict[str, Any],
cls: type[Serializable],
fields: list[tuple[str, type | str]],
serialize_deserialize: list[tuple[tuple[Any, ...], bytes]] | None = None,
validation_fail: list[tuple[tuple[Any, ...], type[Exception] | TestExc]] | None = None,
deserialization_fail: list[tuple[bytes, type[Exception] | TestExc]] | None = None,
):
"""Generate tests for a serializable class.
This function generates tests for the serialization, deserialization, validation, and deserialization error
handling
:param context: The context to add the test functions to. This is usually `globals()`.
:param cls: The serializable class to test.
:param fields: A list of tuples containing the field names and types of the serializable class.
:param serialize_deserialize: A list of tuples containing:
- The tuple representing the arguments to pass to the :class:`mcproto.utils.abc.Serializable` class
- The expected bytes
:param validation_fail: A list of tuples containing the arguments to pass to the
:class:`mcproto.utils.abc.Serializable` class and the expected exception, either as is or wrapped in a
:class:`TestExc` object.
:param deserialization_fail: A list of tuples containing the bytes to pass to the :meth:`deserialize` method of the
class and the expected exception, either as is or wrapped in a :class:`TestExc` object.
Example usage:
.. literalinclude:: /../tests/mcproto/utils/test_serializable.py
:start-after: # region ToyClass
:linenos:
:language: python
This will add 1 class test with 4 test functions containing the tests for serialization, deserialization,
validation, and deserialization error handling
.. note::
The test cases will use :meth:`__eq__` to compare the objects, so make sure to implement it in the class if
you are not using the autogenerated method from :func:`attrs.define`.
"""
# This holds the parameters for the serialization and deserialization tests
parameters: list[tuple[dict[str, Any], bytes]] = []
# This holds the parameters for the validation tests
validation_fail_kw: list[tuple[dict[str, Any], TestExc]] = []
for data, exp_bytes in [] if serialize_deserialize is None else serialize_deserialize:
kwargs = dict(zip([f[0] for f in fields], data))
parameters.append((kwargs, exp_bytes))
for data, exc in [] if validation_fail is None else validation_fail:
kwargs = dict(zip([f[0] for f in fields], data))
exc_wrapped = TestExc.from_exception(exc)
validation_fail_kw.append((kwargs, exc_wrapped))
# Just make sure that the exceptions are wrapped in TestExc
deserialization_fail = (
[]
if deserialization_fail is None
else [(data, TestExc.from_exception(exc)) for data, exc in deserialization_fail]
)
def generate_name(param: dict[str, Any] | bytes, i: int) -> str:
"""Generate a name for the test case."""
length = 30
result = f"{i:02d}] : " # the first [ is added by pytest
if isinstance(param, bytes):
result += repr(param[:length]) + "..." if len(param) > (length + 3) else repr(param)
elif isinstance(param, dict):
begin = ", ".join(f"{k}={v!r}" for k, v in param.items())
result += begin[:length] + "..." if len(begin) > (length + 3) else begin
else:
raise TypeError(f"Wrong type for param : {param!r}")
result = result.replace("\n", "\\n").replace("\r", "\\r")
result += f" [{cls.__name__}" # the other [ is added by pytest
return result
class TestClass:
"""Test class for the generated tests."""
@pytest.mark.parametrize(
("kwargs", "expected_bytes"),
parameters,
ids=tuple(generate_name(kwargs, i) for i, (kwargs, _) in enumerate(parameters)),
)
def test_serialization(self, kwargs: dict[str, Any], expected_bytes: bytes):
"""Test serialization of the object."""
obj = cls(**kwargs)
serialized_bytes = bytes(obj.serialize())
assert serialized_bytes == expected_bytes, f"{serialized_bytes} != {expected_bytes}"
@pytest.mark.parametrize(
("kwargs", "expected_bytes"),
parameters,
ids=tuple(generate_name(kwargs, i) for i, (kwargs, _) in enumerate(parameters)),
)
def test_deserialization(self, kwargs: dict[str, Any], expected_bytes: bytes):
"""Test deserialization of the object."""
buf = Buffer(expected_bytes)
obj = cls.deserialize(buf)
equality = cls(**kwargs) == obj
error_message: list[str] = []
# Try to find the mismatched field
if not equality:
for field, value in kwargs.items():
obj_val = getattr(obj, field, None)
if obj_val is None: # Either skip it, or it is intended to be None
continue
if obj_val != value:
error_message.append(f"{field}={obj_val} != {value}")
break
if error_message:
assert equality, f"Object not equal: {', '.join(error_message)}"
else:
assert equality, f"Object not equal: {obj} != {cls(**kwargs)} (expected)"
assert buf.remaining == 0, f"Buffer has {buf.remaining} bytes remaining"
@pytest.mark.parametrize(
("kwargs", "exc"),
validation_fail_kw,
ids=tuple(generate_name(kwargs, i) for i, (kwargs, _) in enumerate(validation_fail_kw)),
)
def test_validation(self, kwargs: dict[str, Any], exc: TestExc):
"""Test validation of the object."""
with pytest.raises(exc.exception, match=exc.match) as exc_info:
cls(**kwargs)
# If exc.kwargs is not None, check them against the exception
if exc.kwargs is not None:
for key, value in exc.kwargs.items():
assert value == getattr(
exc_info.value, key
), f"{key}: {value!r} != {getattr(exc_info.value, key)!r}"
@pytest.mark.parametrize(
("content", "exc"),
deserialization_fail,
ids=tuple(generate_name(content, i) for i, (content, _) in enumerate(deserialization_fail)),
)
def test_deserialization_error(self, content: bytes, exc: TestExc):
"""Test deserialization error handling."""
buf = Buffer(content)
with pytest.raises(exc.exception, match=exc.match) as exc_info:
cls.deserialize(buf)
# If exc.kwargs is not None, check them against the exception
if exc.kwargs is not None:
for key, value in exc.kwargs.items():
assert value == getattr(
exc_info.value, key
), f"{key}: {value!r} != {getattr(exc_info.value, key)!r}"
if len(parameters) == 0:
# If there are no serialization tests, remove them
del TestClass.test_serialization
del TestClass.test_deserialization
if len(validation_fail_kw) == 0:
# If there are no validation tests, remove them
del TestClass.test_validation
if len(deserialization_fail) == 0:
# If there are no deserialization error tests, remove them
del TestClass.test_deserialization_error
# Set the names of the class
TestClass.__name__ = f"TestGen{cls.__name__}"
# Add the test functions to the global context
context[TestClass.__name__] = TestClass # BERK