Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alex lemmatizer classifier 2 #1422

Open
wants to merge 10 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions stanza/models/lemma/attach_lemma_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import argparse

from stanza.models.lemma.trainer import Trainer
from stanza.models.lemma_classifier.base_model import LemmaClassifier

def attach_classifier(input_filename, output_filename, classifiers):
trainer = Trainer(model_file=input_filename)

for classifier in classifiers:
classifier = LemmaClassifier.load(classifier)
trainer.contextual_lemmatizers.append(classifier)

trainer.save(output_filename)

def main(args=None):
parser = argparse.ArgumentParser()
parser.add_argument('--input', type=str, required=True, help='Which lemmatizer to start from')
parser.add_argument('--output', type=str, required=True, help='Where to save the lemmatizer')
parser.add_argument('--classifier', type=str, required=True, nargs='+', help='Lemma classifier to attach')
args = parser.parse_args(args)

attach_classifier(args.input, args.output, args.classifier)

if __name__ == '__main__':
main()
10 changes: 9 additions & 1 deletion stanza/models/lemma/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from stanza.models.common import utils, loss
from stanza.models.lemma import edit
from stanza.models.lemma.vocab import MultiVocab
from stanza.models.lemma_classifier.base_model import LemmaClassifier

logger = logging.getLogger('stanza')

Expand Down Expand Up @@ -45,6 +46,7 @@ def __init__(self, args=None, vocab=None, emb_matrix=None, model_file=None, devi
# dict-based components
self.word_dict = dict()
self.composite_dict = dict()
self.contextual_lemmatizers = []

self.caseless = self.args.get('caseless', False)

Expand Down Expand Up @@ -228,8 +230,11 @@ def save(self, filename, skip_modules=True):
'model': model_state,
'dicts': (self.word_dict, self.composite_dict),
'vocab': self.vocab.state_dict(),
'config': self.args
'config': self.args,
'contextual': [],
}
for contextual in self.contextual_lemmatizers:
params['contextual'].append(contextual.get_save_dict())
os.makedirs(os.path.split(filename)[0], exist_ok=True)
torch.save(params, filename, _use_new_zipfile_serialization=False)
logger.info("Model saved to {}".format(filename))
Expand All @@ -253,3 +258,6 @@ def load(self, filename, args, foundation_cache):
else:
self.model = None
self.vocab = MultiVocab.load_state_dict(checkpoint['vocab'])
self.contextual_lemmatizers = []
for contextual in checkpoint.get('contextual', []):
self.contextual_lemmatizers.append(LemmaClassifier.from_checkpoint(contextual))
Empty file.
120 changes: 120 additions & 0 deletions stanza/models/lemma_classifier/base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""
Base class for the LemmaClassifier types.

