Skip to content

Commit

Permalink
Merge pull request #10 from R1j1t/dev
Browse files Browse the repository at this point in the history
updated extension data
  • Loading branch information
R1j1t authored Jun 13, 2020
2 parents 999e2d1 + 42ce2c7 commit 3499393
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 33 deletions.
87 changes: 69 additions & 18 deletions contextualSpellCheck/contextualSpellCheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import editdistance
import datetime
import os
import copy
import warnings

from spacy.tokens import Doc, Token, Span
from spacy.vocab import Vocab
Expand Down Expand Up @@ -40,7 +42,7 @@ def __init__(self, vocab_path="", debug=False, performance=False):
# {originalToken-1:[suggestedToken-1,suggestedToken-2,..],
# originalToken-2:[...]}
Doc.set_extension("suggestions_spellCheck", default={})
Doc.set_extension("outcome_spellCheck", getter=self.doc_outcome_spellCheck)
Doc.set_extension("outcome_spellCheck", default="")
Doc.set_extension("score_spellCheck", default=None)

Span.set_extension(
Expand Down Expand Up @@ -71,10 +73,10 @@ def __call__(self, doc):
if self.performance:
modelLoadTime = self.timeLog("Misspell identification: ", modelLodaded)
if len(misspellTokens) > 0:
candidate = self.candidateGenerator(doc, misspellTokens)
doc, candidate = self.candidateGenerator(doc, misspellTokens)
if self.performance:
modelLoadTime = self.timeLog("candidate Generator: ", modelLodaded)
answer = self.candidateRanking(candidate)
answer = self.candidateRanking(doc, candidate)
if self.performance:
modelLoadTime = self.timeLog("candidate ranking: ", modelLodaded)
return doc
Expand Down Expand Up @@ -132,9 +134,11 @@ def misspellIdentify(self, doc, query=""):
`tuple` -- returns `List[`Spacy.Token`]` and `Spacy.Doc`
"""

docCopy = copy.deepcopy(doc)

# doc = self.nlp(query)
misspell = []
for token in doc:
for token in docCopy:
if (
(token.text.lower() not in self.vocab)
and (token.ent_type_ != "PERSON")
Expand All @@ -150,7 +154,7 @@ def misspellIdentify(self, doc, query=""):
misspell.append(token)

if self.debug:
print(misspell)
print("misspell identified: ", misspell)
return (misspell, doc)

def candidateGenerator(self, doc, misspellings, top_n=10):
Expand Down Expand Up @@ -223,16 +227,19 @@ def candidateGenerator(self, doc, misspellings, top_n=10):

if self.debug:
print(
"response[token]: ", response[token], "score[token]: ", score[token]
"response[" + "`" + str(token) + "`" + "]: ",
response[token],
"score[" + "`" + str(token) + "`" + "]: ",
score[token],
)

if len(misspellings) != 0:
doc._.set("performed_spellCheck", True)
doc._.set("score_spellCheck", score)

return response
return (doc, response)

def candidateRanking(self, misspellingsDict):
def candidateRanking(self, doc, misspellingsDict):
"""Ranking the candidates based on edit Distance
At present using a library to calculate edit distance
Expand Down Expand Up @@ -265,10 +272,23 @@ def candidateRanking(self, misspellingsDict):
tempToken = misspell

if self.debug:
print("response[misspell]", response[misspell])
print("response[" + "`" + str(misspell) + "`" + "]", response[misspell])

if len(response) > 0:
tempToken.doc._.set("suggestions_spellCheck", response)
doc._.set("suggestions_spellCheck", response)
updatedQuery = ""
for i in doc:
updatedToken = i.text_with_ws
for misspell in response.keys():
if i.i == misspell.i:
updatedToken = response[misspell] + misspell.whitespace_
break
updatedQuery += updatedToken
doc._.set("outcome_spellCheck", updatedQuery)

if self.debug:
print("Final suggestions", doc._.suggestions_spellCheck)

return response

def timeLog(self, fnName, relativeTime):
Expand Down Expand Up @@ -297,7 +317,7 @@ def token_require_spellCheck(self, token):
"""
return any(
[
token.i == suggestion.i
token.i == suggestion.i and token.text == suggestion.text
for suggestion in token.doc._.suggestions_spellCheck.keys()
]
)
Expand All @@ -313,7 +333,12 @@ def token_suggestion_spellCheck(self, token):
"""
for suggestion in token.doc._.suggestions_spellCheck.keys():
if token.i == suggestion.i:
return token.doc._.suggestions_spellCheck[token]
if token.text_with_ws == suggestion.text_with_ws:
return token.doc._.suggestions_spellCheck[suggestion]
else:
warnings.warn(
"Position of tokens modified by downstream element in pipeline eg. merge_entities"
)
return ""

def token_score_spellCheck(self, token):
Expand All @@ -329,7 +354,12 @@ def token_score_spellCheck(self, token):
return []
for suggestion in token.doc._.score_spellCheck.keys():
if token.i == suggestion.i:
return token.doc._.score_spellCheck[token]
if token.text == suggestion.text:
return token.doc._.score_spellCheck[suggestion]
else:
warnings.warn(
"Position of tokens modified by downstream element in pipeline eg. merge_entities"
)
return []

def span_score_spellCheck(self, span):
Expand Down Expand Up @@ -389,10 +419,14 @@ def doc_outcome_spellCheck(self, doc):
suggestions = doc._.suggestions_spellCheck

for i in doc:
if i.i in [misspell.i for misspell in suggestions.keys()]:
updatedQuery += suggestions[i] + i.whitespace_
else:
updatedQuery += i.text_with_ws
updatedToken = i.text_with_ws
for misspell in suggestions.keys():
if misspell.text_with_ws in i.text_with_ws:
updatedToken = suggestions[misspell] + misspell.whitespace_
suggestions.remove(misspell)
break

updatedQuery += updatedToken

if self.debug:
print("Did you mean: ", updatedQuery)
Expand All @@ -403,10 +437,13 @@ def doc_outcome_spellCheck(self, doc):
if __name__ == "__main__":
print("Code running...")
nlp = spacy.load("en_core_web_sm")
# for issue #1
# merge_ents = nlp.create_pipe("merge_entities")
if "parser" not in nlp.pipe_names:
raise AttributeError("parser is required please enable it in nlp pipeline")
checker = ContextualSpellCheck(debug=False)
checker = ContextualSpellCheck(debug=True)
nlp.add_pipe(checker)
# nlp.add_pipe(merge_ents)

doc = nlp(u"Income was $9.4 milion compared to the prior year of $2.7 milion.")

Expand All @@ -417,3 +454,17 @@ def doc_outcome_spellCheck(self, doc):
print(doc._.performed_spellCheck)
print(doc._.suggestions_spellCheck)
print(doc._.score_spellCheck)

token_pos = 4
print("=" * 20, "Token Extention Test", "=" * 20)
print(doc[token_pos].text, doc[token_pos].i)
print(doc[token_pos]._.get_require_spellCheck)
print(doc[token_pos]._.get_suggestion_spellCheck)
print(doc[token_pos]._.score_spellCheck)

span_start = token_pos - 2
span_end = token_pos + 2
print("=" * 20, "Span Extention Test", "=" * 20)
print(doc[span_start:span_end].text)
print(doc[span_start:span_end]._.get_has_spellCheck)
print(doc[span_start:span_end]._.score_spellCheck)
93 changes: 78 additions & 15 deletions contextualSpellCheck/tests/test_contextualSpellCheck.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import spacy
from pytest import approx
import warnings

from contextualSpellCheck.contextualSpellCheck import ContextualSpellCheck

Expand Down Expand Up @@ -49,7 +50,13 @@ def test_type_misspellIdentify(inputSentence, misspell):
def test_identify_misspellIdentify(inputSentence, misspell):
print("Start misspell word identifation test\n")
doc = nlp(inputSentence)
assert checker.misspellIdentify(doc)[0] == [doc[i] for i in misspell]
checkerReturn = checker.misspellIdentify(doc)[0]
assert type(checkerReturn) == list
## Changed the approach after v0.1.0
assert [tok.text_with_ws for tok in checkerReturn] == [
doc[i].text_with_ws for i in misspell
]
assert [tok.i for tok in checkerReturn] == [doc[i].i for i in misspell]


@pytest.mark.parametrize(
Expand Down Expand Up @@ -118,7 +125,9 @@ def test_skipURL_misspellIdentify(inputSentence, misspell):
def test_type_candidateGenerator(inputSentence, misspell):
doc = nlp(inputSentence)
misspell, doc = checker.misspellIdentify(doc)
assert type(checker.candidateGenerator(doc, misspell)) == dict
assert type(checker.candidateGenerator(doc, misspell)) == tuple
assert type(checker.candidateGenerator(doc, misspell)[0]) == type(doc)
assert type(checker.candidateGenerator(doc, misspell)[1]) == dict


@pytest.mark.parametrize(
Expand Down Expand Up @@ -176,9 +185,15 @@ def test_identify_candidateGenerator(inputSentence, misspell):
print("Start misspell word identifation test\n")
doc = nlp(inputSentence)
(misspellings, doc) = checker.misspellIdentify(doc)
suggestions = checker.candidateGenerator(doc, misspellings)
gold_suggestions = {doc[key]: value for key, value in misspell.items()}
assert suggestions == gold_suggestions
doc, suggestions = checker.candidateGenerator(doc, misspellings)
## changed after v1.0 because of deepCopy creatng issue with ==
# gold_suggestions = {doc[key]: value for key, value in misspell.items()}
assert [tok.i for tok in suggestions] == [key for key in misspell.keys()]
assert [suggString for suggString in suggestions.values()] == [
suggString for suggString in misspell.values()
]

# assert suggestions == gold_suggestions


@pytest.mark.parametrize(
Expand Down Expand Up @@ -249,11 +264,12 @@ def test_extension_candidateGenerator(inputSentence, misspell):
def test_extension2_candidateGenerator(inputSentence, misspell):
doc = nlp(inputSentence)
(misspellings, doc) = checker.misspellIdentify(doc)
suggestions = checker.candidateGenerator(doc, misspellings)
assert (
doc._.score_spellCheck.keys()
== {doc[key]: value for key, value in misspell.items()}.keys()
)
doc, suggestions = checker.candidateGenerator(doc, misspellings)

## changes after v0.1.0
assert [tokIndex.i for tokIndex in doc._.score_spellCheck.keys()] == [
tokIndex for tokIndex in misspell.keys()
]
assert [
word_score[0]
for value in doc._.score_spellCheck.values()
Expand Down Expand Up @@ -283,9 +299,14 @@ def test_extension2_candidateGenerator(inputSentence, misspell):
def test_ranking_candidateRanking(inputSentence, misspell):
doc = nlp(inputSentence)
(misspellings, doc) = checker.misspellIdentify(doc)
suggestions = checker.candidateGenerator(doc, misspellings)
selectedWord = checker.candidateRanking(suggestions)
assert selectedWord == {doc[key]: value for key, value in misspell.items()}
doc, suggestions = checker.candidateGenerator(doc, misspellings)
selectedWord = checker.candidateRanking(doc, suggestions)
## changes made after v0.1
# assert selectedWord == {doc[key]: value for key, value in misspell.items()}
assert [tok.i for tok in selectedWord.keys()] == [tok for tok in misspell.keys()]
assert [tokString for tokString in selectedWord.values()] == [
tok for tok in misspell.values()
]


def test_compatible_spacyPipeline():
Expand Down Expand Up @@ -333,10 +354,22 @@ def test_doc_extensions():
}
assert doc._.contextual_spellCheck == True
assert doc._.performed_spellCheck == True
assert doc._.suggestions_spellCheck == gold_suggestion
## updated after v0.1
assert [tok.i for tok in doc._.suggestions_spellCheck.keys()] == [
tok.i for tok in gold_suggestion.keys()
]
assert [tokString for tokString in doc._.suggestions_spellCheck.values()] == [
tokString for tokString in gold_suggestion.values()
]
assert doc._.outcome_spellCheck == gold_outcome
# splitting components to make use of approx function
assert doc._.score_spellCheck.keys() == gold_score.keys()
assert [tok.i for tok in doc._.score_spellCheck.keys()] == [
tok.i for tok in gold_score.keys()
]
assert [tok.text_with_ws for tok in doc._.score_spellCheck.keys()] == [
tok.text_with_ws for tok in gold_score.keys()
]

assert [
word_score[0]
for value in doc._.score_spellCheck.values()
Expand Down Expand Up @@ -432,3 +465,33 @@ def test_token_extension():
[word_score[1] for word_score in gold_score], rel=1e-4, abs=1e-4
)
nlp.remove_pipe("contextual spellchecker")


def test_worning():
if "contextual spellchecker" not in nlp.pipe_names:
nlp.add_pipe(checker)
merge_ents = nlp.create_pipe("merge_entities")
nlp.add_pipe(merge_ents)
doc = nlp("Income was $9.4 milion compared to the prior year of $2.7 milion.")

with warnings.catch_warnings(record=True) as w:
# Cause all warnings to always be triggered.
warnings.simplefilter("always")
# Trigger a warning.

assert doc[4]._.get_require_spellCheck == False
assert doc[4]._.get_suggestion_spellCheck == ""
assert doc[4]._.score_spellCheck == []
# Verify Warning
assert issubclass(w[-1].category, UserWarning)
assert (
"Position of tokens modified by downstream element in pipeline eg. merge_entities"
in str(w[-1].message)
)

nlp.remove_pipe("contextual spellchecker")
print(nlp.pipe_names)

nlp.remove_pipe("merge_entities")
print(nlp.pipe_names)
warnings.simplefilter("default")

0 comments on commit 3499393

Please sign in to comment.