Skip to content

Commit

Permalink
Tldr 353 refactor detect text correction (#290)
Browse files Browse the repository at this point in the history
* TLDR-353 start refactoring

* TLDR-353 pdf_auto_reader refactoring

* TLDR-353 tests fixed
  • Loading branch information
NastyBoget authored Jul 10, 2023
1 parent 44c3037 commit 055f8f7
Show file tree
Hide file tree
Showing 12 changed files with 349 additions and 396 deletions.
2 changes: 1 addition & 1 deletion dedoc/download_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def download_from_hub(out_dir: str, out_name: str, repo_name: str, hub_name: str

def download(resources_path: str) -> None:
download_from_hub(out_dir=resources_path,
out_name="catboost_detect_tl_correctness.pth",
out_name="catboost_detect_tl_correctness.pkl.gz",
repo_name="catboost_detect_tl_correctness",
hub_name="model.pkl.gz")

Expand Down
212 changes: 101 additions & 111 deletions dedoc/readers/pdf_reader/pdf_auto_reader/pdf_auto_reader.py

Large diffs are not rendered by default.

213 changes: 0 additions & 213 deletions dedoc/readers/pdf_reader/pdf_auto_reader/pdf_txtlayer_correctness.py

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,30 @@
import os
import pickle
from typing import List

import catboost.core
from dedoc.download_models import download_from_hub

from dedoc.config import get_config
from dedoc.download_models import download_from_hub
from dedoc.readers.pdf_reader.data_classes.text_with_bbox import TextWithBBox


class CatboostModelExtractor:
class TxtlayerClassifier:
"""
The CatboostModelExtractor class is used for detecting the correctness of the text layer in a PDF document
using a CatBoost model.
The TxtlayerClassifier class is used for classifying the correctness of the text layer in a PDF document.
"""
def __init__(self, *, config: dict) -> None:
self.config = config
self.logger = config.get("logger", logging.getLogger())
eng = list(map(chr, range(ord('a'), ord('z') + 1)))

rus = [chr(i) for i in range(ord('а'), ord('а') + 32)]
rus.append("ё")

eng = list(map(chr, range(ord('a'), ord('z') + 1)))
rus = [chr(i) for i in range(ord('а'), ord('а') + 32)] + ["ё"]
digits = [str(i) for i in range(10)]
special_symbols = [i for i in "<>~!@#$%^&*_+-/\"|?.,:;'`= "]
brackets = [i for i in "{}[]()"]
self.list_letters = eng + [i.upper() for i in eng] + rus + [i.upper() for i in rus]
self.list_symbols = digits + special_symbols + brackets

self.letters_list = eng + [i.upper() for i in eng] + rus + [i.upper() for i in rus]
self.symbols_list = digits + special_symbols + brackets

self.path = os.path.join(get_config()["resources_path"], "catboost_detect_tl_correctness.pkl.gz")
self.__model = None
Expand All @@ -47,33 +46,39 @@ def __get_model(self) -> catboost.core.CatBoostClassifier:

return self.__model

def detect_text_layer_correctness(self, text_layer_bboxes: List[TextWithBBox]) -> bool:
def predict(self, text_with_bboxes: List[TextWithBBox]) -> bool:
"""
Detect the correctness of the text layer in a PDF document.
:param text_layer_bboxes: List of text lines with bounding boxes.
Classifies the correctness of the text layer in a PDF document.
:param text_with_bboxes: List of text lines with bounding boxes.
:returns: True if the text layer is correct, False otherwise.
"""
text_layer = u"".join([pdf_line.text for pdf_line in text_layer_bboxes])
text_layer = u"".join([pdf_line.text for pdf_line in text_with_bboxes])
if not text_layer:
return False

features = self.__get_feature_for_predict(text_layer)
return True if self.__get_model.predict(features) == 1 else False
return self.__get_model.predict(features) == 1

def __get_feature_for_predict(self, text: str) -> List[float]:
list_of_sub = []
features = []
num_letters_in_data = self._count_letters(text)
num_other_symbol_in_data = self._count_other(text)
for symbol in self.list_letters:

for symbol in self.letters_list:
# proportion of occurring english and russian letters
list_of_sub.append(round(text.count(symbol) / num_letters_in_data, 5) if num_letters_in_data != 0 else 0.0)
for symbol in self.list_symbols:
list_of_sub.append(text.count(symbol))
list_of_sub.append((num_letters_in_data + num_other_symbol_in_data) / len(text) if len(text) != 0 else 0)
return list_of_sub
features.append(round(text.count(symbol) / num_letters_in_data, 5) if num_letters_in_data != 0 else 0.0)

for symbol in self.symbols_list:
# number of symbols
features.append(text.count(symbol))

# proportion of letters with symbols
features.append((num_letters_in_data + num_other_symbol_in_data) / len(text) if len(text) != 0 else 0)
return features

def _count_letters(self, text: str) -> int:
return sum(1 for symbol in text if symbol in self.list_letters)
return sum(1 for symbol in text if symbol in self.letters_list)

def _count_other(self, text: str) -> int:
return sum(1 for symbol in text if symbol in self.list_symbols)
return sum(1 for symbol in text if symbol in self.symbols_list)
Loading

0 comments on commit 055f8f7

Please sign in to comment.