Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade to pydantic v2 #20

Merged
merged 20 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
7811809
Added unit tests for model validation related to PredictData and Trai…
Sep 20, 2024
7d0a4f0
Provided more unit tests for PredictSettings class. The attribute out…
Sep 20, 2024
645b54f
Additional whitespace added after commas to avoid linting errors.
Sep 20, 2024
baf5c3b
Remove commas after last entry inside lists.
Sep 20, 2024
6c3a036
Add comma to avoid linting errors from black.
Sep 20, 2024
f726ea0
Add pydantic v2 inside the pyproject.toml file. root_validator and An…
Sep 23, 2024
5f78b99
Fix flake8 errors and add pydantic-settings package in the pyprojec.t…
Sep 23, 2024
c42e915
Remove comma to avoid flake 8 error. Annotated is imported now from t…
Sep 23, 2024
448c5b9
Refactor code to avoid linting errors from the package Black.
Sep 23, 2024
9bd1fab
Refactor code to comply with Black.
Sep 23, 2024
a357dd0
Modify dummy subthesauri in tests/common.py to lowercase, so that the…
Sep 23, 2024
979a4bb
Replace dict() and json() with model_dump() and model_dump_json(), re…
Sep 23, 2024
7546484
In order to avoid a flake8 error by spreading a lengthy line inside t…
Sep 23, 2024
a003cd7
Ignore lengthy lines in flake8 from github actions.
Sep 23, 2024
350f3e5
Fix linting issues flagged by Black.
Sep 23, 2024
71f94a9
Put all imports from pydantic in one line for the interface/config.py…
Sep 23, 2024
9fde662
Fix linting issue raised by Black.
Sep 23, 2024
e628306
Pull changes from master and update poetry.lock file.
Sep 23, 2024
b9d3b65
Remove ignore command from github actions inside step associated with…
Sep 24, 2024
ec7b91e
Extra unit test added to check status 422 for rest api.
Sep 25, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 171 additions & 64 deletions poetry.lock

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ python = ">=3.8,<3.12"
scikit-learn = "~1.2"
scipy = "~1.10.1"
rdflib = "~6.3.2"
pydantic = "~1.10"
fastapi = "~0.95.1"
pydantic = "~2.8"
pydantic-settings = "~2.4"
fastapi = "~0.100"
uvicorn = "~0.22"


Expand Down
45 changes: 33 additions & 12 deletions qualle/interface/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,33 @@
from enum import Enum
from pathlib import Path
from typing import Dict, List, Optional, Union

from pydantic import BaseSettings, root_validator, FilePath, DirectoryPath
from typing_extensions import Annotated

from pydantic import (
model_validator,
FilePath,
DirectoryPath,
TypeAdapter,
PlainValidator,
AfterValidator,
)
from pydantic_settings import BaseSettings
from pydantic.networks import AnyUrl
from qualle.features.confidence import ConfidenceFeatures
from qualle.features.text import TextFeatures

# From pydantic v2 onwards, AnyUrl object does not inherit from a string class.
# The following code block performs validation on a pydantic AnyUrl object
# as if it were a string. Another problem is that a trailing '/' character is also
# appended in pydantic v2 and it is being removed in the code block given below.

AnyUrlAdapter = TypeAdapter(AnyUrl)
HttpUrlStr = Annotated[
str,
PlainValidator(lambda x: AnyUrlAdapter.validate_strings(x)),
AfterValidator(lambda x: str(x).rstrip("/")),
]


FileOrDirPath = Union[FilePath, DirectoryPath]

Expand All @@ -31,10 +52,10 @@ class RegressorSettings(BaseSettings):

class SubthesauriLabelCalibrationSettings(BaseSettings):
thesaurus_file: FilePath
subthesaurus_type: AnyUrl
concept_type: AnyUrl
concept_type_prefix: AnyUrl
subthesauri: List[AnyUrl]
subthesaurus_type: HttpUrlStr
concept_type: HttpUrlStr
concept_type_prefix: HttpUrlStr
subthesauri: List[HttpUrlStr]
use_sparse_count_matrix: bool = False


