Skip to content

Commit

Permalink
Fix torch.compile on nn.module instead of on LightningModule (#587
Browse files Browse the repository at this point in the history
)
  • Loading branch information
tesfaldet authored Sep 1, 2023
1 parent 429947f commit 2654bad
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 11 deletions.
1 change: 1 addition & 0 deletions configs/experiment/example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ model:
lin1_size: 128
lin2_size: 256
lin3_size: 64
compile: false

data:
batch_size: 64
Expand Down
3 changes: 3 additions & 0 deletions configs/model/mnist.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,6 @@ net:
lin2_size: 128
lin3_size: 64
output_size: 10

# compile model for faster training with pytorch 2.0
compile: false
3 changes: 0 additions & 3 deletions configs/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,6 @@ train: True
# lightning chooses best weights based on the metric specified in checkpoint callback
test: True

# compile model for faster training with pytorch 2.0
compile: False

# simply provide checkpoint path to resume training
ckpt_path: null

Expand Down
1 change: 1 addition & 0 deletions environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ channels:
# compatibility is usually guaranteed

dependencies:
- python=3.10
- pytorch=2.*
- torchvision=0.*
- lightning=2.*
Expand Down
20 changes: 16 additions & 4 deletions src/models/mnist_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
net: torch.nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler,
compile: bool,
) -> None:
"""Initialize a `MNISTLitModule`.
Expand Down Expand Up @@ -176,10 +177,21 @@ def on_test_epoch_end(self) -> None:
"""Lightning hook that is called when a test epoch ends."""
pass

def configure_optimizers(self) -> Dict[str, Any]:
"""Configures optimizers and learning-rate schedulers to be used for training.
def setup(self, stage: str) -> None:
"""Lightning hook that is called at the beginning of fit (train + validate), validate,
test, or predict.
This is a good hook when you need to build models dynamically or adjust something about
them. This hook is called on every process when using DDP.
Normally you'd need one, but in the case of GANs or similar you might need multiple.
:param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
"""
if self.hparams.compile and stage == "fit":
self.net = torch.compile(self.net)

def configure_optimizers(self) -> Dict[str, Any]:
"""Choose what optimizers and learning-rate schedulers to use in your optimization.
Normally you'd need one. But in the case of GANs or similar you might have multiple.
Examples:
https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers
Expand All @@ -202,4 +214,4 @@ def configure_optimizers(self) -> Dict[str, Any]:


if __name__ == "__main__":
_ = MNISTLitModule(None, None, None)
_ = MNISTLitModule(None, None, None, None)
4 changes: 0 additions & 4 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,6 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
log.info("Logging hyperparameters!")
utils.log_hyperparameters(object_dict)

if cfg.get("compile"):
log.info("Compiling model!")
model = torch.compile(model)

if cfg.get("train"):
log.info("Starting training!")
trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
Expand Down

0 comments on commit 2654bad

Please sign in to comment.