Skip to content

Commit

Permalink
set the vocab size correctly when recreating the full embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
Benedikt Fuchs committed Dec 4, 2023
1 parent 00f49cf commit 036a129
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 0 deletions.
2 changes: 2 additions & 0 deletions tests/test_contextual_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def test_saving_while_reduction_can_be_loaded_afterwards():
"Home sweet home",
"ay ay ay",
]
initial_vocab_size = model.config.vocab_size
with tempfile.TemporaryDirectory() as tdir:
with reduce_train_vocab(model=model, tokenizer=tokenizer, texts=texts):
model.save_pretrained(tdir)
Expand All @@ -77,3 +78,4 @@ def test_saving_while_reduction_can_be_loaded_afterwards():
new_tokenizer = AutoTokenizer.from_pretrained(tdir)
assert new_model.config.vocab_size == 13
assert len(new_tokenizer) == 13
assert model.config.vocab_size == initial_vocab_size
1 change: 1 addition & 0 deletions transformer_smaller_training_vocab/modify_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,6 @@ def recreate_embedding(
for reduced_id, full_id in enumerate(keep_token_ids):
saved_embeddings[full_id] = embedding_weights[reduced_id]
new_input_embedding = nn.Embedding(saved_embeddings.size(0), saved_embeddings.size(1), _weight=saved_embeddings)
model.config.vocab_size = saved_embeddings.size(0)
model.set_input_embeddings(new_input_embedding)
model.get_input_embeddings().to(model_device)

0 comments on commit 036a129

Please sign in to comment.