Skip to content

Commit

Permalink
better syntax and simplified code
Browse files Browse the repository at this point in the history
  • Loading branch information
adbar committed Jun 17, 2024
1 parent 9508f09 commit dd07aff
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions py3langid/langid.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from base64 import b64decode
from collections import Counter
from operator import itemgetter
from pathlib import Path
from urllib.parse import parse_qs

Expand All @@ -33,6 +34,9 @@
# affect the relative ordering of the predicted classes. It can be
# re-enabled at runtime - see the readme.

# quantization: faster but less precise
DATATYPE = "uint16"


def load_model(path=None):
"""
Expand Down Expand Up @@ -60,7 +64,7 @@ def set_languages(langs=None):
return IDENTIFIER.set_languages(langs)


def classify(instance, datatype='uint16'):
def classify(instance, datatype=DATATYPE):
"""
Convenience method using a global identifier instance with the default
model included in langid.py. Identifies the language that a string is
Expand Down Expand Up @@ -214,7 +218,7 @@ def set_languages(self, langs=None):
self.nb_ptc = nb_ptc[:, subset_mask]
self.nb_pc = nb_pc[subset_mask]

def instance2fv(self, text, datatype='uint16'):
def instance2fv(self, text, datatype=DATATYPE):
"""
Map an instance into the feature space of the trained model.
Expand All @@ -227,11 +231,12 @@ def instance2fv(self, text, datatype='uint16'):

# Convert the text to a sequence of ascii values and
# Count the number of times we enter each state
state = 0
indexes = []
for letter in list(text):
state, indexes = 0, []
extend = indexes.extend

for letter in text:
state = self.tk_nextmove[(state << 8) + letter]
indexes.extend(self.tk_output.get(state, []))
extend(self.tk_output.get(state, []))

# datatype: consider that less feature counts are going to be needed
arr = np.zeros(self.nb_numfeats, dtype=datatype)
Expand All @@ -247,7 +252,7 @@ def nb_classprobs(self, fv):
# compute the partial log-probability of the document in each class
return pdc + self.nb_pc

def classify(self, text, datatype='uint16'):
def classify(self, text, datatype=DATATYPE):
"""
Classify an instance.
"""
Expand All @@ -262,7 +267,7 @@ def rank(self, text):
"""
fv = self.instance2fv(text)
probs = self.norm_probs(self.nb_classprobs(fv))
return [(str(k), float(v)) for (v, k) in sorted(zip(probs, self.nb_classes), reverse=True)]
return sorted(zip(self.nb_classes, probs), key=itemgetter(1), reverse=True)

def cl_path(self, path):
"""
Expand Down

0 comments on commit dd07aff

Please sign in to comment.