-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
146 lines (114 loc) · 6.06 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import os
import tensorflow as tf
from data_utils import load_data
import transformers
import pickle
import argparse
from bert_models.tf_distilbert_for_ordinal_regression import TFDistilBertForOrdinalRegression
from bert_models.tf_distilbert_for_classification import TFDistilBertForClassification
from metrics import *
from math import ceil
parser = argparse.ArgumentParser()
parser.add_argument('--version', '-v', type=str, default='final',
help='Version of model, used for naming the saved files.')
parser.add_argument('--pretrain', '-p', action='store_true', help='Whether to use the model pretrained on the corpus')
parser.add_argument('--ordinal', '-o', action='store_true',
help='Whether to use ordinal regression instead of classification.')
parser.add_argument('--as_features', '-f', action='store_true',
help='Whether to freeze the BERT layers and use them only as features instead of fine-tuning.')
parser.add_argument('--loss_weights', '-lw', nargs='+', default=[1, 1, 1, 1, 1],
help='Loss weights for each possible star label')
parser.add_argument('--disable_layer_norm', '-dn', action='store_true',
help='Whether to use layer normalization before the classification output')
parser.add_argument('--epochs', '-e', type=int, default=4, help='Number of epochs to train for.')
parser.add_argument('--batch_size', '-b', type=int, default=16,
help='Batch size to use for training')
args = parser.parse_args()
strategy = tf.distribute.MirroredStrategy()
def make_weighted_loss(loss_type, weights, ord=False):
"""
Returns a weighted version of `loss_fn`.
Parameters
==========
loss_type : Class[tf.keras.losses.Loss]
The class of the original loss function
weights : list
A list of weights, one for each possible label.
ord : bool
If true, assumes that the labels are ordinal labels.
"""
weights = tf.constant(weights)
loss_fn = loss_type(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)
def loss(y_true, y_pred):
indices = tf.cast(tf.reduce_sum(y_true, axis=1) if ord else y_true, tf.int32)
losses = loss_fn(y_true, y_pred)
weighted_losses = tf.reshape(tf.gather(weights, indices), [-1]) * losses
return tf.reduce_mean(weighted_losses)
return loss
assert len(args.loss_weights) == 5, "Must have exactly 5 loss weights (one for each star rating)"
loss_weights = [float(i) for i in args.loss_weights]
weighted_loss = (loss_weights != [1, 1, 1, 1, 1])
SparseCategoricalCrossentropy = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
WeightedSparseCategoricalCrossentropy = make_weighted_loss(tf.keras.losses.SparseCategoricalCrossentropy, loss_weights, ord=False)
BinaryCrossentropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
WeightedBinaryCrossentropy = make_weighted_loss(tf.keras.losses.BinaryCrossentropy, loss_weights, ord=True)
if weighted_loss:
print(f"Using loss weights {loss_weights}")
if args.ordinal:
model_type = TFDistilBertForOrdinalRegression
loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)
num_labels = 4
metrics = [ord_pred_accuracy, ord_pred_abs_error]
else:
model_type = TFDistilBertForClassification
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
num_labels = 5
metrics = ['accuracy', pred_abs_error]
if args.pretrain:
config = transformers.DistilBertConfig.from_pretrained('pretrained/', num_labels=num_labels)
else:
config = transformers.DistilBertConfig.from_pretrained('distilbert-base-uncased', num_labels=num_labels)
with strategy.scope():
if args.pretrain:
model = model_type.from_pretrained('pretrained/', config=config, as_features=args.as_features,
use_layer_norm=not args.disable_layer_norm, from_pt=True)
else:
model = model_type.from_pretrained('distilbert-base-uncased', config=config, as_features=args.as_features,
use_layer_norm=not args.disable_layer_norm)
model.compile(optimizer='adam', loss=loss, metrics=metrics)
batch_size_per_replica = args.batch_size
batch_size = batch_size_per_replica * strategy.num_replicas_in_sync
train_dataset = load_data(split='train', ordinal=args.ordinal, batch_size=batch_size)
valid_dataset = load_data(split='valid', ordinal=args.ordinal, batch_size=batch_size)
num_examples = ceil(747010 / batch_size)
num_training_steps = num_examples * args.epochs
num_warmup_steps = num_training_steps // 10
base_learning_rate = 2e-5
current_epoch = 0
def learning_rate_schedule(batch, logs):
step = batch + (current_epoch * num_examples)
if step < num_warmup_steps:
new_lr = base_learning_rate * float(step + 1) / float(max(1, num_warmup_steps))
else:
new_lr = base_learning_rate * max(0.0, float(num_training_steps - step) /
float(max(1, num_training_steps - num_warmup_steps)))
tf.keras.backend.set_value(model.optimizer.lr, new_lr)
lr_callback = tf.keras.callbacks.LambdaCallback(on_batch_begin=learning_rate_schedule)
def record_epoch(epoch, logs):
global current_epoch
current_epoch = epoch
epoch_callback = tf.keras.callbacks.LambdaCallback(on_epoch_begin=record_epoch)
checkpoint_dir = os.path.join('checkpoints', args.version)
checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt_{epoch}')
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix, save_weights_only=True)
history = model.fit(train_dataset, epochs=args.epochs, callbacks=[epoch_callback, checkpoint_callback, lr_callback],
validation_data=valid_dataset)
save_dir = os.path.join('models', args.version)
if not os.path.isdir(save_dir):
os.mkdir(save_dir)
model.save_pretrained(save_dir)
train_history_dir = os.path.join('train_history', args.version)
if not os.path.isdir(train_history_dir):
os.mkdir(train_history_dir)
with open(os.path.join(train_history_dir, 'train_history.pickle'), 'wb') as f:
pickle.dump(history.history, f)