Versions include LSTM and Transformer varieties
"""

import logging

from abc import ABC, abstractmethod

import os

import torch
import torch.nn as nn

from stanza.models.common.foundation_cache import load_pretrain
from stanza.models.lemma_classifier.constants import ModelType

logger = logging.getLogger('stanza.lemmaclassifier')

class LemmaClassifier(ABC, nn.Module):
def __init__(self, label_decoder, target_words, *args, **kwargs):
super().__init__(*args, **kwargs)

self.label_decoder = label_decoder
self.target_words = target_words
self.unsaved_modules = []

def add_unsaved_module(self, name, module):
self.unsaved_modules += [name]
setattr(self, name, module)

def is_unsaved_module(self, name):
return name.split('.')[0] in self.unsaved_modules

def save(self, save_name):
"""
Save the model to the given path, possibly with some args
"""
save_dir = os.path.split(save_name)[0]
if save_dir:
os.makedirs(save_dir, exist_ok=True)
save_dict = self.get_save_dict()
torch.save(save_dict, save_name)
return save_dict

@abstractmethod
def model_type(self):
"""
return a ModelType
"""

def target_indices(self, sentence):
return [idx for idx, word in enumerate(sentence) if word.lower() in self.target_words]

@staticmethod
def from_checkpoint(checkpoint, args=None):
model_type = checkpoint['model_type']
if model_type is ModelType.LSTM:
# TODO: if anyone can suggest a way to avoid this circular import
# (or better yet, avoid the load method knowing about subclasses)
# please do so
# maybe the subclassing is not necessary and we just put
# save & load in the trainer
from stanza.models.lemma_classifier.lstm_model import LemmaClassifierLSTM

saved_args = checkpoint['args']
# other model args are part of the model and cannot be changed for evaluation or pipeline
# the file paths might be relevant, though
keep_args = ['wordvec_pretrain_file', 'charlm_forward_file', 'charlm_backward_file']
for arg in keep_args:
if args is not None and args.get(arg, None) is not None:
saved_args[arg] = args[arg]

# TODO: refactor loading the pretrain (also done in the trainer)
pt = load_pretrain(saved_args['wordvec_pretrain_file'])

use_charlm = saved_args['use_charlm']
charlm_forward_file = saved_args.get('charlm_forward_file', None)
charlm_backward_file = saved_args.get('charlm_backward_file', None)

model = LemmaClassifierLSTM(model_args=saved_args,
output_dim=len(checkpoint['label_decoder']),
pt_embedding=pt,
label_decoder=checkpoint['label_decoder'],
upos_to_id=checkpoint['upos_to_id'],
known_words=checkpoint['known_words'],
target_words=checkpoint['target_words'],
use_charlm=use_charlm,
charlm_forward_file=charlm_forward_file,
charlm_backward_file=charlm_backward_file)
elif model_type is ModelType.TRANSFORMER:
from stanza.models.lemma_classifier.transformer_model import LemmaClassifierWithTransformer

output_dim = len(checkpoint['label_decoder'])
saved_args = checkpoint['args']
bert_model = saved_args['bert_model']
model = LemmaClassifierWithTransformer(model_args=saved_args,
output_dim=output_dim,
transformer_name=bert_model,
label_decoder=checkpoint['label_decoder'],
target_words=checkpoint['target_words'])
else:
raise ValueError("Unknown model type %s" % model_type)

# strict=False to accommodate missing parameters from the transformer or charlm
model.load_state_dict(checkpoint['params'], strict=False)
return model

@staticmethod
def load(filename, args=None):
try:
checkpoint = torch.load(filename, lambda storage, loc: storage)
except BaseException:
logger.exception("Cannot load model from %s", filename)
raise

logger.debug("Loading LemmaClassifier model from %s", filename)

return LemmaClassifier.from_checkpoint(checkpoint)
114 changes: 114 additions & 0 deletions stanza/models/lemma_classifier/base_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@

from abc import ABC, abstractmethod
import logging
import os
from typing import List, Tuple, Any, Mapping

import torch
import torch.nn as nn
import torch.optim as optim

from stanza.models.common.utils import default_device
from stanza.models.lemma_classifier import utils
from stanza.models.lemma_classifier.constants import DEFAULT_BATCH_SIZE
from stanza.models.lemma_classifier.evaluate_models import evaluate_model
from stanza.utils.get_tqdm import get_tqdm

tqdm = get_tqdm()
logger = logging.getLogger('stanza.lemmaclassifier')

class BaseLemmaClassifierTrainer(ABC):
def configure_weighted_loss(self, label_decoder: Mapping, counts: Mapping):
"""
If applicable, this function will update the loss function of the LemmaClassifierLSTM model to become BCEWithLogitsLoss.
The weights are determined by the counts of the classes in the dataset. The weights are inversely proportional to the
frequency of the class in the set. E.g. classes with lower frequency will have higher weight.
"""
weights = [0 for _ in label_decoder.keys()] # each key in the label decoder is one class, we have one weight per class
total_samples = sum(counts.values())
for class_idx in counts:
weights[class_idx] = total_samples / (counts[class_idx] * len(counts)) # weight_i = total / (# examples in class i * num classes)
weights = torch.tensor(weights)
logger.info(f"Using weights {weights} for weighted loss.")
self.criterion = nn.BCEWithLogitsLoss(weight=weights)

@abstractmethod
def build_model(self, label_decoder, upos_to_id, known_words):
"""
Build a model using pieces of the dataset to determine some of the model shape
"""

def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str, train_file: str) -> None:
"""
Trains a model on batches of texts, position indices of the target token, and labels (lemma annotation) for the target token.

