Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Contextual tokenizer #1415

Open
wants to merge 20 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions stanza/models/tokenization/data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from bisect import bisect_right
from copy import copy
from copy import copy, deepcopy
import numpy as np
import random
import logging
Expand Down Expand Up @@ -355,15 +355,20 @@ def strings_starting(id_pair, offset=0, pad_len=self.args['max_seqlen']):
features[i, :len(f_), :] = f_
raw_units.append(r_ + ['<PAD>'] * (pad_len - len(r_)))

# so we always return text that's not been <UNK>ed, but will return
# IDs with <UNK>s in them. REVIEW: check if the raw text is used
# anywhere else such that the lack of UNKs will cause a problem
dropped_units = deepcopy(raw_units)

if unit_dropout > 0 and not self.eval:
# dropout characters/units at training time and replace them with UNKs
mask = np.random.random_sample(units.shape) < unit_dropout
mask[units == padid] = 0
units[mask] = unkid
for i in range(len(raw_units)):
for j in range(len(raw_units[i])):
for i in range(len(dropped_units)):
for j in range(len(dropped_units[i])):
if mask[i, j]:
raw_units[i][j] = '<UNK>'
dropped_units[i][j] = '<UNK>'

# dropout unit feature vector in addition to only torch.dropout in the model.
# experiments showed that only torch.dropout hurts the model
Expand All @@ -372,8 +377,8 @@ def strings_starting(id_pair, offset=0, pad_len=self.args['max_seqlen']):
if self.args['use_dictionary'] and feat_unit_dropout > 0 and not self.eval:
mask_feat = np.random.random_sample(units.shape) < feat_unit_dropout
mask_feat[units == padid] = 0
for i in range(len(raw_units)):
for j in range(len(raw_units[i])):
for i in range(len(dropped_units)):
for j in range(len(dropped_units[i])):
if mask_feat[i,j]:
features[i,j,:] = 0

Expand Down
133 changes: 129 additions & 4 deletions stanza/models/tokenization/model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,63 @@
import torch
import torch.nn.functional as F
import torch.nn as nn
from itertools import tee

from stanza.models.common.seq2seq_constant import PAD, UNK, UNK_ID

class SentenceAnalyzer(nn.Module):
def __init__(self, args, pretrain, hidden_dim, device=None, dropout=0):
super().__init__()

assert pretrain != None, "2nd pass sentence anayzer is missing pretrain word vectors"

self.args = args
self.vocab = pretrain.vocab
self.embeddings = nn.Embedding.from_pretrained(
torch.from_numpy(pretrain.emb), freeze=True)

self.emb_proj = nn.Linear(pretrain.emb.shape[1], hidden_dim)
self.lstm = nn.LSTM(hidden_dim*3, hidden_dim, bidirectional=True,
batch_first=True, num_layers=args['rnn_layers'])

self.dropout = nn.Dropout(dropout)

self.hidden = hidden_dim

# this is zero-initialized to make the second pass initially the id
# function; and then it could change only as needed but would otherwise
# be zero
self.final_proj = nn.Parameter(torch.zeros(hidden_dim*2, 1), requires_grad=True)

@property
def device(self):
return next(self.parameters()).device

def forward(self, words, tok_embeds, word_tok_mapping, padding_mask):
# map the vocab to pretrain IDs
token_ids = [[self.vocab[j.strip()] for j in i] for i in words]
embs = self.embeddings(torch.tensor(token_ids, device=self.device))
net = self.emb_proj(embs)
# we want to now concatenate token embeddings with the word embeddings
final_inp = torch.zeros(tok_embeds.size(0), tok_embeds.size(1),
self.hidden*3).to(tok_embeds.device)
final_inp[:,:,:tok_embeds.size(2)] = tok_embeds
# because we want to set the values for that's relavent to the word token embedding
# to True, but everything else to False (including the slots for tok_embs)
final_inp_second_idx = word_tok_mapping.unsqueeze(-1).repeat(1,1,self.hidden*3)
final_inp_second_idx[:,:,:tok_embeds.size(2)] = False
final_inp[final_inp_second_idx] = net[padding_mask].view(-1)

net = self.lstm(self.dropout(final_inp))[0]
return net @ self.final_proj


class Tokenizer(nn.Module):
def __init__(self, args, nchars, emb_dim, hidden_dim, dropout, feat_dropout):
def __init__(self, args, nchars, emb_dim, hidden_dim, dropout, feat_dropout, pretrain=None):
super().__init__()

self.args = args
self.pretrain = pretrain
feat_dim = args['feat_dim']

