forked from mmatena/model_merging
-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluation.py
34 lines (28 loc) · 1.16 KB
/
evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
"""Scripts for evaluation of models."""
import datasets as hfds
import tensorflow as tf
def load_metric_for_glue_task(task: str):
if task == 'sst-2':
task = 'sst2'
if task == 'sts-b':
task = 'stsb'
return hfds.load_metric("glue", task)
def evaluate_model(model, dataset: tf.data.Dataset, metric: hfds.Metric, mergeable_models):
for model_input, gold_references in dataset:
model_predictions = model(model_input).logits
model_predictions = tf.argmax(model_predictions, axis=-1)
# print('model inputs:')
# print(model_input)
# input_model_predictions = mergeable_models[0](model_input).logits
# input_model_predictions = tf.argmax(input_model_predictions, axis=-1)
# print('input model predictions:')
# print(input_model_predictions)
# print('model predictions:')
# print(model_predictions)
# print('gold references:')
# print(gold_references)
# print()
metric.add_batch(predictions=model_predictions, references=gold_references)
return metric.compute()
def average_score(score):
return sum(score.values()) / len(score.values())