Skip to content

Commit

Permalink
Tweak to Algorithm.shared_step_end
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Jun 25, 2024
1 parent 4633846 commit e1b8b2d
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 15 deletions.
31 changes: 24 additions & 7 deletions project/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,21 +115,38 @@ def training_step_end(self, step_output: StepOutputDict) -> StepOutputDict:
"""
return self.shared_step_end(step_output, phase="train")

def validation_step_end(self, step_output: StepOutputDict) -> StepOutputDict:
def validation_step_end[Out: torch.Tensor | StepOutputDict](self, step_output: Out) -> Out:
return self.shared_step_end(step_output, phase="val")

def test_step_end(self, step_output: StepOutputDict) -> StepOutputDict:
def test_step_end[Out: torch.Tensor | StepOutputDict](self, step_output: Out) -> Out:
return self.shared_step_end(step_output, phase="test")

def shared_step_end(self, step_output: StepOutputDict, phase: PhaseStr) -> StepOutputDict:
fused_output = step_output.copy()
loss: Tensor | float | None = step_output.get("loss", None)
def shared_step_end[Out: torch.Tensor | StepOutputDict](
self, step_output: Out, phase: PhaseStr
) -> Out:
"""This is a default implementation for `[train/validation/test]_step_end`.
if isinstance(loss, Tensor) and loss.shape:
This does the following:
- Averages out the `loss` tensor if it was left unreduced.
- the main metrics are logged inside `training_step_end` (supposed to be better for DP/DDP)
"""

if (
isinstance(step_output, dict)
and isinstance((loss := step_output.get("loss")), torch.Tensor)
and loss.shape
):
# Replace the loss with its mean. This is useful when automatic
# optimization is enabled, for example in the example algo, where each replica
# returns the un-reduced cross-entropy loss. Here we need to reduce it to a scalar.
fused_output["loss"] = loss.mean()
fused_output = step_output | {"loss": loss.mean()}

else:
assert isinstance(step_output, torch.Tensor)
loss = step_output
# todo: find out if this was already logged, to not log it twice.
self.log(f"{phase}/loss", torch.as_tensor(loss).mean(), sync_dist=True)
fused_output = step_output

if loss is not None:
# todo: find out if this was already logged, to not log it twice.
Expand Down
9 changes: 1 addition & 8 deletions project/algorithms/manual_optimization_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,7 @@ def validation_step(
def shared_step(
self, batch: tuple[Tensor, Tensor], batch_index: int, phase: PhaseStr
) -> ClassificationOutputs:
"""Performs a training/validation/test step.
This must return a dictionary with at least the 'y' and 'logits' keys, and an optional
`loss` entry. This is so that the training of the model is easier to parallelize the
training across GPUs:
- the cross entropy loss gets calculated using the global batch size
- the main metrics are logged inside `training_step_end` (supposed to be better for DP/DDP)
"""
"""Performs a training/validation/test step."""
x, y = batch
logits = self(x)

Expand Down

0 comments on commit e1b8b2d

Please sign in to comment.