Skip to content

Commit

Permalink
NLP Example (#39)
Browse files Browse the repository at this point in the history
* added HuggingFace example

Signed-off-by: cmvcordova <[email protected]>

* changed to evaluate.load, revamped location test

* "fix" for a weird inter-dep between configs

Signed-off-by: Fabrice Normandin <[email protected]>

* restored gitignore, install.md

* hf_example typing nitpicks

* more fixes

* nitpicks and change from pass to slow on overfit test

* removed three files

* torch.set_num_threads(1) reverted to comment to avoid template-wide application

* moved utils back into hf datamodule

* reverted main to master

* restored main, added typehints in hf_text.py

* Nitpicky suggestions from code review

---------

Signed-off-by: cmvcordova <[email protected]>
Signed-off-by: Fabrice Normandin <[email protected]>
Co-authored-by: cmvcordova <[email protected]>
Co-authored-by: Fabrice Normandin <[email protected]>
  • Loading branch information
cmvcordova and lebrice authored Aug 30, 2024
1 parent 422ddba commit 32ae062
Show file tree
Hide file tree
Showing 21 changed files with 1,116 additions and 363 deletions.
2 changes: 2 additions & 0 deletions project/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from project.algorithms.hf_example import HFExample
from project.algorithms.jax_example import JaxExample
from project.algorithms.no_op import NoOp

Expand All @@ -7,4 +8,5 @@
"ExampleAlgorithm",
"JaxExample",
"NoOp",
"HFExample",
]
3 changes: 2 additions & 1 deletion project/algorithms/example_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Example showing how the test suite can be used to add tests for a new algorithm."""

import torch.nn
from transformers import PreTrainedModel

from project.algorithms.testsuites.algorithm_tests import LearningAlgorithmTests
from project.datamodules.image_classification.image_classification import (
Expand All @@ -13,7 +14,7 @@

@run_for_all_configs_of_type("algorithm", ExampleAlgorithm)
@run_for_all_configs_of_type("datamodule", ImageClassificationDataModule)
@run_for_all_configs_of_type("network", torch.nn.Module)
@run_for_all_configs_of_type("network", torch.nn.Module, excluding=PreTrainedModel)
class TestExampleAlgo(LearningAlgorithmTests[ExampleAlgorithm]):
"""Tests for the `ExampleAlgorithm`.
Expand Down
124 changes: 124 additions & 0 deletions project/algorithms/hf_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from datetime import datetime
from pathlib import Path

import torch
from evaluate import load as load_metric
from lightning import LightningModule
from torch.optim import AdamW
from transformers import (
AutoConfig,
AutoModelForSequenceClassification,
PreTrainedModel,
get_linear_schedule_with_warmup,
)

from project.datamodules.text.hf_text import HFDataModule


def pretrained_network(model_name_or_path: str | Path, **kwargs) -> PreTrainedModel:
config = AutoConfig.from_pretrained(model_name_or_path, **kwargs)
return AutoModelForSequenceClassification.from_pretrained(model_name_or_path, config=config)


class HFExample(LightningModule):
"""Example of a lightning module used to train a huggingface model."""

def __init__(
self,
datamodule: HFDataModule,
network: PreTrainedModel,
hf_metric_name: str,
learning_rate: float = 2e-5,
adam_epsilon: float = 1e-8,
warmup_steps: int = 0,
weight_decay: float = 0.0,
**kwargs,
):
super().__init__()

self.save_hyperparameters()
self.num_labels = datamodule.num_labels
self.task_name = datamodule.task_name
self.network = network
self.hf_metric_name = hf_metric_name
self.metric = load_metric(
self.hf_metric_name,
self.task_name,
experiment_id=datetime.now().strftime("%d-%m-%Y_%H-%M-%S"),
)

def forward(
self,
input_ids: torch.Tensor,
token_type_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels: torch.Tensor,
):
return self.network(
input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels
)

def model_step(self, batch: dict[str, torch.Tensor]):
input_ids = batch["input_ids"]
token_type_ids = batch["token_type_ids"]
attention_mask = batch["attention_mask"]
labels = batch["labels"]

outputs = self.forward(input_ids, token_type_ids, attention_mask, labels)
loss = outputs.loss
logits = outputs.logits

if self.num_labels > 1:
preds = torch.argmax(logits, axis=1)
else:
preds = logits.squeeze()

return loss, preds, labels

def training_step(self, batch: dict[str, torch.Tensor], batch_idx: int):
loss, preds, labels = self.model_step(batch)
self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=True)
return {"loss": loss, "preds": preds, "labels": labels}

def validation_step(
self, batch: dict[str, torch.Tensor], batch_idx: int, dataloader_idx: int = 0
):
val_loss, preds, labels = self.model_step(batch)
self.log("val/loss", val_loss, on_step=False, on_epoch=True, prog_bar=True)
return {"val/loss": val_loss, "preds": preds, "labels": labels}

