Skip to content

Commit

Permalink
rename model_ckpt to pretrained_model
Browse files Browse the repository at this point in the history
  • Loading branch information
bagustris committed May 27, 2024
1 parent 92a5605 commit 7fc73d4
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
4 changes: 2 additions & 2 deletions ini_file.md
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,8 @@
* device = cpu
* **patience**: Number of epochs to wait if the result gets better (for early stopping)
* patience = 5
* **model_ckpt**: Base model for finetuning/transfer learning. Variants of wav2vec2, Hubert, and WavLM are tested to work.
* model_ckpt = microsoft/wavlm-base
* **pretrained_model**: Base model for finetuning/transfer learning. Variants of wav2vec2, Hubert, and WavLM are tested to work.
* pretrained_model = microsoft/wavlm-base

### EXPL
* **model**: Which model to use to estimate feature importance.
Expand Down
8 changes: 4 additions & 4 deletions nkululeko/models/model_tuned.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,14 @@ def __init__(self, df_train, df_test, feats_train, feats_test):

def _init_model(self):
model_path = "facebook/wav2vec2-large-robust-ft-swbd-300h"
model_ckpt = self.util.config_val("MODEL", "model_ckpt", model_path)
pretrained_model = self.util.config_val("MODEL", "pretrained_model", model_path)
self.num_layers = None
self.sampling_rate = 16000
self.max_duration_sec = 8.0
self.accumulation_steps = 4

# print finetuning information via debug
self.util.debug(f"Finetuning from model: {model_ckpt}")
self.util.debug(f"Finetuning from model: {pretrained_model}")

# create dataset
dataset = {}
Expand Down Expand Up @@ -92,7 +92,7 @@ def _init_model(self):
value in target_mapping.items()}

self.config = transformers.AutoConfig.from_pretrained(
model_ckpt,
pretrained_model,
num_labels=len(target_mapping),
label2id=target_mapping,
id2label=target_mapping_reverse,
Expand Down Expand Up @@ -124,7 +124,7 @@ def _init_model(self):
assert self.processor.feature_extractor.sampling_rate == self.sampling_rate

self.model = Model.from_pretrained(
model_ckpt,
pretrained_model,
config=self.config,
)
self.model.freeze_feature_extractor()
Expand Down
2 changes: 1 addition & 1 deletion tests/exp_emodb_finetune.ini
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ type = []
type = finetune
device = 1
batch_size = 8
model_ckpt = microsoft/wavlm-base
pretrained_model = microsoft/wavlm-base

0 comments on commit 7fc73d4

Please sign in to comment.