From 8a388cc6c35e70eafcd065197ffcb4215e1f2a6f Mon Sep 17 00:00:00 2001 From: Alex Shan Date: Thu, 26 Oct 2023 20:45:21 -0700 Subject: [PATCH 01/10] Binary (or n-way) classifier on top of the standard lemmatizer for 's token in English or other lemmas with ambiguous resolutions Includes data processing class for extracting sentences of interest Has evaluation functions for single example and multiexample Adds utility functions for loading dataset from file and handling unknown tokens during embedding lookup Can use charlm models for training Includes a baseline which uses a transformer to compare against the LSTM model Uses AutoTokenizer and AutoModel to load the transformer - can provide a specific model name with the --bert_model flag Includes a feature to drop certain lemmas, or rather, only accept lemmas if they match a regex. This will be particularly useful for a language like Farsi, where the training data only has 6 and 1 examples of the 3rd and 4th most common expansions Automatically extract the label information from the dataset. Save the label_decoder in the regular model and the transformer baseline model. Word vectors are trainable in the LSTM model Word vectors used are the ones shipped with Stanza for whichever language, not specifically Glove. This allows for using WV for whichever language we are using Model selection during training loop done using eval set performance - both baseline and LSTM model Training/testing done via batch processing for speed Include UPOS tags in data processing/loading for files. We then use UPOS embeddings for the words in the LSTM model as an additional signal for the query word Implement multihead attention option for LSTM model Add positional encodings to MultiHeadAttention layer of the LSTM model. The common train() method from the two trainer classes is treated as one parent class. Should make it easier to update pieces and keep them in sync Keep the dataset in a single object rather than a bunch of lists. Makes it easier to shuffle, keeps everything in one place Don't save the transformer, charlm, or original word vector file in the model files. Word vectors are finetuned and the deltas are saved. import full path --- stanza/models/lemma_classifier/__init__.py | 0 stanza/models/lemma_classifier/base_model.py | 111 +++++++++ .../models/lemma_classifier/base_trainer.py | 113 +++++++++ .../models/lemma_classifier/baseline_model.py | 54 +++++ stanza/models/lemma_classifier/constants.py | 14 ++ .../lemma_classifier/evaluate_models.py | 226 ++++++++++++++++++ stanza/models/lemma_classifier/lstm_model.py | 211 ++++++++++++++++ .../lemma_classifier/prepare_dataset.py | 136 +++++++++++ .../lemma_classifier/train_lstm_model.py | 146 +++++++++++ .../train_transformer_model.py | 129 ++++++++++ .../lemma_classifier/transformer_model.py | 83 +++++++ stanza/models/lemma_classifier/utils.py | 167 +++++++++++++ 12 files changed, 1390 insertions(+) create mode 100644 stanza/models/lemma_classifier/__init__.py create mode 100644 stanza/models/lemma_classifier/base_model.py create mode 100644 stanza/models/lemma_classifier/base_trainer.py create mode 100644 stanza/models/lemma_classifier/baseline_model.py create mode 100644 stanza/models/lemma_classifier/constants.py create mode 100644 stanza/models/lemma_classifier/evaluate_models.py create mode 100644 stanza/models/lemma_classifier/lstm_model.py create mode 100644 stanza/models/lemma_classifier/prepare_dataset.py create mode 100644 stanza/models/lemma_classifier/train_lstm_model.py create mode 100644 stanza/models/lemma_classifier/train_transformer_model.py create mode 100644 stanza/models/lemma_classifier/transformer_model.py create mode 100644 stanza/models/lemma_classifier/utils.py diff --git a/stanza/models/lemma_classifier/__init__.py b/stanza/models/lemma_classifier/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/stanza/models/lemma_classifier/base_model.py b/stanza/models/lemma_classifier/base_model.py new file mode 100644 index 0000000000..fbdba5be1f --- /dev/null +++ b/stanza/models/lemma_classifier/base_model.py @@ -0,0 +1,111 @@ +""" +Base class for the LemmaClassifier types. + +Versions include LSTM and Transformer varieties +""" + +import logging + +from abc import ABC, abstractmethod + +import os + +import torch +import torch.nn as nn + +from stanza.models.common.foundation_cache import load_pretrain +from stanza.models.lemma_classifier.constants import ModelType + +logger = logging.getLogger('stanza.lemmaclassifier') + +class LemmaClassifier(ABC, nn.Module): + def __init__(self, label_decoder, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.label_decoder = label_decoder + self.unsaved_modules = [] + + def add_unsaved_module(self, name, module): + self.unsaved_modules += [name] + setattr(self, name, module) + + def is_unsaved_module(self, name): + return name.split('.')[0] in self.unsaved_modules + + def save(self, save_name): + """ + Save the model to the given path, possibly with some args + """ + save_dir = os.path.split(save_name)[0] + if save_dir: + os.makedirs(save_dir, exist_ok=True) + save_dict = self.get_save_dict() + torch.save(save_dict, save_name) + return save_dict + + @abstractmethod + def model_type(self): + """ + return a ModelType + """ + + @staticmethod + def from_checkpoint(checkpoint, args=None): + model_type = checkpoint['model_type'] + if model_type is ModelType.LSTM: + # TODO: if anyone can suggest a way to avoid this circular import + # (or better yet, avoid the load method knowing about subclasses) + # please do so + # maybe the subclassing is not necessary and we just put + # save & load in the trainer + from stanza.models.lemma_classifier.lstm_model import LemmaClassifierLSTM + + saved_args = checkpoint['args'] + # other model args are part of the model and cannot be changed for evaluation or pipeline + # the file paths might be relevant, though + keep_args = ['wordvec_pretrain_file', 'charlm_forward_file', 'charlm_backward_file'] + for arg in keep_args: + if args is not None and args.get(arg, None) is not None: + saved_args[arg] = args[arg] + + # TODO: refactor loading the pretrain (also done in the trainer) + pt = load_pretrain(saved_args['wordvec_pretrain_file']) + + use_charlm = saved_args['use_charlm'] + charlm_forward_file = saved_args.get('charlm_forward_file', None) + charlm_backward_file = saved_args.get('charlm_backward_file', None) + + model = LemmaClassifierLSTM(model_args=saved_args, + output_dim=len(checkpoint['label_decoder']), + pt_embedding=pt, + label_decoder=checkpoint['label_decoder'], + upos_to_id=checkpoint['upos_to_id'], + known_words=checkpoint['known_words'], + use_charlm=use_charlm, + charlm_forward_file=charlm_forward_file, + charlm_backward_file=charlm_backward_file) + elif model_type is ModelType.TRANSFORMER: + from stanza.models.lemma_classifier.transformer_model import LemmaClassifierWithTransformer + + output_dim = len(checkpoint['label_decoder']) + saved_args = checkpoint['args'] + bert_model = saved_args['bert_model'] + model = LemmaClassifierWithTransformer(model_args = saved_args, output_dim=output_dim, transformer_name=bert_model, label_decoder=checkpoint['label_decoder']) + else: + raise ValueError("Unknown model type %s" % model_type) + + # strict=False to accommodate missing parameters from the transformer or charlm + model.load_state_dict(checkpoint['params'], strict=False) + return model + + @staticmethod + def load(filename, args=None): + try: + checkpoint = torch.load(filename, lambda storage, loc: storage) + except BaseException: + logger.exception("Cannot load model from %s", filename) + raise + + logger.debug("Loading LemmaClassifier model from %s", filename) + + return LemmaClassifier.from_checkpoint(checkpoint) diff --git a/stanza/models/lemma_classifier/base_trainer.py b/stanza/models/lemma_classifier/base_trainer.py new file mode 100644 index 0000000000..4c7d0f183e --- /dev/null +++ b/stanza/models/lemma_classifier/base_trainer.py @@ -0,0 +1,113 @@ + +from abc import ABC, abstractmethod +import logging +import os +from typing import List, Tuple, Any, Mapping + +import torch +import torch.nn as nn +import torch.optim as optim + +from stanza.models.common.utils import default_device +from stanza.models.lemma_classifier import utils +from stanza.models.lemma_classifier.constants import DEFAULT_BATCH_SIZE +from stanza.models.lemma_classifier.evaluate_models import evaluate_model +from stanza.utils.get_tqdm import get_tqdm + +tqdm = get_tqdm() +logger = logging.getLogger('stanza.lemmaclassifier') + +class BaseLemmaClassifierTrainer(ABC): + def configure_weighted_loss(self, label_decoder: Mapping, counts: Mapping): + """ + If applicable, this function will update the loss function of the LemmaClassifierLSTM model to become BCEWithLogitsLoss. + The weights are determined by the counts of the classes in the dataset. The weights are inversely proportional to the + frequency of the class in the set. E.g. classes with lower frequency will have higher weight. + """ + weights = [0 for _ in label_decoder.keys()] # each key in the label decoder is one class, we have one weight per class + total_samples = sum(counts.values()) + for class_idx in counts: + weights[class_idx] = total_samples / (counts[class_idx] * len(counts)) # weight_i = total / (# examples in class i * num classes) + weights = torch.tensor(weights) + logger.info(f"Using weights {weights} for weighted loss.") + self.criterion = nn.BCEWithLogitsLoss(weight=weights) + + @abstractmethod + def build_model(self, label_decoder, upos_to_id, known_words): + """ + Build a model using pieces of the dataset to determine some of the model shape + """ + + def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str, train_file: str) -> None: + """ + Trains a model on batches of texts, position indices of the target token, and labels (lemma annotation) for the target token. + + Args: + num_epochs (int): Number of training epochs + save_name (str): Path to file where trained model should be saved. + eval_file (str): Path to the dev set file for evaluating model checkpoints each epoch. + train_file (str): Path to data file, containing tokenized text sentences, token index and true label for token lemma on each line. + """ + # Put model on GPU (if possible) + device = default_device() + + if not train_file: + raise ValueError("Cannot train model - no train_file supplied!") + + dataset = utils.Dataset(train_file, get_counts=self.weighted_loss, batch_size=args.get("batch_size", DEFAULT_BATCH_SIZE)) + label_decoder = dataset.label_decoder + upos_to_id = dataset.upos_to_id + self.output_dim = len(label_decoder) + logger.info(f"Loaded dataset successfully from {train_file}") + logger.info(f"Using label decoder: {label_decoder} Output dimension: {self.output_dim}") + + self.model = self.build_model(label_decoder, upos_to_id, dataset.known_words) + self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) + + self.model.to(device) + logger.info(f"Training model on device: {device}. {next(self.model.parameters()).device}") + + if os.path.exists(save_name): + raise FileExistsError(f"Save name {save_name} already exists; training would overwrite previous file contents. Aborting...") + + if self.weighted_loss: + self.configure_weighted_loss(label_decoder, dataset.counts) + + # Put the criterion on GPU too + logger.debug(f"Criterion on {next(self.model.parameters()).device}") + self.criterion = self.criterion.to(next(self.model.parameters()).device) + + best_model, best_f1 = None, float("-inf") # Used for saving checkpoints of the model + for epoch in range(num_epochs): + # go over entire dataset with each epoch + for sentences, positions, upos_tags, labels in tqdm(dataset): + assert len(sentences) == len(positions) == len(labels), f"Input sentences, positions, and labels are of unequal length ({len(sentences), len(positions), len(labels)})" + + self.optimizer.zero_grad() + outputs = self.model(positions, sentences, upos_tags) + + # Compute loss, which is different if using CE or BCEWithLogitsLoss + if self.weighted_loss: # BCEWithLogitsLoss requires a vector for target where probability is 1 on the true label class, and 0 on others. + # TODO: three classes? + targets = torch.stack([torch.tensor([1, 0]) if label == 0 else torch.tensor([0, 1]) for label in labels]).to(dtype=torch.float32).to(device) + # should be shape size (batch_size, 2) + else: # CELoss accepts target as just raw label + targets = labels.to(device) + + loss = self.criterion(outputs, targets) + + loss.backward() + self.optimizer.step() + + logger.info(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item()}") + if eval_file: + # Evaluate model on dev set to see if it should be saved. + _, _, _, f1 = evaluate_model(self.model, eval_file, is_training=True) + logger.info(f"Weighted f1 for model: {f1}") + if f1 > best_f1: + best_f1 = f1 + self.model.save(save_name) + logger.info(f"New best model: weighted f1 score of {f1}.") + else: + self.model.save(save_name) + diff --git a/stanza/models/lemma_classifier/baseline_model.py b/stanza/models/lemma_classifier/baseline_model.py new file mode 100644 index 0000000000..de4f0f1a64 --- /dev/null +++ b/stanza/models/lemma_classifier/baseline_model.py @@ -0,0 +1,54 @@ +""" +Baseline model for the existing lemmatizer which always predicts "be" and never "have" on the "'s" token. + +The BaselineModel class can be updated to any arbitrary token and predicton lemma, not just "be" on the "s" token. +""" + +import stanza +import os +from stanza.models.lemma_classifier.evaluate_models import evaluate_sequences +from stanza.models.lemma_classifier.prepare_dataset import load_doc_from_conll_file + +class BaselineModel: + + def __init__(self, token_to_lemmatize, prediction_lemma, prediction_upos): + self.token_to_lemmatize = token_to_lemmatize + self.prediction_lemma = prediction_lemma + self.prediction_upos = prediction_upos + + def predict(self, token): + if token == self.token_to_lemmatize: + return self.prediction_lemma + + def evaluate(self, conll_path): + """ + Evaluates the baseline model against the test set defined in conll_path. + + Returns a map where the keys are each class and the values are another map including the precision, recall and f1 scores + for that class. + + Also returns confusion matrix. Keys are gold tags and inner keys are predicted tags + """ + doc = load_doc_from_conll_file(conll_path) + gold_tag_sequences, pred_tag_sequences = [], [] + for sentence in doc.sentences: + gold_tags, pred_tags = [], [] + for word in sentence.words: + if word.upos in self.prediction_upos and word.text == self.token_to_lemmatize: + pred = self.prediction_lemma + gold = word.lemma + gold_tags.append(gold) + pred_tags.append(pred) + gold_tag_sequences.append(gold_tags) + pred_tag_sequences.append(pred_tags) + + multiclass_result, confusion_mtx, weighted_f1 = evaluate_sequences(gold_tag_sequences, pred_tag_sequences) + return multiclass_result, confusion_mtx + + +if __name__ == "__main__": + + bl_model = BaselineModel("'s", "be", ["AUX"]) + coNLL_path = os.path.join(os.path.dirname(__file__), "en_gum-ud-train.conllu") + bl_model.evaluate(coNLL_path) + diff --git a/stanza/models/lemma_classifier/constants.py b/stanza/models/lemma_classifier/constants.py new file mode 100644 index 0000000000..09fa9044cd --- /dev/null +++ b/stanza/models/lemma_classifier/constants.py @@ -0,0 +1,14 @@ +from enum import Enum + +UNKNOWN_TOKEN = "unk" # token name for unknown tokens +UNKNOWN_TOKEN_IDX = -1 # custom index we apply to unknown tokens + +# TODO: ModelType could just be LSTM and TRANSFORMER +# and then the transformer baseline would have the transformer as another argument +class ModelType(Enum): + LSTM = 1 + TRANSFORMER = 2 + BERT = 3 + ROBERTA = 4 + +DEFAULT_BATCH_SIZE = 16 \ No newline at end of file diff --git a/stanza/models/lemma_classifier/evaluate_models.py b/stanza/models/lemma_classifier/evaluate_models.py new file mode 100644 index 0000000000..9c9e4ffa4b --- /dev/null +++ b/stanza/models/lemma_classifier/evaluate_models.py @@ -0,0 +1,226 @@ +import os +import sys + +parentdir = os.path.dirname(__file__) +parentdir = os.path.dirname(parentdir) +parentdir = os.path.dirname(parentdir) +sys.path.append(parentdir) + +import logging +import argparse +import os + +from typing import Any, List, Tuple, Mapping +from collections import defaultdict +from numpy import random + +import torch +import torch.nn as nn + +import stanza + +from stanza.models.common.utils import default_device +from stanza.models.lemma_classifier import utils +from stanza.models.lemma_classifier.base_model import LemmaClassifier +from stanza.models.lemma_classifier.lstm_model import LemmaClassifierLSTM +from stanza.models.lemma_classifier.transformer_model import LemmaClassifierWithTransformer +from stanza.utils.confusion import format_confusion +from stanza.utils.get_tqdm import get_tqdm + +tqdm = get_tqdm() + +logger = logging.getLogger('stanza.lemmaclassifier') + + +def get_weighted_f1(mcc_results: Mapping[int, Mapping[str, float]], confusion: Mapping[int, Mapping[int, int]]) -> float: + """ + Computes the weighted F1 score across an evaluation set. + + The weight of a class's F1 score is equal to the number of examples in evaluation. This makes classes that have more + examples in the evaluation more impactful to the weighted f1. + """ + num_total_examples = 0 + weighted_f1 = 0 + + for class_id in mcc_results: + class_f1 = mcc_results.get(class_id).get("f1") + num_class_examples = sum(confusion.get(class_id).values()) + weighted_f1 += class_f1 * num_class_examples + num_total_examples += num_class_examples + + return weighted_f1 / num_total_examples + + +def evaluate_sequences(gold_tag_sequences: List[Any], pred_tag_sequences: List[Any], label_decoder: Mapping, verbose=True): + """ + Evaluates a model's predicted tags against a set of gold tags. Computes precision, recall, and f1 for all classes. + + Precision = true positives / true positives + false positives + Recall = true positives / true positives + false negatives + F1 = 2 * (Precision * Recall) / (Precision + Recall) + + Returns: + 1. Multi class result dictionary, where each class is a key and maps to another map of its F1, precision, and recall scores. + e.g. multiclass_results[0]["precision"] would give class 0's precision. + 2. Confusion matrix, where each key is a gold tag and its value is another map with a key of the predicted tag with value of that (gold, pred) count. + e.g. confusion[0][1] = 6 would mean that for gold tag 0, the model predicted tag 1 a total of 6 times. + """ + assert len(gold_tag_sequences) == len(pred_tag_sequences), \ + f"Length of gold tag sequences is {len(gold_tag_sequences)}, while length of predicted tag sequence is {len(pred_tag_sequences)}" + + confusion = defaultdict(lambda: defaultdict(int)) + + reverse_label_decoder = {y: x for x, y in label_decoder.items()} + for gold, pred in zip(gold_tag_sequences, pred_tag_sequences): + confusion[reverse_label_decoder[gold]][reverse_label_decoder[pred]] += 1 + + multi_class_result = defaultdict(lambda: defaultdict(float)) + # compute precision, recall and f1 for each class and store inside of `multi_class_result` + for gold_tag in confusion.keys(): + + try: + prec = confusion.get(gold_tag, {}).get(gold_tag, 0) / sum([confusion.get(k, {}).get(gold_tag, 0) for k in confusion.keys()]) + except ZeroDivisionError: + prec = 0.0 + + try: + recall = confusion.get(gold_tag, {}).get(gold_tag, 0) / sum(confusion.get(gold_tag, {}).values()) + except ZeroDivisionError: + recall = 0.0 + + try: + f1 = 2 * (prec * recall) / (prec + recall) + except ZeroDivisionError: + f1 = 0.0 + + multi_class_result[gold_tag] = { + "precision": prec, + "recall": recall, + "f1": f1 + } + + if verbose: + for lemma in multi_class_result: + logger.info(f"Lemma '{lemma}' had precision {100 * multi_class_result[lemma]['precision']}, recall {100 * multi_class_result[lemma]['recall']} and F1 score of {100 * multi_class_result[lemma]['f1']}") + + weighted_f1 = get_weighted_f1(multi_class_result, confusion) + + return multi_class_result, confusion, weighted_f1 + + +def model_predict(model: nn.Module, position_indices: torch.Tensor, sentences: List[List[str]], upos_tags: List[List[int]]=[]) -> torch.Tensor: + """ + A LemmaClassifierLSTM or LemmaClassifierWithTransformer is used to predict on a single text example, given the position index of the target token. + + Args: + model (LemmaClassifier): A trained LemmaClassifier that is able to predict on a target token. + position_indices (Tensor[int]): A tensor of the (zero-indexed) position of the target token in `text` for each example in the batch. + sentences (List[List[str]]): A list of lists of the tokenized strings of the input sentences. + + Returns: + (int): The index of the predicted class in `model`'s output. + """ + with torch.no_grad(): + logits = model(position_indices, sentences, upos_tags) # should be size (batch_size, output_size) + predicted_class = torch.argmax(logits, dim=1) # should be size (batch_size, 1) + + return predicted_class + + +def evaluate_model(model: nn.Module, eval_path: str, verbose: bool = True, is_training: bool = False) -> Tuple[Mapping, Mapping, float, float]: + """ + Helper function for model evaluation + + Args: + model (LemmaClassifierLSTM or LemmaClassifierWithTransformer): An instance of the LemmaClassifier class that has architecture initialized which matches the model saved in `model_path`. + model_path (str): Path to the saved model weights that will be loaded into `model`. + eval_path (str): Path to the saved evaluation dataset. + verbose (bool, optional): True if `evaluate_sequences()` should print the F1, Precision, and Recall for each class. Defaults to True. + is_training (bool, optional): Whether the model is in training mode. If the model is training, we do not change it to eval mode. + + Returns: + 1. Multi-class results (Mapping[int, Mapping[str, float]]): first map has keys as the classes (lemma indices) and value is + another map with key of "f1", "precision", or "recall" with corresponding values. + 2. Confusion Matrix (Mapping[int, Mapping[int, int]]): A confusion matrix with keys equal to the index of the gold tag, and a value of the + map with the key as the predicted tag and corresponding count of that (gold, pred) pair. + 3. Accuracy (float): the total accuracy (num correct / total examples) across the evaluation set. + """ + # load model + device = default_device() + model.to(device) + + if not is_training: + model.eval() # set to eval mode + + # load in eval data + dataset = utils.Dataset(eval_path, label_decoder=model.label_decoder, shuffle=False) + + logger.info(f"Evaluating on evaluation file {eval_path}") + + correct, total = 0, 0 + gold_tags, pred_tags = dataset.labels, [] + + # run eval on each example from dataset + for sentences, pos_indices, upos_tags, labels in tqdm(dataset, "Evaluating examples from data file"): + pred = model_predict(model, pos_indices, sentences, upos_tags) # Pred should be size (batch_size, ) + correct_preds = pred == labels.to(device) + correct += torch.sum(correct_preds) + total += len(correct_preds) + pred_tags += pred.tolist() + + logger.info("Finished evaluating on dataset. Computing scores...") + accuracy = correct / total + + mc_results, confusion, weighted_f1 = evaluate_sequences(gold_tags, pred_tags, dataset.label_decoder, verbose=verbose) + # add brackets around batches of gold and pred tags because each batch is an element within the sequences in this helper + if verbose: + logger.info(f"Accuracy: {accuracy} ({correct}/{total})") + logger.info(f"Label decoder: {dataset.label_decoder}") + + return mc_results, confusion, accuracy, weighted_f1 + + +def main(args=None): + + # TODO: can unify this script with train_lstm_model.py? + # TODO: can save the model type in the model .pt, then + # automatically figure out what type of model we are using by + # looking in the file + parser = argparse.ArgumentParser() + parser.add_argument("--vocab_size", type=int, default=10000, help="Number of tokens in vocab") + parser.add_argument("--embedding_dim", type=int, default=100, help="Number of dimensions in word embeddings (currently using GloVe)") + parser.add_argument("--hidden_dim", type=int, default=256, help="Size of hidden layer") + parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read') + parser.add_argument("--charlm", action='store_true', default=False, help="Whether not to use the charlm embeddings") + parser.add_argument('--charlm_shorthand', type=str, default=None, help="Shorthand for character-level language model training corpus.") + parser.add_argument("--charlm_forward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_forward.pt"), help="Path to forward charlm file") + parser.add_argument("--charlm_backward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_backwards.pt"), help="Path to backward charlm file") + parser.add_argument("--save_name", type=str, default=os.path.join(os.path.dirname(__file__), "saved_models", "lemma_classifier_model.pt"), help="Path to model save file") + parser.add_argument("--model_type", type=str, default="roberta", help="Which transformer to use ('bert' or 'roberta' or 'lstm')") + parser.add_argument("--bert_model", type=str, default=None, help="Use a specific transformer instead of the default bert/roberta") + parser.add_argument("--eval_file", type=str, help="path to evaluation file") + + args = parser.parse_args(args) + + logger.info("Running training script with the following args:") + args = vars(args) + for arg in args: + logger.info(f"{arg}: {args[arg]}") + logger.info("------------------------------------------------------------") + + logger.info(f"Attempting evaluation of model from {args['save_name']} on file {args['eval_file']}") + model = LemmaClassifier.load(args['save_name'], args) + + mcc_results, confusion, acc, weighted_f1 = evaluate_model(model, args['eval_file']) + + logger.info(f"MCC Results: {dict(mcc_results)}") + logger.info("______________________________________________") + logger.info(f"Confusion:\n%s", format_confusion(confusion)) + logger.info("______________________________________________") + logger.info(f"Accuracy: {acc}") + logger.info("______________________________________________") + logger.info(f"Weighted f1: {weighted_f1}") + + +if __name__ == "__main__": + main() diff --git a/stanza/models/lemma_classifier/lstm_model.py b/stanza/models/lemma_classifier/lstm_model.py new file mode 100644 index 0000000000..5cd20c63f8 --- /dev/null +++ b/stanza/models/lemma_classifier/lstm_model.py @@ -0,0 +1,211 @@ +import torch +import torch.nn as nn +import os +import logging +import math +from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence +from stanza.models.common.char_model import CharacterModel, CharacterLanguageModel +from typing import List, Tuple + +from stanza.models.common.vocab import UNK_ID +from stanza.models.lemma_classifier import utils +from stanza.models.lemma_classifier.base_model import LemmaClassifier +from stanza.models.lemma_classifier.constants import ModelType + +logger = logging.getLogger('stanza.lemmaclassifier') + +class LemmaClassifierLSTM(LemmaClassifier): + """ + Model architecture: + Extracts word embeddings over the sentence, passes embeddings into a bi-LSTM to get a sentence encoding. + From the LSTM output, we get the embedding of the specific token that we classify on. That embedding + is fed into an MLP for classification. + """ + def __init__(self, model_args, output_dim, pt_embedding, label_decoder, upos_to_id, known_words, + use_charlm=False, charlm_forward_file=None, charlm_backward_file=None): + """ + Args: + vocab_size (int): Size of the vocab being used (if custom vocab) + output_dim (int): Size of output vector from MLP layer + upos_to_id (Mapping[str, int]): A dictionary mapping UPOS tag strings to their respective IDs + pt_embedding (Pretrain): pretrained embeddings + known_words (list(str)): Words which are in the training data + use_charlm (bool): Whether or not to use the charlm embeddings + charlm_forward_file (str): The path to the forward pass model for the character language model + charlm_backward_file (str): The path to the forward pass model for the character language model. + + Kwargs: + upos_emb_dim (int): The size of the UPOS tag embeddings + num_heads (int): The number of heads to use for attention. If there are more than 0 heads, attention will be used instead of the LSTM. + + Raises: + FileNotFoundError: if the forward or backward charlm file cannot be found. + """ + super(LemmaClassifierLSTM, self).__init__(label_decoder) + self.model_args = model_args + + self.hidden_dim = model_args['hidden_dim'] + self.input_size = 0 + self.num_heads = self.model_args['num_heads'] + + emb_matrix = pt_embedding.emb + self.add_unsaved_module("embeddings", nn.Embedding.from_pretrained(torch.from_numpy(emb_matrix), freeze=True)) + self.vocab_map = { word.replace('\xa0', ' '): i for i, word in enumerate(pt_embedding.vocab) } + self.vocab_size = emb_matrix.shape[0] + self.embedding_dim = emb_matrix.shape[1] + + self.known_words = known_words + self.known_word_map = {word: idx for idx, word in enumerate(known_words)} + self.delta_embedding = nn.Embedding(num_embeddings=len(known_words)+1, + embedding_dim=self.embedding_dim, + padding_idx=0) + nn.init.normal_(self.delta_embedding.weight, std=0.01) + + self.input_size += self.embedding_dim + + # Optionally, include charlm embeddings + self.use_charlm = use_charlm + + if self.use_charlm: + if charlm_forward_file is None or not os.path.exists(charlm_forward_file): + raise FileNotFoundError(f'Could not find forward character model: {charlm_forward_file}') + if charlm_backward_file is None or not os.path.exists(charlm_backward_file): + raise FileNotFoundError(f'Could not find backward character model: {charlm_backward_file}') + self.add_unsaved_module('charmodel_forward', CharacterLanguageModel.load(charlm_forward_file, finetune=False)) + self.add_unsaved_module('charmodel_backward', CharacterLanguageModel.load(charlm_backward_file, finetune=False)) + + self.input_size += self.charmodel_forward.hidden_dim() + self.charmodel_backward.hidden_dim() + + self.upos_emb_dim = self.model_args["upos_emb_dim"] + self.upos_to_id = upos_to_id + if self.upos_emb_dim > 0 and self.upos_to_id is not None: + # TODO: should leave space for unknown POS? + self.upos_emb = nn.Embedding(num_embeddings=len(self.upos_to_id), + embedding_dim=self.upos_emb_dim, + padding_idx=0) + self.input_size += self.upos_emb_dim + + device = next(self.parameters()).device + # Determine if attn or LSTM should be used + if self.num_heads > 0: + self.input_size = utils.round_up_to_multiple(self.input_size, self.num_heads) + self.multihead_attn = nn.MultiheadAttention(embed_dim=self.input_size, num_heads=self.num_heads, batch_first=True).to(device) + logger.debug(f"Using attention mechanism with embed dim {self.input_size} and {self.num_heads} attention heads.") + else: + self.lstm = nn.LSTM(self.input_size, + self.hidden_dim, + batch_first=True, + bidirectional=True) + logger.debug(f"Using LSTM mechanism.") + + mlp_input_size = self.hidden_dim * 2 if self.num_heads == 0 else self.input_size + self.mlp = nn.Sequential( + nn.Linear(mlp_input_size, 64), + nn.ReLU(), + nn.Linear(64, output_dim) + ) + + def get_save_dict(self): + save_dict = { + "params": self.state_dict(), + "label_decoder": self.label_decoder, + "model_type": self.model_type(), + "args": self.model_args, + "upos_to_id": self.upos_to_id, + "known_words": self.known_words, + } + skipped = [k for k in save_dict["params"].keys() if self.is_unsaved_module(k)] + for k in skipped: + del save_dict["params"][k] + return save_dict + + def forward(self, pos_indices: List[int], sentences: List[List[str]], upos_tags: List[List[int]]): + """ + Computes the forward pass of the neural net + + Args: + pos_indices (List[int]): A list of the position index of the target token for lemmatization classification in each sentence. + sentences (List[List[str]]): A list of the token-split sentences of the input data. + upos_tags (List[List[int]]): A list of the upos tags for each token in every sentence. + + Returns: + torch.tensor: Output logits of the neural network, where the shape is (n, output_size) where n is the number of sentences. + """ + device = next(self.parameters()).device + batch_size = len(sentences) + token_ids = [] + delta_token_ids = [] + for words in sentences: + sentence_token_ids = [self.vocab_map.get(word.lower(), UNK_ID) for word in words] + sentence_token_ids = torch.tensor(sentence_token_ids, device=device) + token_ids.append(sentence_token_ids) + + sentence_delta_token_ids = [self.known_word_map.get(word.lower(), 0) for word in words] + sentence_delta_token_ids = torch.tensor(sentence_delta_token_ids, device=device) + delta_token_ids.append(sentence_delta_token_ids) + + token_ids = pad_sequence(token_ids, batch_first=True) + delta_token_ids = pad_sequence(delta_token_ids, batch_first=True) + embedded = self.embeddings(token_ids) + self.delta_embedding(delta_token_ids) + + if self.upos_emb_dim > 0: + upos_tags = [torch.tensor(sentence_tags) for sentence_tags in upos_tags] # convert internal lists to tensors + upos_tags = pad_sequence(upos_tags, batch_first=True, padding_value=0).to(device) + pos_emb = self.upos_emb(upos_tags) + embedded = torch.cat((embedded, pos_emb), 2).to(device) + + if self.use_charlm: + char_reps_forward = self.charmodel_forward.build_char_representation(sentences) # takes [[str]] + char_reps_backward = self.charmodel_backward.build_char_representation(sentences) + + char_reps_forward = pad_sequence(char_reps_forward, batch_first=True) + char_reps_backward = pad_sequence(char_reps_backward, batch_first=True) + + embedded = torch.cat((embedded, char_reps_forward, char_reps_backward), 2) + + if self.num_heads > 0: + + def positional_encoding(seq_len, d_model, device): + encoding = torch.zeros(seq_len, d_model, device=device) + position = torch.arange(0, seq_len, dtype=torch.float, device=device).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)).to(device) + + encoding[:, 0::2] = torch.sin(position * div_term) + encoding[:, 1::2] = torch.cos(position * div_term) + + # Add a new dimension to fit the batch size + encoding = encoding.unsqueeze(0) + return encoding + + seq_len, d_model = embedded.shape[1], embedded.shape[2] + pos_enc = positional_encoding(seq_len, d_model, device=device) + + embedded += pos_enc.expand_as(embedded) + + padded_sequences = pad_sequence(embedded, batch_first=True) + lengths = torch.tensor([len(seq) for seq in embedded]) + + if self.num_heads > 0: + target_seq_length, src_seq_length = padded_sequences.size(1), padded_sequences.size(1) + attn_mask = torch.triu(torch.ones(batch_size * self.num_heads, target_seq_length, src_seq_length, dtype=torch.bool), diagonal=1) + + attn_mask = attn_mask.view(batch_size, self.num_heads, target_seq_length, src_seq_length) + attn_mask = attn_mask.repeat(1, 1, 1, 1).view(batch_size * self.num_heads, target_seq_length, src_seq_length).to(device) + + attn_output, attn_weights = self.multihead_attn(padded_sequences, padded_sequences, padded_sequences, attn_mask=attn_mask) + # Extract the hidden state at the index of the token to classify + token_reps = attn_output[torch.arange(attn_output.size(0)), pos_indices] + + else: + packed_sequences = pack_padded_sequence(padded_sequences, lengths, batch_first=True) + lstm_out, (hidden, _) = self.lstm(packed_sequences) + # Extract the hidden state at the index of the token to classify + unpacked_lstm_outputs, _ = pad_packed_sequence(lstm_out, batch_first=True) + token_reps = unpacked_lstm_outputs[torch.arange(unpacked_lstm_outputs.size(0)), pos_indices] + + # MLP forward pass + output = self.mlp(token_reps) + return output + + def model_type(self): + return ModelType.LSTM diff --git a/stanza/models/lemma_classifier/prepare_dataset.py b/stanza/models/lemma_classifier/prepare_dataset.py new file mode 100644 index 0000000000..02bdd29f86 --- /dev/null +++ b/stanza/models/lemma_classifier/prepare_dataset.py @@ -0,0 +1,136 @@ +import argparse +import os +import re + +import stanza +from stanza.models.lemma_classifier import utils + +from typing import List, Tuple, Any + +""" +The code in this file processes a CoNLL dataset by taking its sentences and filtering out all sentences that do not contain the target token. +Furthermore, it will store tuples of the Stanza document object, the position index of the target token, and its lemma. +""" + + +def load_doc_from_conll_file(path: str): + """" + loads in a Stanza document object from a path to a CoNLL file containing annotated sentences. + """ + return stanza.utils.conll.CoNLL.conll2doc(path) + + +class DataProcessor(): + + def __init__(self, target_word: str, target_upos: List[str], allowed_lemmas: str): + self.target_word = target_word + self.target_upos = target_upos + self.allowed_lemmas = re.compile(allowed_lemmas) + + def find_all_occurrences(self, sentence) -> List[int]: + """ + Finds all occurrences of self.target_word in tokens and returns the index(es) of such occurrences. + """ + occurrences = [] + for idx, token in enumerate(sentence.words): + if token.text == self.target_word and token.upos in self.target_upos: + occurrences.append(idx) + return occurrences + + def process_document(self, doc, keep_condition: callable, save_name: str) -> None: + """ + Takes any sentence from `doc` that meets the condition of `keep_condition` and writes its tokens, index of target word, and lemma to `save_name` + + Sentences that meet `keep_condition` and contain `self.target_word` multiple times have each instance in a different example in the output file. + + Args: + doc (Stanza.doc): Document object that represents the file to be analyzed + keep_condition (callable): A function that outputs a boolean representing whether to analyze (True) or not analyze the sentence for a target word. + save_name (str): Path to the file for storing output + """ + with open(save_name, "w+", encoding="utf-8") as output_f: + for sentence in doc.sentences: + # for each sentence, we need to determine if it should be added to the output file. + # if the sentence fulfills the keep_condition, then we will save it along with the target word's index and its corresponding lemma + if keep_condition(sentence): + tokens = [token.text for token in sentence.words] + indexes = self.find_all_occurrences(sentence) + for idx in indexes: + if self.allowed_lemmas.fullmatch(sentence.words[idx].lemma): + # for each example found, we write the tokens, their respective upos tags, the target token index, lemma, and the number of tokens in the sentence + upos_tags = [sentence.words[i].upos for i in range(len(sentence.words))] + num_tokens = len(upos_tags) + # TODO maybe this should just be done in JSON to avoid lengthy and non-extendable data processing. + output_f.write(f'{" ".join(tokens)} {" ".join(upos_tags)} {idx} {sentence.words[idx].lemma} {num_tokens}\n') + + def read_processed_data(self, file_name: str) -> List[dict]: + """ + Reads the output file from `process_document()` and outputs a list that contains the sentences of interest. Each object within the list + contains a map with three (key, val) pairs: + + "words" is a list that contains the tokens of the sentence + "index" is an integer representing which token in "words" the lemma annotation corresponds to + "upos" is a string that corresponds to the target token's UPOS tag + "lemma" is a string that is the lemma of the target word in the sentence. + + """ + output = [] + with open(file_name, "r", encoding="utf-8") as f: + for line in f.readlines(): + if not line: + continue + + obj = {} + split = line.split() + num_tokens = int(split[-1]) + + # Extract data fields + words = split[: num_tokens] + upos_tags = split[num_tokens: 2 * num_tokens] + index = int(split[-3]) + lemma = split[-2] + + obj["words"] = words + obj["index"] = index + obj["upos_tags"] = upos_tags + obj["lemma"] = lemma + + output.append(obj) + + return output + + +def main(args=None): + parser = argparse.ArgumentParser() + + parser.add_argument("--conll_path", type=str, default=os.path.join(os.path.dirname(__file__), "en_gum-ud-train.conllu"), help="path to the conll file to translate") + parser.add_argument("--target_word", type=str, default="'s", help="Token to classify on, e.g. 's.") + parser.add_argument("--target_upos", type=str, default="AUX", help="upos on target token") + parser.add_argument("--output_path", type=str, default="test_output.txt", help="Path for output file") + parser.add_argument("--allowed_lemmas", type=str, default=".*", help="A regex for allowed lemmas. If not set, all lemmas are allowed") + + args = parser.parse_args(args) + + conll_path = args.conll_path + target_word = args.target_word + target_upos = args.target_upos + output_path = args.output_path + allowed_lemmas = args.allowed_lemmas + + args = vars(args) + for arg in args: + print(f"{arg}: {args[arg]}") + + doc = load_doc_from_conll_file(conll_path) + processor = DataProcessor(target_word=target_word, target_upos=[target_upos], allowed_lemmas=allowed_lemmas) + + def keep_sentence(sentence): + for word in sentence.words: + if word.text == target_word and word.upos == target_upos: + return True + return False + + processor.process_document(doc, keep_sentence, output_path) + +if __name__ == "__main__": + main() diff --git a/stanza/models/lemma_classifier/train_lstm_model.py b/stanza/models/lemma_classifier/train_lstm_model.py new file mode 100644 index 0000000000..400702bc45 --- /dev/null +++ b/stanza/models/lemma_classifier/train_lstm_model.py @@ -0,0 +1,146 @@ +""" +The code in this file works to train a lemma classifier for 's +""" + +import argparse +import logging +import os + +import torch +import torch.nn as nn + +from stanza.models.common.foundation_cache import load_pretrain +from stanza.models.lemma_classifier.base_trainer import BaseLemmaClassifierTrainer +from stanza.models.lemma_classifier.constants import DEFAULT_BATCH_SIZE +from stanza.models.lemma_classifier.lstm_model import LemmaClassifierLSTM + +logger = logging.getLogger('stanza.lemmaclassifier') + +class LemmaClassifierTrainer(BaseLemmaClassifierTrainer): + """ + Class to assist with training a LemmaClassifierLSTM + """ + + def __init__(self, model_args: dict, embedding_file: str, use_charlm: bool = False, charlm_forward_file: str = None, charlm_backward_file: str = None, lr: float = 0.001, loss_func: str = None): + """ + Initializes the LemmaClassifierTrainer class. + + Args: + model_args (dict): Various model shape parameters + embedding_file (str): What word embeddings file to use. Use a Stanza pretrain .pt + use_charlm (bool, optional): Whether to use charlm embeddings as well. Defaults to False. + charlm_forward_file (str): Path to the forward pass embeddings for the charlm + charlm_backward_file (str): Path to the backward pass embeddings for the charlm + upos_emb_dim (int): The dimension size of UPOS tag embeddings + num_heads (int): The number of attention heads to use. + lr (float): Learning rate, defaults to 0.001. + loss_func (str): Which loss function to use (either 'ce' or 'weighted_bce') + + Raises: + FileNotFoundError: If the forward charlm file is not present + FileNotFoundError: If the backward charlm file is not present + """ + super().__init__() + + self.model_args = model_args + + # Load word embeddings + pt = load_pretrain(embedding_file) + self.pt_embedding = pt + + # Load CharLM embeddings + if use_charlm and charlm_forward_file is not None and not os.path.exists(charlm_forward_file): + raise FileNotFoundError(f"Could not find forward charlm file: {charlm_forward_file}") + if use_charlm and charlm_backward_file is not None and not os.path.exists(charlm_backward_file): + raise FileNotFoundError(f"Could not find backward charlm file: {charlm_backward_file}") + + # TODO: just pass around the args instead + self.use_charlm = use_charlm + self.charlm_forward_file = charlm_forward_file + self.charlm_backward_file = charlm_backward_file + self.lr = lr + + # Find loss function + if loss_func == "ce": + self.criterion = nn.CrossEntropyLoss() + self.weighted_loss = False + logger.debug("Using CE loss") + elif loss_func == "weighted_bce": + self.criterion = nn.BCEWithLogitsLoss() + self.weighted_loss = True # used to add weights during train time. + logger.debug("Using Weighted BCE loss") + else: + raise ValueError("Must enter a valid loss function (e.g. 'ce' or 'weighted_bce')") + + def build_model(self, label_decoder, upos_to_id, known_words): + return LemmaClassifierLSTM(self.model_args, self.output_dim, self.pt_embedding, label_decoder, upos_to_id, known_words, + use_charlm=self.use_charlm, charlm_forward_file=self.charlm_forward_file, charlm_backward_file=self.charlm_backward_file) + +def build_argparse(): + parser = argparse.ArgumentParser() + parser.add_argument("--hidden_dim", type=int, default=256, help="Size of hidden layer") + parser.add_argument('--wordvec_pretrain_file', type=str, default=os.path.join(os.path.dirname(__file__), "pretrain", "glove.pt"), help='Exact name of the pretrain file to read') + parser.add_argument("--charlm", action='store_true', dest='use_charlm', default=False, help="Whether not to use the charlm embeddings") + parser.add_argument('--charlm_shorthand', type=str, default=None, help="Shorthand for character-level language model training corpus.") + parser.add_argument("--charlm_forward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_forward.pt"), help="Path to forward charlm file") + parser.add_argument("--charlm_backward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_backwards.pt"), help="Path to backward charlm file") + parser.add_argument("--upos_emb_dim", type=int, default=20, help="Dimension size for UPOS tag embeddings.") + parser.add_argument("--use_attn", action='store_true', dest='attn', default=False, help='Whether to use multihead attention instead of LSTM.') + parser.add_argument("--num_heads", type=int, default=0, help="Number of heads to use for multihead attention.") + parser.add_argument("--save_name", type=str, default=os.path.join(os.path.dirname(__file__), "saved_models", "lemma_classifier_model_weighted_loss_charlm_new.pt"), help="Path to model save file") + parser.add_argument("--lr", type=float, default=0.001, help="learning rate") + parser.add_argument("--num_epochs", type=float, default=10, help="Number of training epochs") + parser.add_argument("--batch_size", type=int, default=DEFAULT_BATCH_SIZE, help="Number of examples to include in each batch") + parser.add_argument("--train_file", type=str, default=os.path.join(os.path.dirname(__file__), "data", "processed_ud_en", "combined_train.txt"), help="Full path to training file") + parser.add_argument("--weighted_loss", action='store_true', dest='weighted_loss', default=False, help="Whether to use weighted loss during training.") + parser.add_argument("--eval_file", type=str, default=os.path.join(os.path.dirname(__file__), "data", "processed_ud_en", "combined_dev.txt"), help="Path to dev file used to evaluate model for saves") + return parser + +def main(args=None): + parser = build_argparse() + args = parser.parse_args(args) + + wordvec_pretrain_file = args.wordvec_pretrain_file + use_charlm = args.use_charlm + charlm_forward_file = args.charlm_forward_file + charlm_backward_file = args.charlm_backward_file + upos_emb_dim = args.upos_emb_dim + use_attention = args.attn + num_heads = args.num_heads + save_name = args.save_name + lr = args.lr + num_epochs = args.num_epochs + train_file = args.train_file + weighted_loss = args.weighted_loss + eval_file = args.eval_file + + args = vars(args) + + if os.path.exists(save_name): + raise FileExistsError(f"Save name {save_name} already exists. Training would override existing data. Aborting...") + if not os.path.exists(train_file): + raise FileNotFoundError(f"Training file {train_file} not found. Try again with a valid path.") + + logger.info("Running training script with the following args:") + for arg in args: + logger.info(f"{arg}: {args[arg]}") + logger.info("------------------------------------------------------------") + + trainer = LemmaClassifierTrainer(model_args=args, + embedding_file=wordvec_pretrain_file, + use_charlm=use_charlm, + charlm_forward_file=charlm_forward_file, + charlm_backward_file=charlm_backward_file, + lr=lr, + loss_func="weighted_bce" if weighted_loss else "ce", + ) + + trainer.train( + num_epochs=num_epochs, save_name=save_name, args=args, eval_file=eval_file, train_file=train_file + ) + + return trainer + +if __name__ == "__main__": + main() + diff --git a/stanza/models/lemma_classifier/train_transformer_model.py b/stanza/models/lemma_classifier/train_transformer_model.py new file mode 100644 index 0000000000..f9129e5738 --- /dev/null +++ b/stanza/models/lemma_classifier/train_transformer_model.py @@ -0,0 +1,129 @@ +""" +This file contains code used to train a baseline transformer model to classify on a lemma of a particular token. +""" + +import argparse +import os +import sys +import logging + +import torch +import torch.nn as nn +import torch.optim as optim + +from stanza.models.lemma_classifier.base_trainer import BaseLemmaClassifierTrainer +from stanza.models.lemma_classifier.constants import DEFAULT_BATCH_SIZE +from stanza.models.lemma_classifier.transformer_model import LemmaClassifierWithTransformer +from stanza.models.common.utils import default_device + +logger = logging.getLogger('stanza.lemmaclassifier') + +class TransformerBaselineTrainer(BaseLemmaClassifierTrainer): + """ + Class to assist with training a baseline transformer model to classify on token lemmas. + To find the model spec, refer to `model.py` in this directory. + """ + + def __init__(self, model_args: dict, transformer_name: str = "roberta", loss_func: str = "ce", lr: int = 0.001): + """ + Creates the Trainer object + + Args: + transformer_name (str, optional): What kind of transformer to use for embeddings. Defaults to "roberta". + loss_func (str, optional): Which loss function to use (either 'ce' or 'weighted_bce'). Defaults to "ce". + lr (int, optional): learning rate for the optimizer. Defaults to 0.001. + """ + super().__init__() + + self.model_args = model_args + + # Find loss function + if loss_func == "ce": + self.criterion = nn.CrossEntropyLoss() + self.weighted_loss = False + elif loss_func == "weighted_bce": + self.criterion = nn.BCEWithLogitsLoss() + self.weighted_loss = True # used to add weights during train time. + else: + raise ValueError("Must enter a valid loss function (e.g. 'ce' or 'weighted_bce')") + + self.transformer_name = transformer_name + self.lr = lr + + def set_layer_learning_rates(self, transformer_lr: float, mlp_lr: float) -> torch.optim: + """ + Sets learning rates for each layer of the model. + Currently, the model has the transformer layer and the MLP layer, so these are tweakable. + + Returns (torch.optim): An Adam optimizer with the learning rates adjusted per layer. + + Currently unused - could be refactored into the parent class's train method, + or the parent class could call a build_optimizer and this subclass would use the optimizer + """ + transformer_params, mlp_params = [], [] + for name, param in self.model.named_parameters(): + if 'transformer' in name: + transformer_params.append(param) + elif 'mlp' in name: + mlp_params.append(param) + optimizer = optim.Adam([ + {"params": transformer_params, "lr": transformer_lr}, + {"params": mlp_params, "lr": mlp_lr} + ]) + return optimizer + + def build_model(self, label_decoder, upos_to_id, known_words): + return LemmaClassifierWithTransformer(model_args=self.model_args, output_dim=self.output_dim, transformer_name=self.transformer_name, label_decoder=label_decoder) + + +def main(args=None): + parser = argparse.ArgumentParser() + + parser.add_argument("--save_name", type=str, default=os.path.join(os.path.dirname(os.path.dirname(__file__)), "saved_models", "big_model_roberta_weighted_loss.pt"), help="Path to model save file") + parser.add_argument("--num_epochs", type=float, default=10, help="Number of training epochs") + parser.add_argument("--train_file", type=str, default=os.path.join(os.path.dirname(os.path.dirname(__file__)), "test_sets", "combined_train.txt"), help="Full path to training file") + parser.add_argument("--model_type", type=str, default="roberta", help="Which transformer to use ('bert' or 'roberta')") + parser.add_argument("--bert_model", type=str, default=None, help="Use a specific transformer instead of the default bert/roberta") + parser.add_argument("--loss_fn", type=str, default="weighted_bce", help="Which loss function to train with (e.g. 'ce' or 'weighted_bce')") + parser.add_argument("--batch_size", type=int, default=DEFAULT_BATCH_SIZE, help="Number of examples to include in each batch") + parser.add_argument("--eval_file", type=str, default=os.path.join(os.path.dirname(os.path.dirname(__file__)), "test_sets", "combined_dev.txt"), help="Path to dev file used to evaluate model for saves") + parser.add_argument("--lr", type=float, default=0.001, help="Learning rate for the optimizer.") + + args = parser.parse_args(args) + + save_name = args.save_name + num_epochs = args.num_epochs + train_file = args.train_file + loss_fn = args.loss_fn + eval_file = args.eval_file + lr = args.lr + + args = vars(args) + + if args['model_type'] == 'bert': + args['bert_model'] = 'bert-base-uncased' + elif args['model_type'] == 'roberta': + args['bert_model'] = 'roberta-base' + elif args['model_type'] == 'transformer': + if args['bert_model'] is None: + raise ValueError("Need to specify a bert_model for model_type transformer!") + else: + raise ValueError("Unknown model type " + args['model_type']) + + if os.path.exists(save_name): + raise FileExistsError(f"Save name {save_name} already exists. Training would override existing data. Aborting...") + if not os.path.exists(train_file): + raise FileNotFoundError(f"Training file {train_file} not found. Try again with a valid path.") + + logger.info("Running training script with the following args:") + for arg in args: + logger.info(f"{arg}: {args[arg]}") + logger.info("------------------------------------------------------------") + + trainer = TransformerBaselineTrainer(model_args=args, transformer_name=args['bert_model'], loss_func=loss_fn, lr=lr) + + trainer.train(num_epochs=num_epochs, save_name=save_name, train_file=train_file, args=args, eval_file=eval_file) + return trainer + +if __name__ == "__main__": + main() diff --git a/stanza/models/lemma_classifier/transformer_model.py b/stanza/models/lemma_classifier/transformer_model.py new file mode 100644 index 0000000000..5f32151191 --- /dev/null +++ b/stanza/models/lemma_classifier/transformer_model.py @@ -0,0 +1,83 @@ +import torch +import torch.nn as nn +import os +import sys +import logging + +from transformers import AutoTokenizer, AutoModel +from typing import Mapping, List, Tuple, Any +from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, pad_sequence +from stanza.models.common.bert_embedding import extract_bert_embeddings +from stanza.models.lemma_classifier.base_model import LemmaClassifier +from stanza.models.lemma_classifier.constants import ModelType + +logger = logging.getLogger('stanza.lemmaclassifier') + +class LemmaClassifierWithTransformer(LemmaClassifier): + def __init__(self, model_args: dict, output_dim: int, transformer_name: str, label_decoder: Mapping): + """ + Model architecture: + + Use a transformer (BERT or RoBERTa) to extract contextual embedding over a sentence. + Get the embedding for the word that is to be classified on, and feed the embedding + as input to an MLP classifier that has 2 linear layers, and a prediction head. + + Args: + model_args (dict): args for the model + output_dim (int): Dimension of the output from the MLP + transformer_name (str): name of the HF transformer to use + label_decoder (dict): a map of the labels available to the model + """ + super(LemmaClassifierWithTransformer, self).__init__(label_decoder) + self.model_args = model_args + + # Choose transformer + self.transformer_name = transformer_name + self.tokenizer = AutoTokenizer.from_pretrained(transformer_name, use_fast=True, add_prefix_space=True) + self.add_unsaved_module("transformer", AutoModel.from_pretrained(transformer_name)) + config = self.transformer.config + + embedding_size = config.hidden_size + + # define an MLP layer + self.mlp = nn.Sequential( + nn.Linear(embedding_size, 64), + nn.ReLU(), + nn.Linear(64, output_dim) + ) + + def get_save_dict(self): + save_dict = { + "params": self.state_dict(), + "label_decoder": self.label_decoder, + "model_type": self.model_type(), + "args": self.model_args, + } + skipped = [k for k in save_dict["params"].keys() if self.is_unsaved_module(k)] + for k in skipped: + del save_dict["params"][k] + return save_dict + + def forward(self, idx_positions: List[int], sentences: List[List[str]], upos_tags: List[List[int]]): + """ + Computes the forward pass of the transformer baselines + + Args: + idx_positions (List[int]): A list of the position index of the target token for lemmatization classification in each sentence. + sentences (List[List[str]]): A list of the token-split sentences of the input data. + upos_tags (List[List[int]]): A list of the upos tags for each token in every sentence - not used in this model, here for compatibility + + Returns: + torch.tensor: Output logits of the neural network, where the shape is (n, output_size) where n is the number of sentences. + """ + device = next(self.transformer.parameters()).device + bert_embeddings = extract_bert_embeddings(self.transformer_name, self.tokenizer, self.transformer, sentences, device, + keep_endpoints=False, num_layers=1, detach=True) + embeddings = [emb[idx] for idx, emb in zip(idx_positions, bert_embeddings)] + embeddings = torch.stack(embeddings, dim=0)[:, :, 0] + # pass to the MLP + output = self.mlp(embeddings) + return output + + def model_type(self): + return ModelType.TRANSFORMER diff --git a/stanza/models/lemma_classifier/utils.py b/stanza/models/lemma_classifier/utils.py new file mode 100644 index 0000000000..fa8ec3b216 --- /dev/null +++ b/stanza/models/lemma_classifier/utils.py @@ -0,0 +1,167 @@ +from collections import Counter, defaultdict +import logging +import os +import random +from typing import List, Tuple, Any, Mapping + +import stanza +import torch + +from stanza.models.lemma_classifier.constants import DEFAULT_BATCH_SIZE +import stanza.models.lemma_classifier.prepare_dataset as prepare_dataset + +logger = logging.getLogger('stanza.lemmaclassifier') + +class Dataset: + def __init__(self, data_path: str, batch_size: int =DEFAULT_BATCH_SIZE, get_counts: bool = False, label_decoder: dict = None, shuffle: bool = True): + """ + Loads a data file into data batches for tokenized text sentences, token indices, and true labels for each sentence. + + Args: + data_path (str): Path to data file, containing tokenized text sentences, token index and true label for token lemma on each line. + batch_size (int): Size of each batch of examples + get_counts (optional, bool): Whether there should be a map of the label index to counts + + Returns: + 1. List[List[List[str]]]: Batches of sentences, where each token is a separate entry in each sentence + 2. List[torch.tensor[int]]: A batch of indexes for the target token corresponding to its sentence + 3. List[torch.tensor[int]]: A batch of labels for the target token's lemma + 4. List[List[int]]: A batch of UPOS IDs for the target token (this is a List of Lists, not a tensor. It should be padded later.) + 5 (Optional): A mapping of label ID to counts in the dataset. + 6. Mapping[str, int]: A map between the labels and their indexes + 7. Mapping[str, int]: A map between the UPOS tags and their corresponding IDs found in the UPOS batches + """ + + if data_path is None or not os.path.exists(data_path): + raise FileNotFoundError(f"Data file {data_path} could not be found.") + + if label_decoder is None: + label_decoder = {} + else: + # if labels in the test set aren't in the original model, + # the model will never predict those labels, + # but we can still use those labels in a confusion matrix + label_decoder = dict(label_decoder) + + logger.debug("Final label decoder: %s Should be strings to ints", label_decoder) + + known_words = set() + + with open(data_path, "r+", encoding="utf-8") as f: + sentences, indices, labels, upos_ids, counts, upos_to_id = [], [], [], [], Counter(), defaultdict(str) + + data_processor = prepare_dataset.DataProcessor("", [], "") + sentences_data = data_processor.read_processed_data(data_path) + + for idx, sentence in enumerate(sentences_data): + # TODO Could replace this with sentence.values(), but need to know if Stanza requires Python 3.7 or later for backward compatability reasons + words, target_idx, upos_tags, label = sentence.get("words"), sentence.get("index"), sentence.get("upos_tags"), sentence.get("lemma") + if None in [words, target_idx, upos_tags, label]: + raise ValueError(f"Expected data to be complete but found a null value in sentence {idx}: {sentence}") + + label_id = label_decoder.get(label, None) + if label_id is None: + label_decoder[label] = len(label_decoder) # create a new ID for the unknown label + + converted_upos_tags = [] # convert upos tags to upos IDs + for upos_tag in upos_tags: + upos_id = upos_to_id.get(upos_tag, None) + if upos_id is None: + upos_to_id[upos_tag] = len(upos_to_id) # create a new ID for the unknown UPOS tag + converted_upos_tags.append(upos_to_id[upos_tag]) + + sentences.append(words) + indices.append(target_idx) + upos_ids.append(converted_upos_tags) + labels.append(label_decoder[label]) + + if get_counts: + counts[label_decoder[label]] += 1 + + known_words.update(words) + + self.sentences = sentences + self.indices = indices + self.upos_ids = upos_ids + self.labels = labels + + self.counts = counts + self.label_decoder = label_decoder + self.upos_to_id = upos_to_id + + self.batch_size = batch_size + self.shuffle = shuffle + + self.known_words = [x.lower() for x in sorted(known_words)] + + def __len__(self): + """ + Number of batches, rounded up to nearest batch + """ + return len(self.sentences) // self.batch_size + (len(self.sentences) % self.batch_size > 0) + + def __iter__(self): + num_sentences = len(self.sentences) + indices = list(range(num_sentences)) + if self.shuffle: + random.shuffle(indices) + for i in range(self.__len__()): + batch_start = self.batch_size * i + batch_end = min(batch_start + self.batch_size, num_sentences) + + batch_sentences = [self.sentences[x] for x in indices[batch_start:batch_end]] + batch_indices = torch.tensor([self.indices[x] for x in indices[batch_start:batch_end]]) + batch_upos_ids = [self.upos_ids[x] for x in indices[batch_start:batch_end]] + batch_labels = torch.tensor([self.labels[x] for x in indices[batch_start:batch_end]]) + yield batch_sentences, batch_indices, batch_upos_ids, batch_labels + +def extract_unknown_token_indices(tokenized_indices: torch.tensor, unknown_token_idx: int) -> List[int]: + """ + Extracts the indices within `tokenized_indices` which match `unknown_token_idx` + + Args: + tokenized_indices (torch.tensor): A tensor filled with tokenized indices of words that have been mapped to vector indices. + unknown_token_idx (int): The special index for which unknown tokens are marked in the word vectors. + + Returns: + List[int]: A list of indices in `tokenized_indices` which match `unknown_token_index` + """ + return [idx for idx, token_index in enumerate(tokenized_indices) if token_index == unknown_token_idx] + + +def get_device(): + """ + Get the device to run computations on + """ + if torch.cuda.is_available: + device = torch.device("cuda") + if torch.backends.mps.is_available(): + device = torch.device("mps") + else: + device = torch.device("cpu") + + return device + + +def round_up_to_multiple(number, multiple): + if multiple == 0: + return "Error: The second number (multiple) cannot be zero." + + # Calculate the remainder when dividing the number by the multiple + remainder = number % multiple + + # If remainder is non-zero, round up to the next multiple + if remainder != 0: + rounded_number = number + (multiple - remainder) + else: + rounded_number = number # No rounding needed + + return rounded_number + + +def main(): + default_test_path = os.path.join(os.path.dirname(__file__), "test_sets", "processed_ud_en", "combined_dev.txt") # get the GUM stuff + sentence_batches, indices_batches, upos_batches, _, counts, _, upos_to_id = load_dataset(default_test_path, get_counts=True) + +if __name__ == "__main__": + main() From d61096033419e366a5bab35cfda64b72b7c4d1bf Mon Sep 17 00:00:00 2001 From: John Bauer Date: Sun, 15 Sep 2024 20:26:56 -0700 Subject: [PATCH 02/10] Wrapper file for training the lemma classifier. Will download missing charlms if they exist run_lemma_classifier.py now automatically tries to pick a save name and training filename appropriate for the dataset being trained. Still need to calculate the lemmas to predict and use a language-appropriate wordvec file before we can do other languages, though Add the ability to use run_lemma_classifier.py in --score_dev mode Add --score_test to the lemma_classifier as well Connects the transformer baseline to the run_lemma_classifier script Reports the dev & test scores when running in TRAIN mode --- stanza/utils/training/run_lemma_classifier.py | 83 +++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 stanza/utils/training/run_lemma_classifier.py diff --git a/stanza/utils/training/run_lemma_classifier.py b/stanza/utils/training/run_lemma_classifier.py new file mode 100644 index 0000000000..5f68f07048 --- /dev/null +++ b/stanza/utils/training/run_lemma_classifier.py @@ -0,0 +1,83 @@ +import os + +from stanza.models.lemma_classifier import evaluate_models +from stanza.models.lemma_classifier import train_lstm_model +from stanza.models.lemma_classifier import train_transformer_model +from stanza.models.lemma_classifier.constants import ModelType + +from stanza.resources.default_packages import default_pretrains, TRANSFORMERS +from stanza.utils.training import common +from stanza.utils.training.common import Mode, add_charlm_args, build_lemma_charlm_args, choose_lemma_charlm, find_wordvec_pretrain + +def add_lemma_args(parser): + add_charlm_args(parser) + + parser.add_argument('--model_type', default=ModelType.LSTM, type=lambda x: ModelType[x.upper()], + help='Model type to use. {}'.format(", ".join(x.name for x in ModelType))) + +def build_model_filename(paths, short_name, command_args, extra_args): + return os.path.join("saved_models", "lemma_classifier", short_name + "_lemma_classifier.pt") + +def run_treebank(mode, paths, treebank, short_name, + temp_output_file, command_args, extra_args): + short_language, dataset = short_name.split("_", 1) + + base_args = [] + if '--save_name' not in extra_args: + base_args += ['--save_name', build_model_filename(paths, short_name, command_args, extra_args)] + + embedding_args = build_lemma_charlm_args(short_language, dataset, command_args.charlm) + if '--wordvec_pretrain_file' not in extra_args: + wordvec_pretrain = find_wordvec_pretrain(short_language, default_pretrains, {}, dataset) + embedding_args += ["--wordvec_pretrain_file", wordvec_pretrain] + + bert_args = [] + if command_args.model_type is ModelType.TRANSFORMER: + if '--bert_model' not in extra_args: + if short_language in TRANSFORMERS: + bert_args = ['--bert_model', TRANSFORMERS.get(short_language)] + else: + raise ValueError("--bert_model not specified, so cannot figure out which transformer to use for language %s" % short_language) + + if mode == Mode.TRAIN: + train_args = [] + if "--train_file" not in extra_args: + train_file = os.path.join("data", "lemma_classifier", "%s.train.lemma" % short_name) + train_args += ['--train_file', train_file] + if "--eval_file" not in extra_args: + eval_file = os.path.join("data", "lemma_classifier", "%s.dev.lemma" % short_name) + train_args += ['--eval_file', eval_file] + train_args = base_args + train_args + extra_args + + if command_args.model_type == ModelType.LSTM: + train_args = embedding_args + train_args + train_lstm_model.main(train_args) + else: + model_type_args = ["--model_type", command_args.model_type.name.lower()] + train_args = bert_args + model_type_args + train_args + train_transformer_model.main(train_args) + + if mode == Mode.SCORE_DEV or mode == Mode.TRAIN: + eval_args = [] + if "--eval_file" not in extra_args: + eval_file = os.path.join("data", "lemma_classifier", "%s.dev.lemma" % short_name) + eval_args += ['--eval_file', eval_file] + model_type_args = ["--model_type", command_args.model_type.name.lower()] + eval_args = bert_args + model_type_args + base_args + eval_args + embedding_args + extra_args + evaluate_models.main(eval_args) + + if mode == Mode.SCORE_TEST or mode == Mode.TRAIN: + eval_args = [] + if "--eval_file" not in extra_args: + eval_file = os.path.join("data", "lemma_classifier", "%s.test.lemma" % short_name) + eval_args += ['--eval_file', eval_file] + model_type_args = ["--model_type", command_args.model_type.name.lower()] + eval_args = bert_args + model_type_args + base_args + eval_args + embedding_args + extra_args + evaluate_models.main(eval_args) + +def main(): + common.main(run_treebank, "lemma_classifier", "lemma_classifier", add_lemma_args, sub_argparse=train_lstm_model.build_argparse(), build_model_filename=build_model_filename, choose_charlm_method=choose_lemma_charlm) + + +if __name__ == '__main__': + main() From ccdea16384fc97c4c5f876f704078a609c4e78ef Mon Sep 17 00:00:00 2001 From: John Bauer Date: Mon, 18 Dec 2023 01:03:49 -0800 Subject: [PATCH 03/10] Add a script to convert the various datasets to a lemma classifier dataset fa_perdt, ja_gsd, AR, HI as current options for the lemma classifier --- .../datasets/prepare_lemma_classifier.py | 90 +++++++++++++++++++ stanza/utils/default_paths.py | 1 + 2 files changed, 91 insertions(+) create mode 100644 stanza/utils/datasets/prepare_lemma_classifier.py diff --git a/stanza/utils/datasets/prepare_lemma_classifier.py b/stanza/utils/datasets/prepare_lemma_classifier.py new file mode 100644 index 0000000000..630091dd21 --- /dev/null +++ b/stanza/utils/datasets/prepare_lemma_classifier.py @@ -0,0 +1,90 @@ +import os +import sys + +from stanza.utils.datasets.common import find_treebank_dataset_file, UnknownDatasetError +from stanza.utils.default_paths import get_default_paths +from stanza.models.lemma_classifier import prepare_dataset +from stanza.models.common.short_name_to_treebank import short_name_to_treebank + +SECTIONS = ("train", "dev", "test") + +def process_treebank(paths, short_name, word, upos, allowed_lemmas, sections=SECTIONS): + treebank = short_name_to_treebank(short_name) + udbase_dir = paths["UDBASE"] + + output_dir = paths["LEMMA_CLASSIFIER_DATA_DIR"] + os.makedirs(output_dir, exist_ok=True) + + output_filenames = [] + + for section in sections: + filename = find_treebank_dataset_file(treebank, udbase_dir, section, "conllu", fail=True) + output_filename = os.path.join(output_dir, "%s.%s.lemma" % (short_name, section)) + args = ["--conll_path", filename, + "--target_word", word, + "--target_upos", upos, + "--output_path", output_filename] + if allowed_lemmas is not None: + args.extend(["--allowed_lemmas", allowed_lemmas]) + prepare_dataset.main(args) + output_filenames.append(output_filename) + + return output_filenames + +def process_ja_gsd(paths, short_name): + # this one looked promising, but only has 10 total dev & test cases + # 行っ VERB Counter({'行う': 60, '行く': 38}) + # could possibly do + # ない AUX Counter({'ない': 383, '無い': 99}) + # なく AUX Counter({'無い': 53, 'ない': 42}) + # currently this one has enough in the dev & test data + # and functions well + # だ AUX Counter({'だ': 237, 'た': 67}) + word = "だ" + upos = "AUX" + allowed_lemmas = None + + process_treebank(paths, short_name, word, upos, allowed_lemmas) + +def process_fa_perdt(paths, short_name): + word = "شد" + upos = "VERB" + allowed_lemmas = "کرد|شد" + + process_treebank(paths, short_name, word, upos, allowed_lemmas) + +def process_hi_hdtb(paths, short_name): + word = "के" + upos = "ADP" + allowed_lemmas = "का|के" + + process_treebank(paths, short_name, word, upos, allowed_lemmas) + +def process_ar_padt(paths, short_name): + word = "أن" + upos = "SCONJ" + allowed_lemmas = "أَن|أَنَّ" + + process_treebank(paths, short_name, word, upos, allowed_lemmas) + +DATASET_MAPPING = { + "ar_padt": process_ar_padt, + "fa_perdt": process_fa_perdt, + "hi_hdtb": process_hi_hdtb, + "ja_gsd": process_ja_gsd, +} + + +def main(dataset_name): + paths = get_default_paths() + print("Processing %s" % dataset_name) + + # obviously will want to multiplex to multiple languages / datasets + if dataset_name in DATASET_MAPPING: + DATASET_MAPPING[dataset_name](paths, dataset_name) + else: + raise UnknownDatasetError(dataset_name, f"dataset {dataset_name} currently not handled by prepare_ner_dataset") + print("Done processing %s" % dataset_name) + +if __name__ == '__main__': + main(sys.argv[1]) diff --git a/stanza/utils/default_paths.py b/stanza/utils/default_paths.py index 0618d8b634..ef87cc14f6 100644 --- a/stanza/utils/default_paths.py +++ b/stanza/utils/default_paths.py @@ -21,6 +21,7 @@ def get_default_paths(): "SENTIMENT_DATA_DIR": DATA_ROOT + "/sentiment", "CONSTITUENCY_DATA_DIR": DATA_ROOT + "/constituency", "COREF_DATA_DIR": DATA_ROOT + "/coref", + "LEMMA_CLASSIFIER_DATA_DIR": DATA_ROOT + "/lemma_classifier", # Set directories to store external word vector data "WORDVEC_DIR": "extern_data/wordvec", From 356f76b0426205cefef7686ac5e3f1e6dc3a4303 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Sun, 15 Sep 2024 21:07:22 -0700 Subject: [PATCH 04/10] Add Greek as an option to the lemma_classifier data preparation This requires using a target regex instead of target word to make it simpler to match multiple words at once in the data preparation code --- .../models/lemma_classifier/prepare_dataset.py | 8 ++++---- .../utils/datasets/prepare_lemma_classifier.py | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/stanza/models/lemma_classifier/prepare_dataset.py b/stanza/models/lemma_classifier/prepare_dataset.py index 02bdd29f86..d9855732a6 100644 --- a/stanza/models/lemma_classifier/prepare_dataset.py +++ b/stanza/models/lemma_classifier/prepare_dataset.py @@ -33,7 +33,7 @@ def find_all_occurrences(self, sentence) -> List[int]: """ occurrences = [] for idx, token in enumerate(sentence.words): - if token.text == self.target_word and token.upos in self.target_upos: + if self.target_word.fullmatch(token.text) and token.upos in self.target_upos: occurrences.append(idx) return occurrences @@ -112,7 +112,7 @@ def main(args=None): args = parser.parse_args(args) conll_path = args.conll_path - target_word = args.target_word + target_word = re.compile(args.target_word) target_upos = args.target_upos output_path = args.output_path allowed_lemmas = args.allowed_lemmas @@ -126,8 +126,8 @@ def main(args=None): def keep_sentence(sentence): for word in sentence.words: - if word.text == target_word and word.upos == target_upos: - return True + if target_word.fullmatch(word.text) and word.upos == target_upos: + return True return False processor.process_document(doc, keep_sentence, output_path) diff --git a/stanza/utils/datasets/prepare_lemma_classifier.py b/stanza/utils/datasets/prepare_lemma_classifier.py index 630091dd21..b18692ed6b 100644 --- a/stanza/utils/datasets/prepare_lemma_classifier.py +++ b/stanza/utils/datasets/prepare_lemma_classifier.py @@ -67,8 +67,26 @@ def process_ar_padt(paths, short_name): process_treebank(paths, short_name, word, upos, allowed_lemmas) +def process_el_gdt(paths, short_name): + """ + All of the Greek lemmas for these words are εγώ or μου + + τους PRON Counter({'μου': 118, 'εγώ': 32}) + μας PRON Counter({'μου': 89, 'εγώ': 32}) + του PRON Counter({'μου': 82, 'εγώ': 8}) + της PRON Counter({'μου': 80, 'εγώ': 2}) + σας PRON Counter({'μου': 34, 'εγώ': 24}) + μου PRON Counter({'μου': 45, 'εγώ': 10}) + """ + word = "τους|μας|του|της|σας|μου" + upos = "PRON" + allowed_lemmas = None + + process_treebank(paths, short_name, word, upos, allowed_lemmas) + DATASET_MAPPING = { "ar_padt": process_ar_padt, + "el_gdt": process_el_gdt, "fa_perdt": process_fa_perdt, "hi_hdtb": process_hi_hdtb, "ja_gsd": process_ja_gsd, From 4412d5a18ca71eae4b65216554972d7b0379fb68 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Sun, 15 Sep 2024 21:17:28 -0700 Subject: [PATCH 05/10] Add a test of the lemma_classifier data preparation code Add a sample 9/2/2 dataset and test that it gets read in a way we might like --- stanza/tests/lemma_classifier/__init__.py | 0 .../lemma_classifier/test_data_preparation.py | 256 ++++++++++++++++++ 2 files changed, 256 insertions(+) create mode 100644 stanza/tests/lemma_classifier/__init__.py create mode 100644 stanza/tests/lemma_classifier/test_data_preparation.py diff --git a/stanza/tests/lemma_classifier/__init__.py b/stanza/tests/lemma_classifier/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/stanza/tests/lemma_classifier/test_data_preparation.py b/stanza/tests/lemma_classifier/test_data_preparation.py new file mode 100644 index 0000000000..93efa8aaff --- /dev/null +++ b/stanza/tests/lemma_classifier/test_data_preparation.py @@ -0,0 +1,256 @@ +import os + +import pytest + +import stanza.models.lemma_classifier.utils as utils +import stanza.utils.datasets.prepare_lemma_classifier as prepare_lemma_classifier + +pytestmark = [pytest.mark.pipeline, pytest.mark.travis] + +EWT_ONE_SENTENCE = """ +# sent_id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-0002 +# newpar id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-p0002 +# text = Here's a Miami Herald interview +1-2 Here's _ _ _ _ _ _ _ _ +1 Here here ADV RB PronType=Dem 0 root 0:root _ +2 's be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 1 cop 1:cop _ +3 a a DET DT Definite=Ind|PronType=Art 6 det 6:det _ +4 Miami Miami PROPN NNP Number=Sing 5 compound 5:compound _ +5 Herald Herald PROPN NNP Number=Sing 6 compound 6:compound _ +6 interview interview NOUN NN Number=Sing 1 nsubj 1:nsubj _ +""".lstrip() + + +EWT_TRAIN_SENTENCES = """ +# sent_id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-0002 +# newpar id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-p0002 +# text = Here's a Miami Herald interview +1-2 Here's _ _ _ _ _ _ _ _ +1 Here here ADV RB PronType=Dem 0 root 0:root _ +2 's be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 1 cop 1:cop _ +3 a a DET DT Definite=Ind|PronType=Art 6 det 6:det _ +4 Miami Miami PROPN NNP Number=Sing 5 compound 5:compound _ +5 Herald Herald PROPN NNP Number=Sing 6 compound 6:compound _ +6 interview interview NOUN NN Number=Sing 1 nsubj 1:nsubj _ + +# sent_id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-0027 +# text = But Posada's nearly 80 years old +1 But but CCONJ CC _ 7 cc 7:cc _ +2-3 Posada's _ _ _ _ _ _ _ _ +2 Posada Posada PROPN NNP Number=Sing 7 nsubj 7:nsubj _ +3 's be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 7 cop 7:cop _ +4 nearly nearly ADV RB _ 5 advmod 5:advmod _ +5 80 80 NUM CD NumForm=Digit|NumType=Card 6 nummod 6:nummod _ +6 years year NOUN NNS Number=Plur 7 obl:npmod 7:obl:npmod _ +7 old old ADJ JJ Degree=Pos 0 root 0:root SpaceAfter=No + +# sent_id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-0067 +# newpar id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-p0011 +# text = Now that's a post I can relate to. +1 Now now ADV RB _ 5 advmod 5:advmod _ +2-3 that's _ _ _ _ _ _ _ _ +2 that that PRON DT Number=Sing|PronType=Dem 5 nsubj 5:nsubj _ +3 's be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 5 cop 5:cop _ +4 a a DET DT Definite=Ind|PronType=Art 5 det 5:det _ +5 post post NOUN NN Number=Sing 0 root 0:root _ +6 I I PRON PRP Case=Nom|Number=Sing|Person=1|PronType=Prs 8 nsubj 8:nsubj _ +7 can can AUX MD VerbForm=Fin 8 aux 8:aux _ +8 relate relate VERB VB VerbForm=Inf 5 acl:relcl 5:acl:relcl _ +9 to to ADP IN _ 8 obl 8:obl SpaceAfter=No +10 . . PUNCT . _ 5 punct 5:punct _ + +# sent_id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-0073 +# newpar id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-p0012 +# text = hey that's a great blog +1 hey hey INTJ UH _ 6 discourse 6:discourse _ +2-3 that's _ _ _ _ _ _ _ _ +2 that that PRON DT Number=Sing|PronType=Dem 6 nsubj 6:nsubj _ +3 's be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 6 cop 6:cop _ +4 a a DET DT Definite=Ind|PronType=Art 6 det 6:det _ +5 great great ADJ JJ Degree=Pos 6 amod 6:amod _ +6 blog blog NOUN NN Number=Sing 0 root 0:root SpaceAfter=No + +# sent_id = weblog-blogspot.com_rigorousintuition_20050518101500_ENG_20050518_101500-0089 +# text = And It's Not Hard To Do +1 And and CCONJ CC _ 5 cc 5:cc _ +2-3 It's _ _ _ _ _ _ _ _ +2 It it PRON PRP Case=Nom|Gender=Neut|Number=Sing|Person=3|PronType=Prs 5 expl 5:expl _ +3 's be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 5 cop 5:cop _ +4 Not not PART RB _ 5 advmod 5:advmod _ +5 Hard hard ADJ JJ Degree=Pos 0 root 0:root _ +6 To to PART TO _ 7 mark 7:mark _ +7 Do do VERB VB VerbForm=Inf 5 csubj 5:csubj SpaceAfter=No + +# sent_id = weblog-blogspot.com_rigorousintuition_20060511134300_ENG_20060511_134300-0029 +# text = Meanwhile, a decision's been reached +1 Meanwhile meanwhile ADV RB _ 7 advmod 7:advmod SpaceAfter=No +2 , , PUNCT , _ 1 punct 1:punct _ +3 a a DET DT Definite=Ind|PronType=Art 4 det 4:det _ +4-5 decision's _ _ _ _ _ _ _ _ +4 decision decision NOUN NN Number=Sing 7 nsubj:pass 7:nsubj:pass _ +5 's have AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 7 aux 7:aux _ +6 been be AUX VBN Tense=Past|VerbForm=Part 7 aux:pass 7:aux:pass _ +7 reached reach VERB VBN Tense=Past|VerbForm=Part|Voice=Pass 0 root 0:root _ + +# sent_id = weblog-blogspot.com_rigorousintuition_20060511134300_ENG_20060511_134300-0138 +# text = It's become a guardian of morality +1-2 It's _ _ _ _ _ _ _ _ +1 It it PRON PRP Case=Nom|Gender=Neut|Number=Sing|Person=3|PronType=Prs 3 nsubj 3:nsubj|5:nsubj:xsubj _ +2 's have AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 3 aux 3:aux _ +3 become become VERB VBN Tense=Past|VerbForm=Part 0 root 0:root _ +4 a a DET DT Definite=Ind|PronType=Art 5 det 5:det _ +5 guardian guardian NOUN NN Number=Sing 3 xcomp 3:xcomp _ +6 of of ADP IN _ 7 case 7:case _ +7 morality morality NOUN NN Number=Sing 5 nmod 5:nmod:of _ + +# sent_id = email-enronsent15_01-0018 +# text = It's got its own bathroom and tv +1-2 It's _ _ _ _ _ _ _ _ +1 It it PRON PRP Case=Nom|Gender=Neut|Number=Sing|Person=3|PronType=Prs 3 nsubj 3:nsubj|13:nsubj _ +2 's have AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 3 aux 3:aux _ +3 got get VERB VBN Tense=Past|VerbForm=Part 0 root 0:root _ +4 its its PRON PRP$ Case=Gen|Gender=Neut|Number=Sing|Person=3|Poss=Yes|PronType=Prs 6 nmod:poss 6:nmod:poss _ +5 own own ADJ JJ Degree=Pos 6 amod 6:amod _ +6 bathroom bathroom NOUN NN Number=Sing 3 obj 3:obj _ +7 and and CCONJ CC _ 8 cc 8:cc _ +8 tv TV NOUN NN Number=Sing 6 conj 3:obj|6:conj:and SpaceAfter=No + +# sent_id = newsgroup-groups.google.com_alt.animals.cat_01ff709c4bf2c60c_ENG_20040418_040100-0022 +# text = It's also got the website +1-2 It's _ _ _ _ _ _ _ _ +1 It it PRON PRP Case=Nom|Gender=Neut|Number=Sing|Person=3|PronType=Prs 4 nsubj 4:nsubj _ +2 's have AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 4 aux 4:aux _ +3 also also ADV RB _ 4 advmod 4:advmod _ +4 got get VERB VBN Tense=Past|VerbForm=Part 0 root 0:root _ +5 the the DET DT Definite=Def|PronType=Art 6 det 6:det _ +6 website website NOUN NN Number=Sing 4 obj 4:obj|12:obl _ +""".lstrip() + + +# from the train set, actually +EWT_DEV_SENTENCES = """ +# sent_id = answers-20111108104724AAuBUR7_ans-0044 +# text = He's only exhibited weight loss and some muscle atrophy +1-2 He's _ _ _ _ _ _ _ _ +1 He he PRON PRP Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs 4 nsubj 4:nsubj _ +2 's have AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 4 aux 4:aux _ +3 only only ADV RB _ 4 advmod 4:advmod _ +4 exhibited exhibit VERB VBN Tense=Past|VerbForm=Part 0 root 0:root _ +5 weight weight NOUN NN Number=Sing 6 compound 6:compound _ +6 loss loss NOUN NN Number=Sing 4 obj 4:obj _ +7 and and CCONJ CC _ 10 cc 10:cc _ +8 some some DET DT PronType=Ind 10 det 10:det _ +9 muscle muscle NOUN NN Number=Sing 10 compound 10:compound _ +10 atrophy atrophy NOUN NN Number=Sing 6 conj 4:obj|6:conj:and SpaceAfter=No + +# sent_id = weblog-blogspot.com_rigorousintuition_20060511134300_ENG_20060511_134300-0097 +# text = It's a good thing too. +1-2 It's _ _ _ _ _ _ _ _ +1 It it PRON PRP Case=Nom|Gender=Neut|Number=Sing|Person=3|PronType=Prs 5 nsubj 5:nsubj _ +2 's be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 5 cop 5:cop _ +3 a a DET DT Definite=Ind|PronType=Art 5 det 5:det _ +4 good good ADJ JJ Degree=Pos 5 amod 5:amod _ +5 thing thing NOUN NN Number=Sing 0 root 0:root _ +6 too too ADV RB _ 5 advmod 5:advmod SpaceAfter=No +7 . . PUNCT . _ 5 punct 5:punct _ +""".lstrip() + +# from the train set, actually +EWT_TEST_SENTENCES = """ +# sent_id = reviews-162422-0015 +# text = He said he's had a long and bad day. +1 He he PRON PRP Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs 2 nsubj 2:nsubj _ +2 said say VERB VBD Mood=Ind|Number=Sing|Person=3|Tense=Past|VerbForm=Fin 0 root 0:root _ +3-4 he's _ _ _ _ _ _ _ _ +3 he he PRON PRP Case=Nom|Gender=Masc|Number=Sing|Person=3|PronType=Prs 5 nsubj 5:nsubj _ +4 's have AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 5 aux 5:aux _ +5 had have VERB VBN Tense=Past|VerbForm=Part 2 ccomp 2:ccomp _ +6 a a DET DT Definite=Ind|PronType=Art 10 det 10:det _ +7 long long ADJ JJ Degree=Pos 10 amod 10:amod _ +8 and and CCONJ CC _ 9 cc 9:cc _ +9 bad bad ADJ JJ Degree=Pos 7 conj 7:conj:and|10:amod _ +10 day day NOUN NN Number=Sing 5 obj 5:obj SpaceAfter=No +11 . . PUNCT . _ 2 punct 2:punct _ + +# sent_id = weblog-blogspot.com_rigorousintuition_20060511134300_ENG_20060511_134300-0100 +# text = What's a few dead soldiers +1-2 What's _ _ _ _ _ _ _ _ +1 What what PRON WP PronType=Int 6 nsubj 6:nsubj _ +2 's be AUX VBZ Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin 6 cop 6:cop _ +3 a a DET DT Definite=Ind|PronType=Art 6 det 6:det _ +4 few few ADJ JJ Degree=Pos 6 amod 6:amod _ +5 dead dead ADJ JJ Degree=Pos 6 amod 6:amod _ +6 soldiers soldier NOUN NNS Number=Plur 0 root 0:root _ +""" + +def write_test_dataset(tmp_path, texts, datasets): + ud_path = tmp_path / "ud" + input_path = ud_path / "UD_English-EWT" + output_path = tmp_path / "data" / "lemma_classifier" + + os.makedirs(input_path, exist_ok=True) + + for text, dataset in zip(texts, datasets): + sample_file = input_path / ("en_ewt-ud-%s.conllu" % dataset) + with open(sample_file, "w", encoding="utf-8") as fout: + fout.write(text) + + paths = {"UDBASE": ud_path, + "LEMMA_CLASSIFIER_DATA_DIR": output_path} + + return paths + +def write_english_test_dataset(tmp_path): + texts = (EWT_TRAIN_SENTENCES, EWT_DEV_SENTENCES, EWT_TEST_SENTENCES) + datasets = prepare_lemma_classifier.SECTIONS + return write_test_dataset(tmp_path, texts, datasets) + +def convert_english_dataset(tmp_path): + paths = write_english_test_dataset(tmp_path) + converted_files = prepare_lemma_classifier.process_treebank(paths, "en_ewt", "'s", "AUX", "be|have") + assert len(converted_files) == 3 + + return converted_files + +def test_convert_one_sentence(tmp_path): + texts = [EWT_ONE_SENTENCE] + datasets = ["train"] + paths = write_test_dataset(tmp_path, texts, datasets) + + converted_files = prepare_lemma_classifier.process_treebank(paths, "en_ewt", "'s", "AUX", "be|have", ["train"]) + assert len(converted_files) == 1 + + dataset = utils.Dataset(converted_files[0], get_counts=True, batch_size=10, shuffle=False) + + assert len(dataset) == 1 + assert dataset.label_decoder == {'be': 0} + id_to_upos = {y: x for x, y in dataset.upos_to_id.items()} + + for text_batches, _, upos_batches, _ in dataset: + assert text_batches == [['Here', "'s", 'a', 'Miami', 'Herald', 'interview']] + upos = [id_to_upos[x] for x in upos_batches[0]] + assert upos == ['ADV', 'AUX', 'DET', 'PROPN', 'PROPN', 'NOUN'] + +def test_convert_dataset(tmp_path): + converted_files = convert_english_dataset(tmp_path) + + dataset = utils.Dataset(converted_files[0], get_counts=True, batch_size=10, shuffle=False) + + assert len(dataset) == 1 + label_decoder = dataset.label_decoder + assert len(label_decoder) == 2 + assert "be" in label_decoder + assert "have" in label_decoder + for text_batches, _, _, _ in dataset: + assert len(text_batches) == 9 + + dataset = utils.Dataset(converted_files[1], get_counts=True, batch_size=10, shuffle=False) + assert len(dataset) == 1 + for text_batches, _, _, _ in dataset: + assert len(text_batches) == 2 + + dataset = utils.Dataset(converted_files[2], get_counts=True, batch_size=10, shuffle=False) + assert len(dataset) == 1 + for text_batches, _, _, _ in dataset: + assert len(text_batches) == 2 + From c01a009c4b52a2ec9e1af56b7d10895b5914b191 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Wed, 10 Jan 2024 23:45:11 -0800 Subject: [PATCH 06/10] Add a test which iterates the LSTM and transformer versions of the LemmaClassifier model Call evaluate_model just in case, although the expectation is that the F1 isn't going to be great --- .../tests/lemma_classifier/test_training.py | 53 +++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 stanza/tests/lemma_classifier/test_training.py diff --git a/stanza/tests/lemma_classifier/test_training.py b/stanza/tests/lemma_classifier/test_training.py new file mode 100644 index 0000000000..c12a3e107c --- /dev/null +++ b/stanza/tests/lemma_classifier/test_training.py @@ -0,0 +1,53 @@ +import glob +import os + +import pytest + +pytestmark = [pytest.mark.pipeline, pytest.mark.travis] + +from stanza.models.lemma_classifier import train_lstm_model +from stanza.models.lemma_classifier import train_transformer_model +from stanza.models.lemma_classifier.base_model import LemmaClassifier +from stanza.models.lemma_classifier.evaluate_models import evaluate_model + +from stanza.tests import TEST_WORKING_DIR +from stanza.tests.lemma_classifier.test_data_preparation import convert_english_dataset + +@pytest.fixture(scope="module") +def pretrain_file(): + return f'{TEST_WORKING_DIR}/in/tiny_emb.pt' + +def test_train_lstm(tmp_path, pretrain_file): + converted_files = convert_english_dataset(tmp_path) + + save_name = str(tmp_path / 'lemma.pt') + + train_file = converted_files[0] + eval_file = converted_files[1] + train_args = ['--wordvec_pretrain_file', pretrain_file, + '--save_name', save_name, + '--train_file', train_file, + '--eval_file', eval_file] + trainer = train_lstm_model.main(train_args) + + evaluate_model(trainer.model, eval_file) + # test that loading the model works + model = LemmaClassifier.load(save_name, None) + +def test_train_transformer(tmp_path, pretrain_file): + converted_files = convert_english_dataset(tmp_path) + + save_name = str(tmp_path / 'lemma.pt') + + train_file = converted_files[0] + eval_file = converted_files[1] + train_args = ['--bert_model', 'hf-internal-testing/tiny-bert', + '--save_name', save_name, + '--train_file', train_file, + '--eval_file', eval_file] + trainer = train_transformer_model.main(train_args) + + evaluate_model(trainer.model, eval_file) + + # test that loading the model works + model = LemmaClassifier.load(save_name, None) From 8573ffbd94ffca4735541d379c0a14494a8b7ca6 Mon Sep 17 00:00:00 2001 From: Alex Shan Date: Thu, 18 Jan 2024 14:26:19 -0800 Subject: [PATCH 07/10] Add utility to train multiple file variants at the same time --- .../models/lemma_classifier/evaluate_many.py | 68 ++++++++ .../lemma_classifier/evaluate_models.py | 44 ++--- .../lemma_classifier/train_lstm_model.py | 4 +- stanza/models/lemma_classifier/train_many.py | 155 ++++++++++++++++++ .../train_transformer_model.py | 4 +- 5 files changed, 250 insertions(+), 25 deletions(-) create mode 100644 stanza/models/lemma_classifier/evaluate_many.py create mode 100644 stanza/models/lemma_classifier/train_many.py diff --git a/stanza/models/lemma_classifier/evaluate_many.py b/stanza/models/lemma_classifier/evaluate_many.py new file mode 100644 index 0000000000..a0ab2c662c --- /dev/null +++ b/stanza/models/lemma_classifier/evaluate_many.py @@ -0,0 +1,68 @@ +""" +Utils to evaluate many models of the same type at once +""" +import argparse +import os +import logging + +from stanza.models.lemma_classifier.evaluate_models import main as evaluate_main + + +logger = logging.getLogger('stanza.lemmaclassifier') + +def evaluate_n_models(path_to_models_dir, args): + + total_results = { + "be": 0.0, + "have": 0.0, + "accuracy": 0.0, + "weighted_f1": 0.0 + } + paths = os.listdir(path_to_models_dir) + num_models = len(paths) + for model_path in paths: + full_path = os.path.join(path_to_models_dir, model_path) + args.save_name = full_path + mcc_results, confusion, acc, weighted_f1 = evaluate_main(predefined_args=args) + + for lemma in mcc_results: + + lemma_f1 = mcc_results.get(lemma, None).get("f1") * 100 + total_results[lemma] += lemma_f1 + + total_results["accuracy"] += acc + total_results["weighted_f1"] += weighted_f1 + + total_results["be"] /= num_models + total_results["have"] /= num_models + total_results["accuracy"] /= num_models + total_results["weighted_f1"] /= num_models + + logger.info(f"Models in {path_to_models_dir} had average weighted f1 of {100 * total_results['weighted_f1']}.\nLemma 'be' had f1: {total_results['be']}\nLemma 'have' had f1: {total_results['have']}.\nAccuracy: {100 * total_results['accuracy']}.\n ({num_models} models evaluated).") + return total_results + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--vocab_size", type=int, default=10000, help="Number of tokens in vocab") + parser.add_argument("--embedding_dim", type=int, default=100, help="Number of dimensions in word embeddings (currently using GloVe)") + parser.add_argument("--hidden_dim", type=int, default=256, help="Size of hidden layer") + parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read') + parser.add_argument("--charlm", action='store_true', default=False, help="Whether not to use the charlm embeddings") + parser.add_argument('--charlm_shorthand', type=str, default=None, help="Shorthand for character-level language model training corpus.") + parser.add_argument("--charlm_forward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_forward.pt"), help="Path to forward charlm file") + parser.add_argument("--charlm_backward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_backwards.pt"), help="Path to backward charlm file") + parser.add_argument("--save_name", type=str, default=os.path.join(os.path.dirname(__file__), "saved_models", "lemma_classifier_model.pt"), help="Path to model save file") + parser.add_argument("--model_type", type=str, default="roberta", help="Which transformer to use ('bert' or 'roberta' or 'lstm')") + parser.add_argument("--bert_model", type=str, default=None, help="Use a specific transformer instead of the default bert/roberta") + parser.add_argument("--eval_file", type=str, help="path to evaluation file") + + # Args specific to several model eval + parser.add_argument("--base_path", type=str, default=None, help="path to dir for eval") + + args = parser.parse_args() + evaluate_n_models(args.base_path, args) + + +if __name__ == "__main__": + main() diff --git a/stanza/models/lemma_classifier/evaluate_models.py b/stanza/models/lemma_classifier/evaluate_models.py index 9c9e4ffa4b..9deb98fdf4 100644 --- a/stanza/models/lemma_classifier/evaluate_models.py +++ b/stanza/models/lemma_classifier/evaluate_models.py @@ -1,14 +1,14 @@ -import os -import sys +import os +import sys parentdir = os.path.dirname(__file__) parentdir = os.path.dirname(parentdir) parentdir = os.path.dirname(parentdir) sys.path.append(parentdir) -import logging +import logging import argparse -import os +import os from typing import Any, List, Tuple, Mapping from collections import defaultdict @@ -47,7 +47,7 @@ def get_weighted_f1(mcc_results: Mapping[int, Mapping[str, float]], confusion: M num_class_examples = sum(confusion.get(class_id).values()) weighted_f1 += class_f1 * num_class_examples num_total_examples += num_class_examples - + return weighted_f1 / num_total_examples @@ -66,8 +66,8 @@ def evaluate_sequences(gold_tag_sequences: List[Any], pred_tag_sequences: List[A e.g. confusion[0][1] = 6 would mean that for gold tag 0, the model predicted tag 1 a total of 6 times. """ assert len(gold_tag_sequences) == len(pred_tag_sequences), \ - f"Length of gold tag sequences is {len(gold_tag_sequences)}, while length of predicted tag sequence is {len(pred_tag_sequences)}" - + f"Length of gold tag sequences is {len(gold_tag_sequences)}, while length of predicted tag sequence is {len(pred_tag_sequences)}" + confusion = defaultdict(lambda: defaultdict(int)) reverse_label_decoder = {y: x for x, y in label_decoder.items()} @@ -81,8 +81,8 @@ def evaluate_sequences(gold_tag_sequences: List[Any], pred_tag_sequences: List[A try: prec = confusion.get(gold_tag, {}).get(gold_tag, 0) / sum([confusion.get(k, {}).get(gold_tag, 0) for k in confusion.keys()]) except ZeroDivisionError: - prec = 0.0 - + prec = 0.0 + try: recall = confusion.get(gold_tag, {}).get(gold_tag, 0) / sum(confusion.get(gold_tag, {}).values()) except ZeroDivisionError: @@ -91,21 +91,21 @@ def evaluate_sequences(gold_tag_sequences: List[Any], pred_tag_sequences: List[A try: f1 = 2 * (prec * recall) / (prec + recall) except ZeroDivisionError: - f1 = 0.0 + f1 = 0.0 multi_class_result[gold_tag] = { "precision": prec, "recall": recall, "f1": f1 } - + if verbose: for lemma in multi_class_result: logger.info(f"Lemma '{lemma}' had precision {100 * multi_class_result[lemma]['precision']}, recall {100 * multi_class_result[lemma]['recall']} and F1 score of {100 * multi_class_result[lemma]['f1']}") - + weighted_f1 = get_weighted_f1(multi_class_result, confusion) - return multi_class_result, confusion, weighted_f1 + return multi_class_result, confusion, weighted_f1 def model_predict(model: nn.Module, position_indices: torch.Tensor, sentences: List[List[str]], upos_tags: List[List[int]]=[]) -> torch.Tensor: @@ -116,14 +116,14 @@ def model_predict(model: nn.Module, position_indices: torch.Tensor, sentences: L model (LemmaClassifier): A trained LemmaClassifier that is able to predict on a target token. position_indices (Tensor[int]): A tensor of the (zero-indexed) position of the target token in `text` for each example in the batch. sentences (List[List[str]]): A list of lists of the tokenized strings of the input sentences. - + Returns: (int): The index of the predicted class in `model`'s output. """ with torch.no_grad(): logits = model(position_indices, sentences, upos_tags) # should be size (batch_size, output_size) predicted_class = torch.argmax(logits, dim=1) # should be size (batch_size, 1) - + return predicted_class @@ -139,9 +139,9 @@ def evaluate_model(model: nn.Module, eval_path: str, verbose: bool = True, is_tr is_training (bool, optional): Whether the model is in training mode. If the model is training, we do not change it to eval mode. Returns: - 1. Multi-class results (Mapping[int, Mapping[str, float]]): first map has keys as the classes (lemma indices) and value is + 1. Multi-class results (Mapping[int, Mapping[str, float]]): first map has keys as the classes (lemma indices) and value is another map with key of "f1", "precision", or "recall" with corresponding values. - 2. Confusion Matrix (Mapping[int, Mapping[int, int]]): A confusion matrix with keys equal to the index of the gold tag, and a value of the + 2. Confusion Matrix (Mapping[int, Mapping[int, int]]): A confusion matrix with keys equal to the index of the gold tag, and a value of the map with the key as the predicted tag and corresponding count of that (gold, pred) pair. 3. Accuracy (float): the total accuracy (num correct / total examples) across the evaluation set. """ @@ -154,7 +154,7 @@ def evaluate_model(model: nn.Module, eval_path: str, verbose: bool = True, is_tr # load in eval data dataset = utils.Dataset(eval_path, label_decoder=model.label_decoder, shuffle=False) - + logger.info(f"Evaluating on evaluation file {eval_path}") correct, total = 0, 0 @@ -176,11 +176,11 @@ def evaluate_model(model: nn.Module, eval_path: str, verbose: bool = True, is_tr if verbose: logger.info(f"Accuracy: {accuracy} ({correct}/{total})") logger.info(f"Label decoder: {dataset.label_decoder}") - + return mc_results, confusion, accuracy, weighted_f1 -def main(args=None): +def main(args=None, predefined_args=None): # TODO: can unify this script with train_lstm_model.py? # TODO: can save the model type in the model .pt, then @@ -200,7 +200,7 @@ def main(args=None): parser.add_argument("--bert_model", type=str, default=None, help="Use a specific transformer instead of the default bert/roberta") parser.add_argument("--eval_file", type=str, help="path to evaluation file") - args = parser.parse_args(args) + args = parser.parse_args(args) if not predefined_args else predefined_args logger.info("Running training script with the following args:") args = vars(args) @@ -221,6 +221,8 @@ def main(args=None): logger.info("______________________________________________") logger.info(f"Weighted f1: {weighted_f1}") + return mcc_results, confusion, acc, weighted_f1 + if __name__ == "__main__": main() diff --git a/stanza/models/lemma_classifier/train_lstm_model.py b/stanza/models/lemma_classifier/train_lstm_model.py index 400702bc45..a8437818c5 100644 --- a/stanza/models/lemma_classifier/train_lstm_model.py +++ b/stanza/models/lemma_classifier/train_lstm_model.py @@ -96,9 +96,9 @@ def build_argparse(): parser.add_argument("--eval_file", type=str, default=os.path.join(os.path.dirname(__file__), "data", "processed_ud_en", "combined_dev.txt"), help="Path to dev file used to evaluate model for saves") return parser -def main(args=None): +def main(args=None, predefined_args=None): parser = build_argparse() - args = parser.parse_args(args) + args = parser.parse_args(args) if predefined_args is None else predefined_args wordvec_pretrain_file = args.wordvec_pretrain_file use_charlm = args.use_charlm diff --git a/stanza/models/lemma_classifier/train_many.py b/stanza/models/lemma_classifier/train_many.py new file mode 100644 index 0000000000..cefe7b93f6 --- /dev/null +++ b/stanza/models/lemma_classifier/train_many.py @@ -0,0 +1,155 @@ +""" +Utils for training and evaluating multiple models simultaneously +""" + +import argparse +import os + +from stanza.models.lemma_classifier.train_lstm_model import main as train_lstm_main +from stanza.models.lemma_classifier.train_transformer_model import main as train_tfmr_main +from stanza.models.lemma_classifier.constants import DEFAULT_BATCH_SIZE + + +change_params_map = { + "lstm_layer": [16, 32, 64, 128, 256, 512], + "upos_emb_dim": [5, 10, 20, 30], + "training_size": [150, 300, 450, 600, 'full'], +} # TODO: Add attention + +def train_n_models(num_models: int, base_path: str, args): + + if args.change_param == "lstm_layer": + for num_layers in change_params_map.get("lstm_layer", None): + for i in range(num_models): + new_save_name = os.path.join(base_path, f"{num_layers}_{i}.pt") + args.save_name = new_save_name + args.hidden_dim = num_layers + train_lstm_main(predefined_args=args) + + if args.change_param == "upos_emb_dim": + for upos_dim in change_params_map("upos_emb_dim", None): + for i in range(num_models): + new_save_name = os.path.join(base_path, f"dim_{upos_dim}_{i}.pt") + args.save_name = new_save_name + args.upos_emb_dim = upos_dim + train_lstm_main(predefined_args=args) + + if args.change_param == "training_size": + for size in change_params_map.get("training_size", None): + for i in range(num_models): + new_save_name = os.path.join(base_path, f"{size}_examples_{i}.pt") + new_train_file = os.path.join(os.path.dirname(__file__), "data", "processed_ud_en", "combined_train.txt") + args.save_name = new_save_name + args.train_file = new_train_file + train_lstm_main(predefined_args=args) + + if args.change_param == "base": + for i in range(num_models): + new_save_name = os.path.join(base_path, f"lstm_model_{i}.pt") + args.save_name = new_save_name + args.weighted_loss = False + train_lstm_main(predefined_args=args) + + if not args.weighted_loss: + args.weighted_loss = True + new_save_name = os.path.join(base_path, f"lstm_model_wloss_{i}.pt") + args.save_name = new_save_name + train_lstm_main(predefined_args=args) + + if args.change_param == "base_charlm": + for i in range(num_models): + new_save_name = os.path.join(base_path, f"lstm_charlm_{i}.pt") + args.save_name = new_save_name + train_lstm_main(predefined_args=args) + + if args.change_param == "base_charlm_upos": + for i in range(num_models): + new_save_name = os.path.join(base_path, f"lstm_charlm_upos_{i}.pt") + args.save_name = new_save_name + train_lstm_main(predefined_args=args) + + if args.change_param == "base_upos": + for i in range(num_models): + new_save_name = os.path.join(base_path, f"lstm_upos_{i}.pt") + args.save_name = new_save_name + train_lstm_main(predefined_args=args) + + if args.change_param == "attn_model": + for i in range(num_models): + new_save_name = os.path.join(base_path, f"attn_model_{args.num_heads}_heads_{i}.pt") + args.save_name = new_save_name + train_lstm_main(predefined_args=args) + +def train_n_tfmrs(num_models: int, base_path: str, args): + + if args.multi_train_type == "tfmr": + + for i in range(num_models): + + if args.change_param == "bert": + new_save_name = os.path.join(base_path, f"bert_{i}.pt") + args.save_name = new_save_name + args.loss_fn = "ce" + train_tfmr_main(predefined_args=args) + + new_save_name = os.path.join(base_path, f"bert_wloss_{i}.pt") + args.save_name = new_save_name + args.loss_fn = "weighted_bce" + train_tfmr_main(predefined_args=args) + + elif args.change_param == "roberta": + new_save_name = os.path.join(base_path, f"roberta_{i}.pt") + args.save_name = new_save_name + args.loss_fn = "ce" + train_tfmr_main(predefined_args=args) + + new_save_name = os.path.join(base_path, f"roberta_wloss_{i}.pt") + args.save_name = new_save_name + args.loss_fn = "weighted_bce" + train_tfmr_main(predefined_args=args) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--hidden_dim", type=int, default=256, help="Size of hidden layer") + parser.add_argument('--wordvec_pretrain_file', type=str, default=os.path.join(os.path.dirname(__file__), "pretrain", "glove.pt"), help='Exact name of the pretrain file to read') + parser.add_argument("--charlm", action='store_true', dest='use_charlm', default=False, help="Whether not to use the charlm embeddings") + parser.add_argument('--charlm_shorthand', type=str, default=None, help="Shorthand for character-level language model training corpus.") + parser.add_argument("--charlm_forward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_forward.pt"), help="Path to forward charlm file") + parser.add_argument("--charlm_backward_file", type=str, default=os.path.join(os.path.dirname(__file__), "charlm_files", "1billion_backwards.pt"), help="Path to backward charlm file") + parser.add_argument("--upos_emb_dim", type=int, default=20, help="Dimension size for UPOS tag embeddings.") + parser.add_argument("--use_attn", action='store_true', dest='attn', default=False, help='Whether to use multihead attention instead of LSTM.') + parser.add_argument("--num_heads", type=int, default=0, help="Number of heads to use for multihead attention.") + parser.add_argument("--save_name", type=str, default=os.path.join(os.path.dirname(__file__), "saved_models", "lemma_classifier_model_weighted_loss_charlm_new.pt"), help="Path to model save file") + parser.add_argument("--lr", type=float, default=0.001, help="learning rate") + parser.add_argument("--num_epochs", type=float, default=10, help="Number of training epochs") + parser.add_argument("--batch_size", type=int, default=DEFAULT_BATCH_SIZE, help="Number of examples to include in each batch") + parser.add_argument("--train_file", type=str, default=os.path.join(os.path.dirname(__file__), "data", "processed_ud_en", "combined_train.txt"), help="Full path to training file") + parser.add_argument("--weighted_loss", action='store_true', dest='weighted_loss', default=False, help="Whether to use weighted loss during training.") + parser.add_argument("--eval_file", type=str, default=os.path.join(os.path.dirname(__file__), "data", "processed_ud_en", "combined_dev.txt"), help="Path to dev file used to evaluate model for saves") + # Tfmr-specific args + parser.add_argument("--model_type", type=str, default="roberta", help="Which transformer to use ('bert' or 'roberta')") + parser.add_argument("--bert_model", type=str, default=None, help="Use a specific transformer instead of the default bert/roberta") + parser.add_argument("--loss_fn", type=str, default="weighted_bce", help="Which loss function to train with (e.g. 'ce' or 'weighted_bce')") + # Multi-model train args + parser.add_argument("--multi_train_type", type=str, default="lstm", help="Whether you are attempting to multi-train an LSTM or transformer") + parser.add_argument("--multi_train_count", type=int, default=5, help="Number of each model to build") + parser.add_argument("--base_path", type=str, default=None, help="Path to start generating model type for.") + parser.add_argument("--change_param", type=str, default=None, help="Which hyperparameter to change when training") + + + args = parser.parse_args() + + if args.multi_train_type == "lstm": + train_n_models(num_models=args.multi_train_count, + base_path=args.base_path, + args=args) + elif args.multi_train_type == "tfmr": + train_n_tfmrs(num_models=args.multi_train_count, + base_path=args.base_path, + args=args) + else: + raise ValueError(f"Improper input {args.multi_train_type}") + +if __name__ == "__main__": + main() diff --git a/stanza/models/lemma_classifier/train_transformer_model.py b/stanza/models/lemma_classifier/train_transformer_model.py index f9129e5738..2a36fb1731 100644 --- a/stanza/models/lemma_classifier/train_transformer_model.py +++ b/stanza/models/lemma_classifier/train_transformer_model.py @@ -76,7 +76,7 @@ def build_model(self, label_decoder, upos_to_id, known_words): return LemmaClassifierWithTransformer(model_args=self.model_args, output_dim=self.output_dim, transformer_name=self.transformer_name, label_decoder=label_decoder) -def main(args=None): +def main(args=None, predefined_args=None): parser = argparse.ArgumentParser() parser.add_argument("--save_name", type=str, default=os.path.join(os.path.dirname(os.path.dirname(__file__)), "saved_models", "big_model_roberta_weighted_loss.pt"), help="Path to model save file") @@ -89,7 +89,7 @@ def main(args=None): parser.add_argument("--eval_file", type=str, default=os.path.join(os.path.dirname(os.path.dirname(__file__)), "test_sets", "combined_dev.txt"), help="Path to dev file used to evaluate model for saves") parser.add_argument("--lr", type=float, default=0.001, help="Learning rate for the optimizer.") - args = parser.parse_args(args) + args = parser.parse_args(args) if predefined_args is None else predefined_args save_name = args.save_name num_epochs = args.num_epochs From db429ce0304752e58cd02b3948301a3517b0692d Mon Sep 17 00:00:00 2001 From: John Bauer Date: Wed, 18 Sep 2024 13:01:11 -0700 Subject: [PATCH 08/10] Add a flag to the classifier training scripts to force overwriting an existing model --- stanza/models/lemma_classifier/base_trainer.py | 2 +- stanza/models/lemma_classifier/train_lstm_model.py | 3 ++- stanza/models/lemma_classifier/train_transformer_model.py | 3 ++- stanza/utils/training/run_lemma_classifier.py | 6 +++++- 4 files changed, 10 insertions(+), 4 deletions(-) diff --git a/stanza/models/lemma_classifier/base_trainer.py b/stanza/models/lemma_classifier/base_trainer.py index 4c7d0f183e..1ff3b24a9b 100644 --- a/stanza/models/lemma_classifier/base_trainer.py +++ b/stanza/models/lemma_classifier/base_trainer.py @@ -67,7 +67,7 @@ def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str, self.model.to(device) logger.info(f"Training model on device: {device}. {next(self.model.parameters()).device}") - if os.path.exists(save_name): + if os.path.exists(save_name) and not args.get('force', False): raise FileExistsError(f"Save name {save_name} already exists; training would overwrite previous file contents. Aborting...") if self.weighted_loss: diff --git a/stanza/models/lemma_classifier/train_lstm_model.py b/stanza/models/lemma_classifier/train_lstm_model.py index a8437818c5..64266baac0 100644 --- a/stanza/models/lemma_classifier/train_lstm_model.py +++ b/stanza/models/lemma_classifier/train_lstm_model.py @@ -94,6 +94,7 @@ def build_argparse(): parser.add_argument("--train_file", type=str, default=os.path.join(os.path.dirname(__file__), "data", "processed_ud_en", "combined_train.txt"), help="Full path to training file") parser.add_argument("--weighted_loss", action='store_true', dest='weighted_loss', default=False, help="Whether to use weighted loss during training.") parser.add_argument("--eval_file", type=str, default=os.path.join(os.path.dirname(__file__), "data", "processed_ud_en", "combined_dev.txt"), help="Path to dev file used to evaluate model for saves") + parser.add_argument("--force", action='store_true', default=False, help='Whether or not to clobber an existing save file') return parser def main(args=None, predefined_args=None): @@ -116,7 +117,7 @@ def main(args=None, predefined_args=None): args = vars(args) - if os.path.exists(save_name): + if os.path.exists(save_name) and not args.get('force', False): raise FileExistsError(f"Save name {save_name} already exists. Training would override existing data. Aborting...") if not os.path.exists(train_file): raise FileNotFoundError(f"Training file {train_file} not found. Try again with a valid path.") diff --git a/stanza/models/lemma_classifier/train_transformer_model.py b/stanza/models/lemma_classifier/train_transformer_model.py index 2a36fb1731..77d8be6faa 100644 --- a/stanza/models/lemma_classifier/train_transformer_model.py +++ b/stanza/models/lemma_classifier/train_transformer_model.py @@ -88,6 +88,7 @@ def main(args=None, predefined_args=None): parser.add_argument("--batch_size", type=int, default=DEFAULT_BATCH_SIZE, help="Number of examples to include in each batch") parser.add_argument("--eval_file", type=str, default=os.path.join(os.path.dirname(os.path.dirname(__file__)), "test_sets", "combined_dev.txt"), help="Path to dev file used to evaluate model for saves") parser.add_argument("--lr", type=float, default=0.001, help="Learning rate for the optimizer.") + parser.add_argument("--force", action='store_true', default=False, help='Whether or not to clobber an existing save file') args = parser.parse_args(args) if predefined_args is None else predefined_args @@ -110,7 +111,7 @@ def main(args=None, predefined_args=None): else: raise ValueError("Unknown model type " + args['model_type']) - if os.path.exists(save_name): + if os.path.exists(save_name) and not args.get('force', False): raise FileExistsError(f"Save name {save_name} already exists. Training would override existing data. Aborting...") if not os.path.exists(train_file): raise FileNotFoundError(f"Training file {train_file} not found. Try again with a valid path.") diff --git a/stanza/utils/training/run_lemma_classifier.py b/stanza/utils/training/run_lemma_classifier.py index 5f68f07048..1a8420814a 100644 --- a/stanza/utils/training/run_lemma_classifier.py +++ b/stanza/utils/training/run_lemma_classifier.py @@ -39,6 +39,10 @@ def run_treebank(mode, paths, treebank, short_name, else: raise ValueError("--bert_model not specified, so cannot figure out which transformer to use for language %s" % short_language) + extra_train_args = [] + if command_args.force: + extra_train_args.append('--force') + if mode == Mode.TRAIN: train_args = [] if "--train_file" not in extra_args: @@ -47,7 +51,7 @@ def run_treebank(mode, paths, treebank, short_name, if "--eval_file" not in extra_args: eval_file = os.path.join("data", "lemma_classifier", "%s.dev.lemma" % short_name) train_args += ['--eval_file', eval_file] - train_args = base_args + train_args + extra_args + train_args = base_args + train_args + extra_args + extra_train_args if command_args.model_type == ModelType.LSTM: train_args = embedding_args + train_args From 05f865772f2a82eb4379496b8e2e9ccf9d54a388 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Wed, 18 Sep 2024 17:26:35 -0700 Subject: [PATCH 09/10] Keep the target words (words whcih we trained to recognize) in a set. Will be useful for integrating with the Pipeline --- stanza/models/lemma_classifier/base_model.py | 13 +++++++++++-- stanza/models/lemma_classifier/base_trainer.py | 3 ++- stanza/models/lemma_classifier/lstm_model.py | 6 ++++-- stanza/models/lemma_classifier/train_lstm_model.py | 4 ++-- .../lemma_classifier/train_transformer_model.py | 4 ++-- stanza/models/lemma_classifier/transformer_model.py | 6 ++++-- stanza/models/lemma_classifier/utils.py | 6 ++++++ 7 files changed, 31 insertions(+), 11 deletions(-) diff --git a/stanza/models/lemma_classifier/base_model.py b/stanza/models/lemma_classifier/base_model.py index fbdba5be1f..9bfc0fc9b8 100644 --- a/stanza/models/lemma_classifier/base_model.py +++ b/stanza/models/lemma_classifier/base_model.py @@ -19,10 +19,11 @@ logger = logging.getLogger('stanza.lemmaclassifier') class LemmaClassifier(ABC, nn.Module): - def __init__(self, label_decoder, *args, **kwargs): + def __init__(self, label_decoder, target_words, *args, **kwargs): super().__init__(*args, **kwargs) self.label_decoder = label_decoder + self.target_words = target_words self.unsaved_modules = [] def add_unsaved_module(self, name, module): @@ -49,6 +50,9 @@ def model_type(self): return a ModelType """ + def target_indices(self, sentence): + return [idx for idx, word in enumerate(sentence) if word.lower() in self.target_words] + @staticmethod def from_checkpoint(checkpoint, args=None): model_type = checkpoint['model_type'] @@ -81,6 +85,7 @@ def from_checkpoint(checkpoint, args=None): label_decoder=checkpoint['label_decoder'], upos_to_id=checkpoint['upos_to_id'], known_words=checkpoint['known_words'], + target_words=checkpoint['target_words'], use_charlm=use_charlm, charlm_forward_file=charlm_forward_file, charlm_backward_file=charlm_backward_file) @@ -90,7 +95,11 @@ def from_checkpoint(checkpoint, args=None): output_dim = len(checkpoint['label_decoder']) saved_args = checkpoint['args'] bert_model = saved_args['bert_model'] - model = LemmaClassifierWithTransformer(model_args = saved_args, output_dim=output_dim, transformer_name=bert_model, label_decoder=checkpoint['label_decoder']) + model = LemmaClassifierWithTransformer(model_args=saved_args, + output_dim=output_dim, + transformer_name=bert_model, + label_decoder=checkpoint['label_decoder'], + target_words=checkpoint['target_words']) else: raise ValueError("Unknown model type %s" % model_type) diff --git a/stanza/models/lemma_classifier/base_trainer.py b/stanza/models/lemma_classifier/base_trainer.py index 1ff3b24a9b..01bef3d76d 100644 --- a/stanza/models/lemma_classifier/base_trainer.py +++ b/stanza/models/lemma_classifier/base_trainer.py @@ -60,8 +60,9 @@ def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str, self.output_dim = len(label_decoder) logger.info(f"Loaded dataset successfully from {train_file}") logger.info(f"Using label decoder: {label_decoder} Output dimension: {self.output_dim}") + logger.info(f"Target words: {dataset.target_words}") - self.model = self.build_model(label_decoder, upos_to_id, dataset.known_words) + self.model = self.build_model(label_decoder, upos_to_id, dataset.known_words, dataset.target_words) self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) self.model.to(device) diff --git a/stanza/models/lemma_classifier/lstm_model.py b/stanza/models/lemma_classifier/lstm_model.py index 5cd20c63f8..c8cd829e8f 100644 --- a/stanza/models/lemma_classifier/lstm_model.py +++ b/stanza/models/lemma_classifier/lstm_model.py @@ -21,7 +21,7 @@ class LemmaClassifierLSTM(LemmaClassifier): From the LSTM output, we get the embedding of the specific token that we classify on. That embedding is fed into an MLP for classification. """ - def __init__(self, model_args, output_dim, pt_embedding, label_decoder, upos_to_id, known_words, + def __init__(self, model_args, output_dim, pt_embedding, label_decoder, upos_to_id, known_words, target_words, use_charlm=False, charlm_forward_file=None, charlm_backward_file=None): """ Args: @@ -30,6 +30,7 @@ def __init__(self, model_args, output_dim, pt_embedding, label_decoder, upos_to_ upos_to_id (Mapping[str, int]): A dictionary mapping UPOS tag strings to their respective IDs pt_embedding (Pretrain): pretrained embeddings known_words (list(str)): Words which are in the training data + target_words (set(str)): a set of the words which might need lemmatization use_charlm (bool): Whether or not to use the charlm embeddings charlm_forward_file (str): The path to the forward pass model for the character language model charlm_backward_file (str): The path to the forward pass model for the character language model. @@ -41,7 +42,7 @@ def __init__(self, model_args, output_dim, pt_embedding, label_decoder, upos_to_ Raises: FileNotFoundError: if the forward or backward charlm file cannot be found. """ - super(LemmaClassifierLSTM, self).__init__(label_decoder) + super(LemmaClassifierLSTM, self).__init__(label_decoder, target_words) self.model_args = model_args self.hidden_dim = model_args['hidden_dim'] @@ -113,6 +114,7 @@ def get_save_dict(self): "args": self.model_args, "upos_to_id": self.upos_to_id, "known_words": self.known_words, + "target_words": self.target_words, } skipped = [k for k in save_dict["params"].keys() if self.is_unsaved_module(k)] for k in skipped: diff --git a/stanza/models/lemma_classifier/train_lstm_model.py b/stanza/models/lemma_classifier/train_lstm_model.py index 64266baac0..53b57d840e 100644 --- a/stanza/models/lemma_classifier/train_lstm_model.py +++ b/stanza/models/lemma_classifier/train_lstm_model.py @@ -72,8 +72,8 @@ def __init__(self, model_args: dict, embedding_file: str, use_charlm: bool = Fal else: raise ValueError("Must enter a valid loss function (e.g. 'ce' or 'weighted_bce')") - def build_model(self, label_decoder, upos_to_id, known_words): - return LemmaClassifierLSTM(self.model_args, self.output_dim, self.pt_embedding, label_decoder, upos_to_id, known_words, + def build_model(self, label_decoder, upos_to_id, known_words, target_words): + return LemmaClassifierLSTM(self.model_args, self.output_dim, self.pt_embedding, label_decoder, upos_to_id, known_words, target_words, use_charlm=self.use_charlm, charlm_forward_file=self.charlm_forward_file, charlm_backward_file=self.charlm_backward_file) def build_argparse(): diff --git a/stanza/models/lemma_classifier/train_transformer_model.py b/stanza/models/lemma_classifier/train_transformer_model.py index 77d8be6faa..27115b5f41 100644 --- a/stanza/models/lemma_classifier/train_transformer_model.py +++ b/stanza/models/lemma_classifier/train_transformer_model.py @@ -72,8 +72,8 @@ def set_layer_learning_rates(self, transformer_lr: float, mlp_lr: float) -> torc ]) return optimizer - def build_model(self, label_decoder, upos_to_id, known_words): - return LemmaClassifierWithTransformer(model_args=self.model_args, output_dim=self.output_dim, transformer_name=self.transformer_name, label_decoder=label_decoder) + def build_model(self, label_decoder, upos_to_id, known_words, target_words): + return LemmaClassifierWithTransformer(model_args=self.model_args, output_dim=self.output_dim, transformer_name=self.transformer_name, label_decoder=label_decoder, target_words=target_words) def main(args=None, predefined_args=None): diff --git a/stanza/models/lemma_classifier/transformer_model.py b/stanza/models/lemma_classifier/transformer_model.py index 5f32151191..bb78162523 100644 --- a/stanza/models/lemma_classifier/transformer_model.py +++ b/stanza/models/lemma_classifier/transformer_model.py @@ -14,7 +14,7 @@ logger = logging.getLogger('stanza.lemmaclassifier') class LemmaClassifierWithTransformer(LemmaClassifier): - def __init__(self, model_args: dict, output_dim: int, transformer_name: str, label_decoder: Mapping): + def __init__(self, model_args: dict, output_dim: int, transformer_name: str, label_decoder: Mapping, target_words: set): """ Model architecture: @@ -27,8 +27,9 @@ def __init__(self, model_args: dict, output_dim: int, transformer_name: str, lab output_dim (int): Dimension of the output from the MLP transformer_name (str): name of the HF transformer to use label_decoder (dict): a map of the labels available to the model + target_words (set(str)): a set of the words which might need lemmatization """ - super(LemmaClassifierWithTransformer, self).__init__(label_decoder) + super(LemmaClassifierWithTransformer, self).__init__(label_decoder, target_words) self.model_args = model_args # Choose transformer @@ -50,6 +51,7 @@ def get_save_dict(self): save_dict = { "params": self.state_dict(), "label_decoder": self.label_decoder, + "target_words": self.target_words, "model_type": self.model_type(), "args": self.model_args, } diff --git a/stanza/models/lemma_classifier/utils.py b/stanza/models/lemma_classifier/utils.py index fa8ec3b216..36996dbf75 100644 --- a/stanza/models/lemma_classifier/utils.py +++ b/stanza/models/lemma_classifier/utils.py @@ -45,6 +45,10 @@ def __init__(self, data_path: str, batch_size: int =DEFAULT_BATCH_SIZE, get_coun logger.debug("Final label decoder: %s Should be strings to ints", label_decoder) + # words which we are analyzing + target_words = set() + + # all known words in the dataset, not just target words known_words = set() with open(data_path, "r+", encoding="utf-8") as f: @@ -78,6 +82,7 @@ def __init__(self, data_path: str, batch_size: int =DEFAULT_BATCH_SIZE, get_coun if get_counts: counts[label_decoder[label]] += 1 + target_words.add(words[target_idx]) known_words.update(words) self.sentences = sentences @@ -93,6 +98,7 @@ def __init__(self, data_path: str, batch_size: int =DEFAULT_BATCH_SIZE, get_coun self.shuffle = shuffle self.known_words = [x.lower() for x in sorted(known_words)] + self.target_words = set(x.lower() for x in target_words) def __len__(self): """ From b7f63a46f0bed91371fb1b485d0b20d60843751c Mon Sep 17 00:00:00 2001 From: John Bauer Date: Thu, 19 Sep 2024 14:46:21 -0700 Subject: [PATCH 10/10] Add a short script to attach a LemmaClassifier to a Lemmatizer trainer --- .../models/lemma/attach_lemma_classifier.py | 25 +++++++++++++++++++ stanza/models/lemma/trainer.py | 10 +++++++- 2 files changed, 34 insertions(+), 1 deletion(-) create mode 100644 stanza/models/lemma/attach_lemma_classifier.py diff --git a/stanza/models/lemma/attach_lemma_classifier.py b/stanza/models/lemma/attach_lemma_classifier.py new file mode 100644 index 0000000000..4f59782c89 --- /dev/null +++ b/stanza/models/lemma/attach_lemma_classifier.py @@ -0,0 +1,25 @@ +import argparse + +from stanza.models.lemma.trainer import Trainer +from stanza.models.lemma_classifier.base_model import LemmaClassifier + +def attach_classifier(input_filename, output_filename, classifiers): + trainer = Trainer(model_file=input_filename) + + for classifier in classifiers: + classifier = LemmaClassifier.load(classifier) + trainer.contextual_lemmatizers.append(classifier) + + trainer.save(output_filename) + +def main(args=None): + parser = argparse.ArgumentParser() + parser.add_argument('--input', type=str, required=True, help='Which lemmatizer to start from') + parser.add_argument('--output', type=str, required=True, help='Where to save the lemmatizer') + parser.add_argument('--classifier', type=str, required=True, nargs='+', help='Lemma classifier to attach') + args = parser.parse_args(args) + + attach_classifier(args.input, args.output, args.classifier) + +if __name__ == '__main__': + main() diff --git a/stanza/models/lemma/trainer.py b/stanza/models/lemma/trainer.py index 4b5e4a0b74..d7bf37daa6 100644 --- a/stanza/models/lemma/trainer.py +++ b/stanza/models/lemma/trainer.py @@ -18,6 +18,7 @@ from stanza.models.common import utils, loss from stanza.models.lemma import edit from stanza.models.lemma.vocab import MultiVocab +from stanza.models.lemma_classifier.base_model import LemmaClassifier logger = logging.getLogger('stanza') @@ -45,6 +46,7 @@ def __init__(self, args=None, vocab=None, emb_matrix=None, model_file=None, devi # dict-based components self.word_dict = dict() self.composite_dict = dict() + self.contextual_lemmatizers = [] self.caseless = self.args.get('caseless', False) @@ -228,8 +230,11 @@ def save(self, filename, skip_modules=True): 'model': model_state, 'dicts': (self.word_dict, self.composite_dict), 'vocab': self.vocab.state_dict(), - 'config': self.args + 'config': self.args, + 'contextual': [], } + for contextual in self.contextual_lemmatizers: + params['contextual'].append(contextual.get_save_dict()) os.makedirs(os.path.split(filename)[0], exist_ok=True) torch.save(params, filename, _use_new_zipfile_serialization=False) logger.info("Model saved to {}".format(filename)) @@ -253,3 +258,6 @@ def load(self, filename, args, foundation_cache): else: self.model = None self.vocab = MultiVocab.load_state_dict(checkpoint['vocab']) + self.contextual_lemmatizers = [] + for contextual in checkpoint.get('contextual', []): + self.contextual_lemmatizers.append(LemmaClassifier.from_checkpoint(contextual))