forked from mmatena/model_merging
-
Notifications
You must be signed in to change notification settings - Fork 0
/
hf_util.py
52 lines (38 loc) · 1.42 KB
/
hf_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
"""Utilities for HuggingFace."""
from typing import Tuple, Union
import tensorflow as tf
from transformers import TFBertPreTrainedModel
from transformers import TFRobertaPreTrainedModel
# def get_body_and_head(
# model: Union[TFBertPreTrainedModel, TFRobertaPreTrainedModel]
# ) -> Tuple[tf.keras.layers.Layer, tf.keras.layers.Layer]:
# body, *head = model.layers
# if not head:
# head = None
# elif len(head) > 1:
# raise ValueError(
# f"Expected model to have a single 'head' layer. Instead found {len(head)}. TODO: Support this."
# )
# else:
# head = head[0]
# return body, head
# def get_body(model):
# return get_body_and_head(model)[0]
# def get_mergeable_variables(model):
# return get_body_and_head(model)[0].trainable_variables
# TODO: try this one because I don't think they ever use the head
def get_body_and_head(
model: Union[TFBertPreTrainedModel, TFRobertaPreTrainedModel]
) -> tf.keras.layers.Layer: # -> Tuple[tf.keras.layers.Layer, tf.keras.layers.Layer]:
body, *head = model.layers
# print(body)
return model.layers[0]
def get_body(model):
return get_body_and_head(model)
def get_mergeable_variables(model):
return get_body_and_head(model).trainable_variables
def clone_model(model):
cloned = model.__class__(model.config)
cloned(model.dummy_inputs)
cloned.set_weights(model.get_weights())
return cloned