self.embeddings = nn.Embedding(nchars, emb_dim, padding_idx=0)
Expand Down Expand Up @@ -36,12 +87,18 @@ def __init__(self, args, nchars, emb_dim, hidden_dim, dropout, feat_dropout):
if self.args['use_mwt']:
self.mwt_clf2 = nn.Linear(hidden_dim * 2, 1, bias=False)

if args['sentence_second_pass']:
self.sent_2nd_pass_clf = SentenceAnalyzer(args, pretrain, hidden_dim, dropout)
# initially, don't use 2nd pass that much (this is near 0, meaning it will pretty much
# not be mixed in
self.sent_2nd_mix = nn.Parameter(torch.full((1,), -5.0), requires_grad=True)

self.dropout = nn.Dropout(dropout)
self.dropout_feat = nn.Dropout(feat_dropout)

self.toknoise = nn.Dropout(self.args['tok_noise'])

def forward(self, x, feats):
def forward(self, x, feats, text, detach_2nd_pass=False):
emb = self.embeddings(x)
emb = self.dropout(emb)
feats = self.dropout_feat(feats)
Expand Down Expand Up @@ -87,12 +144,80 @@ def forward(self, x, feats):

nontok = F.logsigmoid(-tok0)
tok = F.logsigmoid(tok0)
nonsent = F.logsigmoid(-sent0)
sent = F.logsigmoid(sent0)
if self.args['use_mwt']:
nonmwt = F.logsigmoid(-mwt0)
mwt = F.logsigmoid(mwt0)

nonsent = F.logsigmoid(-sent0)
sent = F.logsigmoid(sent0)

# use the rough predictions from the char tokenizer to create word tokens
# then use those word tokens + contextual/fixed word embeddings to refine
# sentence predictions

if self.args["sentence_second_pass"]:
# these are the draft predictions for only token-level decisinos
# which we can use to slice the text
if self.args['use_mwt']:
draft_pred_locs = torch.cat([nontok, tok+nonsent+nonmwt, tok+sent+nonmwt, tok+nonsent+mwt, tok+sent+mwt], 2).argmax(dim=2)
else:
draft_pred_locs = torch.cat([nontok, tok+nonsent, tok+sent], 2).argmax(dim=2)

draft_pred_locs = (draft_pred_locs > 0)
# these boolean indicies are *inclusive*, so predict it or not
# we need to split on the last token if we want to keep the
# final word
draft_pred_locs[:,-1] = True

# both: batch x [variable: text token count]
extracted_tokens = []
partial = []
last = 0
last_batch = -1

nonzero = draft_pred_locs.nonzero().cpu().tolist()
for i,j in nonzero:
if i != last_batch:
last_batch = i
last = 0
if i != 0:
extracted_tokens.append(partial)
partial = []

substring = text[i][last:j+1]
last = j+1

partial.append("".join(substring))
extracted_tokens.append(partial)

# dynamically pad the batch tokens to size
# why to at least a fix size? it must be wider
# than our kernel
max_size = max(max([len(i) for i in extracted_tokens]),
self.args["sentence_analyzer_kernel"])
batch_tokens_padded = []
batch_tokens_isntpad = []
for i in extracted_tokens:
batch_tokens_padded.append(i + [PAD for _ in range(max_size-len(i))])
batch_tokens_isntpad.append([True for _ in range(len(i))] +
[False for _ in range(max_size-len(i))])
pad_mask = torch.tensor(batch_tokens_isntpad)


# pass the aligned result to the second pass classifier
second_pass_scores = self.sent_2nd_pass_clf(batch_tokens_padded, inp, draft_pred_locs, pad_mask)

mix = F.sigmoid(self.sent_2nd_mix)

# update sent0 value
if detach_2nd_pass:
sent0 = (1-mix.detach())*sent0 + mix.detach()*second_pass_scores.detach()
else:
sent0 = (1-mix)*sent0 + mix*second_pass_scores

nonsent = F.logsigmoid(-sent0)
sent = F.logsigmoid(sent0)

if self.args['use_mwt']:
pred = torch.cat([nontok, tok+nonsent+nonmwt, tok+sent+nonmwt, tok+nonsent+mwt, tok+sent+mwt], 2)
else:
Expand Down
33 changes: 25 additions & 8 deletions stanza/models/tokenization/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
logger = logging.getLogger('stanza')

