Skip to content

Commit

Permalink
Merge pull request #20 from zbw/18-upgrade-to-pydantic-v2
Browse files Browse the repository at this point in the history
Upgrade to pydantic v2
  • Loading branch information
gmmajal authored Sep 25, 2024
2 parents 3c2be3b + ec7b91e commit 4b9bf05
Show file tree
Hide file tree
Showing 9 changed files with 347 additions and 97 deletions.
235 changes: 171 additions & 64 deletions poetry.lock

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ python = ">=3.8,<3.12"
scikit-learn = "~1.2"
scipy = "~1.10.1"
rdflib = "~6.3.2"
pydantic = "~1.10"
fastapi = "~0.95.1"
pydantic = "~2.8"
pydantic-settings = "~2.4"
fastapi = "~0.100"
uvicorn = "~0.22"


Expand Down
45 changes: 33 additions & 12 deletions qualle/interface/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,33 @@
from enum import Enum
from pathlib import Path
from typing import Dict, List, Optional, Union

from pydantic import BaseSettings, root_validator, FilePath, DirectoryPath
from typing_extensions import Annotated

from pydantic import (
model_validator,
FilePath,
DirectoryPath,
TypeAdapter,
PlainValidator,
AfterValidator,
)
from pydantic_settings import BaseSettings
from pydantic.networks import AnyUrl
from qualle.features.confidence import ConfidenceFeatures
from qualle.features.text import TextFeatures

# From pydantic v2 onwards, AnyUrl object does not inherit from a string class.
# The following code block performs validation on a pydantic AnyUrl object
# as if it were a string. Another problem is that a trailing '/' character is also
# appended in pydantic v2 and it is being removed in the code block given below.

AnyUrlAdapter = TypeAdapter(AnyUrl)
HttpUrlStr = Annotated[
str,
PlainValidator(lambda x: AnyUrlAdapter.validate_strings(x)),
AfterValidator(lambda x: str(x).rstrip("/")),
]


FileOrDirPath = Union[FilePath, DirectoryPath]

Expand All @@ -31,10 +52,10 @@ class RegressorSettings(BaseSettings):

class SubthesauriLabelCalibrationSettings(BaseSettings):
thesaurus_file: FilePath
subthesaurus_type: AnyUrl
concept_type: AnyUrl
concept_type_prefix: AnyUrl
subthesauri: List[AnyUrl]
subthesaurus_type: HttpUrlStr
concept_type: HttpUrlStr
concept_type_prefix: HttpUrlStr
subthesauri: List[HttpUrlStr]
use_sparse_count_matrix: bool = False


Expand Down Expand Up @@ -66,18 +87,18 @@ class EvalSettings(BaseSettings):
class PredictSettings(BaseSettings):
predict_data_path: FileOrDirPath
model_file: FilePath
output_path: Optional[Path]
output_path: Optional[Path] = None

@root_validator
def check_output_path_specified_for_input_file(cls, values):
predict_data_path = values.get("predict_data_path")
output_path = values.get("output_path")
@model_validator(mode="after")
def check_output_path_specified_for_input_file(self):
predict_data_path = self.predict_data_path
output_path = self.output_path
if predict_data_path.is_file() and not output_path:
raise ValueError(
"output_path has to be specified if predict_data_path "
"refers to a file"
)
return values
return self


class RESTSettings(BaseSettings):
Expand Down
20 changes: 10 additions & 10 deletions qualle/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import List, Union

import numpy as np
from pydantic import root_validator, BaseModel
from pydantic import model_validator, BaseModel
from scipy.sparse import spmatrix

Labels = List[str]
Expand All @@ -29,10 +29,10 @@ class PredictData(BaseModel):
predicted_labels: List[Labels]
scores: List[Scores]

@root_validator
def check_equal_length(cls, values):
@model_validator(mode="after")
def check_equal_length(self):
length = None
for v in values.values():
for v in self.__dict__.values():
if length is None:
length = len(v)
else:
Expand All @@ -41,21 +41,21 @@ def check_equal_length(cls, values):
"docs, predicted_labels and scores "
"should have the same length"
)
return values
return self


class TrainData(BaseModel):

predict_data: PredictData
true_labels: List[Labels]

@root_validator
def check_equal_length(cls, values):
p_data = values.get("predict_data")
t_labels = values.get("true_labels")
@model_validator(mode="after")
def check_equal_length(self):
p_data = self.predict_data
t_labels = self.true_labels
if len(p_data.predicted_labels) != len(t_labels):
raise ValueError("length of true labels and predicted labels do not match")
return values
return self


@dataclass
Expand Down
4 changes: 2 additions & 2 deletions tests/interface/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# limitations under the License.
from pathlib import Path

DUMMY_SUBTHESAURUS_B = "http://subB"
DUMMY_SUBTHESAURUS_A = "http://subA"
DUMMY_SUBTHESAURUS_B = "http://subb"
DUMMY_SUBTHESAURUS_A = "http://suba"
DUMMY_CONCEPT_TYPE = "http://ctype"
DUMMY_CONCEPT_TYPE_PREFIX = "http://prefix"
DUMMY_SUBTHESAURUS_TYPE = "http://stype"
Expand Down
21 changes: 19 additions & 2 deletions tests/interface/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,23 @@
def test_predict_settings_input_file_but_no_output_raises_exc(tmp_path):
fp = tmp_path / "fp.tsv"
fp.write_text("t\tc:0\tc")

