Skip to content

Commit

Permalink
TLDR-714 change orient classificator orientation
Browse files Browse the repository at this point in the history
  • Loading branch information
oksidgy committed Jun 17, 2024
1 parent ec37f49 commit 3984f2d
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,9 @@ def __init__(self) -> None:
def load_dataset(self, csv_path: str, image_path: str, batch_size: int = 4) -> DataLoader:
trainset = DatasetImageOrient(csv_file=csv_path, root_dir=image_path, transform=self.transform)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
self.amount = len(trainset)

return trainloader

def __len__(self) -> int:
return self.amount
25 changes: 25 additions & 0 deletions resources/benchmarks/orient_classifier_scores.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@

Orientation predictions:
+-------+-----------+--------+-------+-------+
| Class | Precision | Recall | F1 | Count |
+=======+===========+========+=======+=======+
| 0 | 0.998 | 1 | 0.999 | 537 |
+-------+-----------+--------+-------+-------+
| 90 | 1 | 0.998 | 0.999 | 537 |
+-------+-----------+--------+-------+-------+
| 180 | 1 | 0.998 | 0.999 | 537 |
+-------+-----------+--------+-------+-------+
| 270 | 0.998 | 1 | 0.999 | 537 |
+-------+-----------+--------+-------+-------+
| AVG | 0.999 | 0.999 | 0.999 | None |
+-------+-----------+--------+-------+-------+
Column predictions:
+-------+-----------+--------+-------+-------+
| Class | Precision | Recall | F1 | Count |
+=======+===========+========+=======+=======+
| 1 | 1 | 0.999 | 0.999 | 1692 |
+-------+-----------+--------+-------+-------+
| 2 | 0.996 | 1 | 0.998 | 456 |
+-------+-----------+--------+-------+-------+
| AVG | 0.999 | 0.999 | 0.999 | None |
+-------+-----------+--------+-------+-------+
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
from time import time
from typing import List

import numpy as np
import torch
from sklearn.metrics import precision_recall_fscore_support
from texttable import Texttable
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from tqdm import tqdm

from dedoc.config import get_config
from dedoc.readers.pdf_reader.pdf_image_reader.columns_orientation_classifier.columns_orientation_classifier import ColumnsOrientationClassifier
Expand All @@ -16,19 +20,27 @@
checkpoint_path_save = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "resources", "efficient_net_b0_fixed.pth"))
checkpoint_path_load = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "resources", "efficient_net_b0_fixed.pth"))
checkpoint_path = "../../resources"
output_dir = os.path.abspath(os.path.join(checkpoint_path, "benchmarks"))

parser.add_argument("-t", "--train", type=bool, help="run for train model", default=False)
parser.add_argument("-s", "--checkpoint_save", help="Path to checkpoint for save or load", default=checkpoint_path_save)
parser.add_argument("-l", "--checkpoint_load", help="Path to checkpoint for load", default=checkpoint_path_load)
parser.add_argument("-f", "--from_checkpoint", type=bool, help="run for train model", default=True)
parser.add_argument("-d", "--input_data_folder", help="Path to data with folders train or test")
parser.add_argument("-d", "--input_data_folder", help="Path to data with folders train or test",
default="/home/ox/work/datasets/generate_dataset_orient_cls/generate_dataset_orient_classifier")

args = parser.parse_args()
BATCH_SIZE = 1
ON_GPU = False
ON_GPU = True

"""
Input data are available from our confluence (closed data).
First, you need generate full train/test data (all orientation of src documents) using scripts/gen_dataset.py
Then, you can use this script.
"""

def accuracy_step(data_executor: DataLoaderImageOrient, net_executor: ColumnsOrientationClassifier) -> None:

