Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] V1.1.0 #327

Draft
wants to merge 33 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
264730b
Add solarization
Mar 23, 2022
19b2afb
Make cupy kernel use the previous state's device instead of defaultin…
ericjang Jun 11, 2022
84eef0d
single-worker from_iterable_dataset implementation
ericjang Jul 2, 2022
655c7f8
reproduce label bug with a test
Jul 6, 2022
4d54dcb
Allow Pillow resizing for dataset creation
warner-benjamin Mar 3, 2023
b04ec5b
Add Random Cutout & Random Erasing
warner-benjamin Mar 6, 2023
0a0908c
correct traversalorder check
sanagno Mar 24, 2023
2251f18
Allowing nopython and setting class defaults
dendraR2 Mar 29, 2023
38dc51a
http link to submodules
dendraR2 Mar 29, 2023
31b4282
Merge pull request #298 from warner-benjamin/random_cutout_erasing
andrewilyas Mar 29, 2023
21edbf2
type annotation
sanagno Mar 30, 2023
4bbacb3
reverting class attributes change
dendraR2 Mar 31, 2023
c21e36e
Merge pull request #302 from sanagno/main
andrewilyas Mar 31, 2023
38d8b97
Merge pull request #306 from rsmith013/allow_nopython
andrewilyas Mar 31, 2023
d27f3e0
Improve Parameter docs
warner-benjamin Apr 3, 2023
720cc2e
Merge pull request #297 from warner-benjamin/pillow_resize
andrewilyas Apr 5, 2023
3bbe945
Correct CIFAR10_MEAN and CIFAR10_STD in train_cifar.py
epistoteles May 31, 2023
24aa79c
Merge branch 'v1.1.0' into add_solarize
andrewilyas Jun 19, 2023
7ed86a5
Merge pull request #203 from bordesf/add_solarize
andrewilyas Jun 19, 2023
cdab564
Add RandomVerticalFlip
wouterzwerink Jun 28, 2023
43d8cf3
Merge branch 'v1.1.0' into main
andrewilyas Jun 28, 2023
d35f553
Merge pull request #331 from wouterzwerink/main
andrewilyas Jun 28, 2023
c5e40a4
Deepcopy mutable loader args
wouterzwerink Jun 29, 2023
ba96e7d
Add warning for field length
wouterzwerink Jun 29, 2023
4d56ee4
Merge pull request #334 from wouterzwerink/warn-for-field-length
andrewilyas Jun 29, 2023
7c3e270
Merge pull request #333 from wouterzwerink/deepcopy-immutable-loader-…
andrewilyas Jun 29, 2023
53b9d3e
Fix missing docs
wouterzwerink Jul 5, 2023
caccd85
Merge pull request #336 from wouterzwerink/v1.1.0
andrewilyas Jul 5, 2023
3616d1d
Merge pull request #322 from epistoteles/patch-1
andrewilyas Jul 10, 2023
e5f235b
Merge branch 'main' into v1.1.0
andrewilyas May 6, 2024
b68bc61
Merge pull request #224 from ericjang/main
andrewilyas May 6, 2024
9299848
Update __init__.py
andrewilyas May 6, 2024
39b2dd8
Update setup.py
andrewilyas May 6, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[submodule "examples/imagenet-example"]
path = examples/imagenet-example
url = git@github.com:libffcv/ffcv-imagenet.git
url = https://github.com/libffcv/ffcv-imagenet.git
1 change: 1 addition & 0 deletions docs/benchmarks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ We compare our results against existing data loading platforms:
- `Pytorch DataLoader <https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader>`_: This is the default option that comes with the Pytorch library and uses individual JPEG files as the source.
- `Webdataset <https://github.com/webdataset/webdataset>`_: This loader requires pre-processed files aggregated in multiple big `.tar` archives.
- `DALI <https://docs.nvidia.com/deeplearning/dali/user-guide/docs/>`_: Data loading pipeline developed by Nvidia. In this experiment we used the default file format which is the same as that of the Pytorch DataLoader.

The specific instantiation of DALI that we apply is the PyTorch ImageNet example DALI code found in the `NVIDIA DeepLearningExamples repository <https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Classification/ConvNets/resnet50v1.5>`_.
We use the DGX-1 configuration and remove all the model optimization, benchmarking only the dataloader.

