Skip to content

Commit

Permalink
Update docs for torch-directml 0.2.2 (#593)
Browse files Browse the repository at this point in the history
* update docs for next torch-directml release

* Minor readme spacing issues

---------

Co-authored-by: Sheil Kumar <[email protected]>
Co-authored-by: Dwayne Robinson <[email protected]>
  • Loading branch information
3 people authored Jun 15, 2024
1 parent 4d65cad commit 372a622
Show file tree
Hide file tree
Showing 32 changed files with 106,715 additions and 634 deletions.
1 change: 1 addition & 0 deletions PyTorch/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ For `torch-directml` samples find brief summaries below or explore the [cv](./cv
* [resnet50 - an image classification model](./cv/resnet50)
* [maskrcnn - an object detection model](./cv/objectDetection/maskrcnn/)
* [llm - a text generation and chatbot app supporting various language models](./llm/)
* [whisper - a general-purpose speech recognition model](./audio/whisper/)

## External Links

Expand Down
21 changes: 21 additions & 0 deletions PyTorch/audio/whisper/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2022 OpenAI

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
94 changes: 94 additions & 0 deletions PyTorch/audio/whisper/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Speech Recognition with Whisper
This sample guides you on how to run OpenAI's automatic speech recognition (ASR) [Whisper model](https://github.com/openai/whisper/blob/main/README.md) with our DirectML-backend.

- [Setup](#setup)
- [About Whisper](#run-the-whisper-model)
- [Basic Settings](#basic-settings)
- [External Links](#external-links)
- [Model License](#model-license)


## About Whisper

The [OpenAI Whisper](https://github.com/openai/whisper/) model is a general-purpose speech recognition model. It is trained on a large dataset of diverse audio and is also a multitasking model that can perform multilingual speech recognition, speech translation, and language identification.

Whisper supports five model sizes, four with English-only versions and all five with multilingual versions.
| Size | Parameters | English-only model | Multilingual model | Required VRAM
|:---------:|:----------:|:------------------:|:------------------:|:-------------:|
| tiny | 39 M | `tiny.en` | `tiny` | ~1 GB |
| base | 74 M | `base.en` | `base` | ~1 GB |
| small | 244 M | `small.en` | `small` | ~2 GB |
| medium | 769 M | `medium.en` | `medium` | ~5 GB |
| large v3 | 1550 M | N/A | `large-v3` | ~10 GB |

For more information on the model, please refer to the [OpenAI Whisper GitHub repo](https://github.com/openai/whisper/).


## Setup
Once you've setup `torch-directml` following our [Windows](https://learn.microsoft.com/en-us/windows/ai/directml/pytorch-windows) and [WSL](https://learn.microsoft.com/en-us/windows/ai/directml/pytorch-wsl) guidance, install the following requirements for running the app:


```
conda install ffmpeg
pip install -r requirements.txt
```


## Run the Whisper model
Run Whisper with DirectML backend with a sample audio file with the following command:
```bash
python run.py --input_file <audio_file> --model_size "tiny.en"
```


For example, you should see the result output as below:
```
> python run.py --input_file test/samples_jfk.wav --model_size "tiny.en"
100%|█████████████████████████████████████| 72.1M/72.1M [00:09<00:00, 7.90MiB/s]
test/samples_jfk.wav
And so my fellow Americans ask not what your country can do for you ask what you can do for your country.
```


Note, by default [OpenAI Whisper](https://github.com/openai/whisper/) uses a naive implementation for the scaled dot product attention. If you want to improve performance further to leverage DirectML's scaled dot product attention, execute `run.py` with `--use_dml_attn` flag:

```bash
python run.py --input_file <audio_file> --model_size "tiny.en" --use_dml_attn
```
Based on this flag `MultiHeadAttention` module in `model.py` would choose between naive whisper scaled dot product attention and DirectML's scaled dot product attention:
```python
if use_dml_attn:
wv, qk = self.dml_sdp_attn(q, k, v, mask, cross_attention=cross_attention)
else:
wv, qk = self.qkv_attention(q, k, v, mask)
```

## Basic Settings

Following is a list of the basic settings supported by `run.py`:



| Flag | Description | Default |
| ---------------------- | ------------------------------------------------------------ | ------- |
| `--help` | Show this help message. | - |
| `--input_file` | [Required] Path to input audio file | - |
| `--model_size` | Size of Whisper model to use. Options: [`tiny.en`, `tiny`, `base.en`, `base`, `small.en`, `small`, `medium.en`, `medium`, `large-v3`] | `tiny.en` |
| `--fp16` | Runs inference with fp16 precision. | True |
| `--use_dml_attn` | Runs inference with DirectML Scaled dot product attention impl. | False |


## External Links
- [Whisper Base Hugging Face Repository](https://huggingface.co/openai/whisper-base.en)
- [Whisper Tiny Hugging Face Repository](https://huggingface.co/openai/whisper-tiny.en)
- [Whisper Small Hugging Face Repository](https://huggingface.co/openai/whisper-small.en)
- [Whisper Medium Hugging Face Repository](https://huggingface.co/openai/whisper-medium.en)
- [Whisper Large v3 Hugging Face Repository](https://huggingface.co/openai/whisper-large-v3)
- [Whisper GitHub Repo](https://github.com/openai/whisper)



## Model License

Whisper's code and model weights are released under the MIT License. See [LICENSE](https://github.com/openai/whisper/blob/main/LICENSE) for further details.
6 changes: 6 additions & 0 deletions PyTorch/audio/whisper/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
numba
numpy
tqdm
more-itertools
tiktoken
ffmpeg-python
45 changes: 45 additions & 0 deletions PyTorch/audio/whisper/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) Microsoft Corporation.
#
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import whisper
import torch_directml
import argparse


def main(args):
device = torch_directml.device(torch_directml.default_device())
model = whisper.load_model(args.model_size, device=device, use_dml_attn=args.use_dml_attn)

# Load audio and pad/trim it to fit 30 seconds
audio = whisper.load_audio(args.input_file)
audio = whisper.pad_or_trim(audio)

n_mels = 80
if args.model_size == "large-v3":
n_mels = 128

mel = whisper.log_mel_spectrogram(audio, n_mels=n_mels).to(model.device)
language = "en"
if "en" not in args.model_size:
_, probs = model.detect_language(mel)
language = max(probs, key=probs.get)
print(f"Detected language: {language}")

options = whisper.DecodingOptions(language=language, fp16=args.fp16)
result = whisper.decode(model, mel, options)

print(result.text)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Run Whisper model on specified audio file with warmup.')
parser.add_argument('--model_size', type=str, default='tiny.en', help='Size of the Whisper model to use.')
parser.add_argument('--input_file', type=str, required=True, help='Path to the input audio file.')
parser.add_argument('--fp16', action="store_true", help='Runs inference with fp16 precision.')
parser.add_argument('--use_dml_attn', action="store_true", help='Use DirectML attention implementation.')
args = parser.parse_args()

main(args)
Binary file added PyTorch/audio/whisper/test/samples_jfk.wav
Binary file not shown.
160 changes: 160 additions & 0 deletions PyTorch/audio/whisper/whisper/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import hashlib
import io
import os
import urllib
import warnings
from typing import List, Optional, Union

import torch
from tqdm import tqdm

from .audio import load_audio, log_mel_spectrogram, pad_or_trim
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
from .model import ModelDimensions, Whisper
from .transcribe import transcribe
# from .version import __version__

_MODELS = {
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
"large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
}

# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
_ALIGNMENT_HEADS = {
"tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
"tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
"base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
"base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
"small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
"small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
"medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
}


def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
os.makedirs(root, exist_ok=True)

expected_sha256 = url.split("/")[-2]
download_target = os.path.join(root, os.path.basename(url))

if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")

if os.path.isfile(download_target):
with open(download_target, "rb") as f:
model_bytes = f.read()
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
return model_bytes if in_memory else download_target
else:
warnings.warn(
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
)

with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(
total=int(source.info().get("Content-Length")),
ncols=80,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break

output.write(buffer)
loop.update(len(buffer))

model_bytes = open(download_target, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
)

return model_bytes if in_memory else download_target


def available_models() -> List[str]:
"""Returns the names of available models"""
return list(_MODELS.keys())


def load_model(
name: str,
device: Optional[Union[str, torch.device]] = None,
download_root: str = None,
in_memory: bool = False,
use_dml_attn: bool = False,
) -> Whisper:
"""
Load a Whisper ASR model
Parameters
----------
name : str
one of the official model names listed by `whisper.available_models()`, or
path to a model checkpoint containing the model dimensions and the model state_dict.
device : Union[str, torch.device]
the PyTorch device to put the model into
download_root: str
path to download the model files; by default, it uses "~/.cache/whisper"
in_memory: bool
whether to preload the model weights into host memory
Returns
-------
model : Whisper
The Whisper ASR model instance
"""

if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if download_root is None:
default = os.path.join(os.path.expanduser("~"), ".cache")
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")

if name in _MODELS:
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
alignment_heads = _ALIGNMENT_HEADS[name]
elif os.path.isfile(name):
checkpoint_file = open(name, "rb").read() if in_memory else name
alignment_heads = None
else:
raise RuntimeError(
f"Model {name} not found; available models = {available_models()}"
)

# with (
# io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
# ) as fp:
# # checkpoint = torch.load(fp, map_location=device)
# checkpoint = torch.load(fp, mmap=True, weights_only=True)
# del checkpoint_file
checkpoint = torch.load(checkpoint_file, mmap=True, weights_only=True)

dims = ModelDimensions(**checkpoint["dims"])
model = Whisper(dims, use_dml_attn=use_dml_attn)


model.load_state_dict(checkpoint["model_state_dict"])
if alignment_heads is not None:
model.set_alignment_heads(alignment_heads)

return model.to(device)
Loading

0 comments on commit 372a622

Please sign in to comment.