mp = tmp_path / "model"
mp.write_text("modelInfo")
with pytest.raises(ValidationError):
PredictSettings(predict_data_path=fp, model=tmp_path / "model")
PredictSettings(predict_data_path=fp, model_file=mp)


def test_predict_settings_input_path_no_exc_1(tmp_path):
fp = tmp_path / "data"
fp.mkdir()
mp = tmp_path / "model"
mp.write_text("modelInfo")
PredictSettings(predict_data_path=fp, model_file=mp)


def test_predict_settings_input_path_no_exc_2(tmp_path):
fp = tmp_path / "fp.tsv"
fp.write_text("t\tc:0\tc")
mp = tmp_path / "model"
mp.write_text("modelInfo")
PredictSettings(predict_data_path=fp, model_file=mp, output_path=tmp_path)
14 changes: 11 additions & 3 deletions tests/interface/test_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,20 @@ def documents(train_data):


def test_return_http_200_for_predict(client, documents):
resp = client.post(PREDICT_ENDPOINT, json=documents.dict())
resp = client.post(PREDICT_ENDPOINT, json=documents.model_dump())
assert resp.status_code == status.HTTP_200_OK


def test_return_http_422_for_predict(client, documents):
docs = documents.model_dump()["documents"]
for doc in docs:
del doc["predicted_labels"]
resp = client.post(PREDICT_ENDPOINT, json=docs)
assert resp.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY


def test_return_scores_for_predict(client, documents):
resp = client.post(PREDICT_ENDPOINT, json=documents.dict())
resp = client.post(PREDICT_ENDPOINT, json=documents.model_dump())

expected_scores = QualityEstimation(
scores=[
Expand All @@ -79,7 +87,7 @@ def test_return_scores_for_predict(client, documents):
)
]
)
assert resp.json() == json.loads(expected_scores.json())
assert resp.json() == json.loads(expected_scores.model_dump_json())


def test_return_http_200_for_up(client):
Expand Down
6 changes: 4 additions & 2 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,15 @@ def test_rest(train_data_file, model_path):
scores=[0.5, 1],
)
]
).dict(),
).model_dump(),
)

assert res.status_code == status.HTTP_200_OK
# We can make following assumption due to the construction of train data
assert res.json() == json.loads(
QualityEstimation(scores=[QualityScores(name=Metric.RECALL, scores=[1])]).json()
QualityEstimation(
scores=[QualityScores(name=Metric.RECALL, scores=[1])]
).model_dump_json()
)


Expand Down
94 changes: 94 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright 2021-2023 ZBW – Leibniz Information Centre for Economics
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from pydantic import ValidationError

from qualle.models import PredictData, TrainData


DUMMY_DOCS = ["doc1", "doc2"]
DUMMY_PRED_LABELS_1 = [["label 1"]]
DUMMY_PRED_LABELS_2 = [["label 1"], ["label 2"]]
DUMMY_SCORES_1 = [[2.5, 3.0]]
DUMMY_SCORES_2 = [[2.5, 3.0], [5.5, 7.2]]
DUMMY_SCORES_3 = [[2.5, 3.0], [5.5, 7.2], [11.5, 23.1]]
DUMMY_TRUE_LABELS = [["true label 1"], ["true label 2"]]


def test_unequal_length_in_predict_data_raises_validator_exc_1():
with pytest.raises(ValidationError):
PredictData(
docs=DUMMY_DOCS,
predicted_labels=DUMMY_PRED_LABELS_1,
scores=DUMMY_SCORES_1,
)


def test_unequal_length_in_predict_data_raises_validator_exc_2():
with pytest.raises(ValidationError):
PredictData(
docs=DUMMY_DOCS,
predicted_labels=DUMMY_PRED_LABELS_1,
scores=DUMMY_SCORES_3,
)


def test_missing_attrbute_in_predict_data_raises_validator_exc():
with pytest.raises(ValidationError):
PredictData(
docs=DUMMY_DOCS,
predicted_labels=DUMMY_PRED_LABELS_1,
)


def test_predict_data_no_validation_exc():
PredictData(
docs=DUMMY_DOCS,
predicted_labels=DUMMY_PRED_LABELS_2,
scores=DUMMY_SCORES_2,
)


def test_unequal_length_in_train_data_raises_validator_exc_():
with pytest.raises(ValidationError):
TrainData(
predict_data=PredictData(
docs=DUMMY_DOCS,
predicted_labels=DUMMY_PRED_LABELS_2,
scores=DUMMY_SCORES_2,
),
true_labels=DUMMY_PRED_LABELS_1,
)


def test_missing_attribute_in_train_data_raises_validator_exc_():
with pytest.raises(ValidationError):
TrainData(
predict_data=PredictData(
docs=DUMMY_DOCS,
predicted_labels=DUMMY_PRED_LABELS_2,
scores=DUMMY_SCORES_2,
),
)


def test_train_data_no_validator_exc_():
TrainData(
predict_data=PredictData(
docs=DUMMY_DOCS,
predicted_labels=DUMMY_PRED_LABELS_2,
scores=DUMMY_SCORES_2,
),
true_labels=DUMMY_TRUE_LABELS,
)

0 comments on commit 4b9bf05

Please sign in to comment.