Expand Down
4 changes: 2 additions & 2 deletions docs/ffcv_examples/custom_transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ Doing so requires providing implementation for two functions:
# Return the code to run this operation
@abstractmethod
def generate_code(self) -> Callable:
raise NotImplementedError
raise NotImplementedError()

@abstractmethod
def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]:
raise NotImplementedError
raise NotImplementedError()

Advancing state and pre-allocating memory
------------------------------------------
Expand Down
4 changes: 2 additions & 2 deletions examples/cifar/train_cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def make_dataloaders(train_dataset=None, val_dataset=None, batch_size=None, num_
}

start_time = time.time()
CIFAR_MEAN = [125.307, 122.961, 113.8575]
CIFAR_STD = [51.5865, 50.847, 51.255]
CIFAR_MEAN = [125.307, 122.950, 113.865]
CIFAR_STD = [62.993, 62.089, 66.705]
loaders = {}

for name in ['train', 'test']:
Expand Down
2 changes: 1 addition & 1 deletion ffcv/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .loader import Loader
from .writer import DatasetWriter
__version__ = '1.0.2'
__version__ = '1.1.0'

__all__ = ['Loader']
2 changes: 1 addition & 1 deletion ffcv/benchmarks/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ def __init__(self, **kwargs):

@abstractmethod
def run(self):
raise NotImplemented()
raise NotImplementedError()
10 changes: 5 additions & 5 deletions ffcv/fields/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,21 @@ class Field(ABC):
@property
@abstractmethod
def metadata_type(self) -> np.dtype:
raise NotImplemented
raise NotImplementedError()

@staticmethod
@abstractmethod
def from_binary(binary: ARG_TYPE) -> Field:
raise NotImplementedError
raise NotImplementedError()

@abstractmethod
def to_binary(self) -> ARG_TYPE:
raise NotImplementedError
raise NotImplementedError()

@abstractmethod
def encode(field, metadata_destination, malloc):
raise NotImplementedError
raise NotImplementedError()

@abstractmethod
def get_decoder_class(self) -> Type[Operation]:
raise NotImplementedError
raise NotImplementedError()
102 changes: 74 additions & 28 deletions ffcv/fields/rgb_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,20 @@
import cv2
import numpy as np
from numba.typed import Dict
from PIL.Image import Image
from PIL import Image

try:
LANCZOS = Image.Resampling.LANCZOS
except AttributeError:
from PIL.Image import LANCZOS

from .base import Field, ARG_TYPE
from ..pipeline.operation import Operation
from ..pipeline.state import State
from ..pipeline.compiler import Compiler
from ..pipeline.allocation_query import AllocationQuery
from ..libffcv import imdecode, memcpy, resize_crop
from ..utils import pil_to_numpy

if TYPE_CHECKING:
from ..memory_managers.base import MemoryManager
Expand All @@ -34,15 +40,33 @@ def encode_jpeg(numpy_image, quality):
return result.reshape(-1)


def resizer(image, target_resolution):
if target_resolution is None:
return image
original_size = np.array([image.shape[1], image.shape[0]])
ratio = target_resolution / original_size.max()
def resizer(image, max_resolution, min_resolution, interpolation=(cv2.INTER_AREA, LANCZOS)):
pillow_resize = isinstance(image, Image.Image)
if max_resolution is None and min_resolution is None:
return pil_to_numpy(image) if pillow_resize else image

if pillow_resize:
original_size = np.array([image.size[0], image.size[1]])
else:
original_size = np.array([image.shape[1], image.shape[0]])

if max_resolution is not None:
ratio = max_resolution / original_size.max()
elif min_resolution is not None:
ratio = min_resolution / original_size.min()
else:
ratio = 1

if ratio < 1:
new_size = (ratio * original_size).astype(int)
image = cv2.resize(image, tuple(new_size), interpolation=cv2.INTER_AREA)
return image
if pillow_resize:
image = image.resize(new_size, resample=interpolation[1])
else:
image = cv2.resize(image, tuple(new_size), interpolation=interpolation[0])
if pillow_resize:
return pil_to_numpy(image)
else:
return image


def get_random_crop(height, width, scale, ratio):
Expand Down Expand Up @@ -214,7 +238,7 @@ def decode(batch_indices, my_storage, metadata, storage_state):
@property
@abstractmethod
def get_crop_generator():
raise NotImplementedError
raise NotImplementedError()


