Skip to content

Commit

Permalink
fix(OpenAPI): Correctly handle msgspec.Struct tagged unions (#3742)
Browse files Browse the repository at this point in the history
  • Loading branch information
provinzkraut authored Sep 15, 2024
1 parent f2ed95d commit bb1d0d4
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 11 deletions.
33 changes: 22 additions & 11 deletions litestar/_openapi/schema_generation/plugins/struct.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Literal

import msgspec
from msgspec import Struct
Expand All @@ -20,10 +20,11 @@ class StructSchemaPlugin(OpenAPISchemaPlugin):
def is_plugin_supported_field(self, field_definition: FieldDefinition) -> bool:
return not field_definition.is_union and field_definition.is_subclass_of(Struct)

def to_openapi_schema(self, field_definition: FieldDefinition, schema_creator: SchemaCreator) -> Schema:
def is_field_required(field: msgspec.inspect.Field) -> bool:
return field.required or field.default_factory is Empty
@staticmethod
def _is_field_required(field: msgspec.inspect.Field) -> bool:
return field.required or field.default_factory is Empty

def to_openapi_schema(self, field_definition: FieldDefinition, schema_creator: SchemaCreator) -> Schema:
type_hints = field_definition.get_type_hints(include_extras=True, resolve_generics=True)
struct_info: msgspec.inspect.StructType = msgspec.inspect.type_info(field_definition.type_) # type: ignore[assignment]
struct_fields = struct_info.fields
Expand All @@ -41,14 +42,24 @@ def is_field_required(field: msgspec.inspect.Field) -> bool:
**field_definition_kwargs,
)

required = [
field.encode_name
for field in struct_fields
if self._is_field_required(field=field) and not is_optional_union(type_hints[field.name])
]

# Support tagged unions: https://jcristharif.com/msgspec/structs.html#tagged-unions
# These structs contain a tag_field and a tag. Since these fields are added
# dynamically, they are not present within the regular struct fields and don't
# have any type annotation associated with them, so we create a FieldDefinition
# manually
if struct_info.tag_field:
# using a Literal here will set these as a const in the schema
property_fields[struct_info.tag_field] = FieldDefinition.from_annotation(Literal[struct_info.tag]) # pyright: ignore
required.append(struct_info.tag_field)

return schema_creator.create_component_schema(
field_definition,
required=sorted(
[
field.encode_name
for field in struct_fields
if is_field_required(field=field) and not is_optional_union(type_hints[field.name])
]
),
required=sorted(required),
property_fields=property_fields,
)
62 changes: 62 additions & 0 deletions tests/unit/test_contrib/test_msgspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from msgspec import Meta, Struct, field
from typing_extensions import Annotated

from litestar import Litestar, post
from litestar.dto import DTOField, Mark, MsgspecDTO, dto_field
from litestar.dto.data_structures import DTOFieldDefinition
from litestar.typing import FieldDefinition
Expand Down Expand Up @@ -131,3 +132,64 @@ class Model(Struct):
fields = list(dto_type.generate_field_definitions(Model))
assert fields[0].dto_field == DTOField("read-only")
assert fields[1].dto_field == DTOField("read-only")


def test_tag_field_included_in_schema() -> None:
# default tag field, default tag value
class Model(Struct, tag=True):
regular_field: str

# default tag field, custom tag value
class Model2(Struct, tag=2):
regular_field: str

# custom tag field, custom tag value
class Model3(Struct, tag_field="foo", tag="bar"):
regular_field: str

@post("/1")
def handler(data: Model) -> None:
return None

@post("/2")
def handler_2(data: Model2) -> None:
return None

@post("/3")
def handler_3(data: Model3) -> None:
return None

components = Litestar(
[handler, handler_2, handler_3],
signature_types=[Model, Model2, Model3],
).openapi_schema.components.to_schema()["schemas"]

assert components["test_tag_field_included_in_schema.Model"] == {
"properties": {
"regular_field": {"type": "string"},
"type": {"type": "string", "const": "Model"},
},
"type": "object",
"required": ["regular_field", "type"],
"title": "Model",
}

assert components["test_tag_field_included_in_schema.Model2"] == {
"properties": {
"regular_field": {"type": "string"},
"type": {"type": "integer", "const": 2},
},
"type": "object",
"required": ["regular_field", "type"],
"title": "Model2",
}

assert components["test_tag_field_included_in_schema.Model3"] == {
"properties": {
"regular_field": {"type": "string"},
"foo": {"type": "string", "const": "bar"},
},
"type": "object",
"required": ["foo", "regular_field"],
"title": "Model3",
}

0 comments on commit bb1d0d4

Please sign in to comment.