Expand Down Expand Up @@ -66,18 +87,18 @@ class EvalSettings(BaseSettings):
class PredictSettings(BaseSettings):
predict_data_path: FileOrDirPath
model_file: FilePath
output_path: Optional[Path]
output_path: Optional[Path] = None

@root_validator
def check_output_path_specified_for_input_file(cls, values):
predict_data_path = values.get("predict_data_path")
output_path = values.get("output_path")
@model_validator(mode="after")
def check_output_path_specified_for_input_file(self):
predict_data_path = self.predict_data_path
output_path = self.output_path
if predict_data_path.is_file() and not output_path:
raise ValueError(
"output_path has to be specified if predict_data_path "
"refers to a file"
)
return values
return self


class RESTSettings(BaseSettings):
Expand Down
20 changes: 10 additions & 10 deletions qualle/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import List, Union

import numpy as np
from pydantic import root_validator, BaseModel
from pydantic import model_validator, BaseModel
from scipy.sparse import spmatrix

Labels = List[str]
Expand All @@ -29,10 +29,10 @@ class PredictData(BaseModel):
predicted_labels: List[Labels]
scores: List[Scores]

@root_validator
def check_equal_length(cls, values):
@model_validator(mode="after")
def check_equal_length(self):
length = None
for v in values.values():
for v in self.__dict__.values():
if length is None:
length = len(v)
else:
Expand All @@ -41,21 +41,21 @@ def check_equal_length(cls, values):
"docs, predicted_labels and scores "
"should have the same length"
)
return values
return self


class TrainData(BaseModel):

predict_data: PredictData
true_labels: List[Labels]

@root_validator
def check_equal_length(cls, values):
p_data = values.get("predict_data")
t_labels = values.get("true_labels")
@model_validator(mode="after")
def check_equal_length(self):
p_data = self.predict_data
t_labels = self.true_labels
if len(p_data.predicted_labels) != len(t_labels):
raise ValueError("length of true labels and predicted labels do not match")
return values
return self


@dataclass
Expand Down
4 changes: 2 additions & 2 deletions tests/interface/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# limitations under the License.
from pathlib import Path

DUMMY_SUBTHESAURUS_B = "http://subB"
DUMMY_SUBTHESAURUS_A = "http://subA"
DUMMY_SUBTHESAURUS_B = "http://subb"
DUMMY_SUBTHESAURUS_A = "http://suba"
DUMMY_CONCEPT_TYPE = "http://ctype"
DUMMY_CONCEPT_TYPE_PREFIX = "http://prefix"
DUMMY_SUBTHESAURUS_TYPE = "http://stype"
Expand Down
21 changes: 19 additions & 2 deletions tests/interface/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,23 @@
def test_predict_settings_input_file_but_no_output_raises_exc(tmp_path):
fp = tmp_path / "fp.tsv"
fp.write_text("t\tc:0\tc")

mp = tmp_path / "model"
mp.write_text("modelInfo")
with pytest.raises(ValidationError):
PredictSettings(predict_data_path=fp, model=tmp_path / "model")
PredictSettings(predict_data_path=fp, model_file=mp)


def test_predict_settings_input_path_no_exc_1(tmp_path):
fp = tmp_path / "data"
fp.mkdir()
mp = tmp_path / "model"
mp.write_text("modelInfo")
PredictSettings(predict_data_path=fp, model_file=mp)


def test_predict_settings_input_path_no_exc_2(tmp_path):
fp = tmp_path / "fp.tsv"
fp.write_text("t\tc:0\tc")
mp = tmp_path / "model"
mp.write_text("modelInfo")
PredictSettings(predict_data_path=fp, model_file=mp, output_path=tmp_path)
14 changes: 11 additions & 3 deletions tests/interface/test_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,20 @@ def documents(train_data):


def test_return_http_200_for_predict(client, documents):
resp = client.post(PREDICT_ENDPOINT, json=documents.dict())
resp = client.post(PREDICT_ENDPOINT, json=documents.model_dump())
assert resp.status_code == status.HTTP_200_OK


