diff --git a/litestar/_openapi/schema_generation/plugins/struct.py b/litestar/_openapi/schema_generation/plugins/struct.py index 7ac0dd0220..e8fa8e273e 100644 --- a/litestar/_openapi/schema_generation/plugins/struct.py +++ b/litestar/_openapi/schema_generation/plugins/struct.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal import msgspec from msgspec import Struct @@ -20,10 +20,11 @@ class StructSchemaPlugin(OpenAPISchemaPlugin): def is_plugin_supported_field(self, field_definition: FieldDefinition) -> bool: return not field_definition.is_union and field_definition.is_subclass_of(Struct) - def to_openapi_schema(self, field_definition: FieldDefinition, schema_creator: SchemaCreator) -> Schema: - def is_field_required(field: msgspec.inspect.Field) -> bool: - return field.required or field.default_factory is Empty + @staticmethod + def _is_field_required(field: msgspec.inspect.Field) -> bool: + return field.required or field.default_factory is Empty + def to_openapi_schema(self, field_definition: FieldDefinition, schema_creator: SchemaCreator) -> Schema: type_hints = field_definition.get_type_hints(include_extras=True, resolve_generics=True) struct_info: msgspec.inspect.StructType = msgspec.inspect.type_info(field_definition.type_) # type: ignore[assignment] struct_fields = struct_info.fields @@ -41,14 +42,24 @@ def is_field_required(field: msgspec.inspect.Field) -> bool: **field_definition_kwargs, ) + required = [ + field.encode_name + for field in struct_fields + if self._is_field_required(field=field) and not is_optional_union(type_hints[field.name]) + ] + + # Support tagged unions: https://jcristharif.com/msgspec/structs.html#tagged-unions + # These structs contain a tag_field and a tag. Since these fields are added + # dynamically, they are not present within the regular struct fields and don't + # have any type annotation associated with them, so we create a FieldDefinition + # manually + if struct_info.tag_field: + # using a Literal here will set these as a const in the schema + property_fields[struct_info.tag_field] = FieldDefinition.from_annotation(Literal[struct_info.tag]) # pyright: ignore + required.append(struct_info.tag_field) + return schema_creator.create_component_schema( field_definition, - required=sorted( - [ - field.encode_name - for field in struct_fields - if is_field_required(field=field) and not is_optional_union(type_hints[field.name]) - ] - ), + required=sorted(required), property_fields=property_fields, ) diff --git a/tests/unit/test_contrib/test_msgspec.py b/tests/unit/test_contrib/test_msgspec.py index 2c0c177ec4..9c0f77ddf6 100644 --- a/tests/unit/test_contrib/test_msgspec.py +++ b/tests/unit/test_contrib/test_msgspec.py @@ -8,6 +8,7 @@ from msgspec import Meta, Struct, field from typing_extensions import Annotated +from litestar import Litestar, post from litestar.dto import DTOField, Mark, MsgspecDTO, dto_field from litestar.dto.data_structures import DTOFieldDefinition from litestar.typing import FieldDefinition @@ -131,3 +132,64 @@ class Model(Struct): fields = list(dto_type.generate_field_definitions(Model)) assert fields[0].dto_field == DTOField("read-only") assert fields[1].dto_field == DTOField("read-only") + + +def test_tag_field_included_in_schema() -> None: + # default tag field, default tag value + class Model(Struct, tag=True): + regular_field: str + + # default tag field, custom tag value + class Model2(Struct, tag=2): + regular_field: str + + # custom tag field, custom tag value + class Model3(Struct, tag_field="foo", tag="bar"): + regular_field: str + + @post("/1") + def handler(data: Model) -> None: + return None + + @post("/2") + def handler_2(data: Model2) -> None: + return None + + @post("/3") + def handler_3(data: Model3) -> None: + return None + + components = Litestar( + [handler, handler_2, handler_3], + signature_types=[Model, Model2, Model3], + ).openapi_schema.components.to_schema()["schemas"] + + assert components["test_tag_field_included_in_schema.Model"] == { + "properties": { + "regular_field": {"type": "string"}, + "type": {"type": "string", "const": "Model"}, + }, + "type": "object", + "required": ["regular_field", "type"], + "title": "Model", + } + + assert components["test_tag_field_included_in_schema.Model2"] == { + "properties": { + "regular_field": {"type": "string"}, + "type": {"type": "integer", "const": 2}, + }, + "type": "object", + "required": ["regular_field", "type"], + "title": "Model2", + } + + assert components["test_tag_field_included_in_schema.Model3"] == { + "properties": { + "regular_field": {"type": "string"}, + "foo": {"type": "string", "const": "bar"}, + }, + "type": "object", + "required": ["foo", "regular_field"], + "title": "Model3", + }