Skip to content

Commit

Permalink
Add single language mode statistics to scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
pemistahl committed Sep 20, 2024
1 parent ef426d7 commit 3a431e0
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 7 deletions.
31 changes: 25 additions & 6 deletions scripts/accuracy_plot_drawer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np
import pandas as pd
import seaborn as sns
from lingua import Language
from math import floor
from matplotlib.patches import Patch
from pathlib import Path
Expand All @@ -43,11 +44,12 @@ class AccuracyPlotDrawer:
"fasttext": "FastText 0.9.2",
"langdetect": "Langdetect 1.0.9",
"langid": "Langid 1.1.6",
"lingua-low-accuracy": "Lingua 1.3.5\nlow accuracy mode",
"lingua-high-accuracy": "Lingua 1.3.5\nhigh accuracy mode",
"lingua-low-accuracy": "Lingua 1.4.0\nlow accuracy mode",
"lingua-high-accuracy": "Lingua 1.4.0\nhigh accuracy mode",
"lingua-single-language-detector": "Lingua 1.4.0\nsingle language mode",
"simplemma": "Simplemma 0.9.1",
}
_hatches = ("|", "-", "/", "x", "+", "\\", "o", ".", "*", "O")
_hatches = ("|", "-", "/", "x", "+", "\\", "o", "oo", ".", "*", "O")
_palette = (
"#39d7e6",
"#6bbcff",
Expand All @@ -57,6 +59,7 @@ class AccuracyPlotDrawer:
"#ff8800",
"#ffb866",
"#ffc400",
"#fff480",
"#8edca7",
"#41c46b",
)
Expand All @@ -68,8 +71,22 @@ def __init__(self, plot_title: str, report_file_path: Path):

def _read_into_dataframe(self, report_file_path: Path) -> pd.DataFrame:
df = pd.read_csv(report_file_path, index_col="language")

single_language_mode_columns = [
f"lingua-{language.name.lower()}-detector" for language in Language
]
merged_single_language_mode_column = {
"lingua-single-language-detector": df[single_language_mode_columns].sum(
axis="columns"
)
}
df = df.assign(**merged_single_language_mode_column).drop(
single_language_mode_columns, axis="columns"
)

# Sort classifier columns by their mean value
df = df.reindex(df.mean().sort_values().index, axis=1)
df = df.reindex(df.mean().sort_values().index, axis="columns")

return pd.melt(
frame=df.reset_index(),
id_vars="language",
Expand All @@ -78,10 +95,12 @@ def _read_into_dataframe(self, report_file_path: Path) -> pd.DataFrame:
)

def draw_barplot(self, file_path: Path):
row_filter = self._dataframe[self._hue].isin(self._column_labels.keys())
column_labels = self._column_labels.copy()
del column_labels["lingua-single-language-detector"]
row_filter = self._dataframe[self._hue].isin(column_labels.keys())
data = self._dataframe[row_filter]
classifiers = data[self._hue].unique()
labels = [self._column_labels[classifier] for classifier in classifiers]
labels = [column_labels[classifier] for classifier in classifiers]
handles = [
Patch(facecolor=color, edgecolor="black", label=label, hatch=hatch)
for color, label, hatch in zip(self._palette, labels, self._hatches)
Expand Down
18 changes: 17 additions & 1 deletion scripts/accuracy_table_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import math
import pandas as pd

from lingua import Language
from pathlib import Path


Expand All @@ -30,6 +31,7 @@ class AccuracyTableWriter:
"langid": "Langid",
"lingua-low-accuracy": "Lingua<br>(low accuracy mode)",
"lingua-high-accuracy": "Lingua<br>(high accuracy mode)",
"lingua-single-language-detector": "Lingua<br>(single language mode)",
"simplemma": "Simplemma",
}

Expand Down Expand Up @@ -104,7 +106,21 @@ def write_accuracy_table(self, file_path: Path):
accuracy_table_file.write(table)

def _read_into_dataframe(self, report_file_path: Path) -> pd.DataFrame:
return pd.read_csv(report_file_path, index_col="language")
df = pd.read_csv(report_file_path, index_col="language")

single_language_mode_columns = [
f"lingua-{language.name.lower()}-detector" for language in Language
]
merged_single_language_mode_column = {
"lingua-single-language-detector": df[single_language_mode_columns].sum(
axis="columns"
)
}
df = df.assign(**merged_single_language_mode_column).drop(
single_language_mode_columns, axis="columns"
)

return df

def _get_square_color(self, accuracy_value: float) -> str:
if math.isnan(accuracy_value):
Expand Down

0 comments on commit 3a431e0

Please sign in to comment.