Args:
num_epochs (int): Number of training epochs
save_name (str): Path to file where trained model should be saved.
eval_file (str): Path to the dev set file for evaluating model checkpoints each epoch.
train_file (str): Path to data file, containing tokenized text sentences, token index and true label for token lemma on each line.
"""
# Put model on GPU (if possible)
device = default_device()

if not train_file:
raise ValueError("Cannot train model - no train_file supplied!")

dataset = utils.Dataset(train_file, get_counts=self.weighted_loss, batch_size=args.get("batch_size", DEFAULT_BATCH_SIZE))
label_decoder = dataset.label_decoder
upos_to_id = dataset.upos_to_id
self.output_dim = len(label_decoder)
logger.info(f"Loaded dataset successfully from {train_file}")
logger.info(f"Using label decoder: {label_decoder} Output dimension: {self.output_dim}")
logger.info(f"Target words: {dataset.target_words}")

self.model = self.build_model(label_decoder, upos_to_id, dataset.known_words, dataset.target_words)
self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)

self.model.to(device)
logger.info(f"Training model on device: {device}. {next(self.model.parameters()).device}")

if os.path.exists(save_name) and not args.get('force', False):
raise FileExistsError(f"Save name {save_name} already exists; training would overwrite previous file contents. Aborting...")

if self.weighted_loss:
self.configure_weighted_loss(label_decoder, dataset.counts)

# Put the criterion on GPU too
logger.debug(f"Criterion on {next(self.model.parameters()).device}")
self.criterion = self.criterion.to(next(self.model.parameters()).device)

best_model, best_f1 = None, float("-inf") # Used for saving checkpoints of the model
for epoch in range(num_epochs):
# go over entire dataset with each epoch
for sentences, positions, upos_tags, labels in tqdm(dataset):
assert len(sentences) == len(positions) == len(labels), f"Input sentences, positions, and labels are of unequal length ({len(sentences), len(positions), len(labels)})"

self.optimizer.zero_grad()
outputs = self.model(positions, sentences, upos_tags)

# Compute loss, which is different if using CE or BCEWithLogitsLoss
if self.weighted_loss: # BCEWithLogitsLoss requires a vector for target where probability is 1 on the true label class, and 0 on others.
# TODO: three classes?
targets = torch.stack([torch.tensor([1, 0]) if label == 0 else torch.tensor([0, 1]) for label in labels]).to(dtype=torch.float32).to(device)
# should be shape size (batch_size, 2)
else: # CELoss accepts target as just raw label
targets = labels.to(device)

loss = self.criterion(outputs, targets)

loss.backward()
self.optimizer.step()

logger.info(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item()}")
if eval_file:
# Evaluate model on dev set to see if it should be saved.
_, _, _, f1 = evaluate_model(self.model, eval_file, is_training=True)
logger.info(f"Weighted f1 for model: {f1}")
if f1 > best_f1:
best_f1 = f1
self.model.save(save_name)
logger.info(f"New best model: weighted f1 score of {f1}.")
else:
self.model.save(save_name)

54 changes: 54 additions & 0 deletions stanza/models/lemma_classifier/baseline_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""
Baseline model for the existing lemmatizer which always predicts "be" and never "have" on the "'s" token.

The BaselineModel class can be updated to any arbitrary token and predicton lemma, not just "be" on the "s" token.
"""

import stanza
import os
from stanza.models.lemma_classifier.evaluate_models import evaluate_sequences
from stanza.models.lemma_classifier.prepare_dataset import load_doc_from_conll_file

class BaselineModel:

def __init__(self, token_to_lemmatize, prediction_lemma, prediction_upos):
self.token_to_lemmatize = token_to_lemmatize
self.prediction_lemma = prediction_lemma
self.prediction_upos = prediction_upos

def predict(self, token):
if token == self.token_to_lemmatize:
return self.prediction_lemma

def evaluate(self, conll_path):
"""
Evaluates the baseline model against the test set defined in conll_path.

Returns a map where the keys are each class and the values are another map including the precision, recall and f1 scores
for that class.

Also returns confusion matrix. Keys are gold tags and inner keys are predicted tags
"""
doc = load_doc_from_conll_file(conll_path)
gold_tag_sequences, pred_tag_sequences = [], []
for sentence in doc.sentences:
gold_tags, pred_tags = [], []
for word in sentence.words:
if word.upos in self.prediction_upos and word.text == self.token_to_lemmatize:
pred = self.prediction_lemma
gold = word.lemma
gold_tags.append(gold)
pred_tags.append(pred)
gold_tag_sequences.append(gold_tags)
pred_tag_sequences.append(pred_tags)

multiclass_result, confusion_mtx, weighted_f1 = evaluate_sequences(gold_tag_sequences, pred_tag_sequences)
return multiclass_result, confusion_mtx


if __name__ == "__main__":

bl_model = BaselineModel("'s", "be", ["AUX"])
coNLL_path = os.path.join(os.path.dirname(__file__), "en_gum-ud-train.conllu")
bl_model.evaluate(coNLL_path)

14 changes: 14 additions & 0 deletions stanza/models/lemma_classifier/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from enum import Enum

UNKNOWN_TOKEN = "unk" # token name for unknown tokens
UNKNOWN_TOKEN_IDX = -1 # custom index we apply to unknown tokens

# TODO: ModelType could just be LSTM and TRANSFORMER
# and then the transformer baseline would have the transformer as another argument
class ModelType(Enum):
LSTM = 1
TRANSFORMER = 2
BERT = 3
ROBERTA = 4

DEFAULT_BATCH_SIZE = 16
Loading
Loading