diff --git a/stanza/models/tokenization/data.py b/stanza/models/tokenization/data.py index 3ff919b0ba..3c4746a6d8 100644 --- a/stanza/models/tokenization/data.py +++ b/stanza/models/tokenization/data.py @@ -1,5 +1,5 @@ from bisect import bisect_right -from copy import copy +from copy import copy, deepcopy import numpy as np import random import logging @@ -355,15 +355,20 @@ def strings_starting(id_pair, offset=0, pad_len=self.args['max_seqlen']): features[i, :len(f_), :] = f_ raw_units.append(r_ + [''] * (pad_len - len(r_))) + # so we always return text that's not been ed, but will return + # IDs with s in them. REVIEW: check if the raw text is used + # anywhere else such that the lack of UNKs will cause a problem + dropped_units = deepcopy(raw_units) + if unit_dropout > 0 and not self.eval: # dropout characters/units at training time and replace them with UNKs mask = np.random.random_sample(units.shape) < unit_dropout mask[units == padid] = 0 units[mask] = unkid - for i in range(len(raw_units)): - for j in range(len(raw_units[i])): + for i in range(len(dropped_units)): + for j in range(len(dropped_units[i])): if mask[i, j]: - raw_units[i][j] = '' + dropped_units[i][j] = '' # dropout unit feature vector in addition to only torch.dropout in the model. # experiments showed that only torch.dropout hurts the model @@ -372,8 +377,8 @@ def strings_starting(id_pair, offset=0, pad_len=self.args['max_seqlen']): if self.args['use_dictionary'] and feat_unit_dropout > 0 and not self.eval: mask_feat = np.random.random_sample(units.shape) < feat_unit_dropout mask_feat[units == padid] = 0 - for i in range(len(raw_units)): - for j in range(len(raw_units[i])): + for i in range(len(dropped_units)): + for j in range(len(dropped_units[i])): if mask_feat[i,j]: features[i,j,:] = 0 diff --git a/stanza/models/tokenization/model.py b/stanza/models/tokenization/model.py index 1f60987126..8ce08775db 100644 --- a/stanza/models/tokenization/model.py +++ b/stanza/models/tokenization/model.py @@ -1,12 +1,63 @@ import torch import torch.nn.functional as F import torch.nn as nn +from itertools import tee + +from stanza.models.common.seq2seq_constant import PAD, UNK, UNK_ID + +class SentenceAnalyzer(nn.Module): + def __init__(self, args, pretrain, hidden_dim, device=None, dropout=0): + super().__init__() + + assert pretrain != None, "2nd pass sentence anayzer is missing pretrain word vectors" + + self.args = args + self.vocab = pretrain.vocab + self.embeddings = nn.Embedding.from_pretrained( + torch.from_numpy(pretrain.emb), freeze=True) + + self.emb_proj = nn.Linear(pretrain.emb.shape[1], hidden_dim) + self.lstm = nn.LSTM(hidden_dim*3, hidden_dim, bidirectional=True, + batch_first=True, num_layers=args['rnn_layers']) + + self.dropout = nn.Dropout(dropout) + + self.hidden = hidden_dim + + # this is zero-initialized to make the second pass initially the id + # function; and then it could change only as needed but would otherwise + # be zero + self.final_proj = nn.Parameter(torch.zeros(hidden_dim*2, 1), requires_grad=True) + + @property + def device(self): + return next(self.parameters()).device + + def forward(self, words, tok_embeds, word_tok_mapping, padding_mask): + # map the vocab to pretrain IDs + token_ids = [[self.vocab[j.strip()] for j in i] for i in words] + embs = self.embeddings(torch.tensor(token_ids, device=self.device)) + net = self.emb_proj(embs) + # we want to now concatenate token embeddings with the word embeddings + final_inp = torch.zeros(tok_embeds.size(0), tok_embeds.size(1), + self.hidden*3).to(tok_embeds.device) + final_inp[:,:,:tok_embeds.size(2)] = tok_embeds + # because we want to set the values for that's relavent to the word token embedding + # to True, but everything else to False (including the slots for tok_embs) + final_inp_second_idx = word_tok_mapping.unsqueeze(-1).repeat(1,1,self.hidden*3) + final_inp_second_idx[:,:,:tok_embeds.size(2)] = False + final_inp[final_inp_second_idx] = net[padding_mask].view(-1) + + net = self.lstm(self.dropout(final_inp))[0] + return net @ self.final_proj + class Tokenizer(nn.Module): - def __init__(self, args, nchars, emb_dim, hidden_dim, dropout, feat_dropout): + def __init__(self, args, nchars, emb_dim, hidden_dim, dropout, feat_dropout, pretrain=None): super().__init__() self.args = args + self.pretrain = pretrain feat_dim = args['feat_dim'] self.embeddings = nn.Embedding(nchars, emb_dim, padding_idx=0) @@ -36,12 +87,18 @@ def __init__(self, args, nchars, emb_dim, hidden_dim, dropout, feat_dropout): if self.args['use_mwt']: self.mwt_clf2 = nn.Linear(hidden_dim * 2, 1, bias=False) + if args['sentence_second_pass']: + self.sent_2nd_pass_clf = SentenceAnalyzer(args, pretrain, hidden_dim, dropout) + # initially, don't use 2nd pass that much (this is near 0, meaning it will pretty much + # not be mixed in + self.sent_2nd_mix = nn.Parameter(torch.full((1,), -5.0), requires_grad=True) + self.dropout = nn.Dropout(dropout) self.dropout_feat = nn.Dropout(feat_dropout) self.toknoise = nn.Dropout(self.args['tok_noise']) - def forward(self, x, feats): + def forward(self, x, feats, text, detach_2nd_pass=False): emb = self.embeddings(x) emb = self.dropout(emb) feats = self.dropout_feat(feats) @@ -87,12 +144,80 @@ def forward(self, x, feats): nontok = F.logsigmoid(-tok0) tok = F.logsigmoid(tok0) - nonsent = F.logsigmoid(-sent0) - sent = F.logsigmoid(sent0) if self.args['use_mwt']: nonmwt = F.logsigmoid(-mwt0) mwt = F.logsigmoid(mwt0) + nonsent = F.logsigmoid(-sent0) + sent = F.logsigmoid(sent0) + + # use the rough predictions from the char tokenizer to create word tokens + # then use those word tokens + contextual/fixed word embeddings to refine + # sentence predictions + + if self.args["sentence_second_pass"]: + # these are the draft predictions for only token-level decisinos + # which we can use to slice the text + if self.args['use_mwt']: + draft_pred_locs = torch.cat([nontok, tok+nonsent+nonmwt, tok+sent+nonmwt, tok+nonsent+mwt, tok+sent+mwt], 2).argmax(dim=2) + else: + draft_pred_locs = torch.cat([nontok, tok+nonsent, tok+sent], 2).argmax(dim=2) + + draft_pred_locs = (draft_pred_locs > 0) + # these boolean indicies are *inclusive*, so predict it or not + # we need to split on the last token if we want to keep the + # final word + draft_pred_locs[:,-1] = True + + # both: batch x [variable: text token count] + extracted_tokens = [] + partial = [] + last = 0 + last_batch = -1 + + nonzero = draft_pred_locs.nonzero().cpu().tolist() + for i,j in nonzero: + if i != last_batch: + last_batch = i + last = 0 + if i != 0: + extracted_tokens.append(partial) + partial = [] + + substring = text[i][last:j+1] + last = j+1 + + partial.append("".join(substring)) + extracted_tokens.append(partial) + + # dynamically pad the batch tokens to size + # why to at least a fix size? it must be wider + # than our kernel + max_size = max(max([len(i) for i in extracted_tokens]), + self.args["sentence_analyzer_kernel"]) + batch_tokens_padded = [] + batch_tokens_isntpad = [] + for i in extracted_tokens: + batch_tokens_padded.append(i + [PAD for _ in range(max_size-len(i))]) + batch_tokens_isntpad.append([True for _ in range(len(i))] + + [False for _ in range(max_size-len(i))]) + pad_mask = torch.tensor(batch_tokens_isntpad) + + + # pass the aligned result to the second pass classifier + second_pass_scores = self.sent_2nd_pass_clf(batch_tokens_padded, inp, draft_pred_locs, pad_mask) + + mix = F.sigmoid(self.sent_2nd_mix) + + # update sent0 value + if detach_2nd_pass: + sent0 = (1-mix.detach())*sent0 + mix.detach()*second_pass_scores.detach() + else: + sent0 = (1-mix)*sent0 + mix*second_pass_scores + + nonsent = F.logsigmoid(-sent0) + sent = F.logsigmoid(sent0) + if self.args['use_mwt']: pred = torch.cat([nontok, tok+nonsent+nonmwt, tok+sent+nonmwt, tok+nonsent+mwt, tok+sent+mwt], 2) else: diff --git a/stanza/models/tokenization/trainer.py b/stanza/models/tokenization/trainer.py index 254f419f92..6d892a5a5c 100644 --- a/stanza/models/tokenization/trainer.py +++ b/stanza/models/tokenization/trainer.py @@ -14,7 +14,8 @@ logger = logging.getLogger('stanza') class Trainer(BaseTrainer): - def __init__(self, args=None, vocab=None, lexicon=None, dictionary=None, model_file=None, device=None): + def __init__(self, args=None, vocab=None, lexicon=None, dictionary=None, model_file=None, device=None, pretrain=None): + self.pretrain = pretrain if model_file is not None: # load everything from file self.load(model_file) @@ -24,23 +25,36 @@ def __init__(self, args=None, vocab=None, lexicon=None, dictionary=None, model_f self.vocab = vocab self.lexicon = lexicon self.dictionary = dictionary - self.model = Tokenizer(self.args, self.args['vocab_size'], self.args['emb_dim'], self.args['hidden_dim'], dropout=self.args['dropout'], feat_dropout=self.args['feat_dropout']) + self.model = Tokenizer(self.args, self.args['vocab_size'], self.args['emb_dim'], self.args['hidden_dim'], dropout=self.args['dropout'], feat_dropout=self.args['feat_dropout'], pretrain=pretrain) + + if self.args["sentence_second_pass"]: + assert bool(pretrain), "context-aware sentence analysis requires pretrained wordvectors; download them!" + self.model = self.model.to(device) self.criterion = nn.CrossEntropyLoss(ignore_index=-1).to(device) self.optimizer = utils.get_optimizer("adam", self.model, lr=self.args['lr0'], betas=(.9, .9), weight_decay=self.args['weight_decay']) self.feat_funcs = self.args.get('feat_funcs', None) self.lang = self.args['lang'] # language determines how token normalization is done + self.pretrain = pretrain + self.global_step_counter_ = 0 + self.train_2nd_pass = False + + @property + def steps(self): + return self.global_step_counter_ def update(self, inputs): + self.global_step_counter_ += 1 self.model.train() - units, labels, features, _ = inputs + units, labels, features, text = inputs device = next(self.model.parameters()).device units = units.to(device) labels = labels.to(device) features = features.to(device) - pred = self.model(units, features) + # we detach 2nd pass if we are not training second pass + pred = self.model(units, features, text, not self.train_2nd_pass) self.optimizer.zero_grad() classes = pred.size(2) @@ -54,13 +68,13 @@ def update(self, inputs): def predict(self, inputs): self.model.eval() - units, _, features, _ = inputs + units, _, features, text = inputs device = next(self.model.parameters()).device units = units.to(device) features = features.to(device) - pred = self.model(units, features) + pred = self.model(units, features, text) return pred.data.cpu().numpy() @@ -69,7 +83,8 @@ def save(self, filename): 'model': self.model.state_dict() if self.model is not None else None, 'vocab': self.vocab.state_dict(), 'lexicon': self.lexicon, - 'config': self.args + 'config': self.args, + 'steps': self.global_step_counter_ } try: torch.save(params, filename, _use_new_zipfile_serialization=False) @@ -88,11 +103,13 @@ def load(self, filename): # Default to True as many currently saved models # were built with mwt layers self.args['use_mwt'] = True - self.model = Tokenizer(self.args, self.args['vocab_size'], self.args['emb_dim'], self.args['hidden_dim'], dropout=self.args['dropout'], feat_dropout=self.args['feat_dropout']) + self.model = Tokenizer(self.args, self.args['vocab_size'], self.args['emb_dim'], self.args['hidden_dim'], dropout=self.args['dropout'], feat_dropout=self.args['feat_dropout'], pretrain=self.pretrain) self.model.load_state_dict(checkpoint['model']) self.vocab = Vocab.load_state_dict(checkpoint['vocab']) self.lexicon = checkpoint['lexicon'] + self.global_step_counter_ = checkpoint.get("steps", 0) + if self.lexicon is not None: self.dictionary = create_dictionary(self.lexicon) else: diff --git a/stanza/models/tokenizer.py b/stanza/models/tokenizer.py index 8e78798b2c..badf718439 100644 --- a/stanza/models/tokenizer.py +++ b/stanza/models/tokenizer.py @@ -27,6 +27,8 @@ from stanza.models.tokenization.trainer import Trainer from stanza.models.tokenization.data import DataLoader, TokenizationDataset from stanza.models.tokenization.utils import load_mwt_dict, eval_model, output_predictions, load_lexicon, create_dictionary +from stanza.models.common import pretrain + from stanza.models import _training_logging logger = logging.getLogger('stanza') @@ -46,6 +48,13 @@ def build_argparse(): parser.add_argument('--lang', type=str, help="Language") parser.add_argument('--shorthand', type=str, help="UD treebank shorthand") + parser.add_argument('--wordvec_dir', type=str, default='extern_data/wordvec', help='Directory of word vectors.') + parser.add_argument('--wordvec_file', type=str, default=None, help='Word vectors filename.') + parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read') + parser.add_argument('--pretrain_max_vocab', type=int, default=250000) + + parser.add_argument('--sentence_analyzer_kernel', type=int, default=4) + parser.add_argument('--mode', default='train', choices=['train', 'predict']) parser.add_argument('--skip_newline', action='store_true', help="Whether to skip newline characters in input. Particularly useful for languages like Chinese.") @@ -54,6 +63,8 @@ def build_argparse(): parser.add_argument('--conv_filters', type=str, default="1,9", help="Configuration of conv filters. ,, separates layers and , separates filter sizes in the same layer.") parser.add_argument('--no-residual', dest='residual', action='store_false', help="Add linear residual connections") parser.add_argument('--no-hierarchical', dest='hierarchical', action='store_false', help="\"Hierarchical\" RNN tokenizer") + parser.add_argument('--no_sentence_second_pass', dest='sentence_second_pass', action='store_false', help="predict the sentences together with tokens instead of after") + parser.add_argument('--second_pass_start_steps', type=int, help="when (how many steps) to start training the second pass classifier", default=5000) parser.add_argument('--hier_invtemp', type=float, default=0.5, help="Inverse temperature used in propagating tokenization predictions between RNN layers") parser.add_argument('--input_dropout', action='store_true', help="Dropout input embeddings as well") parser.add_argument('--conv_res', type=str, default=None, help="Convolutional residual layers for the RNN") @@ -113,6 +124,17 @@ def model_file_name(args): return save_name return os.path.join(args['save_dir'], save_name) +def load_pretrain(args): + pt = None + if args['sentence_second_pass']: + pretrain_file = pretrain.find_pretrain_file(args['wordvec_pretrain_file'], args['save_dir'], args['shorthand'], args['lang']) + if os.path.exists(pretrain_file): + vec_file = None + else: + vec_file = args['wordvec_file'] if args['wordvec_file'] else utils.get_wordvec_file(args['wordvec_dir'], args['shorthand']) + pt = pretrain.Pretrain(pretrain_file, vec_file, args['pretrain_max_vocab']) + return pt + def main(args=None): args = parse_args(args=args) @@ -164,7 +186,9 @@ def train(args): args['use_mwt'] = train_batches.has_mwt() logger.info("Found {}mwts in the training data. Setting use_mwt to {}".format(("" if args['use_mwt'] else "no "), args['use_mwt'])) - trainer = Trainer(args=args, vocab=vocab, lexicon=lexicon, dictionary=dictionary, device=args['device']) + # load pretrained vectors if needed + pretrain = load_pretrain(args) + trainer = Trainer(args=args, vocab=vocab, lexicon=lexicon, dictionary=dictionary, device=args['device'], pretrain=pretrain) if args['load_name'] is not None: load_name = os.path.join(args['save_dir'], args['load_name']) @@ -190,7 +214,11 @@ def train(args): for step in range(1, steps+1): batch = train_batches.next(unit_dropout=args['unit_dropout'], feat_unit_dropout = args['feat_unit_dropout']) + if trainer.steps > args["second_pass_start_steps"]: + trainer.train_2nd_pass = True + loss = trainer.update(batch) + if step % args['report_steps'] == 0: logger.info("Step {:6d}/{:6d} Loss: {:.3f}".format(step, steps, loss)) if args['wandb']: @@ -234,7 +262,8 @@ def train(args): def evaluate(args): mwt_dict = load_mwt_dict(args['mwt_json_file']) - trainer = Trainer(model_file=args['load_name'] or args['save_name'], device=args['device']) + pretrain = load_pretrain(args) + trainer = Trainer(model_file=args['load_name'] or args['save_name'], device=args['device'], pretrain=pretrain) loaded_args, vocab = trainer.args, trainer.vocab for k in loaded_args: diff --git a/stanza/pipeline/tokenize_processor.py b/stanza/pipeline/tokenize_processor.py index f2fc242db2..92ae8f3ad9 100644 --- a/stanza/pipeline/tokenize_processor.py +++ b/stanza/pipeline/tokenize_processor.py @@ -37,11 +37,14 @@ class TokenizeProcessor(UDProcessor): MAX_SEQ_LENGTH_DEFAULT = 1000 def _set_up_model(self, config, pipeline, device): + # get pretrained word vectors + self._pretrain = pipeline.foundation_cache.load_pretrain(config['pretrain_path']) if 'pretrain_path' in config else None + # set up trainer if config.get('pretokenized'): self._trainer = None else: - self._trainer = Trainer(model_file=config['model_path'], device=device) + self._trainer = Trainer(model_file=config['model_path'], device=device, pretrain=self.pretrain) # get and typecheck the postprocessor postprocessor = config.get('postprocessor')