From 8fb29074115879536ccb7ff5d6c17bee2ce74d4a Mon Sep 17 00:00:00 2001 From: Andrew Perminov Date: Thu, 13 Jul 2023 17:31:13 +0300 Subject: [PATCH] TLDR-379 update bold classifier (#294) * add bold classifier based on rules * replace nn classifier with ruled classifier * remove downloading deleted model * fix style tests * review fixes * add test for bold classifier --- dedoc/download_models.py | 5 - .../bold_classifier/__init__.py | 0 .../agglomerative_clusterizer.py | 66 ++++++ .../bold_classifier/bold_classifier.py | 107 +++++++++ .../valley_emphasis_binarizer.py | 46 ++++ .../font_type_classifier.py | 59 +---- .../metadata_extractor.py | 10 +- .../train/train_line_metadata_classifier.py | 223 ------------------ tests/unit_tests/test_font_classifier.py | 52 ++-- 9 files changed, 258 insertions(+), 310 deletions(-) create mode 100644 dedoc/readers/pdf_reader/pdf_image_reader/line_metadata_extractor/bold_classifier/__init__.py create mode 100644 dedoc/readers/pdf_reader/pdf_image_reader/line_metadata_extractor/bold_classifier/agglomerative_clusterizer.py create mode 100644 dedoc/readers/pdf_reader/pdf_image_reader/line_metadata_extractor/bold_classifier/bold_classifier.py create mode 100644 dedoc/readers/pdf_reader/pdf_image_reader/line_metadata_extractor/bold_classifier/valley_emphasis_binarizer.py delete mode 100644 dedoc/scripts/train/train_line_metadata_classifier.py diff --git a/dedoc/download_models.py b/dedoc/download_models.py index cb65eb7a..376d82bc 100644 --- a/dedoc/download_models.py +++ b/dedoc/download_models.py @@ -37,11 +37,6 @@ def download(resources_path: str) -> None: repo_name="scan_orientation_efficient_net_b0", hub_name="model.pth") - download_from_hub(out_dir=resources_path, - out_name="font_classifier.pth", - repo_name="font_classifier", - hub_name="model.pth") - download_from_hub(out_dir=resources_path, out_name="paragraph_classifier.pkl.gz", repo_name="paragraph_classifier", diff --git a/dedoc/readers/pdf_reader/pdf_image_reader/line_metadata_extractor/bold_classifier/__init__.py b/dedoc/readers/pdf_reader/pdf_image_reader/line_metadata_extractor/bold_classifier/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dedoc/readers/pdf_reader/pdf_image_reader/line_metadata_extractor/bold_classifier/agglomerative_clusterizer.py b/dedoc/readers/pdf_reader/pdf_image_reader/line_metadata_extractor/bold_classifier/agglomerative_clusterizer.py new file mode 100644 index 00000000..e1a37cb7 --- /dev/null +++ b/dedoc/readers/pdf_reader/pdf_image_reader/line_metadata_extractor/bold_classifier/agglomerative_clusterizer.py @@ -0,0 +1,66 @@ +import numpy as np +from scipy.stats import norm +from sklearn.cluster import AgglomerativeClustering + + +class BoldAgglomerativeClusterizer: + def __init__(self) -> None: + self.significance_level = 0.2 + + def clusterize(self, x: np.ndarray) -> np.ndarray: + x_vectors = self.__get_prop_vectors(x) + x_clusters = self.__get_clusters(x_vectors) + x_indicator = self.__get_indicator(x, x_clusters) + return x_indicator + + def __get_prop_vectors(self, x: np.ndarray) -> np.ndarray: + nearby_x = x.copy() + nearby_x[:-1] += x[1:] + nearby_x[1:] += x[:-1] + nearby_x[0] += x[0] + nearby_x[-1] += x[-1] + nearby_x = nearby_x / 3. + return np.stack((x, nearby_x), 1) + + def __get_clusters(self, x_vectors: np.ndarray) -> np.ndarray: + agg = AgglomerativeClustering() + agg.fit(x_vectors) + x_clusters = agg.labels_ + return x_clusters + + def __get_indicator(self, x: np.ndarray, x_clusters: np.ndarray) -> np.ndarray: + # https://www.tsi.lv/sites/default/files/editor/science/Research_journals/Tr_Tel/2003/V1/yatskiv_gousarova.pdf + # https://www.svms.org/classification/DuHS95.pdf + # Pattern Classification and Scene Analysis (2nd ed.) + # Part 1: Pattern Classification + # Richard O. Duda, Peter E. Hart and David G. Stork + # February 27, 1995 + f1 = self.__get_f1_homogeneous(x, x_clusters) + f_cr = self.__get_f_criterion_homogeneous(n=len(x)) + + if f_cr < f1: + return np.zeros_like(x) + if np.mean(x[x_clusters == 1]) < np.mean(x[x_clusters == 0]): + x_clusters[x_clusters == 1] = 1.0 + x_clusters[x_clusters == 0] = 0.0 + else: + x_clusters[x_clusters == 0] = 1.0 + x_clusters[x_clusters == 1] = 0.0 + + return x_clusters + + def __get_f1_homogeneous(self, x: np.ndarray, x_clusters: np.ndarray) -> float: + x_clust0 = x[x_clusters == 0] + x_clust1 = x[x_clusters == 1] + if len(x_clust0) == 0 or len(x_clust1) == 0: + return 1 + + w1 = np.std(x) * len(x) + w2 = np.std(x_clust0) * len(x_clust0) + np.std(x_clust1) * len(x_clust1) + f1 = w2 / w1 + return f1 + + def __get_f_criterion_homogeneous(self, n: int, p: int = 2) -> float: + za1 = norm.ppf(1 - self.significance_level, loc=0, scale=1) + f_cr = 1 - 2 / (np.pi * p) - za1 * np.sqrt(2 * (1 - 8 / (np.pi ** 2 * p)) / (n * p)) + return f_cr diff --git a/dedoc/readers/pdf_reader/pdf_image_reader/line_metadata_extractor/bold_classifier/bold_classifier.py b/dedoc/readers/pdf_reader/pdf_image_reader/line_metadata_extractor/bold_classifier/bold_classifier.py new file mode 100644 index 00000000..18dbdee2 --- /dev/null +++ b/dedoc/readers/pdf_reader/pdf_image_reader/line_metadata_extractor/bold_classifier/bold_classifier.py @@ -0,0 +1,107 @@ +from typing import List + +import numpy as np + +from dedoc.data_structures import BBox +from dedoc.readers.pdf_reader.pdf_image_reader.line_metadata_extractor.bold_classifier.agglomerative_clusterizer import BoldAgglomerativeClusterizer +from dedoc.readers.pdf_reader.pdf_image_reader.line_metadata_extractor.bold_classifier.valley_emphasis_binarizer import ValleyEmphasisBinarizer + + +class BoldClassifier: + """ + This class classifies words (or lines) in bboxes as bold or non-bold. + Given a list of bboxes and an image, it returns a list of boldness probabilities (actually only 0 and 1 for now) + """ + def __init__(self) -> None: + self.permissible_h_bbox = 5 + self.binarizer = ValleyEmphasisBinarizer() + self.clusterizer = BoldAgglomerativeClusterizer() + + def classify(self, image: np.ndarray, bboxes: List[BBox]) -> List[float]: + if len(bboxes) == 0: + return [] + + if len(bboxes) == 1: + return [0.0] + + bboxes_evaluation = self.__get_bboxes_evaluation(image, bboxes) + bold_probabilities = self.__clusterize(bboxes_evaluation) + return bold_probabilities + + def __get_bboxes_evaluation(self, image: np.ndarray, bboxes: List[BBox]) -> List[float]: + processed_image = self.__preprocessing(image) + bboxes_evaluation = self.__get_evaluation_bboxes(processed_image, bboxes) + return bboxes_evaluation + + def __preprocessing(self, image: np.ndarray) -> np.ndarray: + return self.binarizer.binarize(image) + + def __get_evaluation_bboxes(self, image: np.ndarray, bboxes: List[BBox]) -> List[float]: + bboxes_evaluation = [self.__evaluation_one_bbox(image, bbox) for bbox in bboxes] + return bboxes_evaluation + + def __evaluation_one_bbox(self, image: np.ndarray, bbox: BBox) -> float: + bbox_image = image[bbox.y_top_left:bbox.y_bottom_right, bbox.x_top_left:bbox.x_bottom_right] + return self.__evaluation_one_bbox_image(bbox_image) if self.__is_correct_bbox_image(bbox_image) else 1. + + def __evaluation_one_bbox_image(self, image: np.ndarray) -> float: + base_line_image = self.__get_base_line_image(image) + base_line_image_without_spaces = self.__get_rid_spaces(base_line_image) + + p_img = base_line_image[:, :-1] - base_line_image[:, 1:] + p_img[abs(p_img) > 0] = 1. + p_img[p_img < 0] = 0. + p = p_img.mean() + + s = 1 - base_line_image_without_spaces.mean() + + if p > s or s == 0: + evaluation = 1. + else: + evaluation = p / s + return evaluation + + def __clusterize(self, bboxes_evaluation: List[float]) -> List[float]: + vector_bbox_evaluation = np.array(bboxes_evaluation) + vector_bbox_indicators = self.clusterizer.clusterize(vector_bbox_evaluation) + bboxes_indicators = list(vector_bbox_indicators) + return bboxes_indicators + + def __get_rid_spaces(self, image: np.ndarray) -> np.ndarray: + x = image.mean(0) + not_space = x < 0.95 + if len(not_space) > 3: + return image + return image[:, not_space] + + def __get_base_line_image(self, image: np.ndarray) -> np.ndarray: + h = image.shape[0] + if h < self.permissible_h_bbox: + return image + mean_ = image.mean(1) + delta_mean = abs(mean_[:-1] - mean_[1:]) + + max1 = 0 + max2 = 0 + argmax1 = 0 + argmax2 = 0 + for i, delta_mean_i in enumerate(delta_mean): + if delta_mean_i <= max2: + continue + if delta_mean_i > max1: + max2 = max1 + argmax2 = argmax1 + max1 = delta_mean_i + argmax1 = i + else: + max2 = delta_mean_i + argmax2 = i + h_min = min(argmax1, argmax2) + h_max = min(max(argmax1, argmax2) + 1, h) + if h_max - h_min < self.permissible_h_bbox: + return image + return image[h_min:h_max, :] + + def __is_correct_bbox_image(self, image: np.ndarray) -> bool: + h, w = image.shape[0:2] + return h > 3 and w > 3 diff --git a/dedoc/readers/pdf_reader/pdf_image_reader/line_metadata_extractor/bold_classifier/valley_emphasis_binarizer.py b/dedoc/readers/pdf_reader/pdf_image_reader/line_metadata_extractor/bold_classifier/valley_emphasis_binarizer.py new file mode 100644 index 00000000..7488c9a7 --- /dev/null +++ b/dedoc/readers/pdf_reader/pdf_image_reader/line_metadata_extractor/bold_classifier/valley_emphasis_binarizer.py @@ -0,0 +1,46 @@ +import cv2 +import numpy as np + + +class ValleyEmphasisBinarizer: + def __init__(self, n: int = 5) -> None: + self.n = n + + def binarize(self, image: np.ndarray) -> np.ndarray: + gray_img = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + threshold = self.__get_threshold(gray_img) + + gray_img[gray_img <= threshold] = 0 + gray_img[gray_img > threshold] = 1 + return gray_img + + def __get_threshold(self, gray_img: np.ndarray) -> int: + c, x = np.histogram(gray_img, bins=255) + h, w = gray_img.shape + total = h * w + + sum_val = 0 + for t in range(255): + sum_val = sum_val + (t * c[t] / total) + + var_max = 0 + threshold = 0 + + omega_1 = 0 + mu_k = 0 + + for t in range(254): + omega_1 = omega_1 + c[t] / total + omega_2 = 1 - omega_1 + mu_k = mu_k + t * (c[t] / total) + mu_1 = mu_k / omega_1 + mu_2 = (sum_val - mu_k) / omega_2 + sum_of_neighbors = np.sum(c[max(1, t - self.n):min(255, t + self.n)]) + denom = total + current_var = (1 - sum_of_neighbors / denom) * (omega_1 * mu_1 ** 2 + omega_2 * mu_2 ** 2) + + if current_var > var_max: + var_max = current_var + threshold = t + + return threshold diff --git a/dedoc/readers/pdf_reader/pdf_image_reader/line_metadata_extractor/font_type_classifier.py b/dedoc/readers/pdf_reader/pdf_image_reader/line_metadata_extractor/font_type_classifier.py index 86c94c99..bcabdbc9 100644 --- a/dedoc/readers/pdf_reader/pdf_image_reader/line_metadata_extractor/font_type_classifier.py +++ b/dedoc/readers/pdf_reader/pdf_image_reader/line_metadata_extractor/font_type_classifier.py @@ -1,67 +1,22 @@ -import os -from collections import namedtuple -from typing import Any -import torch -from torchvision.transforms import ToTensor - from dedoc.data_structures.concrete_annotations.bold_annotation import BoldAnnotation -from dedoc.download_models import download_from_hub from dedoc.readers.pdf_reader.data_classes.page_with_bboxes import PageWithBBox -from dedoc.utils.image_utils import get_bbox_from_image - -FontType = namedtuple("FontType", ["bold", "other"]) +from dedoc.readers.pdf_reader.pdf_image_reader.line_metadata_extractor.bold_classifier.bold_classifier import BoldClassifier class FontTypeClassifier: - labels_list = ["bold", "OTHER"] - - def __init__(self, model_path: str) -> None: + def __init__(self) -> None: super().__init__() - self._model = None - self.model_path = model_path - - @property - def model(self) -> Any: - if self._model is not None: - return self._model - - if not os.path.isfile(self.model_path): - out_dir, out_name = os.path.split(self.model_path) - download_from_hub(out_dir=out_dir, out_name=out_name, repo_name="font_classifier", hub_name="model.pth") - - with open(self.model_path, "rb") as file: - self._model = torch.load(f=file).eval() - self._model.requires_grad_(False) - - return self._model + self.bold_classifier = BoldClassifier() def predict_annotations(self, page: PageWithBBox) -> PageWithBBox: if len(page.bboxes) == 0: return page - tensor_predictions = self._get_model_predictions(page) - is_bold = ["bold" if p else "not_bold" for p in (tensor_predictions[:, 0] > 0.5)] - is_other = ["other" if p else "not_other" for p in (tensor_predictions[:, 1] > 0.5)] - font_type_predictions = [FontType(*_) for _ in zip(is_bold, is_other)] - boxes_fonts = zip(page.bboxes, font_type_predictions) + bboxes = [bbox.bbox for bbox in page.bboxes] + bold_probabilities = self.bold_classifier.classify(page.image, bboxes) - for bbox, font_type in boxes_fonts: - if font_type.bold == "bold": + for bbox, bold_probability in zip(page.bboxes, bold_probabilities): + if bold_probability > 0.5: bbox.annotations.append(BoldAnnotation(start=0, end=len(bbox.text), value="True")) return page - - @staticmethod - def _page2tensor(page: PageWithBBox) -> torch.Tensor: - if len(page.bboxes) == 0: - return torch.zeros() - to_tensor = ToTensor() - images = (get_bbox_from_image(image=page.image, bbox=bbox.bbox) for bbox in page.bboxes) - tensors = (to_tensor(image) for image in images) - tensors_reshaped = [tensor.unsqueeze(0) for tensor in tensors] - return torch.cat(tensors_reshaped, dim=0) - - def _get_model_predictions(self, page: PageWithBBox) -> torch.Tensor: - tensor = self._page2tensor(page) - with torch.no_grad(): - return self.model(tensor) diff --git a/dedoc/readers/pdf_reader/pdf_image_reader/line_metadata_extractor/metadata_extractor.py b/dedoc/readers/pdf_reader/pdf_image_reader/line_metadata_extractor/metadata_extractor.py index 3ce3ecd8..504e25aa 100644 --- a/dedoc/readers/pdf_reader/pdf_image_reader/line_metadata_extractor/metadata_extractor.py +++ b/dedoc/readers/pdf_reader/pdf_image_reader/line_metadata_extractor/metadata_extractor.py @@ -1,14 +1,13 @@ -import os import re from typing import List, Optional -from numpy import median + import numpy as np +from numpy import median -from dedoc.config import get_config +from dedoc.data_structures.concrete_annotations.color_annotation import ColorAnnotation from dedoc.data_structures.concrete_annotations.indentation_annotation import IndentationAnnotation from dedoc.data_structures.concrete_annotations.size_annotation import SizeAnnotation from dedoc.data_structures.concrete_annotations.spacing_annotation import SpacingAnnotation -from dedoc.data_structures.concrete_annotations.color_annotation import ColorAnnotation from dedoc.data_structures.line_metadata import LineMetadata from dedoc.readers.pdf_reader.data_classes.line_with_location import LineWithLocation from dedoc.readers.pdf_reader.data_classes.page_with_bboxes import PageWithBBox @@ -21,8 +20,7 @@ class LineMetadataExtractor: def __init__(self, default_spacing: int = 50, *, config: dict) -> None: self.config = config - path_model = os.path.join(get_config()["resources_path"], "font_classifier.pth") - self.font_type_classifier = FontTypeClassifier(path_model) + self.font_type_classifier = FontTypeClassifier() self.default_spacing = default_spacing def predict_annotations(self, page_with_lines: PageWithBBox) -> PageWithBBox: diff --git a/dedoc/scripts/train/train_line_metadata_classifier.py b/dedoc/scripts/train/train_line_metadata_classifier.py deleted file mode 100644 index 78be1534..00000000 --- a/dedoc/scripts/train/train_line_metadata_classifier.py +++ /dev/null @@ -1,223 +0,0 @@ -import argparse -import json -import os -import random -import time -import warnings -from collections import defaultdict -from itertools import chain -from typing import List, Tuple - -import numpy as np -import torch -from PIL import Image -from joblib import Parallel, delayed -from numpy import mean -from sklearn.metrics import roc_auc_score -from sklearn.model_selection import train_test_split -from torch.nn import BCELoss -from torch.nn import Sequential, Linear, ReLU, Sigmoid, BatchNorm1d -from torch.nn.modules.loss import _Loss -from torch.optim import Adam, Optimizer -from torch.utils.data import Dataset, DataLoader -from torchvision.models import resnet18, ResNet -from torchvision.transforms import ToTensor -from tqdm import tqdm - - -from dedoc.data_structures.bbox import BBox -from dedoc.readers.pdf_reader.pdf_image_reader.line_metadata_extractor.font_type_classifier import FontTypeClassifier -from dedoc.utils.image_utils import get_bbox_from_image - -parser = argparse.ArgumentParser(add_help=True) -parser.add_argument("-l", "--labels_path", type=str, help="path to the json file with labeled bboxes", required=True) -parser.add_argument("-o", "--output_file", type=str, help="name of file with trained classifier", required=True) -args = parser.parse_args() - -print("GO") - -path = args.labels_path -path_out = args.output_file -seed = 42 - -torch.manual_seed(seed) -torch.cuda.manual_seed(seed) -np.random.seed(seed) -random.seed(seed) - -device = "cuda" if torch.cuda.is_available() else "cpu" - -print(device) - - -def get_model() -> ResNet: - model = resnet18(pretrained=True) - model.fc = Sequential( - Linear(in_features=512, out_features=256), - ReLU(), - BatchNorm1d(256), - Linear(256, out_features=2), - Sigmoid(), - ) - return model - - -class FontTypeDataset(Dataset): - - def __init__(self, path: str, items: List[dict]) -> None: - super().__init__() - self.labels_list = FontTypeClassifier.labels_list - self.to_tensor = ToTensor() - - self.images = Parallel(n_jobs=8)(delayed(self._image2cropped)(path, i) for i in tqdm(items)) - self.images = [self.to_tensor(image) for image in self.images] - labels = [] - for item in items: - labels.append(self._encode_labels(item)) - self.labels = torch.tensor(labels).float() - - def _image2cropped(self, path: str, item: dict) -> Image: - with warnings.catch_warnings(): - warnings.simplefilter('ignore') - image_path = os.path.join(path, "original_documents", item["data"]["original_document_name"]) - image = Image.open(image_path) - bbox_dict = item["data"]["bbox"]["bbox"] - bbox = BBox(x_top_left=bbox_dict["x_upper_left"], - y_top_left=bbox_dict["y_upper_left"], - height=bbox_dict["height"], - width=bbox_dict["width"] - ) - return get_bbox_from_image(image=image, bbox=bbox) - - def _encode_labels(self, item: dict) -> List[int]: - labels_item = [] - for label in self.labels_list: - if label in item["labeled"]: - labels_item.append(1) - else: - labels_item.append(0) - assert len(labels_item) == len(self.labels_list) - return labels_item - - def __getitem__(self, index: int) -> Tuple[Image, torch.Tensor]: - return self.images[index], self.labels[index] - - def __len__(self) -> int: - return len(self.labels) - - -def get_data(path: str) -> Tuple[List, List]: - grouped_tasks = defaultdict(list) - - with open(os.path.join(path, "labeled_tasks.json")) as file: - data = json.load(file) - for item in data.values(): - image = item["data"]["original_document_name"] - if os.path.isfile(os.path.join(path, "original_documents", image)): - grouped_tasks[image].append(item) - - train_group, val_group = train_test_split(list(grouped_tasks.values()), train_size=0.8, ) - train_group = list(chain(*train_group)) - val_group = list(chain(*val_group)) - return train_group, val_group - - -def one_batch_train(model: torch.nn.Module, - data_loader: DataLoader, - optimizer: Optimizer, - criterion: _Loss) -> List[float]: - epoch_losses = [] - for data_input, labels in data_loader: - optimizer.zero_grad() - - data_input = data_input.to(device) - labels = labels.float().to(device) - predictions = model(data_input) - loss = criterion(predictions, labels) - loss.backward() - optimizer.step() - epoch_losses.append(float(loss)) - return epoch_losses - - -def one_batch_val(model: torch.nn.Module, - data_loader: DataLoader, - criterion: _Loss) -> Tuple[List[float], torch.Tensor, torch.Tensor]: - epoch_losses = [] - predictions_all = [] - labels_all = [] - with torch.no_grad(): - for data_input, labels in data_loader: - data_input = data_input.to(device) - labels = labels.float().to(device) - predictions = model(data_input) - loss = criterion(predictions, labels) - epoch_losses.append(float(loss)) - predictions_all.append(predictions.cpu()) - labels_all.append(labels.cpu()) - return epoch_losses, torch.cat(predictions_all, dim=0), torch.cat(labels_all, dim=0) - - -def train_model(model: torch.nn.Module, - criterion: _Loss, - optimizer: Optimizer, - dataloaders: DataLoader, - epoch_start: int = 0, - epoch_end: int = 15) -> None: - res = [] - for epoch in range(epoch_start, epoch_end): - epoch_losses_train = one_batch_train(model, dataloaders["train"], optimizer, criterion) - epoch_losses_val, predictions_all, labels_all = one_batch_val(model, dataloaders["val"], criterion) - - roc_bold = roc_auc_score(y_score=predictions_all[:, 0], y_true=labels_all[:, 0]) - roc_other = roc_auc_score(y_score=predictions_all[:, 1], y_true=labels_all[:, 1]) - epoch_losses_train = mean(epoch_losses_train) - epoch_losses_val = mean(epoch_losses_val) - res.append((epoch, epoch_losses_train, epoch_losses_val, roc_bold, roc_other)) - report_template = "{:011d} epoch={:06d} train {:01.4f} val {:01.4f} bold {:01.4f} other {:01.4f}" - print(report_template.format(int(time.time()), *res[-1])) - return - - -def main() -> None: - train_group, val_group = get_data(path) - with warnings.catch_warnings(): - warnings.simplefilter('ignore') - dataset_val = FontTypeDataset(path, val_group) - dataset_train = FontTypeDataset(path, train_group) - - dataloaders = { - "val": DataLoader(dataset_val, batch_size=16, drop_last=True), - "train": DataLoader(dataset_train, batch_size=16, shuffle=True, drop_last=True) - } - print("GET DATA") - - font_classifier = get_model() - print("GET MODEL") - - font_classifier.requires_grad_(False) - font_classifier.fc.requires_grad_(True) - font_classifier = font_classifier.to(device) - optimizer = Adam(params=font_classifier.fc.parameters(), lr=1e-5) - train_model(model=font_classifier, - criterion=BCELoss(), - dataloaders=dataloaders, - optimizer=optimizer, - epoch_start=0, - epoch_end=15) - - font_classifier.requires_grad_(True) - optimizer = Adam(params=font_classifier.parameters(), lr=1e-4) - train_model(model=font_classifier, - dataloaders=dataloaders, - criterion=BCELoss(), - optimizer=optimizer, - epoch_start=15, - epoch_end=35) - with open(path_out, "wb") as file_out: - font_classifier = font_classifier.cpu() - torch.save(obj=font_classifier, f=file_out) - - -if __name__ == '__main__': - main() diff --git a/tests/unit_tests/test_font_classifier.py b/tests/unit_tests/test_font_classifier.py index d721f6c4..2ef3eb41 100644 --- a/tests/unit_tests/test_font_classifier.py +++ b/tests/unit_tests/test_font_classifier.py @@ -1,42 +1,46 @@ import os import unittest -from PIL import Image +import cv2 +from dedoc.data_structures import BoldAnnotation from dedoc.data_structures.bbox import BBox from dedoc.readers.pdf_reader.data_classes.page_with_bboxes import PageWithBBox from dedoc.readers.pdf_reader.data_classes.text_with_bbox import TextWithBBox from dedoc.readers.pdf_reader.pdf_image_reader.line_metadata_extractor.font_type_classifier import FontTypeClassifier -from tests.test_utils import get_test_config class TestFontClassifier(unittest.TestCase): - data_directory_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "data", "scanned")) - dirname = os.path.dirname(__file__) - path_model = os.path.abspath(os.path.join(get_test_config()["resources_path"], "font_classifier.pth")) - classifier = FontTypeClassifier(path_model) + classifier = FontTypeClassifier() def get_page(self) -> PageWithBBox: - image = Image.open(os.path.join(self.data_directory_path, "orient_1.png")) - - bbox_1 = TextWithBBox(bbox=BBox(10, 20, 11, 23), page_num=0, text="str", line_num=0) - bbox_2 = TextWithBBox(bbox=BBox(20, 30, 11, 23), page_num=0, text="rts", line_num=1) - bboxes = [bbox_1, bbox_2] + image = cv2.imread(os.path.join(self.data_directory_path, "example.png")) + + bboxes = [ + TextWithBBox(bbox=BBox(79, 86, 214, 21), page_num=0, text="Пример документа", line_num=0), + TextWithBBox(bbox=BBox(79, 113, 627, 20), page_num=0, text="Глава 1 с таким длинным названием которое даже не влазит в", line_num=0), + TextWithBBox(bbox=BBox(80, 142, 132, 16), page_num=0, text="одну строчку.", line_num=0), + TextWithBBox(bbox=BBox(80, 163, 154, 15), page_num=0, text="Какие то определения", line_num=0), + TextWithBBox(bbox=BBox(79, 182, 65, 11), page_num=0, text="Статья 1", line_num=0), + TextWithBBox(bbox=BBox(79, 201, 166, 15), page_num=0, text="опрделения", line_num=0), + TextWithBBox(bbox=BBox(79, 220, 66, 11), page_num=0, text="Статья 2", line_num=0), + TextWithBBox(bbox=BBox(79, 239, 124, 14), page_num=0, text="Дадим пояснения", line_num=0), + TextWithBBox(bbox=BBox(81, 259, 203, 11), page_num=0, text="1.2.1 Поясним за непонятное", line_num=0), + TextWithBBox(bbox=BBox(81, 278, 191, 11), page_num=0, text="1.2.2. Поясним за понятное", line_num=0), + TextWithBBox(bbox=BBox(129, 297, 171, 15), page_num=0, text="а) это даже ежу понятно", line_num=0), + TextWithBBox(bbox=BBox(129, 315, 153, 16), page_num=0, text="6) это ежу не понятно", line_num=0), + TextWithBBox(bbox=BBox(81, 335, 30, 11), page_num=0, text="123", line_num=0), + ] return PageWithBBox(image=image, bboxes=bboxes, page_num=0) - def test__page2tensor(self) -> None: - page = self.get_page() - tensor = FontTypeClassifier._page2tensor(page=page) - bbox_num, channels, height, width = tensor.shape - self.assertEqual(2, bbox_num) - self.assertEqual(3, channels) - self.assertEqual(15, height) - self.assertEqual(300, width) - - def test__get_model_predictions(self) -> None: + def test_bold_classification(self) -> None: page = self.get_page() - predictions = self.classifier._get_model_predictions(page) - self.assertEqual(predictions.shape[0], 2) - self.assertEqual(len(predictions.shape), 2) + self.classifier.predict_annotations(page) + + for bbox in page.bboxes[:3]: + self.assertIn(BoldAnnotation.name, [annotation.name for annotation in bbox.annotations]) + + for bbox in page.bboxes[3:]: + self.assertNotIn(BoldAnnotation.name, [annotation.name for annotation in bbox.annotations])