def evaluation_step(data_executor: DataLoaderImageOrient, net_executor: ColumnsOrientationClassifier) -> None:
"""
Function calculates accuracy for the trained model
:param data_executor: Extractor Data from path
Expand All @@ -47,10 +59,22 @@ def accuracy_step(data_executor: DataLoaderImageOrient, net_executor: ColumnsOri

print(f"GroundTruth: orientation {orientation}, columns {columns}")

calc_accuracy_by_classes(testloader, data_executor.classes, net_executor, batch_size=1)
evaluation(testloader, data_executor.classes, net_executor)


def print_metrics(precision: np.array, recall: np.array, f1: np.array, cnt: np.array, avg: np.array, classes: List[str]) -> Texttable:
table = Texttable()

table.header(["Class", "Precision", "Recall", "F1", "Count"])
for i, name_class in enumerate(classes):
table.add_row([name_class, precision[i], recall[i], f1[i], cnt[i]])

table.add_row(["AVG", avg[0], avg[1], avg[2], "None"])

def calc_accuracy_by_classes(testloader: DataLoader, classes: List, classifier: ColumnsOrientationClassifier, batch_size: int = 1) -> None:
return table


def evaluation(testloader: DataLoader, classes: List, classifier: ColumnsOrientationClassifier) -> None:
"""
Function calculates accuracy ba each class
:param testloader: DataLoader
Expand All @@ -59,43 +83,47 @@ def calc_accuracy_by_classes(testloader: DataLoader, classes: List, classifier:
:param batch_size: size of batch
:return:
"""
class_correct = list(0. for _ in range(len(classes)))
class_total = list(0. for _ in range(len(classes)))
orientation_pred, orientation_true = [], []
column_pred, column_true = [], []

time_predict = 0
cnt_predict = 0
with torch.no_grad():
for data in testloader:
for data in tqdm(testloader):
images, orientation, columns = data["image"], data["orientation"], data["columns"]
time_begin = time()

time_begin = time()
outputs = classifier.net(images.float().to(classifier.device))
time_predict += time() - time_begin
cnt_predict += len(images)

# first 2 classes mean columns number
# last 4 classes mean orientation
columns_out, orientation_out = outputs[:, :2], outputs[:, 2:]
_, columns_predicted = torch.max(columns_out, 1)
_, orientation_predicted = torch.max(orientation_out, 1)

orientation_c = (orientation_predicted == orientation.to(classifier.device)).squeeze()
columns_c = (columns_predicted == columns.to(classifier.device)).squeeze()

for i in range(batch_size):
orientation_i = orientation[i]
columns_i = columns[i]
orientation_bool_predict = orientation_c.item() if batch_size == 1 else orientation_c[i].item()
columns_bool_predict = columns_c.item() if batch_size == 1 else columns_c[i].item()
class_correct[2 + orientation_i] += orientation_bool_predict
class_total[2 + orientation_i] += 1
class_correct[columns_i] += orientation_bool_predict
class_total[columns_i] += 1
if not orientation_bool_predict or not columns_bool_predict:
print(
f'{data["image_name"][i]} predict as \norientation: {classes[2 + orientation_predicted[i]]} \ncolumns: {classes[columns_predicted[i]]}'
)

for i in range(len(classes)):
print(f"Accuracy of {classes[i]:5s} : {100 * class_correct[i] / class_total[i] if class_total[i] != 0 else 0:2d} %")
orientation_pred.append(classes[2 + orientation_predicted.squeeze().item()])
orientation_true.append(classes[2 + orientation.to(classifier.device).squeeze().item()])

column_pred.append(classes[columns_predicted.squeeze().item()])
column_true.append(classes[columns.to(classifier.device).squeeze().item()])

with open(os.path.join(output_dir, "orient_classifier_scores.txt"), "w") as benchmark_file:
orient_metrics = precision_recall_fscore_support(orientation_true, orientation_pred, average=None, labels=classes[2:])
orient_avg = precision_recall_fscore_support(orientation_true, orientation_pred, average="weighted")
table = print_metrics(*orient_metrics, orient_avg, classes[2:])
print(table.draw())
benchmark_file.write("\nOrientation predictions:\n")
benchmark_file.write(table.draw())

column_metrics = precision_recall_fscore_support(column_true, column_pred, average=None, labels=classes[:2])
column_avg = precision_recall_fscore_support(column_true, column_pred, average="weighted")
table = print_metrics(*column_metrics, column_avg, classes[:2])
print(table.draw())
benchmark_file.write("\nColumn predictions:\n")
benchmark_file.write(table.draw())

print(f"=== AVG Time predict {time_predict / cnt_predict}")


Expand Down Expand Up @@ -171,4 +199,4 @@ def train_step(data_executor: DataLoaderImageOrient, classifier: ColumnsOrientat
if args.train:
train_step(data_executor, net)
else:
accuracy_step(data_executor, net)
evaluation_step(data_executor, net)

0 comments on commit 3984f2d

Please sign in to comment.