-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* added HuggingFace example Signed-off-by: cmvcordova <[email protected]> * changed to evaluate.load, revamped location test * "fix" for a weird inter-dep between configs Signed-off-by: Fabrice Normandin <[email protected]> * restored gitignore, install.md * hf_example typing nitpicks * more fixes * nitpicks and change from pass to slow on overfit test * removed three files * torch.set_num_threads(1) reverted to comment to avoid template-wide application * moved utils back into hf datamodule * reverted main to master * restored main, added typehints in hf_text.py * Nitpicky suggestions from code review --------- Signed-off-by: cmvcordova <[email protected]> Signed-off-by: Fabrice Normandin <[email protected]> Co-authored-by: cmvcordova <[email protected]> Co-authored-by: Fabrice Normandin <[email protected]>
- Loading branch information
1 parent
422ddba
commit 32ae062
Showing
21 changed files
with
1,116 additions
and
363 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
from datetime import datetime | ||
from pathlib import Path | ||
|
||
import torch | ||
from evaluate import load as load_metric | ||
from lightning import LightningModule | ||
from torch.optim import AdamW | ||
from transformers import ( | ||
AutoConfig, | ||
AutoModelForSequenceClassification, | ||
PreTrainedModel, | ||
get_linear_schedule_with_warmup, | ||
) | ||
|
||
from project.datamodules.text.hf_text import HFDataModule | ||
|
||
|
||
def pretrained_network(model_name_or_path: str | Path, **kwargs) -> PreTrainedModel: | ||
config = AutoConfig.from_pretrained(model_name_or_path, **kwargs) | ||
return AutoModelForSequenceClassification.from_pretrained(model_name_or_path, config=config) | ||
|
||
|
||
class HFExample(LightningModule): | ||
"""Example of a lightning module used to train a huggingface model.""" | ||
|
||
def __init__( | ||
self, | ||
datamodule: HFDataModule, | ||
network: PreTrainedModel, | ||
hf_metric_name: str, | ||
learning_rate: float = 2e-5, | ||
adam_epsilon: float = 1e-8, | ||
warmup_steps: int = 0, | ||
weight_decay: float = 0.0, | ||
**kwargs, | ||
): | ||
super().__init__() | ||
|
||
self.save_hyperparameters() | ||
self.num_labels = datamodule.num_labels | ||
self.task_name = datamodule.task_name | ||
self.network = network | ||
self.hf_metric_name = hf_metric_name | ||
self.metric = load_metric( | ||
self.hf_metric_name, | ||
self.task_name, | ||
experiment_id=datetime.now().strftime("%d-%m-%Y_%H-%M-%S"), | ||
) | ||
|
||
def forward( | ||
self, | ||
input_ids: torch.Tensor, | ||
token_type_ids: torch.Tensor, | ||
attention_mask: torch.Tensor, | ||
labels: torch.Tensor, | ||
): | ||
return self.network( | ||
input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels | ||
) | ||
|
||
def model_step(self, batch: dict[str, torch.Tensor]): | ||
input_ids = batch["input_ids"] | ||
token_type_ids = batch["token_type_ids"] | ||
attention_mask = batch["attention_mask"] | ||
labels = batch["labels"] | ||
|
||
outputs = self.forward(input_ids, token_type_ids, attention_mask, labels) | ||
loss = outputs.loss | ||
logits = outputs.logits | ||
|
||
if self.num_labels > 1: | ||
preds = torch.argmax(logits, axis=1) | ||
else: | ||
preds = logits.squeeze() | ||
|
||
return loss, preds, labels | ||
|
||
def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int): | ||
loss, preds, labels = self.model_step(batch) | ||
self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=True) | ||
return {"loss": loss, "preds": preds, "labels": labels} | ||
|
||
def validation_step( | ||
self, batch: dict[str, torch.Tensor], batch_idx: int, dataloader_idx: int = 0 | ||
): | ||
val_loss, preds, labels = self.model_step(batch) | ||
self.log("val/loss", val_loss, on_step=False, on_epoch=True, prog_bar=True) | ||
return {"val/loss": val_loss, "preds": preds, "labels": labels} | ||
|
||
def configure_optimizers(self): | ||
"""Prepare optimizer and schedule (linear warmup and decay)""" | ||
model = self.network | ||
no_decay = ["bias", "LayerNorm.weight"] | ||
optimizer_grouped_parameters = [ | ||
{ | ||
"params": [ | ||
p | ||
for n, p in model.named_parameters() | ||
if not any(nd_param in n for nd_param in no_decay) | ||
], | ||
"weight_decay": self.hparams.weight_decay, | ||
}, | ||
{ | ||
"params": [ | ||
p | ||
for n, p in model.named_parameters() | ||
if any(nd_param in n for nd_param in no_decay) | ||
], | ||
"weight_decay": 0.0, | ||
}, | ||
] | ||
optimizer = AdamW( | ||
optimizer_grouped_parameters, | ||
lr=self.hparams.learning_rate, | ||
eps=self.hparams.adam_epsilon, | ||
) | ||
|
||
scheduler = get_linear_schedule_with_warmup( | ||
optimizer, | ||
num_warmup_steps=self.hparams.warmup_steps, | ||
num_training_steps=self.trainer.estimated_stepping_batches, | ||
) | ||
scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} | ||
return [optimizer], [scheduler] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import lightning | ||
import pytest | ||
from torch import Tensor | ||
from transformers import PreTrainedModel | ||
|
||
from project.algorithms.hf_example import HFExample | ||
from project.datamodules.text.hf_text import HFDataModule | ||
from project.utils.testutils import run_for_all_configs_of_type | ||
|
||
from .testsuites.algorithm_tests import LearningAlgorithmTests | ||
|
||
|
||
class RecordTrainingLossCb(lightning.Callback): | ||
def __init__(self): | ||
self.losses = [] | ||
|
||
def on_train_batch_end( | ||
self, | ||
trainer, | ||
pl_module, | ||
outputs, | ||
batch, | ||
batch_idx, | ||
): | ||
self.losses.append(outputs["loss"].detach()) | ||
|
||
|
||
@run_for_all_configs_of_type("algorithm", HFExample) | ||
@run_for_all_configs_of_type("datamodule", HFDataModule) | ||
@run_for_all_configs_of_type("network", PreTrainedModel) | ||
class TestHFExample(LearningAlgorithmTests[HFExample]): | ||
"""Tests for the HF example.""" | ||
|
||
@pytest.fixture(scope="session") | ||
def forward_pass_input(self, training_batch: dict[str, Tensor]): | ||
assert isinstance(training_batch, dict) | ||
return training_batch | ||
|
||
@pytest.mark.slow | ||
def test_overfit_batch( | ||
self, | ||
algorithm: HFExample, | ||
datamodule: HFDataModule, | ||
accelerator: str, | ||
devices: int | list[int], | ||
training_batch: dict[str, Tensor], | ||
num_steps: int = 3, | ||
): | ||
"""Test that the loss decreases on a single batch.""" | ||
get_loss_cb = RecordTrainingLossCb() | ||
trainer = lightning.Trainer( | ||
accelerator=accelerator, | ||
callbacks=[get_loss_cb], | ||
devices=devices, | ||
enable_checkpointing=False, | ||
deterministic=True, | ||
overfit_batches=1, | ||
limit_train_batches=1, | ||
max_epochs=num_steps, | ||
) | ||
trainer.fit(algorithm, datamodule) | ||
losses_at_each_epoch: list[Tensor] = get_loss_cb.losses | ||
|
||
assert ( | ||
len(losses_at_each_epoch) == num_steps | ||
), f"Expected {num_steps} losses, got {len(losses_at_each_epoch)}" | ||
|
||
assert losses_at_each_epoch[0] > losses_at_each_epoch[-1], ( | ||
f"Loss did not decrease on overfit: final loss= {losses_at_each_epoch[-1]}," | ||
f"initial loss={losses_at_each_epoch[0]}" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
_target_: project.algorithms.hf_example.HFExample | ||
# NOTE: Why _partial_? Because the config doesn't create the algo directly, it creates a function | ||
# that will accept the datamodule and network and return the algo. | ||
_partial_: true | ||
_recursive_: false | ||
hf_metric_name: glue |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
_target_: project.datamodules.HFDataModule | ||
tokenizer: albert-base-v2 | ||
hf_dataset_path: glue | ||
task_name: cola | ||
max_seq_length: 128 | ||
train_batch_size: 32 | ||
eval_batch_size: 32 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# @package _global_ | ||
|
||
defaults: | ||
- override /network: albert-base-v2 | ||
- override /datamodule: hf_text | ||
- override /algorithm: hf_example | ||
- override /trainer/callbacks: none | ||
|
||
trainer: | ||
min_epochs: 1 | ||
max_epochs: 2 | ||
limit_train_batches: 2 | ||
limit_val_batches: 1 | ||
num_sanity_val_steps: 0 | ||
enable_checkpointing: False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
_target_: project.algorithms.hf_example.pretrained_network | ||
model_name_or_path: albert-base-v2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,8 @@ accelerator: auto | |
strategy: auto | ||
devices: 1 | ||
|
||
deterministic: true | ||
|
||
min_epochs: 1 | ||
max_epochs: 10 | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.