class Trainer(BaseTrainer):
def __init__(self, args=None, vocab=None, lexicon=None, dictionary=None, model_file=None, device=None):
def __init__(self, args=None, vocab=None, lexicon=None, dictionary=None, model_file=None, device=None, pretrain=None):
self.pretrain = pretrain
if model_file is not None:
# load everything from file
self.load(model_file)
Expand All @@ -24,23 +25,36 @@ def __init__(self, args=None, vocab=None, lexicon=None, dictionary=None, model_f
self.vocab = vocab
self.lexicon = lexicon
self.dictionary = dictionary
self.model = Tokenizer(self.args, self.args['vocab_size'], self.args['emb_dim'], self.args['hidden_dim'], dropout=self.args['dropout'], feat_dropout=self.args['feat_dropout'])
self.model = Tokenizer(self.args, self.args['vocab_size'], self.args['emb_dim'], self.args['hidden_dim'], dropout=self.args['dropout'], feat_dropout=self.args['feat_dropout'], pretrain=pretrain)

if self.args["sentence_second_pass"]:
assert bool(pretrain), "context-aware sentence analysis requires pretrained wordvectors; download them!"

self.model = self.model.to(device)
self.criterion = nn.CrossEntropyLoss(ignore_index=-1).to(device)
self.optimizer = utils.get_optimizer("adam", self.model, lr=self.args['lr0'], betas=(.9, .9), weight_decay=self.args['weight_decay'])
self.feat_funcs = self.args.get('feat_funcs', None)
self.lang = self.args['lang'] # language determines how token normalization is done
self.pretrain = pretrain
self.global_step_counter_ = 0
self.train_2nd_pass = False

@property
def steps(self):
return self.global_step_counter_

def update(self, inputs):
self.global_step_counter_ += 1
self.model.train()
units, labels, features, _ = inputs
units, labels, features, text = inputs

device = next(self.model.parameters()).device
units = units.to(device)
labels = labels.to(device)
features = features.to(device)

pred = self.model(units, features)
# we detach 2nd pass if we are not training second pass
pred = self.model(units, features, text, not self.train_2nd_pass)

self.optimizer.zero_grad()
classes = pred.size(2)
Expand All @@ -54,13 +68,13 @@ def update(self, inputs):

def predict(self, inputs):
self.model.eval()
units, _, features, _ = inputs
units, _, features, text = inputs

device = next(self.model.parameters()).device
units = units.to(device)
features = features.to(device)

pred = self.model(units, features)
pred = self.model(units, features, text)

return pred.data.cpu().numpy()

Expand All @@ -69,7 +83,8 @@ def save(self, filename):
'model': self.model.state_dict() if self.model is not None else None,
'vocab': self.vocab.state_dict(),
'lexicon': self.lexicon,
'config': self.args
'config': self.args,
'steps': self.global_step_counter_
}
try:
torch.save(params, filename, _use_new_zipfile_serialization=False)
Expand All @@ -88,11 +103,13 @@ def load(self, filename):
# Default to True as many currently saved models
# were built with mwt layers
self.args['use_mwt'] = True
self.model = Tokenizer(self.args, self.args['vocab_size'], self.args['emb_dim'], self.args['hidden_dim'], dropout=self.args['dropout'], feat_dropout=self.args['feat_dropout'])
self.model = Tokenizer(self.args, self.args['vocab_size'], self.args['emb_dim'], self.args['hidden_dim'], dropout=self.args['dropout'], feat_dropout=self.args['feat_dropout'], pretrain=self.pretrain)
self.model.load_state_dict(checkpoint['model'])
self.vocab = Vocab.load_state_dict(checkpoint['vocab'])
self.lexicon = checkpoint['lexicon']

self.global_step_counter_ = checkpoint.get("steps", 0)

if self.lexicon is not None:
self.dictionary = create_dictionary(self.lexicon)
else:
Expand Down
33 changes: 31 additions & 2 deletions stanza/models/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from stanza.models.tokenization.trainer import Trainer
from stanza.models.tokenization.data import DataLoader, TokenizationDataset
from stanza.models.tokenization.utils import load_mwt_dict, eval_model, output_predictions, load_lexicon, create_dictionary
from stanza.models.common import pretrain

from stanza.models import _training_logging

logger = logging.getLogger('stanza')
Expand All @@ -46,6 +48,13 @@ def build_argparse():
parser.add_argument('--lang', type=str, help="Language")
parser.add_argument('--shorthand', type=str, help="UD treebank shorthand")

parser.add_argument('--wordvec_dir', type=str, default='extern_data/wordvec', help='Directory of word vectors.')
parser.add_argument('--wordvec_file', type=str, default=None, help='Word vectors filename.')
parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read')
parser.add_argument('--pretrain_max_vocab', type=int, default=250000)