def configure_optimizers(self):
"""Prepare optimizer and schedule (linear warmup and decay)"""
model = self.network
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in model.named_parameters()
if not any(nd_param in n for nd_param in no_decay)
],
"weight_decay": self.hparams.weight_decay,
},
{
"params": [
p
for n, p in model.named_parameters()
if any(nd_param in n for nd_param in no_decay)
],
"weight_decay": 0.0,
},
]
optimizer = AdamW(
optimizer_grouped_parameters,
lr=self.hparams.learning_rate,
eps=self.hparams.adam_epsilon,
)

scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=self.hparams.warmup_steps,
num_training_steps=self.trainer.estimated_stepping_batches,
)
scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
return [optimizer], [scheduler]
71 changes: 71 additions & 0 deletions project/algorithms/hf_example_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import lightning
import pytest
from torch import Tensor
from transformers import PreTrainedModel

from project.algorithms.hf_example import HFExample
from project.datamodules.text.hf_text import HFDataModule
from project.utils.testutils import run_for_all_configs_of_type

from .testsuites.algorithm_tests import LearningAlgorithmTests


class RecordTrainingLossCb(lightning.Callback):
def __init__(self):
self.losses = []

def on_train_batch_end(
self,
trainer,
pl_module,
outputs,
batch,
batch_idx,
):
self.losses.append(outputs["loss"].detach())


@run_for_all_configs_of_type("algorithm", HFExample)
@run_for_all_configs_of_type("datamodule", HFDataModule)
@run_for_all_configs_of_type("network", PreTrainedModel)
class TestHFExample(LearningAlgorithmTests[HFExample]):
"""Tests for the HF example."""

@pytest.fixture(scope="session")
def forward_pass_input(self, training_batch: dict[str, Tensor]):
assert isinstance(training_batch, dict)
return training_batch

@pytest.mark.slow
def test_overfit_batch(
self,
algorithm: HFExample,
datamodule: HFDataModule,
accelerator: str,
devices: int | list[int],
training_batch: dict[str, Tensor],
num_steps: int = 3,
):
"""Test that the loss decreases on a single batch."""
get_loss_cb = RecordTrainingLossCb()
trainer = lightning.Trainer(
accelerator=accelerator,
callbacks=[get_loss_cb],
devices=devices,
enable_checkpointing=False,
deterministic=True,
overfit_batches=1,
limit_train_batches=1,
max_epochs=num_steps,
)
trainer.fit(algorithm, datamodule)
losses_at_each_epoch: list[Tensor] = get_loss_cb.losses

assert (
len(losses_at_each_epoch) == num_steps
), f"Expected {num_steps} losses, got {len(losses_at_each_epoch)}"

assert losses_at_each_epoch[0] > losses_at_each_epoch[-1], (
f"Loss did not decrease on overfit: final loss= {losses_at_each_epoch[-1]},"
f"initial loss={losses_at_each_epoch[0]}"
)
28 changes: 20 additions & 8 deletions project/algorithms/testsuites/algorithm_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@
AlgorithmType = TypeVar("AlgorithmType", bound=LightningModule)


def forward_pass(algorithm: LightningModule, input: PyTree[torch.Tensor]):
"""Performs the forward pass with the lightningmodule, unpacking the inputs if necessary."""
if len(inspect.signature(algorithm.forward).parameters) == 1:
return algorithm(input)
assert isinstance(input, dict)
return algorithm(**input)


@pytest.mark.incremental
class LearningAlgorithmTests(Generic[AlgorithmType], ABC):
"""Suite of unit tests for an "Algorithm" (LightningModule).
Expand Down Expand Up @@ -64,9 +72,9 @@ def test_forward_pass_is_deterministic(
input."""

with seeded_rng(seed):
out1 = algorithm(forward_pass_input)
out1 = forward_pass(algorithm, forward_pass_input)
with seeded_rng(seed):
out2 = algorithm(forward_pass_input)
out2 = forward_pass(algorithm, forward_pass_input)
torch.testing.assert_close(out1, out2)

