From 036a12927446defddc0422a0313d11097ac08385 Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Mon, 4 Dec 2023 18:26:08 +0100 Subject: [PATCH 1/2] set the vocab size correctly when recreating the full embedding --- tests/test_contextual_reduce.py | 2 ++ transformer_smaller_training_vocab/modify_model.py | 1 + 2 files changed, 3 insertions(+) diff --git a/tests/test_contextual_reduce.py b/tests/test_contextual_reduce.py index aa9adb2..8315a99 100644 --- a/tests/test_contextual_reduce.py +++ b/tests/test_contextual_reduce.py @@ -69,6 +69,7 @@ def test_saving_while_reduction_can_be_loaded_afterwards(): "Home sweet home", "ay ay ay", ] + initial_vocab_size = model.config.vocab_size with tempfile.TemporaryDirectory() as tdir: with reduce_train_vocab(model=model, tokenizer=tokenizer, texts=texts): model.save_pretrained(tdir) @@ -77,3 +78,4 @@ def test_saving_while_reduction_can_be_loaded_afterwards(): new_tokenizer = AutoTokenizer.from_pretrained(tdir) assert new_model.config.vocab_size == 13 assert len(new_tokenizer) == 13 + assert model.config.vocab_size == initial_vocab_size diff --git a/transformer_smaller_training_vocab/modify_model.py b/transformer_smaller_training_vocab/modify_model.py index 6c7474e..b0376ae 100644 --- a/transformer_smaller_training_vocab/modify_model.py +++ b/transformer_smaller_training_vocab/modify_model.py @@ -54,5 +54,6 @@ def recreate_embedding( for reduced_id, full_id in enumerate(keep_token_ids): saved_embeddings[full_id] = embedding_weights[reduced_id] new_input_embedding = nn.Embedding(saved_embeddings.size(0), saved_embeddings.size(1), _weight=saved_embeddings) + model.config.vocab_size = saved_embeddings.size(0) model.set_input_embeddings(new_input_embedding) model.get_input_embeddings().to(model_device) From b9a193e87cb338b615c13c429642b6bca7c4100e Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Mon, 4 Dec 2023 18:30:24 +0100 Subject: [PATCH 2/2] also test for tokenizer size --- tests/test_contextual_reduce.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_contextual_reduce.py b/tests/test_contextual_reduce.py index 8315a99..f3aae2e 100644 --- a/tests/test_contextual_reduce.py +++ b/tests/test_contextual_reduce.py @@ -79,3 +79,4 @@ def test_saving_while_reduction_can_be_loaded_afterwards(): assert new_model.config.vocab_size == 13 assert len(new_tokenizer) == 13 assert model.config.vocab_size == initial_vocab_size + assert len(tokenizer) == initial_vocab_size