parser.add_argument('--sentence_analyzer_kernel', type=int, default=4)

parser.add_argument('--mode', default='train', choices=['train', 'predict'])
parser.add_argument('--skip_newline', action='store_true', help="Whether to skip newline characters in input. Particularly useful for languages like Chinese.")

Expand All @@ -54,6 +63,8 @@ def build_argparse():
parser.add_argument('--conv_filters', type=str, default="1,9", help="Configuration of conv filters. ,, separates layers and , separates filter sizes in the same layer.")
parser.add_argument('--no-residual', dest='residual', action='store_false', help="Add linear residual connections")
parser.add_argument('--no-hierarchical', dest='hierarchical', action='store_false', help="\"Hierarchical\" RNN tokenizer")
parser.add_argument('--no_sentence_second_pass', dest='sentence_second_pass', action='store_false', help="predict the sentences together with tokens instead of after")
parser.add_argument('--second_pass_start_steps', type=int, help="when (how many steps) to start training the second pass classifier", default=5000)
parser.add_argument('--hier_invtemp', type=float, default=0.5, help="Inverse temperature used in propagating tokenization predictions between RNN layers")
parser.add_argument('--input_dropout', action='store_true', help="Dropout input embeddings as well")
parser.add_argument('--conv_res', type=str, default=None, help="Convolutional residual layers for the RNN")
Expand Down Expand Up @@ -113,6 +124,17 @@ def model_file_name(args):
return save_name
return os.path.join(args['save_dir'], save_name)

def load_pretrain(args):
pt = None
if args['sentence_second_pass']:
pretrain_file = pretrain.find_pretrain_file(args['wordvec_pretrain_file'], args['save_dir'], args['shorthand'], args['lang'])
if os.path.exists(pretrain_file):
vec_file = None
else:
vec_file = args['wordvec_file'] if args['wordvec_file'] else utils.get_wordvec_file(args['wordvec_dir'], args['shorthand'])
pt = pretrain.Pretrain(pretrain_file, vec_file, args['pretrain_max_vocab'])
return pt

def main(args=None):
args = parse_args(args=args)

Expand Down Expand Up @@ -164,7 +186,9 @@ def train(args):
args['use_mwt'] = train_batches.has_mwt()
logger.info("Found {}mwts in the training data. Setting use_mwt to {}".format(("" if args['use_mwt'] else "no "), args['use_mwt']))

trainer = Trainer(args=args, vocab=vocab, lexicon=lexicon, dictionary=dictionary, device=args['device'])
# load pretrained vectors if needed
pretrain = load_pretrain(args)
trainer = Trainer(args=args, vocab=vocab, lexicon=lexicon, dictionary=dictionary, device=args['device'], pretrain=pretrain)

if args['load_name'] is not None:
load_name = os.path.join(args['save_dir'], args['load_name'])
Expand All @@ -190,7 +214,11 @@ def train(args):
for step in range(1, steps+1):
batch = train_batches.next(unit_dropout=args['unit_dropout'], feat_unit_dropout = args['feat_unit_dropout'])

if trainer.steps > args["second_pass_start_steps"]:
trainer.train_2nd_pass = True

loss = trainer.update(batch)

if step % args['report_steps'] == 0:
logger.info("Step {:6d}/{:6d} Loss: {:.3f}".format(step, steps, loss))
if args['wandb']:
Expand Down Expand Up @@ -234,7 +262,8 @@ def train(args):

def evaluate(args):
mwt_dict = load_mwt_dict(args['mwt_json_file'])
trainer = Trainer(model_file=args['load_name'] or args['save_name'], device=args['device'])
pretrain = load_pretrain(args)
trainer = Trainer(model_file=args['load_name'] or args['save_name'], device=args['device'], pretrain=pretrain)
loaded_args, vocab = trainer.args, trainer.vocab

for k in loaded_args:
Expand Down
5 changes: 4 additions & 1 deletion stanza/pipeline/tokenize_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,14 @@ class TokenizeProcessor(UDProcessor):
MAX_SEQ_LENGTH_DEFAULT = 1000

def _set_up_model(self, config, pipeline, device):
# get pretrained word vectors
self._pretrain = pipeline.foundation_cache.load_pretrain(config['pretrain_path']) if 'pretrain_path' in config else None

# set up trainer
if config.get('pretokenized'):
self._trainer = None
else:
self._trainer = Trainer(model_file=config['model_path'], device=device)
self._trainer = Trainer(model_file=config['model_path'], device=device, pretrain=self.pretrain)

# get and typecheck the postprocessor
postprocessor = config.get('postprocessor')
Expand Down
Loading