From 3064b33f27b1fbb069337ed655be174fccc491c6 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Wed, 18 Sep 2024 17:26:35 -0700 Subject: [PATCH] Keep the target words (words whcih we trained to recognize) in a set. Will be useful for integrating with the Pipeline --- stanza/models/lemma_classifier/base_model.py | 13 +++++++++++-- stanza/models/lemma_classifier/base_trainer.py | 3 ++- stanza/models/lemma_classifier/lstm_model.py | 6 ++++-- stanza/models/lemma_classifier/train_lstm_model.py | 4 ++-- .../lemma_classifier/train_transformer_model.py | 4 ++-- stanza/models/lemma_classifier/transformer_model.py | 6 ++++-- stanza/models/lemma_classifier/utils.py | 6 ++++++ 7 files changed, 31 insertions(+), 11 deletions(-) diff --git a/stanza/models/lemma_classifier/base_model.py b/stanza/models/lemma_classifier/base_model.py index fbdba5be1..9bfc0fc9b 100644 --- a/stanza/models/lemma_classifier/base_model.py +++ b/stanza/models/lemma_classifier/base_model.py @@ -19,10 +19,11 @@ logger = logging.getLogger('stanza.lemmaclassifier') class LemmaClassifier(ABC, nn.Module): - def __init__(self, label_decoder, *args, **kwargs): + def __init__(self, label_decoder, target_words, *args, **kwargs): super().__init__(*args, **kwargs) self.label_decoder = label_decoder + self.target_words = target_words self.unsaved_modules = [] def add_unsaved_module(self, name, module): @@ -49,6 +50,9 @@ def model_type(self): return a ModelType """ + def target_indices(self, sentence): + return [idx for idx, word in enumerate(sentence) if word.lower() in self.target_words] + @staticmethod def from_checkpoint(checkpoint, args=None): model_type = checkpoint['model_type'] @@ -81,6 +85,7 @@ def from_checkpoint(checkpoint, args=None): label_decoder=checkpoint['label_decoder'], upos_to_id=checkpoint['upos_to_id'], known_words=checkpoint['known_words'], + target_words=checkpoint['target_words'], use_charlm=use_charlm, charlm_forward_file=charlm_forward_file, charlm_backward_file=charlm_backward_file) @@ -90,7 +95,11 @@ def from_checkpoint(checkpoint, args=None): output_dim = len(checkpoint['label_decoder']) saved_args = checkpoint['args'] bert_model = saved_args['bert_model'] - model = LemmaClassifierWithTransformer(model_args = saved_args, output_dim=output_dim, transformer_name=bert_model, label_decoder=checkpoint['label_decoder']) + model = LemmaClassifierWithTransformer(model_args=saved_args, + output_dim=output_dim, + transformer_name=bert_model, + label_decoder=checkpoint['label_decoder'], + target_words=checkpoint['target_words']) else: raise ValueError("Unknown model type %s" % model_type) diff --git a/stanza/models/lemma_classifier/base_trainer.py b/stanza/models/lemma_classifier/base_trainer.py index 1ff3b24a9..01bef3d76 100644 --- a/stanza/models/lemma_classifier/base_trainer.py +++ b/stanza/models/lemma_classifier/base_trainer.py @@ -60,8 +60,9 @@ def train(self, num_epochs: int, save_name: str, args: Mapping, eval_file: str, self.output_dim = len(label_decoder) logger.info(f"Loaded dataset successfully from {train_file}") logger.info(f"Using label decoder: {label_decoder} Output dimension: {self.output_dim}") + logger.info(f"Target words: {dataset.target_words}") - self.model = self.build_model(label_decoder, upos_to_id, dataset.known_words) + self.model = self.build_model(label_decoder, upos_to_id, dataset.known_words, dataset.target_words) self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) self.model.to(device) diff --git a/stanza/models/lemma_classifier/lstm_model.py b/stanza/models/lemma_classifier/lstm_model.py index e87548bd7..a22120109 100644 --- a/stanza/models/lemma_classifier/lstm_model.py +++ b/stanza/models/lemma_classifier/lstm_model.py @@ -21,7 +21,7 @@ class LemmaClassifierLSTM(LemmaClassifier): From the LSTM output, we get the embedding of the specific token that we classify on. That embedding is fed into an MLP for classification. """ - def __init__(self, model_args, output_dim, pt_embedding, label_decoder, upos_to_id, known_words, + def __init__(self, model_args, output_dim, pt_embedding, label_decoder, upos_to_id, known_words, target_words, use_charlm=False, charlm_forward_file=None, charlm_backward_file=None): """ Args: @@ -30,6 +30,7 @@ def __init__(self, model_args, output_dim, pt_embedding, label_decoder, upos_to_ upos_to_id (Mapping[str, int]): A dictionary mapping UPOS tag strings to their respective IDs pt_embedding (Pretrain): pretrained embeddings known_words (list(str)): Words which are in the training data + target_words (set(str)): a set of the words which might need lemmatization use_charlm (bool): Whether or not to use the charlm embeddings charlm_forward_file (str): The path to the forward pass model for the character language model charlm_backward_file (str): The path to the forward pass model for the character language model. @@ -41,7 +42,7 @@ def __init__(self, model_args, output_dim, pt_embedding, label_decoder, upos_to_ Raises: FileNotFoundError: if the forward or backward charlm file cannot be found. """ - super(LemmaClassifierLSTM, self).__init__(label_decoder) + super(LemmaClassifierLSTM, self).__init__(label_decoder, target_words) self.model_args = model_args self.hidden_dim = model_args['hidden_dim'] @@ -113,6 +114,7 @@ def get_save_dict(self): "args": self.model_args, "upos_to_id": self.upos_to_id, "known_words": self.known_words, + "target_words": self.target_words, } skipped = [k for k in save_dict["params"].keys() if self.is_unsaved_module(k)] for k in skipped: diff --git a/stanza/models/lemma_classifier/train_lstm_model.py b/stanza/models/lemma_classifier/train_lstm_model.py index 64266baac..53b57d840 100644 --- a/stanza/models/lemma_classifier/train_lstm_model.py +++ b/stanza/models/lemma_classifier/train_lstm_model.py @@ -72,8 +72,8 @@ def __init__(self, model_args: dict, embedding_file: str, use_charlm: bool = Fal else: raise ValueError("Must enter a valid loss function (e.g. 'ce' or 'weighted_bce')") - def build_model(self, label_decoder, upos_to_id, known_words): - return LemmaClassifierLSTM(self.model_args, self.output_dim, self.pt_embedding, label_decoder, upos_to_id, known_words, + def build_model(self, label_decoder, upos_to_id, known_words, target_words): + return LemmaClassifierLSTM(self.model_args, self.output_dim, self.pt_embedding, label_decoder, upos_to_id, known_words, target_words, use_charlm=self.use_charlm, charlm_forward_file=self.charlm_forward_file, charlm_backward_file=self.charlm_backward_file) def build_argparse(): diff --git a/stanza/models/lemma_classifier/train_transformer_model.py b/stanza/models/lemma_classifier/train_transformer_model.py index 77d8be6fa..27115b5f4 100644 --- a/stanza/models/lemma_classifier/train_transformer_model.py +++ b/stanza/models/lemma_classifier/train_transformer_model.py @@ -72,8 +72,8 @@ def set_layer_learning_rates(self, transformer_lr: float, mlp_lr: float) -> torc ]) return optimizer - def build_model(self, label_decoder, upos_to_id, known_words): - return LemmaClassifierWithTransformer(model_args=self.model_args, output_dim=self.output_dim, transformer_name=self.transformer_name, label_decoder=label_decoder) + def build_model(self, label_decoder, upos_to_id, known_words, target_words): + return LemmaClassifierWithTransformer(model_args=self.model_args, output_dim=self.output_dim, transformer_name=self.transformer_name, label_decoder=label_decoder, target_words=target_words) def main(args=None, predefined_args=None): diff --git a/stanza/models/lemma_classifier/transformer_model.py b/stanza/models/lemma_classifier/transformer_model.py index 5f3215119..bb7816252 100644 --- a/stanza/models/lemma_classifier/transformer_model.py +++ b/stanza/models/lemma_classifier/transformer_model.py @@ -14,7 +14,7 @@ logger = logging.getLogger('stanza.lemmaclassifier') class LemmaClassifierWithTransformer(LemmaClassifier): - def __init__(self, model_args: dict, output_dim: int, transformer_name: str, label_decoder: Mapping): + def __init__(self, model_args: dict, output_dim: int, transformer_name: str, label_decoder: Mapping, target_words: set): """ Model architecture: @@ -27,8 +27,9 @@ def __init__(self, model_args: dict, output_dim: int, transformer_name: str, lab output_dim (int): Dimension of the output from the MLP transformer_name (str): name of the HF transformer to use label_decoder (dict): a map of the labels available to the model + target_words (set(str)): a set of the words which might need lemmatization """ - super(LemmaClassifierWithTransformer, self).__init__(label_decoder) + super(LemmaClassifierWithTransformer, self).__init__(label_decoder, target_words) self.model_args = model_args # Choose transformer @@ -50,6 +51,7 @@ def get_save_dict(self): save_dict = { "params": self.state_dict(), "label_decoder": self.label_decoder, + "target_words": self.target_words, "model_type": self.model_type(), "args": self.model_args, } diff --git a/stanza/models/lemma_classifier/utils.py b/stanza/models/lemma_classifier/utils.py index fa8ec3b21..36996dbf7 100644 --- a/stanza/models/lemma_classifier/utils.py +++ b/stanza/models/lemma_classifier/utils.py @@ -45,6 +45,10 @@ def __init__(self, data_path: str, batch_size: int =DEFAULT_BATCH_SIZE, get_coun logger.debug("Final label decoder: %s Should be strings to ints", label_decoder) + # words which we are analyzing + target_words = set() + + # all known words in the dataset, not just target words known_words = set() with open(data_path, "r+", encoding="utf-8") as f: @@ -78,6 +82,7 @@ def __init__(self, data_path: str, batch_size: int =DEFAULT_BATCH_SIZE, get_coun if get_counts: counts[label_decoder[label]] += 1 + target_words.add(words[target_idx]) known_words.update(words) self.sentences = sentences @@ -93,6 +98,7 @@ def __init__(self, data_path: str, batch_size: int =DEFAULT_BATCH_SIZE, get_coun self.shuffle = shuffle self.known_words = [x.lower() for x in sorted(known_words)] + self.target_words = set(x.lower() for x in target_words) def __len__(self): """