def test_return_http_422_for_predict(client, documents):
docs = documents.model_dump()["documents"]
for doc in docs:
del doc["predicted_labels"]
resp = client.post(PREDICT_ENDPOINT, json=docs)
assert resp.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY


def test_return_scores_for_predict(client, documents):
resp = client.post(PREDICT_ENDPOINT, json=documents.dict())
resp = client.post(PREDICT_ENDPOINT, json=documents.model_dump())

expected_scores = QualityEstimation(
scores=[
Expand All @@ -79,7 +87,7 @@ def test_return_scores_for_predict(client, documents):
)
]
)
assert resp.json() == json.loads(expected_scores.json())
assert resp.json() == json.loads(expected_scores.model_dump_json())


def test_return_http_200_for_up(client):
Expand Down
6 changes: 4 additions & 2 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,15 @@ def test_rest(train_data_file, model_path):
scores=[0.5, 1],
)
]
).dict(),
).model_dump(),
)

assert res.status_code == status.HTTP_200_OK
# We can make following assumption due to the construction of train data
assert res.json() == json.loads(
QualityEstimation(scores=[QualityScores(name=Metric.RECALL, scores=[1])]).json()
QualityEstimation(
scores=[QualityScores(name=Metric.RECALL, scores=[1])]
).model_dump_json()
)


Expand Down
94 changes: 94 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright 2021-2023 ZBW – Leibniz Information Centre for Economics
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from pydantic import ValidationError

from qualle.models import PredictData, TrainData


DUMMY_DOCS = ["doc1", "doc2"]
DUMMY_PRED_LABELS_1 = [["label 1"]]
DUMMY_PRED_LABELS_2 = [["label 1"], ["label 2"]]
DUMMY_SCORES_1 = [[2.5, 3.0]]
DUMMY_SCORES_2 = [[2.5, 3.0], [5.5, 7.2]]
DUMMY_SCORES_3 = [[2.5, 3.0], [5.5, 7.2], [11.5, 23.1]]
DUMMY_TRUE_LABELS = [["true label 1"], ["true label 2"]]


def test_unequal_length_in_predict_data_raises_validator_exc_1():
with pytest.raises(ValidationError):
PredictData(
docs=DUMMY_DOCS,
predicted_labels=DUMMY_PRED_LABELS_1,
scores=DUMMY_SCORES_1,
)


def test_unequal_length_in_predict_data_raises_validator_exc_2():
with pytest.raises(ValidationError):
PredictData(
docs=DUMMY_DOCS,
predicted_labels=DUMMY_PRED_LABELS_1,
scores=DUMMY_SCORES_3,
)


def test_missing_attrbute_in_predict_data_raises_validator_exc():
with pytest.raises(ValidationError):
PredictData(
docs=DUMMY_DOCS,
predicted_labels=DUMMY_PRED_LABELS_1,
)


def test_predict_data_no_validation_exc():
PredictData(
docs=DUMMY_DOCS,
predicted_labels=DUMMY_PRED_LABELS_2,
scores=DUMMY_SCORES_2,
)


def test_unequal_length_in_train_data_raises_validator_exc_():
with pytest.raises(ValidationError):
TrainData(
predict_data=PredictData(
docs=DUMMY_DOCS,
predicted_labels=DUMMY_PRED_LABELS_2,
scores=DUMMY_SCORES_2,
),
true_labels=DUMMY_PRED_LABELS_1,
)


def test_missing_attribute_in_train_data_raises_validator_exc_():
with pytest.raises(ValidationError):
TrainData(
predict_data=PredictData(
docs=DUMMY_DOCS,
predicted_labels=DUMMY_PRED_LABELS_2,
scores=DUMMY_SCORES_2,
),
)


def test_train_data_no_validator_exc_():
TrainData(
predict_data=PredictData(
docs=DUMMY_DOCS,
predicted_labels=DUMMY_PRED_LABELS_2,
scores=DUMMY_SCORES_2,
),
true_labels=DUMMY_TRUE_LABELS,
)
Loading