Skip to content

Commit

Permalink
Add attention model option for training many models
Browse files Browse the repository at this point in the history
  • Loading branch information
SecroLoL authored and AngledLuffa committed Sep 16, 2024
1 parent 9720d44 commit f8455f4
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions stanza/models/lemma_classifier/train_many.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ def train_n_models(num_models: int, base_path: str, args):
args.save_name = new_save_name
train_lstm_main(predefined_args=args)

if args.change_param == "attn_model":
for i in range(num_models):
new_save_name = os.path.join(base_path, f"attn_model_{args.num_heads}_heads_{i}.pt")
args.save_name = new_save_name
train_lstm_main(predefined_args=args)

def train_n_tfmrs(num_models: int, base_path: str, args):

if args.multi_train_type == "tfmr":
Expand Down

0 comments on commit f8455f4

Please sign in to comment.