-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
TLDR-379 update bold classifier (#294)
* add bold classifier based on rules * replace nn classifier with ruled classifier * remove downloading deleted model * fix style tests * review fixes * add test for bold classifier
- Loading branch information
1 parent
e8da0c6
commit 8fb2907
Showing
9 changed files
with
258 additions
and
310 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
66 changes: 66 additions & 0 deletions
66
...der/pdf_image_reader/line_metadata_extractor/bold_classifier/agglomerative_clusterizer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import numpy as np | ||
from scipy.stats import norm | ||
from sklearn.cluster import AgglomerativeClustering | ||
|
||
|
||
class BoldAgglomerativeClusterizer: | ||
def __init__(self) -> None: | ||
self.significance_level = 0.2 | ||
|
||
def clusterize(self, x: np.ndarray) -> np.ndarray: | ||
x_vectors = self.__get_prop_vectors(x) | ||
x_clusters = self.__get_clusters(x_vectors) | ||
x_indicator = self.__get_indicator(x, x_clusters) | ||
return x_indicator | ||
|
||
def __get_prop_vectors(self, x: np.ndarray) -> np.ndarray: | ||
nearby_x = x.copy() | ||
nearby_x[:-1] += x[1:] | ||
nearby_x[1:] += x[:-1] | ||
nearby_x[0] += x[0] | ||
nearby_x[-1] += x[-1] | ||
nearby_x = nearby_x / 3. | ||
return np.stack((x, nearby_x), 1) | ||
|
||
def __get_clusters(self, x_vectors: np.ndarray) -> np.ndarray: | ||
agg = AgglomerativeClustering() | ||
agg.fit(x_vectors) | ||
x_clusters = agg.labels_ | ||
return x_clusters | ||
|
||
def __get_indicator(self, x: np.ndarray, x_clusters: np.ndarray) -> np.ndarray: | ||
# https://www.tsi.lv/sites/default/files/editor/science/Research_journals/Tr_Tel/2003/V1/yatskiv_gousarova.pdf | ||
# https://www.svms.org/classification/DuHS95.pdf | ||
# Pattern Classification and Scene Analysis (2nd ed.) | ||
# Part 1: Pattern Classification | ||
# Richard O. Duda, Peter E. Hart and David G. Stork | ||
# February 27, 1995 | ||
f1 = self.__get_f1_homogeneous(x, x_clusters) | ||
f_cr = self.__get_f_criterion_homogeneous(n=len(x)) | ||
|
||
if f_cr < f1: | ||
return np.zeros_like(x) | ||
if np.mean(x[x_clusters == 1]) < np.mean(x[x_clusters == 0]): | ||
x_clusters[x_clusters == 1] = 1.0 | ||
x_clusters[x_clusters == 0] = 0.0 | ||
else: | ||
x_clusters[x_clusters == 0] = 1.0 | ||
x_clusters[x_clusters == 1] = 0.0 | ||
|
||
return x_clusters | ||
|
||
def __get_f1_homogeneous(self, x: np.ndarray, x_clusters: np.ndarray) -> float: | ||
x_clust0 = x[x_clusters == 0] | ||
x_clust1 = x[x_clusters == 1] | ||
if len(x_clust0) == 0 or len(x_clust1) == 0: | ||
return 1 | ||
|
||
w1 = np.std(x) * len(x) | ||
w2 = np.std(x_clust0) * len(x_clust0) + np.std(x_clust1) * len(x_clust1) | ||
f1 = w2 / w1 | ||
return f1 | ||
|
||
def __get_f_criterion_homogeneous(self, n: int, p: int = 2) -> float: | ||
za1 = norm.ppf(1 - self.significance_level, loc=0, scale=1) | ||
f_cr = 1 - 2 / (np.pi * p) - za1 * np.sqrt(2 * (1 - 8 / (np.pi ** 2 * p)) / (n * p)) | ||
return f_cr |
107 changes: 107 additions & 0 deletions
107
...rs/pdf_reader/pdf_image_reader/line_metadata_extractor/bold_classifier/bold_classifier.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
from typing import List | ||
|
||
import numpy as np | ||
|
||
from dedoc.data_structures import BBox | ||
from dedoc.readers.pdf_reader.pdf_image_reader.line_metadata_extractor.bold_classifier.agglomerative_clusterizer import BoldAgglomerativeClusterizer | ||
from dedoc.readers.pdf_reader.pdf_image_reader.line_metadata_extractor.bold_classifier.valley_emphasis_binarizer import ValleyEmphasisBinarizer | ||
|
||
|
||
class BoldClassifier: | ||
""" | ||
This class classifies words (or lines) in bboxes as bold or non-bold. | ||
Given a list of bboxes and an image, it returns a list of boldness probabilities (actually only 0 and 1 for now) | ||
""" | ||
def __init__(self) -> None: | ||
self.permissible_h_bbox = 5 | ||
self.binarizer = ValleyEmphasisBinarizer() | ||
self.clusterizer = BoldAgglomerativeClusterizer() | ||
|
||
def classify(self, image: np.ndarray, bboxes: List[BBox]) -> List[float]: | ||
if len(bboxes) == 0: | ||
return [] | ||
|
||
if len(bboxes) == 1: | ||
return [0.0] | ||
|
||
bboxes_evaluation = self.__get_bboxes_evaluation(image, bboxes) | ||
bold_probabilities = self.__clusterize(bboxes_evaluation) | ||
return bold_probabilities | ||
|
||
def __get_bboxes_evaluation(self, image: np.ndarray, bboxes: List[BBox]) -> List[float]: | ||
processed_image = self.__preprocessing(image) | ||
bboxes_evaluation = self.__get_evaluation_bboxes(processed_image, bboxes) | ||
return bboxes_evaluation | ||
|
||
def __preprocessing(self, image: np.ndarray) -> np.ndarray: | ||
return self.binarizer.binarize(image) | ||
|
||
def __get_evaluation_bboxes(self, image: np.ndarray, bboxes: List[BBox]) -> List[float]: | ||
bboxes_evaluation = [self.__evaluation_one_bbox(image, bbox) for bbox in bboxes] | ||
return bboxes_evaluation | ||
|
||
def __evaluation_one_bbox(self, image: np.ndarray, bbox: BBox) -> float: | ||
bbox_image = image[bbox.y_top_left:bbox.y_bottom_right, bbox.x_top_left:bbox.x_bottom_right] | ||
return self.__evaluation_one_bbox_image(bbox_image) if self.__is_correct_bbox_image(bbox_image) else 1. | ||
|
||
def __evaluation_one_bbox_image(self, image: np.ndarray) -> float: | ||
base_line_image = self.__get_base_line_image(image) | ||
base_line_image_without_spaces = self.__get_rid_spaces(base_line_image) | ||
|
||
p_img = base_line_image[:, :-1] - base_line_image[:, 1:] | ||
p_img[abs(p_img) > 0] = 1. | ||
p_img[p_img < 0] = 0. | ||
p = p_img.mean() | ||
|
||
s = 1 - base_line_image_without_spaces.mean() | ||
|
||
if p > s or s == 0: | ||
evaluation = 1. | ||
else: | ||
evaluation = p / s | ||
return evaluation | ||
|
||
def __clusterize(self, bboxes_evaluation: List[float]) -> List[float]: | ||
vector_bbox_evaluation = np.array(bboxes_evaluation) | ||
vector_bbox_indicators = self.clusterizer.clusterize(vector_bbox_evaluation) | ||
bboxes_indicators = list(vector_bbox_indicators) | ||
return bboxes_indicators | ||
|
||
def __get_rid_spaces(self, image: np.ndarray) -> np.ndarray: | ||
x = image.mean(0) | ||
not_space = x < 0.95 | ||
if len(not_space) > 3: | ||
return image | ||
return image[:, not_space] | ||
|
||
def __get_base_line_image(self, image: np.ndarray) -> np.ndarray: | ||
h = image.shape[0] | ||
if h < self.permissible_h_bbox: | ||
return image | ||
mean_ = image.mean(1) | ||
delta_mean = abs(mean_[:-1] - mean_[1:]) | ||
|
||
max1 = 0 | ||
max2 = 0 | ||
argmax1 = 0 | ||
argmax2 = 0 | ||
for i, delta_mean_i in enumerate(delta_mean): | ||
if delta_mean_i <= max2: | ||
continue | ||
if delta_mean_i > max1: | ||
max2 = max1 | ||
argmax2 = argmax1 | ||
max1 = delta_mean_i | ||
argmax1 = i | ||
else: | ||
max2 = delta_mean_i | ||
argmax2 = i | ||
h_min = min(argmax1, argmax2) | ||
h_max = min(max(argmax1, argmax2) + 1, h) | ||
if h_max - h_min < self.permissible_h_bbox: | ||
return image | ||
return image[h_min:h_max, :] | ||
|
||
def __is_correct_bbox_image(self, image: np.ndarray) -> bool: | ||
h, w = image.shape[0:2] | ||
return h > 3 and w > 3 |
46 changes: 46 additions & 0 deletions
46
...der/pdf_image_reader/line_metadata_extractor/bold_classifier/valley_emphasis_binarizer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import cv2 | ||
import numpy as np | ||
|
||
|
||
class ValleyEmphasisBinarizer: | ||
def __init__(self, n: int = 5) -> None: | ||
self.n = n | ||
|
||
def binarize(self, image: np.ndarray) -> np.ndarray: | ||
gray_img = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | ||
threshold = self.__get_threshold(gray_img) | ||
|
||
gray_img[gray_img <= threshold] = 0 | ||
gray_img[gray_img > threshold] = 1 | ||
return gray_img | ||
|
||
def __get_threshold(self, gray_img: np.ndarray) -> int: | ||
c, x = np.histogram(gray_img, bins=255) | ||
h, w = gray_img.shape | ||
total = h * w | ||
|
||
sum_val = 0 | ||
for t in range(255): | ||
sum_val = sum_val + (t * c[t] / total) | ||
|
||
var_max = 0 | ||
threshold = 0 | ||
|
||
omega_1 = 0 | ||
mu_k = 0 | ||
|
||
for t in range(254): | ||
omega_1 = omega_1 + c[t] / total | ||
omega_2 = 1 - omega_1 | ||
mu_k = mu_k + t * (c[t] / total) | ||
mu_1 = mu_k / omega_1 | ||
mu_2 = (sum_val - mu_k) / omega_2 | ||
sum_of_neighbors = np.sum(c[max(1, t - self.n):min(255, t + self.n)]) | ||
denom = total | ||
current_var = (1 - sum_of_neighbors / denom) * (omega_1 * mu_1 ** 2 + omega_2 * mu_2 ** 2) | ||
|
||
if current_var > var_max: | ||
var_max = current_var | ||
threshold = t | ||
|
||
return threshold |
59 changes: 7 additions & 52 deletions
59
dedoc/readers/pdf_reader/pdf_image_reader/line_metadata_extractor/font_type_classifier.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,67 +1,22 @@ | ||
import os | ||
from collections import namedtuple | ||
from typing import Any | ||
import torch | ||
from torchvision.transforms import ToTensor | ||
|
||
from dedoc.data_structures.concrete_annotations.bold_annotation import BoldAnnotation | ||
from dedoc.download_models import download_from_hub | ||
from dedoc.readers.pdf_reader.data_classes.page_with_bboxes import PageWithBBox | ||
from dedoc.utils.image_utils import get_bbox_from_image | ||
|
||
FontType = namedtuple("FontType", ["bold", "other"]) | ||
from dedoc.readers.pdf_reader.pdf_image_reader.line_metadata_extractor.bold_classifier.bold_classifier import BoldClassifier | ||
|
||
|
||
class FontTypeClassifier: | ||
labels_list = ["bold", "OTHER"] | ||
|
||
def __init__(self, model_path: str) -> None: | ||
def __init__(self) -> None: | ||
super().__init__() | ||
self._model = None | ||
self.model_path = model_path | ||
|
||
@property | ||
def model(self) -> Any: | ||
if self._model is not None: | ||
return self._model | ||
|
||
if not os.path.isfile(self.model_path): | ||
out_dir, out_name = os.path.split(self.model_path) | ||
download_from_hub(out_dir=out_dir, out_name=out_name, repo_name="font_classifier", hub_name="model.pth") | ||
|
||
with open(self.model_path, "rb") as file: | ||
self._model = torch.load(f=file).eval() | ||
self._model.requires_grad_(False) | ||
|
||
return self._model | ||
self.bold_classifier = BoldClassifier() | ||
|
||
def predict_annotations(self, page: PageWithBBox) -> PageWithBBox: | ||
if len(page.bboxes) == 0: | ||
return page | ||
tensor_predictions = self._get_model_predictions(page) | ||
is_bold = ["bold" if p else "not_bold" for p in (tensor_predictions[:, 0] > 0.5)] | ||
is_other = ["other" if p else "not_other" for p in (tensor_predictions[:, 1] > 0.5)] | ||
font_type_predictions = [FontType(*_) for _ in zip(is_bold, is_other)] | ||
|
||
boxes_fonts = zip(page.bboxes, font_type_predictions) | ||
bboxes = [bbox.bbox for bbox in page.bboxes] | ||
bold_probabilities = self.bold_classifier.classify(page.image, bboxes) | ||
|
||
for bbox, font_type in boxes_fonts: | ||
if font_type.bold == "bold": | ||
for bbox, bold_probability in zip(page.bboxes, bold_probabilities): | ||
if bold_probability > 0.5: | ||
bbox.annotations.append(BoldAnnotation(start=0, end=len(bbox.text), value="True")) | ||
|
||
return page | ||
|
||
@staticmethod | ||
def _page2tensor(page: PageWithBBox) -> torch.Tensor: | ||
if len(page.bboxes) == 0: | ||
return torch.zeros() | ||
to_tensor = ToTensor() | ||
images = (get_bbox_from_image(image=page.image, bbox=bbox.bbox) for bbox in page.bboxes) | ||
tensors = (to_tensor(image) for image in images) | ||
tensors_reshaped = [tensor.unsqueeze(0) for tensor in tensors] | ||
return torch.cat(tensors_reshaped, dim=0) | ||
|
||
def _get_model_predictions(self, page: PageWithBBox) -> torch.Tensor: | ||
tensor = self._page2tensor(page) | ||
with torch.no_grad(): | ||
return self.model(tensor) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.