class RandomResizedCropRGBImageDecoder(ResizedCropRGBImageDecoder):
Expand Down Expand Up @@ -278,25 +302,44 @@ class RGBImageField(Field):
size), and 'proportion' (JPEG compress a random subset of the data with
size specified by the ``compress_probability`` argument). By default: 'raw'.
max_resolution : int, optional
If specified, will resize images to have maximum side length equal to
this value before saving, by default None
If specified, resize images to have maximum side length equal to this
value if maximum side length is larger. By default: None
min_resolution : int, optional
If specified, resize images to have minimum side length equal to this
value if minimum side length is larger. By default: None
smart_threshold : int, optional
When `write_mode='smart`, will compress an image if it would take more than `smart_threshold` times to use RAW instead of jpeg.
When `write_mode='smart`, will compress an image if RAW byte size is
larger than `smart_threshold`.
jpeg_quality : int, optional
The quality parameter for JPEG encoding (ignored for
``write_mode='raw'``), by default 90
The quality parameter for JPEG encoding (ignored for ``write_mode='raw'``).
By default 90
compress_probability : float, optional
Ignored unless ``write_mode='proportion'``; in the latter case it is the
probability with which image is JPEG-compressed, by default 0.5.
probability with which image is JPEG-compressed. By default 0.5.
interpolation : optional
The OpenCV interpolation flag for resizing images with OpenCV.
By default INTER_AREA.
resample : optional
The Pillow resampling filter for resizing images with Pillow. By default LANCZOS.
pillow_resize : bool, optional
Use Pillow to resize images instead of OpenCV. By default False (OpenCV).
"""
def __init__(self, write_mode='raw', max_resolution: int = None,
smart_threshold: int = None, jpeg_quality: int = 90,
compress_probability: float = 0.5) -> None:
min_resolution: int = None, smart_threshold: int = None,
jpeg_quality: int = 90, compress_probability: float = 0.5,
interpolation = cv2.INTER_AREA, resample = LANCZOS,
pillow_resize:bool = False) -> None:
self.write_mode = write_mode
self.smart_threshold = smart_threshold
self.max_resolution = max_resolution
self.min_resolution = min_resolution
self.jpeg_quality = int(jpeg_quality)
self.proportion = compress_probability
self.interpolation = interpolation
self.resample = resample
self.pillow_resize = pillow_resize
if max_resolution is not None and min_resolution is not None:
raise ValueError(f'Can only set one of {max_resolution=} or {min_resolution=}')

@property
def metadata_type(self) -> np.dtype:
Expand All @@ -318,21 +361,24 @@ def to_binary(self) -> ARG_TYPE:
return np.zeros(1, dtype=ARG_TYPE)[0]

def encode(self, destination, image, malloc):
if isinstance(image, Image):
image = np.array(image)

if not isinstance(image, np.ndarray):
if not isinstance(image, np.ndarray) and not isinstance(image, Image.Image):
raise TypeError(f"Unsupported image type {type(image)}")

if image.dtype != np.uint8:
raise ValueError("Image type has to be uint8")
if self.pillow_resize:
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
else:
if isinstance(image, Image.Image):
image = pil_to_numpy(image)

if image.shape[2] != 3:
raise ValueError(f"Invalid shape for rgb image: {image.shape}")
image = resizer(image, self.max_resolution, self.min_resolution,
(self.interpolation, self.resample))

assert image.dtype == np.uint8
if len(image.shape) > 2 and image.shape[2] != 3:
raise ValueError(f"Invalid shape for rgb image: {image.shape}")

image = resizer(image, self.max_resolution)
if image.dtype != np.uint8:
raise ValueError("Image type has to be uint8")

write_mode = self.write_mode
as_jpg = None
Expand Down Expand Up @@ -362,4 +408,4 @@ def encode(self, destination, image, malloc):
destination['data_ptr'], storage = malloc(image.nbytes)
storage[:] = image_bytes
else:
raise ValueError(f"Unsupported write mode {self.write_mode}")
raise ValueError(f"Unsupported write mode {self.write_mode}")
9 changes: 5 additions & 4 deletions ffcv/loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Any, Callable, Mapping, Sequence, Type, Union, Literal
from collections import defaultdict
from collections.abc import Collection
from copy import deepcopy
from enum import Enum, unique, auto

from ffcv.fields.base import Field
Expand Down Expand Up @@ -40,7 +41,7 @@ class OrderOption(Enum):

]

ORDER_MAP: Mapping[ORDER_TYPE, TraversalOrder] = {
ORDER_MAP: Mapping[ORDER_TYPE, Type[TraversalOrder]] = {
OrderOption.RANDOM: Random,
OrderOption.SEQUENTIAL: Sequential,
OrderOption.QUASI_RANDOM: QuasiRandom
Expand Down Expand Up @@ -122,8 +123,8 @@ def __init__(self,
'order': order,
'distributed': distributed,
'seed': seed,
'indices': indices,
'pipelines': pipelines,
'indices': deepcopy(indices),
'pipelines': deepcopy(pipelines),
'drop_last': drop_last,
'batches_ahead': batches_ahead,
'recompile': recompile
Expand Down Expand Up @@ -158,7 +159,7 @@ def __init__(self,
if order in ORDER_MAP:
self.traversal_order: TraversalOrder = ORDER_MAP[order](self)
elif issubclass(order, TraversalOrder):
self.traversal_order: TraversalOrder = order(self, **order_kwargs)
self.traversal_order: TraversalOrder = order(self)
else:
raise ValueError(f"Order {order} is not a supported order type or a subclass of TraversalOrder")

Expand Down
2 changes: 1 addition & 1 deletion ffcv/memory_managers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def schedule_epoch(self, batches: Sequence[Sequence[int]]) -> MemoryContext:

@abstractmethod
def compile_reader(self, address, size) -> Callable:
raise NotImplemented()
raise NotImplementedError()

@property
@abstractmethod
Expand Down
8 changes: 5 additions & 3 deletions ffcv/pipeline/compiler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pdb
from numba import njit, set_num_threads, prange, warnings as nwarnings, get_num_threads
from numba import jit, set_num_threads, prange, warnings as nwarnings, get_num_threads
from numba.core.errors import NumbaPerformanceWarning
from multiprocessing import cpu_count
import torch as ch
Expand All @@ -25,10 +25,12 @@ def compile(cls, code, signature=None):
parallel = False
if hasattr(code, 'is_parallel'):
parallel = code.is_parallel and cls.num_threads > 1

nopython = getattr(code, 'nopython', True)

if cls.is_enabled:
return njit(signature, fastmath=True, nogil=True, error_model='numpy',
parallel=parallel)(code)
return jit(signature, fastmath=True, nogil=nopython, error_model='numpy',
parallel=parallel, nopython=nopython, forceobj=not nopython)(code)
return code

@classmethod
Expand Down
31 changes: 19 additions & 12 deletions ffcv/pipeline/graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from distutils.log import warn
import warnings
import ast
import sys

try:
# Useful for debugging
Expand All @@ -23,11 +23,18 @@
import torch as ch
import numpy as np

# This is the starting state of the pipeline
INITIAL_STATE = State(jit_mode=True,
device=ch.device('cpu'),
dtype=np.dtype('u1'),
shape=None)
if "sphinx" in sys.modules:
# Sphinx fails on jit+gpu assert due to improper initialization of device
INITIAL_STATE = State(jit_mode=False,
device=ch.device('cpu'),
dtype=np.dtype('u1'),
shape=None)
else:
# This is the starting state of the pipeline
INITIAL_STATE = State(jit_mode=True,
device=ch.device('cpu'),
dtype=np.dtype('u1'),
shape=None)


class Node(ABC):
Expand All @@ -40,34 +47,34 @@ def __init__(self):
@property
@abstractmethod
def is_jitted(self):
raise NotImplemented()
raise NotImplementedError()

@property
@abstractmethod
def parent(self):
raise NotImplemented()
raise NotImplementedError()

@property
@abstractmethod
def arg_id(self):
raise NotImplemented()
raise NotImplementedError()

@property
@abstractmethod
def result_id(self):
raise NotImplemented()
raise NotImplementedError()

@property
@abstractmethod
def result_id(self):
raise NotImplemented()
raise NotImplementedError()

def get_shared_code_ast(self, done_ops):
return ast.Pass()

@abstractmethod
def generate_code(self):
raise NotImplemented()
raise NotImplementedError()

def recompile(self):
self._code = self.generate_code()
Expand Down
Loading