Skip to content

v0.11.0

Compare
Choose a tag to compare
@bandish-shah bandish-shah released this 25 Oct 00:36
· 1242 commits to dev since this release

πŸš€ Composer v0.11.0

Composer v0.11.0 is released! Install via pip:

pip install --upgrade mosaicml==0.11.0

New Features

  1. 🧰 FSDP Beta Support

    Composer now supports PyTorch FSDP! PyTorch FSDP is a strategy for distributed training, similar to PyTorch DDP, that distributes work using data-parallelism only. On top of this, FSDP uses model, gradient, and optimizer sharding to dramatically reduce device memory requirements, and enables users to easily scale and train large models.

    Here's how easy it is to use FSDP with Composer:

    import torch.nn as nn
    from composer import Trainer
    
    class Block (nn.Module):
        ...
    
    # Your custom model
    class Model(nn.Module):
        def __init__(self, n_layers):
            super().__init__()
            self.blocks = nn.ModuleList([
                Block(...) for _ in range(n_layers)
            ]),
            self.head = nn.Linear(...)
        def forward(self, inputs):
            ...
    
        # FSDP Wrap Function
        def fsdp_wrap_fn(self, module):
            return isinstance(module, Block)
    
        # Activation Checkpointing Function
        def activation_checkpointing_fn(self, module):
            return isinstance(module, Block)
    
    # ComposerModel wrapper, used by the Trainer
    # to compute loss, metrics, etc.
    class MyComposerModel(ComposerModel):
    
        def __init__(self, n_layers):
            super().__init__()
            self.model = Model(n_layers)
            ...
    
        def forward(self, batch):
            ...
    
        def eval_forward(self, batch, outputs=None):
            ...
    
        def loss(self, outputs, batch):
            ...
    
    # Pass your ComposerModel and fsdp_config into the Trainer
    composer_model = MyComposerModel(n_layers=3)
    fsdp_config = {
        'sharding_strategy': 'FULL_SHARD',
        'min_params': 1e8,
        'cpu_offload': False, # Not supported yet
        'mixed_precision': 'DEFAULT',
        'backward_prefetch': 'BACKWARD_POST',
        'activation_checkpointing': False,
        'activation_cpu_offload': False,
        'verbose': True
    }
    
    trainer = Trainer(
        model=composer_model,
        fsdp_config=fsdp_config,
        ...
    )
    
    trainer.fit()

    For more information, please see our FSDP docs.

  2. 🚰 Streaming v0.1

    We've spun off Streaming datasets into it's own repository! Streaming datasets is a high-performance drop-in for TorchΒ IterableDataset, enabling users to stream training data from cloud based object stores. Streaming is shipping with built-in support for popular open source datasets (ADE20K, C4, COCO, Enwiki, ImageNet, etc.)

    To get started, install the Streaming PyPi package:

    pip install mosaicml-streaming

    You can use the streaming Dataset class with the PyTorch native DataLoader class as follows:

    import torch
    from streaming import Dataset
    
    dataloader = torch.utils.data.DataLoader(dataset=Dataset(remote='s3://...'))

    For more information, please check out the Streaming docs.

  3. βœ”πŸ‘‰ Simplified Checkpointing Interface

    With this release we’ve greatly simplified configuration of loading and saving checkpoints in Composer.

    To save checkpoints to S3, all you need to do is:

    • Specify with save_folder your full URI to your save directory destination (e.g. 's3://my-bucket/{run_name}/checkpoints')
    • Optionally, set save_filename to the pattern you want for your checkpoint file names
    from composer.trainer import Trainer
    
    # Checkpoint saving to S3.
    trainer = Trainer(
        model=model,
        save_folder="s3://my-bucket/{run_name}/checkpoints",
            run_name='my-run',
        save_interval="1ep",
        save_filename="ep{epoch}.pt",
        save_num_checkpoints_to_keep=0,  # delete all checkpoints locally
            ...
    )
    
    trainer.fit()

    Likewise, to load checkpoints from S3, all you have to do is:

    • Set load_path to the full URI to your desired checkpoint file (e.g.'s3://my-bucket/my-run/checkpoints/epoch13.pt')
    from composer.trainer import Trainer
    
    # Checkpoint loading from S3.
    new_trainer = Trainer(
        model=model,
        train_dataloader=train_dataloader,
        max_duration="10ep",
        load_path="s3://my-bucket/my-run/checkpoints/ep13.pt",
       )
    
        new_trainer.fit()

    For more information, please see our Checkpointing guide.

  4. 𐄳 Improved Distributed Experience

    We’ve made it easier to write your own custom distributed entry points by exposing our distributed API. You can now leverage all of our helpful distributed functions and contexts.

    For example, let's say we want to need to download a dataset in a distributed training application. To avoid race conditions where different ranks try to write the dataset to the same place, we need to ensure that only rank 0 downloads the dataset first:

    import datetime
    from composer.trainer.devices import DeviceGPU
    from composer.utils import dist
    
    dist.initialize(DeviceGPU(), datetime.timedelta(seconds=30)) # Initialize distributed module
    
    if dist.get_local_rank() == 0: # Download dataset on rank zero
        dataset = download_my_dataset()
    dist.barrier() # All ranks wait until dataset is downloaded
    
    # Create and train your model!

    For more information, please check out our Distributed API docs.

Bug Fixes

  • fix loss and eval_forward for HF models (#1597)
  • add more robust casting to int for fsdp min_params (#1608)
  • Deepspeed Docs Typo (#1605)
  • Fix mmdet typo (#1618)
  • Blurpool idempotent (#1625)
  • When model is not on meta device, initialization should occur on compute device not CPU (#1623)
  • Auto resumption (#1615)
  • Adjust speed monitor (#1645)
  • Hot fix console logging (#1643)
  • Lazy Logging + pretty print dict for hparams (#1653)
  • Fix many failing notebook tests (#1646)

What's Changed

Full Changelog: v0.10.1...v0.11.0