# @pytest.mark.timeout(10)
Expand Down Expand Up @@ -147,7 +155,7 @@ def test_forward_pass_is_reproducible(
):
"""Check that the forward pass is reproducible given the same input and random seed."""
with seeded_rng(seed):
out = algorithm(forward_pass_input)
out = forward_pass(algorithm, forward_pass_input)
tensor_regression.check(
{"input": forward_pass_input, "out": out},
default_tolerance={"rtol": 1e-5, "atol": 1e-6}, # some tolerance for changes.
Expand Down Expand Up @@ -179,16 +187,20 @@ def test_backward_pass_is_reproducible(
tmp_path=tmp_path,
)
# BUG: Fix issue in tensor_regression calling .numpy() on cuda tensors.
assert (
isinstance(gradients_callback.batch, list | tuple)
and len(gradients_callback.batch) == 2
)
assert isinstance(gradients_callback.grads, dict)
assert isinstance(gradients_callback.outputs, dict)
batch = gradients_callback.batch
if isinstance(batch, list | tuple):
cpu_batch = {str(i): t.cpu() for i, t in enumerate(batch)}
else:
assert isinstance(batch, dict) and all(
isinstance(v, torch.Tensor) for v in batch.values()
)
cpu_batch = {k: v.cpu() for k, v in batch.items()}
tensor_regression.check(
{
# FIXME: This is ugly, and specific to the image classification example.
"batch": {str(i): t.cpu() for i, t in enumerate(gradients_callback.batch)},
"batch": cpu_batch,
"grads": {
k: v.cpu() if v is not None else None
for k, v in gradients_callback.grads.items()
Expand Down
6 changes: 6 additions & 0 deletions project/configs/algorithm/hf_example.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
_target_: project.algorithms.hf_example.HFExample
# NOTE: Why _partial_? Because the config doesn't create the algo directly, it creates a function
# that will accept the datamodule and network and return the algo.
_partial_: true
_recursive_: false
hf_metric_name: glue
7 changes: 7 additions & 0 deletions project/configs/datamodule/hf_text.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
_target_: project.datamodules.HFDataModule
tokenizer: albert-base-v2
hf_dataset_path: glue
task_name: cola
max_seq_length: 128
train_batch_size: 32
eval_batch_size: 32
15 changes: 15 additions & 0 deletions project/configs/experiment/albert-cola-glue.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# @package _global_

defaults:
- override /network: albert-base-v2
- override /datamodule: hf_text
- override /algorithm: hf_example
- override /trainer/callbacks: none

trainer:
min_epochs: 1
max_epochs: 2
limit_train_batches: 2
limit_val_batches: 1
num_sanity_val_steps: 0
enable_checkpointing: False
2 changes: 2 additions & 0 deletions project/configs/network/albert-base-v2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
_target_: project.algorithms.hf_example.pretrained_network
model_name_or_path: albert-base-v2
2 changes: 2 additions & 0 deletions project/configs/trainer/callbacks/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,5 @@ lr_monitor:

device_utilisation:
_target_: lightning.pytorch.callbacks.DeviceStatsMonitor
throughput:
_target_: project.algorithms.callbacks.samples_per_second.MeasureSamplesPerSecondCallback
2 changes: 2 additions & 0 deletions project/configs/trainer/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ accelerator: auto
strategy: auto
devices: 1

deterministic: true

min_epochs: 1
max_epochs: 10

Expand Down
33 changes: 11 additions & 22 deletions project/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,14 @@
setup_logging,
)
from project.main import PROJECT_NAME
from project.utils.hydra_config_utils import get_config_loader
from project.utils.hydra_utils import resolve_dictconfig
from project.utils.testutils import (
PARAM_WHEN_USED_MARK_NAME,
default_marks_for_config_combinations,
default_marks_for_config_name,
seeded_rng,
)
from project.utils.typing_utils import is_sequence_of
from project.utils.typing_utils import is_mapping_of, is_sequence_of
from project.utils.typing_utils.protocols import DataModule

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -257,27 +256,17 @@ def network(
return network


# BUG: The network has a default config of `resnet18`, which tries to get the
# num_classes from the datamodule. However, the hf_text datamodule doesn't have that attribute,
# and we load the datamodule using the entire experiment config, so loading the network raises an
# error!
# - instantiate(experiment_config).datamodule
# - instantiate(experiment_dictconfig['datamodule'])


@pytest.fixture(scope="session")
def datamodule(
experiment_config: Config,
_common_setup_experiment_part: None,
datamodule_config: str | None,
overrides: list[str] | None,
) -> DataModule:
def datamodule(experiment_config: Config) -> DataModule:
"""Fixture that creates the datamodule for the given config."""
if datamodule_config:
# Load only the datamodule? (assuming it doesn't depend on the network or anything else...)
from hydra.types import RunMode

config = get_config_loader().load_configuration(
f"datamodule/{datamodule_config}.yaml",
overrides=overrides or [],
run_mode=RunMode.RUN,
)
datamodule_config = config["datamodule"]
assert isinstance(datamodule_config, DictConfig)
datamodule = instantiate_datamodule(datamodule_config)
return datamodule
# NOTE: creating the datamodule by itself instead of with everything else.
return instantiate_datamodule(experiment_config.datamodule)

Expand Down Expand Up @@ -319,7 +308,7 @@ def training_batch(
batch = tuple(t.to(device=device) for t in batch)
return batch
else:
assert isinstance(batch, dict) and is_sequence_of(batch.values(), Tensor)
assert is_mapping_of(batch, str, torch.Tensor)
batch = {k: v.to(device=device) for k, v in batch.items()}
return batch

Expand Down
Loading

0 comments on commit 32ae062

Please sign in to comment.