From 264730b476fd7e80799acfc48822eae176b5808c Mon Sep 17 00:00:00 2001 From: Florian Bordes Date: Wed, 23 Mar 2022 12:24:56 -0700 Subject: [PATCH 01/19] Add solarization --- ffcv/transforms/__init__.py | 3 ++- ffcv/transforms/solarization.py | 48 +++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 1 deletion(-) create mode 100644 ffcv/transforms/solarization.py diff --git a/ffcv/transforms/__init__.py b/ffcv/transforms/__init__.py index bc8fa321..da99eb33 100644 --- a/ffcv/transforms/__init__.py +++ b/ffcv/transforms/__init__.py @@ -9,6 +9,7 @@ from .translate import RandomTranslate from .mixup import ImageMixup, LabelMixup, MixupToOneHot from .module import ModuleWrapper +from .solarization import Solarization __all__ = ['ToTensor', 'ToDevice', 'ToTorchImage', 'NormalizeImage', @@ -16,4 +17,4 @@ 'RandomResizedCrop', 'RandomHorizontalFlip', 'RandomTranslate', 'Cutout', 'ImageMixup', 'LabelMixup', 'MixupToOneHot', 'Poison', 'ReplaceLabel', - 'ModuleWrapper'] \ No newline at end of file + 'ModuleWrapper', 'Solarization'] \ No newline at end of file diff --git a/ffcv/transforms/solarization.py b/ffcv/transforms/solarization.py new file mode 100644 index 00000000..15d657f6 --- /dev/null +++ b/ffcv/transforms/solarization.py @@ -0,0 +1,48 @@ +""" +Random Solarization +""" +from numpy.random import rand +from typing import Callable, Optional, Tuple +from ..pipeline.allocation_query import AllocationQuery +from ..pipeline.operation import Operation +from ..pipeline.state import State +from ..pipeline.compiler import Compiler + +class Solarization(Operation): + """Solarize the image randomly with a given probability by inverting all pixel + values above a threshold. If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format, + where ... means it can have an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + + Parameters + ---------- + solarization_prob (float): probability of the image being solarized. Default value is 0.5 + threshold (float): all pixels equal or above this value are inverted. + """ + + def __init__(self, solarization_prob: float = 0.5, threshold: float = 128): + super().__init__() + self.solarization_prob = solarization_prob + self.threshold = threshold + + def generate_code(self) -> Callable: + my_range = Compiler.get_iterator() + solarization_prob = self.solarization_prob + threshold = self.threshold + + def solarize(images, dst): + should_solarize = rand(images.shape[0]) < solarization_prob + for i in my_range(images.shape[0]): + if should_solarize[i]: + mask = (images[i] >= threshold) + dst[i] = images[i] * (1-mask) + (255 - images[i])*mask + else: + dst[i] = images[i] + return dst + + solarize.is_parallel = True + return solarize + + def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]: + assert previous_state.jit_mode + return (previous_state, AllocationQuery(previous_state.shape, previous_state.dtype)) From 19b2afbce0905d96c4258e1252bea127d9e3662c Mon Sep 17 00:00:00 2001 From: Eric Date: Sat, 11 Jun 2022 13:41:48 -0700 Subject: [PATCH 02/19] Make cupy kernel use the previous state's device instead of defaulting to cuda:0. --- ffcv/transforms/normalize.py | 36 +++++++++++++++++++----------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/ffcv/transforms/normalize.py b/ffcv/transforms/normalize.py index a04f45e7..720c1cb1 100644 --- a/ffcv/transforms/normalize.py +++ b/ffcv/transforms/normalize.py @@ -49,6 +49,7 @@ def __init__(self, mean: np.ndarray, std: np.ndarray, self.lookup_table = table self.previous_shape = None self.mode = 'cpu' + self.gpu_device_int = 0 def generate_code(self) -> Callable: if self.mode == 'cpu': @@ -62,27 +63,29 @@ def generate_code_gpu(self) -> Callable: import pytorch_pfn_extras as ppe tn = np.zeros((), dtype=self.dtype).dtype.name - kernel = cp.ElementwiseKernel(f'uint8 input, raw {tn} table', f'{tn} output', 'output = table[input * 3 + i % 3];') - final_type = ch_dtype_from_numpy(self.original_dtype) + with cp.cuda.Device(self.gpu_device_int): + kernel = cp.ElementwiseKernel(f'uint8 input, raw {tn} table', f'{tn} output', 'output = table[input * 3 + i % 3];') + final_type = ch_dtype_from_numpy(self.original_dtype) s = self def normalize_convert(images, result): - B, C, H, W = images.shape - table = self.lookup_table.view(-1) - assert images.is_contiguous(memory_format=ch.channels_last), 'Images need to be in channel last' - result = result[:B] - result_c = result.view(-1) - images = images.permute(0, 2, 3, 1).view(-1) + with cp.cuda.Device(self.gpu_device_int): + B, C, H, W = images.shape + table = self.lookup_table.view(-1) + assert images.is_contiguous(memory_format=ch.channels_last), 'Images need to be in channel last' + result = result[:B] + result_c = result.view(-1) + images = images.permute(0, 2, 3, 1).view(-1) - current_stream = ch.cuda.current_stream() - with ppe.cuda.stream(current_stream): - kernel(images, table, result_c) + current_stream = ch.cuda.current_stream() + with ppe.cuda.stream(current_stream): + kernel(images, table, result_c) - # Mark the result as channel last - final_result = result.reshape(B, H, W, C).permute(0, 3, 1, 2) + # Mark the result as channel last + final_result = result.reshape(B, H, W, C).permute(0, 3, 1, 2) - assert final_result.is_contiguous(memory_format=ch.channels_last), 'Images need to be in channel last' + assert final_result.is_contiguous(memory_format=ch.channels_last), 'Images need to be in channel last' - return final_result.view(final_type) + return final_result.view(final_type) return normalize_convert @@ -123,8 +126,7 @@ def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Option new_state = replace(previous_state, dtype=self.dtype) gpu_type = ch_dtype_from_numpy(self.dtype) - - + self.gpu_device_int = previous_state.device.index # Copy the lookup table into the proper device try: self.lookup_table = ch.from_numpy(self.lookup_table) From 84eef0df95b1e5d87da4dbd979aec7756dc66f22 Mon Sep 17 00:00:00 2001 From: Eric Date: Fri, 1 Jul 2022 22:29:49 -0700 Subject: [PATCH 03/19] single-worker from_iterable_dataset implementation --- ffcv/writer.py | 37 +++++++++++++++++++++++++++++++++++-- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/ffcv/writer.py b/ffcv/writer.py index a783d954..fbfaafaf 100644 --- a/ffcv/writer.py +++ b/ffcv/writer.py @@ -45,7 +45,8 @@ def handle_sample(sample, dest_ix, field_names, metadata, allocator, fields): allocator.set_current_sample(dest_ix) # We extract the sample in question from the dataset # We write each field individually to the metadata region - for field_name, field, field_value in zip(field_names, fields.values(), sample): + # TODO: this is brittle and leads to silent failures if lengths are mismatched. + for field_name, field, field_value in zip(field_names, fields.values(), sample): destination = metadata[field_name][dest_ix: dest_ix + 1] field.encode(destination, field_value, allocator.malloc) # We managed to write all the data without reaching @@ -118,6 +119,24 @@ def worker_job_indexed_dataset(input_queue, metadata_sm, metadata_type, fields, allocations_queue.put(allocator.allocations) +def worker_job_iterable_dataset(input_queue, metadata_sm, metadata_type, fields, + allocator, done_number, allocations_queue, dataset): + metadata = np.frombuffer(metadata_sm.buf, dtype=metadata_type) + field_names = metadata_type.names + + # This `with` block ensures that all the pages allocated have been written + # onto the file + with allocator: + # pop the solo dummy chunk off the queue. + chunk = input_queue.get() + del chunk + for dest_ix, sample in enumerate(dataset): + handle_sample(sample, dest_ix, field_names, metadata, allocator, fields) + with done_number.get_lock(): + done_number.value = 1 + allocations_queue.put(allocator.allocations) + + class DatasetWriter(): """Writes given dataset into FFCV format (.beton). Supports indexable objects (e.g., PyTorch Datasets) and webdataset. @@ -256,7 +275,6 @@ def _write_common(self, num_samples, queue_content, work_fn, extra_worker_args): for p in processes: content = allocations_queue.get() allocation_list.extend(content) - self.finalize(allocation_list) self.metadata_sm.close() self.metadata_sm.unlink() @@ -295,6 +313,21 @@ def from_indexed_dataset(self, dataset, worker_job_indexed_dataset, (dataset, )) + def from_iterable_dataset(self, dataset): + """Read dataset from iterable dataset and write to .beton. + + Shuffled indices not allowed, and for simplicity we only use + one worker. + """ + # this will create one chunk that is unused by the worker job. + # we have to make sure the done number is updated to match the + # num_samples=len(indices) + indices = [[0, 0]] + chunksize = 100 + self._write_common(len(indices), chunks(indices, chunksize), + worker_job_iterable_dataset, (dataset, )) + + def from_webdataset(self, shards: List[str], pipeline: Callable): """Read from webdataset-like format. See https://docs.ffcv.io/writing_datasets.html#webdataset for sample usage. From 655c7f8f7963343a2d6482c371ef78a4c5df58d6 Mon Sep 17 00:00:00 2001 From: Eric Jang Date: Tue, 5 Jul 2022 23:27:16 -0700 Subject: [PATCH 04/19] reproduce label bug with a test --- tests/test_image_ndarray_pipeline.py | 89 ++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 tests/test_image_ndarray_pipeline.py diff --git a/tests/test_image_ndarray_pipeline.py b/tests/test_image_ndarray_pipeline.py new file mode 100644 index 00000000..a9a0afcb --- /dev/null +++ b/tests/test_image_ndarray_pipeline.py @@ -0,0 +1,89 @@ +import numpy as np +import torch as ch +from torch.utils.data import Dataset +from assertpy import assert_that +from tempfile import NamedTemporaryFile +from torchvision.datasets import CIFAR10 +from torch.utils.data import Subset +import torch +from ffcv.writer import DatasetWriter +from ffcv.fields import IntField, RGBImageField, NDArrayField, FloatField +from ffcv.loader import Loader +from ffcv.pipeline.compiler import Compiler + +from ffcv.fields.ndarray import NDArrayDecoder +from ffcv.fields.decoders import SimpleRGBImageDecoder +from ffcv.transforms import ToTensor, ToDevice + + +class DummyDataset(Dataset): + + def __init__(self, length, label_dtype, height, width): + self.length = length + self.height = height + self.width = width + self.label_dtype = label_dtype + + def __len__(self): + return self.length + + def __getitem__(self, index): + if index > self.length: + raise IndexError + dims = (self.height, self.width, 3) + image_data = ((np.ones(dims) * index) % 255).astype('uint8') + if self.label_dtype == np.ndarray: + label = np.ones(2, dtype=np.float32) * index + else: + label = index + result = image_data, label + return result + +def create_and_validate_ndarray(length, dtype, mode='raw'): + + dataset = DummyDataset(length=length, label_dtype=np.ndarray, height=500, width=300) + + with NamedTemporaryFile() as handle: + name = handle.name + fields = { + 'value': RGBImageField(write_mode=mode, jpeg_quality=95), + 'label': NDArrayField(shape=(2,), dtype=np.dtype('float32')), + } + writer = DatasetWriter(name, fields, num_workers=4) + writer.from_indexed_dataset(dataset, chunksize=5) + Compiler.set_enabled(False) + loader = Loader(name, batch_size=5, num_workers=2) + labels = [] + for images, label in loader: + labels.append(label[:, 0]) + expected = np.arange(length).astype(np.float32) + labels = torch.concat(labels).numpy() + np.testing.assert_array_equal(expected, labels) + +def create_and_validate_int(length, dtype, mode='raw'): + + dataset = DummyDataset(length=length, label_dtype=int, height=500, width=300) + + with NamedTemporaryFile() as handle: + name = handle.name + fields = { + 'value': RGBImageField(write_mode=mode, jpeg_quality=95), + 'label': IntField(), + } + writer = DatasetWriter(name, fields, num_workers=4) + writer.from_indexed_dataset(dataset, chunksize=5) + Compiler.set_enabled(False) + loader = Loader(name, batch_size=5, num_workers=2) + labels = [] + for images, label in loader: + labels.append(label[:, 0]) + expected = np.arange(length).astype(np.float32) + labels = torch.concat(labels).numpy().astype(np.float32) + np.testing.assert_array_equal(expected, labels) + +def test_simple_jpg_image_pipeline_ndarray(): + create_and_validate_ndarray(100, 'jpg') + +def test_simple_jpg_image_pipeline_int(): + create_and_validate_int(100, 'jpg') + From 4d54dcb491900e996c1e69be8318ccefa0a96b90 Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Fri, 3 Mar 2023 17:01:56 -0600 Subject: [PATCH 05/19] Allow Pillow resizing for dataset creation --- ffcv/fields/rgb_image.py | 89 +++++++++++++++++++++++++++--------- ffcv/utils.py | 28 +++++++++++- tests/test_image_pipeline.py | 46 +++++++++++-------- 3 files changed, 119 insertions(+), 44 deletions(-) diff --git a/ffcv/fields/rgb_image.py b/ffcv/fields/rgb_image.py index b6420f11..53f9ef4c 100644 --- a/ffcv/fields/rgb_image.py +++ b/ffcv/fields/rgb_image.py @@ -5,7 +5,12 @@ import cv2 import numpy as np from numba.typed import Dict -from PIL.Image import Image +from PIL import Image + +try: + LANCZOS = Image.Resampling.LANCZOS +except AttributeError: + from PIL.Image import LANCZOS from .base import Field, ARG_TYPE from ..pipeline.operation import Operation @@ -13,6 +18,7 @@ from ..pipeline.compiler import Compiler from ..pipeline.allocation_query import AllocationQuery from ..libffcv import imdecode, memcpy, resize_crop +from ..utils import pil_to_numpy if TYPE_CHECKING: from ..memory_managers.base import MemoryManager @@ -34,15 +40,33 @@ def encode_jpeg(numpy_image, quality): return result.reshape(-1) -def resizer(image, target_resolution): - if target_resolution is None: - return image - original_size = np.array([image.shape[1], image.shape[0]]) - ratio = target_resolution / original_size.max() +def resizer(image, max_resolution, min_resolution, interpolation=(cv2.INTER_AREA, LANCZOS)): + pillow_resize = isinstance(image, Image.Image) + if max_resolution is None and min_resolution is None: + return pil_to_numpy(image) if pillow_resize else image + + if pillow_resize: + original_size = np.array([image.size[0], image.size[1]]) + else: + original_size = np.array([image.shape[1], image.shape[0]]) + + if max_resolution is not None: + ratio = max_resolution / original_size.max() + elif min_resolution is not None: + ratio = min_resolution / original_size.min() + else: + ratio = 1 + if ratio < 1: new_size = (ratio * original_size).astype(int) - image = cv2.resize(image, tuple(new_size), interpolation=cv2.INTER_AREA) - return image + if pillow_resize: + image = image.resize(new_size, resample=interpolation[1]) + else: + image = cv2.resize(image, tuple(new_size), interpolation=interpolation[0]) + if pillow_resize: + return pil_to_numpy(image) + else: + return image def get_random_crop(height, width, scale, ratio): @@ -280,23 +304,41 @@ class RGBImageField(Field): max_resolution : int, optional If specified, will resize images to have maximum side length equal to this value before saving, by default None + min_resolution : int, optional + If specified, will resize images to have minimum side length equal to + this value before saving, by default None smart_threshold : int, optional - When `write_mode='smart`, will compress an image if it would take more than `smart_threshold` times to use RAW instead of jpeg. + When `write_mode='smart`, will compress an image if RAW byte size is + larger than `smart_threshold`. jpeg_quality : int, optional The quality parameter for JPEG encoding (ignored for ``write_mode='raw'``), by default 90 compress_probability : float, optional Ignored unless ``write_mode='proportion'``; in the latter case it is the probability with which image is JPEG-compressed, by default 0.5. + interpolation : optional + OpenCV interpolation flag for resizing images with OpenCV, by default INTER_AREA. + resample : optional + Pillow resampling filter for resizing images with Pillow, by default LANCZOS. + pillow_resize : bool, optional + Use Pillow to resize images instead of OpenCV, by default False (OpenCV). """ def __init__(self, write_mode='raw', max_resolution: int = None, - smart_threshold: int = None, jpeg_quality: int = 90, - compress_probability: float = 0.5) -> None: + min_resolution: int = None, smart_threshold: int = None, + jpeg_quality: int = 90, compress_probability: float = 0.5, + interpolation = cv2.INTER_AREA, resample = LANCZOS, + pillow_resize:bool = False) -> None: self.write_mode = write_mode self.smart_threshold = smart_threshold self.max_resolution = max_resolution + self.min_resolution = min_resolution self.jpeg_quality = int(jpeg_quality) self.proportion = compress_probability + self.interpolation = interpolation + self.resample = resample + self.pillow_resize = pillow_resize + if max_resolution is not None and min_resolution is not None: + raise ValueError(f'Can only set one of {max_resolution=} or {min_resolution=}') @property def metadata_type(self) -> np.dtype: @@ -318,21 +360,24 @@ def to_binary(self) -> ARG_TYPE: return np.zeros(1, dtype=ARG_TYPE)[0] def encode(self, destination, image, malloc): - if isinstance(image, Image): - image = np.array(image) - - if not isinstance(image, np.ndarray): + if not isinstance(image, np.ndarray) and not isinstance(image, Image.Image): raise TypeError(f"Unsupported image type {type(image)}") - if image.dtype != np.uint8: - raise ValueError("Image type has to be uint8") + if self.pillow_resize: + if isinstance(image, np.ndarray): + image = Image.fromarray(image) + else: + if isinstance(image, Image.Image): + image = pil_to_numpy(image) - if image.shape[2] != 3: - raise ValueError(f"Invalid shape for rgb image: {image.shape}") + image = resizer(image, self.max_resolution, self.min_resolution, + (self.interpolation, self.resample)) - assert image.dtype == np.uint8 + if len(image.shape) > 2 and image.shape[2] != 3: + raise ValueError(f"Invalid shape for rgb image: {image.shape}") - image = resizer(image, self.max_resolution) + if image.dtype != np.uint8: + raise ValueError("Image type has to be uint8") write_mode = self.write_mode as_jpg = None @@ -362,4 +407,4 @@ def encode(self, destination, image, malloc): destination['data_ptr'], storage = malloc(image.nbytes) storage[:] = image_bytes else: - raise ValueError(f"Unsupported write mode {self.write_mode}") + raise ValueError(f"Unsupported write mode {self.write_mode}") \ No newline at end of file diff --git a/ffcv/utils.py b/ffcv/utils.py index 51bb716f..bfafe4fb 100644 --- a/ffcv/utils.py +++ b/ffcv/utils.py @@ -1,6 +1,7 @@ import numpy as np from numba import types from numba.extending import intrinsic +import PIL.Image as Image def chunks(lst, n): @@ -34,7 +35,7 @@ def codegen(context, builder, signature, args): llrtype = context.get_value_type(rtype) return builder.inttoptr(src, llrtype) return sig, codegen - + from threading import Lock s_print_lock = Lock() @@ -43,4 +44,27 @@ def s_print(*a, **b): """Thread safe print function""" with s_print_lock: print(*a, **b) - \ No newline at end of file + + +# From https://uploadcare.com/blog/fast-import-of-pillow-images-to-numpy-opencv-arrays/ +# Up to 2.5 times faster with the same functionality and a smaller number of allocations than numpy.asarray(img) +def pil_to_numpy(img:Image.Image) -> np.ndarray: + "Fast conversion of Pillow `Image` to NumPy NDArray" + img.load() + # unpack data + enc = Image._getencoder(img.mode, 'raw', img.mode) + enc.setimage(img.im) + + # NumPy buffer for the result + shape, typestr = Image._conv_type_shape(img) + data = np.empty(shape, dtype=np.dtype(typestr)) + mem = data.data.cast('B', (data.data.nbytes,)) + + bufsize, s, offset = 65536, 0, 0 + while not s: + l, s, d = enc.encode(bufsize) + mem[offset:offset + len(d)] = d + offset += len(d) + if s < 0: + raise RuntimeError("encoder error %d in tobytes" % s) + return data \ No newline at end of file diff --git a/tests/test_image_pipeline.py b/tests/test_image_pipeline.py index 289ba638..5bb9580f 100644 --- a/tests/test_image_pipeline.py +++ b/tests/test_image_pipeline.py @@ -18,7 +18,7 @@ def __init__(self, length, height, width, reversed=False): self.height = height self.width = width self.reversed = reversed - + def __len__(self): return self.length @@ -32,18 +32,22 @@ def __getitem__(self, index): result = tuple(reversed(result)) return result -def create_and_validate(length, mode='raw', reversed=False): +def create_and_validate(length, mode='raw', reversed=False, max_resolution=None, + min_resolution=None, pillow_resize=False): dataset = DummyDataset(length, 500, 300, reversed=reversed) with NamedTemporaryFile() as handle: name = handle.name - + fields = { 'index': IntField(), - 'value': RGBImageField(write_mode=mode, jpeg_quality=95) + 'value': RGBImageField(write_mode=mode, jpeg_quality=95, + max_resolution=max_resolution, + min_resolution=min_resolution, + pillow_resize=pillow_resize) } - + if reversed: fields = { 'value': RGBImageField(write_mode=mode, jpeg_quality=95), @@ -51,13 +55,10 @@ def create_and_validate(length, mode='raw', reversed=False): } writer = DatasetWriter(name, fields, num_workers=2) - writer.from_indexed_dataset(dataset, chunksize=5) - Compiler.set_enabled(False) - loader = Loader(name, batch_size=5, num_workers=2) - + for res in loader: if not reversed: index, images = res @@ -65,28 +66,27 @@ def create_and_validate(length, mode='raw', reversed=False): images , index = res for i, image in zip(index, images): - if mode == 'raw': - assert_that(ch.all((image == (i % 255)).reshape(-1))).is_true() - else: - assert_that(ch.all((image == (i % 255)).reshape(-1))).is_true() - + assert_that(ch.all((image == (i % 255)).reshape(-1))).is_true() + if max_resolution is not None: + assert_that(image.shape[0] == max_resolution).is_true() + if min_resolution is not None: + assert_that(image.shape[1] == min_resolution).is_true() + def make_and_read_cifar_subset(length): my_dataset = Subset(CIFAR10(root='/tmp', train=True, download=True), range(length)) with NamedTemporaryFile() as handle: name = handle.name writer = DatasetWriter(name, { - 'image': RGBImageField(write_mode='smart', - max_resolution=32), + 'image': RGBImageField(write_mode='smart', + max_resolution=32), 'label': IntField(), }, num_workers=2) writer.from_indexed_dataset(my_dataset, chunksize=10) - Compiler.set_enabled(False) - loader = Loader(name, batch_size=5, num_workers=2) - + for index, images in loader: pass @@ -99,8 +99,14 @@ def test_simple_raw_image_pipeline(): def test_simple_raw_image_pipeline_rev(): create_and_validate(500, 'raw', True) +def test_simple_raw_image_pipeline_max(): + create_and_validate(500, 'raw', False, max_resolution=400) + +def test_simple_raw_image_pipeline_min(): + create_and_validate(500, 'raw', False, min_resolution=200, pillow_resize=True) + def test_simple_jpg_image_pipeline(): create_and_validate(500, 'jpg', False) def test_simple_jpg_image_pipeline_rev(): - create_and_validate(500, 'jpg', True) + create_and_validate(500, 'jpg', True) \ No newline at end of file From b04ec5b5d246edc1a8536f66a79e7fe5c47189db Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Mon, 6 Mar 2023 15:37:29 -0600 Subject: [PATCH 06/19] Add Random Cutout & Random Erasing --- ffcv/transforms/__init__.py | 6 ++- ffcv/transforms/cutout.py | 42 +++++++++++++++ ffcv/transforms/erasing.py | 105 ++++++++++++++++++++++++++++++++++++ tests/test_augmentations.py | 39 ++++++++++++-- 4 files changed, 185 insertions(+), 7 deletions(-) create mode 100644 ffcv/transforms/erasing.py diff --git a/ffcv/transforms/__init__.py b/ffcv/transforms/__init__.py index 2636a447..7406522e 100644 --- a/ffcv/transforms/__init__.py +++ b/ffcv/transforms/__init__.py @@ -1,4 +1,4 @@ -from .cutout import Cutout +from .cutout import Cutout, RandomCutout from .flip import RandomHorizontalFlip from .ops import ToTensor, ToDevice, ToTorchImage, Convert, View from .common import Squeeze @@ -10,12 +10,14 @@ from .mixup import ImageMixup, LabelMixup, MixupToOneHot from .module import ModuleWrapper from .color_jitter import RandomBrightness, RandomContrast, RandomSaturation +from .erasing import RandomErasing __all__ = ['ToTensor', 'ToDevice', 'ToTorchImage', 'NormalizeImage', 'Convert', 'Squeeze', 'View', 'RandomResizedCrop', 'RandomHorizontalFlip', 'RandomTranslate', - 'Cutout', 'ImageMixup', 'LabelMixup', 'MixupToOneHot', + 'Cutout', 'RandomCutout', 'RandomErasing', + 'ImageMixup', 'LabelMixup', 'MixupToOneHot', 'Poison', 'ReplaceLabel', 'ModuleWrapper', 'RandomBrightness', 'RandomContrast', 'RandomSaturation'] diff --git a/ffcv/transforms/cutout.py b/ffcv/transforms/cutout.py index 89237e0e..033c6b5a 100644 --- a/ffcv/transforms/cutout.py +++ b/ffcv/transforms/cutout.py @@ -2,6 +2,7 @@ Cutout augmentation (https://arxiv.org/abs/1708.04552) """ import numpy as np +from numpy.random import rand from typing import Callable, Optional, Tuple from dataclasses import replace @@ -50,3 +51,44 @@ def cutout_square(images, *_): def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]: return replace(previous_state, jit_mode=True), None + + +class RandomCutout(Cutout): + """Random cutout data augmentation (https://arxiv.org/abs/1708.04552). + + Parameters + ---------- + prob : float + Probability of applying on each image. + crop_size : int + Size of the random square to cut out. + fill : Tuple[int, int, int], optional + An RGB color ((0, 0, 0) by default) to fill the cutout square with. + Useful for when a normalization layer follows cutout, in which case + you can set the fill such that the square is zero post-normalization. + """ + def __init__(self, prob: float, crop_size: int, fill: Tuple[int, int, int] = (0, 0, 0)): + super().__init__(crop_size, fill) + self.prob = np.clip(prob, 0., 1.) + + def generate_code(self) -> Callable: + my_range = Compiler.get_iterator() + crop_size = self.crop_size + fill = self.fill + prob = self.prob + + def cutout_square(images, *_): + should_cutout = rand(images.shape[0]) < prob + for i in my_range(images.shape[0]): + if should_cutout[i]: + # Generate random origin + coord = ( + np.random.randint(images.shape[1] - crop_size + 1), + np.random.randint(images.shape[2] - crop_size + 1), + ) + # Black out image in-place + images[i, coord[0]:coord[0] + crop_size, coord[1]:coord[1] + crop_size] = fill + return images + + cutout_square.is_parallel = True + return cutout_square \ No newline at end of file diff --git a/ffcv/transforms/erasing.py b/ffcv/transforms/erasing.py new file mode 100644 index 00000000..d5dcd08f --- /dev/null +++ b/ffcv/transforms/erasing.py @@ -0,0 +1,105 @@ +""" +Random Erasing augmentation (https://arxiv.org/abs/1708.04896) +""" + +# Implementation inspired by fastai https://docs.fast.ai/vision.augment.html#randomerasing +# fastai - Apache License 2.0 - Copyright (c) 2023 fast.ai + +import math +import numpy as np +from numpy.random import rand +from typing import Callable, Optional, Tuple +from dataclasses import replace + +from ..pipeline.compiler import Compiler +from ..pipeline.allocation_query import AllocationQuery +from ..pipeline.operation import Operation +from ..pipeline.state import State + + +class RandomErasing(Operation): + """Random erasing data augmentation (https://arxiv.org/abs/1708.04896). + + Parameters + ---------- + prob : float + Probability of applying on each image. + min_area : float + Minimum erased area as percentage of image size. + max_area : float + Maximum erased area as percentage of image size. + min_aspect : float + Minimum aspect ratio of erased area. + max_count : int + Maximum number of erased blocks per image. Erased Area is scaled by max_count. + fill_mean : Tuple[int, int, int], optional + The RGB color mean (ImageNet's (124, 116, 103) by default) to randomly fill the + erased area with. Should be the mean of dataset or pretrained dataset. + fill_std : Tuple[int, int, int], optional + The RGB color standard deviation (ImageNet's (58, 57, 57) by default) to randomly + fill the erased area with. Should be the st. dev of dataset or pretrained dataset. + fast_fill : bool + Default of True is ~2X faster by generating noise once per batch and randomly + selecting slices of the noise instead of generating unique noise per each image. + """ + def __init__(self, prob: float, min_area: float = 0.02, max_area: float = 0.3, + min_aspect: float = 0.3, max_count: int = 1, + fill_mean: Tuple[int, int, int] = (124, 116, 103), + fill_std: Tuple[int, int, int] = (58, 57, 57), + fast_fill : bool = True): + super().__init__() + self.prob = np.clip(prob, 0., 1.) + self.min_area = np.clip(min_area, 0., 1.) + self.max_area = np.clip(max_area, 0., 1.) + self.log_ratio = (math.log(np.clip(min_aspect, 0., 1.)), math.log(1/np.clip(min_aspect, 0., 1.))) + self.max_count = max_count + self.fill_mean = np.array(fill_mean) + self.fill_std = np.array(fill_std) + self.fast_fill = fast_fill + + def generate_code(self) -> Callable: + my_range = Compiler.get_iterator() + prob = self.prob + min_area = self.min_area + max_area = self.max_area + log_ratio = self.log_ratio + max_count = self.max_count + fill_mean = self.fill_mean + fill_std = self.fill_std + fast_fill = self.fast_fill + + def random_erase(images, *_): + if fast_fill: + noise = fill_mean + (fill_std * np.random.randn(images.shape[1], images.shape[2], images.shape[3])).astype(images.dtype) + + should_cutout = rand(images.shape[0]) < prob + for i in my_range(images.shape[0]): + if should_cutout[i]: + count = np.random.randint(1, max_count) if max_count > 1 else 1 + for j in range(count): + # Randomly select bounds + area = np.random.uniform(min_area, max_area, 1) * images.shape[1] * images.shape[2] / count + aspect = np.exp(np.random.uniform(log_ratio[0], log_ratio[1], 1)) + bound = ( + int(round(np.sqrt(area * aspect).item())), + int(round(np.sqrt(area / aspect).item())), + ) + # Select random erased area + coord = ( + np.random.randint(0, max(1, images.shape[1] - bound[0])), + np.random.randint(0, max(1, images.shape[2] - bound[1])), + ) + # Fill image with random noise in-place + if fast_fill: + images[i, coord[0]:coord[0] + bound[0], coord[1]:coord[1] + bound[1]] =\ + noise[coord[0]:coord[0] + bound[0], coord[1]:coord[1] + bound[1]] + else: + noise = fill_mean + (fill_std * np.random.randn(bound[0], bound[1], images.shape[3])).astype(images.dtype) + images[i, coord[0]:coord[0] + bound[0], coord[1]:coord[1] + bound[1]] = noise + return images + + random_erase.is_parallel = True + return random_erase + + def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]: + return replace(previous_state, jit_mode=True), None \ No newline at end of file diff --git a/tests/test_augmentations.py b/tests/test_augmentations.py index 01f502b6..b01c9700 100644 --- a/tests/test_augmentations.py +++ b/tests/test_augmentations.py @@ -35,7 +35,7 @@ def run_test(length, pipeline, should_compile=False, aug_name=''): with NamedTemporaryFile() as handle: name = handle.name writer = DatasetWriter(name, { - 'image': RGBImageField(write_mode='smart', + 'image': RGBImageField(write_mode='smart', max_resolution=32), 'label': IntField(), }, num_workers=2) @@ -60,12 +60,12 @@ def run_test(length, pipeline, should_compile=False, aug_name=''): for it_num, ((images, labels), (original_images, original_labels)) in enumerate(zip(loader, unaugmented_loader)): tot_indices += labels.shape[0] tot_images += images.shape[0] - + for label, original_label in zip(labels, original_labels): assert_that(label).is_equal_to(original_label) - + if SAVE_IMAGES and it_num == 0: - save_image(make_grid(ch.concat([images, original_images])/255., images.shape[0]), + save_image(make_grid(ch.concat([images, original_images])/255., images.shape[0]), os.path.join(IMAGES_TMP_PATH, aug_name + '-' + str(uuid.uuid4()) + '.jpeg')) assert_that(tot_indices).is_equal_to(len(my_dataset)) @@ -80,6 +80,35 @@ def test_cutout(): ToTorchImage() ], comp, 'cutout') +def test_random_cutout(): + for comp in [True, False]: + run_test(100, [ + SimpleRGBImageDecoder(), + RandomCutout(0.75, 8), + ToTensor(), + ToTorchImage() + ], comp, 'random_cutout') + + +def test_random_erasing(): + for comp in [True, False]: + run_test(100, [ + SimpleRGBImageDecoder(), + RandomErasing(.75, max_count=3), + ToTensor(), + ToTorchImage() + ], comp, 'random_erasing') + + +def test_random_erasing_slow(): + for comp in [True, False]: + run_test(100, [ + SimpleRGBImageDecoder(), + RandomErasing(.75, fast_fill=False), + ToTensor(), + ToTorchImage() + ], comp, 'random_erasing_slow') + def test_flip(): for comp in [True, False]: @@ -129,7 +158,7 @@ def test_random_resized_crop(): for comp in [True, False]: run_test(100, [ SimpleRGBImageDecoder(), - RandomResizedCrop(scale=(0.08, 1.0), + RandomResizedCrop(scale=(0.08, 1.0), ratio=(0.75, 4/3), size=32), ToTensor(), From 0a0908c7ebc69151d1c56337c85f7131797e0a34 Mon Sep 17 00:00:00 2001 From: sanagno Date: Fri, 24 Mar 2023 16:33:21 +0100 Subject: [PATCH 07/19] correct traversalorder check --- ffcv/loader/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ffcv/loader/loader.py b/ffcv/loader/loader.py index 6e4240ea..97bafa6d 100644 --- a/ffcv/loader/loader.py +++ b/ffcv/loader/loader.py @@ -156,7 +156,7 @@ def __init__(self, if order in ORDER_MAP: self.traversal_order: TraversalOrder = ORDER_MAP[order](self) - elif isinstance(order, TraversalOrder): + elif issubclass(order, TraversalOrder): self.traversal_order: TraversalOrder = order(self) else: raise ValueError(f"Order {order} is not a supported order type or a subclass of TraversalOrder") From 2251f18210c2091661000d55c90b2d0cbf2078d5 Mon Sep 17 00:00:00 2001 From: "richard.smith" Date: Wed, 29 Mar 2023 12:12:21 +0100 Subject: [PATCH 08/19] Allowing nopython and setting class defaults --- ffcv/pipeline/compiler.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/ffcv/pipeline/compiler.py b/ffcv/pipeline/compiler.py index 987356cc..a6996be6 100644 --- a/ffcv/pipeline/compiler.py +++ b/ffcv/pipeline/compiler.py @@ -1,5 +1,5 @@ import pdb -from numba import njit, set_num_threads, prange, warnings as nwarnings, get_num_threads +from numba import jit, set_num_threads, prange, warnings as nwarnings, get_num_threads from numba.core.errors import NumbaPerformanceWarning from multiprocessing import cpu_count import torch as ch @@ -7,6 +7,8 @@ class Compiler: + is_enabled: bool = True + num_threads: int = 1 @classmethod def set_enabled(cls, b): @@ -25,10 +27,12 @@ def compile(cls, code, signature=None): parallel = False if hasattr(code, 'is_parallel'): parallel = code.is_parallel and cls.num_threads > 1 + + nopython = getattr(code, 'nopython', True) if cls.is_enabled: - return njit(signature, fastmath=True, nogil=True, error_model='numpy', - parallel=parallel)(code) + return jit(signature, fastmath=True, nogil=nopython, error_model='numpy', + parallel=parallel, nopython=nopython)(code) return code @classmethod @@ -38,5 +42,5 @@ def get_iterator(cls): else: return range -Compiler.set_enabled(True) -Compiler.set_num_threads(1) +# Compiler.set_enabled(True) +# Compiler.set_num_threads(1) From 38dc51a6650b2ee3633bab7cc91b475200bf6552 Mon Sep 17 00:00:00 2001 From: "richard.smith" Date: Wed, 29 Mar 2023 14:43:22 +0100 Subject: [PATCH 09/19] http link to submodules --- .gitmodules | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitmodules b/.gitmodules index 21f138b4..9f99faa0 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,3 @@ [submodule "examples/imagenet-example"] path = examples/imagenet-example - url = git@github.com:libffcv/ffcv-imagenet.git + url = https://github.com/libffcv/ffcv-imagenet.git From 21edbf20f2241d093630f044929e9d5dc2507633 Mon Sep 17 00:00:00 2001 From: sanagno Date: Thu, 30 Mar 2023 10:28:04 +0200 Subject: [PATCH 10/19] type annotation --- ffcv/loader/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ffcv/loader/loader.py b/ffcv/loader/loader.py index 97bafa6d..2a9af03e 100644 --- a/ffcv/loader/loader.py +++ b/ffcv/loader/loader.py @@ -40,7 +40,7 @@ class OrderOption(Enum): ] -ORDER_MAP: Mapping[ORDER_TYPE, TraversalOrder] = { +ORDER_MAP: Mapping[ORDER_TYPE, Type[TraversalOrder]] = { OrderOption.RANDOM: Random, OrderOption.SEQUENTIAL: Sequential, OrderOption.QUASI_RANDOM: QuasiRandom From 4bbacb3f55c09de2f033f9b16bc17ab311317d6f Mon Sep 17 00:00:00 2001 From: "richard.smith" Date: Fri, 31 Mar 2023 17:26:05 +0100 Subject: [PATCH 11/19] reverting class attributes change --- ffcv/pipeline/compiler.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/ffcv/pipeline/compiler.py b/ffcv/pipeline/compiler.py index a6996be6..34cfae47 100644 --- a/ffcv/pipeline/compiler.py +++ b/ffcv/pipeline/compiler.py @@ -7,8 +7,6 @@ class Compiler: - is_enabled: bool = True - num_threads: int = 1 @classmethod def set_enabled(cls, b): @@ -32,7 +30,7 @@ def compile(cls, code, signature=None): if cls.is_enabled: return jit(signature, fastmath=True, nogil=nopython, error_model='numpy', - parallel=parallel, nopython=nopython)(code) + parallel=parallel, nopython=nopython, forceobj=not nopython)(code) return code @classmethod @@ -42,5 +40,5 @@ def get_iterator(cls): else: return range -# Compiler.set_enabled(True) -# Compiler.set_num_threads(1) +Compiler.set_enabled(True) +Compiler.set_num_threads(1) From d27f3e0b52053402d7465412e8007e779404e7b0 Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Mon, 3 Apr 2023 13:56:29 -0500 Subject: [PATCH 12/19] Improve Parameter docs --- ffcv/fields/rgb_image.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/ffcv/fields/rgb_image.py b/ffcv/fields/rgb_image.py index 53f9ef4c..829ec8a3 100644 --- a/ffcv/fields/rgb_image.py +++ b/ffcv/fields/rgb_image.py @@ -302,26 +302,27 @@ class RGBImageField(Field): size), and 'proportion' (JPEG compress a random subset of the data with size specified by the ``compress_probability`` argument). By default: 'raw'. max_resolution : int, optional - If specified, will resize images to have maximum side length equal to - this value before saving, by default None + If specified, resize images to have maximum side length equal to this + value if maximum side length is larger. By default: None min_resolution : int, optional - If specified, will resize images to have minimum side length equal to - this value before saving, by default None + If specified, resize images to have minimum side length equal to this + value if minimum side length is larger. By default: None smart_threshold : int, optional When `write_mode='smart`, will compress an image if RAW byte size is larger than `smart_threshold`. jpeg_quality : int, optional - The quality parameter for JPEG encoding (ignored for - ``write_mode='raw'``), by default 90 + The quality parameter for JPEG encoding (ignored for ``write_mode='raw'``). + By default 90 compress_probability : float, optional Ignored unless ``write_mode='proportion'``; in the latter case it is the - probability with which image is JPEG-compressed, by default 0.5. + probability with which image is JPEG-compressed. By default 0.5. interpolation : optional - OpenCV interpolation flag for resizing images with OpenCV, by default INTER_AREA. + The OpenCV interpolation flag for resizing images with OpenCV. + By default INTER_AREA. resample : optional - Pillow resampling filter for resizing images with Pillow, by default LANCZOS. + The Pillow resampling filter for resizing images with Pillow. By default LANCZOS. pillow_resize : bool, optional - Use Pillow to resize images instead of OpenCV, by default False (OpenCV). + Use Pillow to resize images instead of OpenCV. By default False (OpenCV). """ def __init__(self, write_mode='raw', max_resolution: int = None, min_resolution: int = None, smart_threshold: int = None, From 3bbe9453af7b5dcd669543cf666317ac61961692 Mon Sep 17 00:00:00 2001 From: Korbinian Koch Date: Wed, 31 May 2023 13:06:28 +0200 Subject: [PATCH 13/19] Correct CIFAR10_MEAN and CIFAR10_STD in train_cifar.py For some reason, the CIFAR10_STD used in the example differed significantly from the actual standard deviation of the CIFAR10 train dataset. I corrected both the MEAN and STD with 3 decimal places accuracy. I calculated the values as seen here: https://gist.github.com/epistoteles/c35bd5154a036748651d8caca11a7efe --- examples/cifar/train_cifar.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/cifar/train_cifar.py b/examples/cifar/train_cifar.py index 733465fe..de9430bd 100644 --- a/examples/cifar/train_cifar.py +++ b/examples/cifar/train_cifar.py @@ -74,8 +74,8 @@ def make_dataloaders(train_dataset=None, val_dataset=None, batch_size=None, num_ } start_time = time.time() - CIFAR_MEAN = [125.307, 122.961, 113.8575] - CIFAR_STD = [51.5865, 50.847, 51.255] + CIFAR_MEAN = [125.307, 122.950, 113.865] + CIFAR_STD = [62.993, 62.089, 66.705] loaders = {} for name in ['train', 'test']: From cdab5649e0d1c2f747428e8ea5e44b6792f8c8fa Mon Sep 17 00:00:00 2001 From: Wouter Zwerink Date: Wed, 28 Jun 2023 10:42:24 +0000 Subject: [PATCH 14/19] Add RandomVerticalFlip --- ffcv/transforms/__init__.py | 6 +++--- ffcv/transforms/flip.py | 37 +++++++++++++++++++++++++++++++++++++ tests/test_augmentations.py | 14 ++++++++++++-- 3 files changed, 52 insertions(+), 5 deletions(-) diff --git a/ffcv/transforms/__init__.py b/ffcv/transforms/__init__.py index 2636a447..ae2167a4 100644 --- a/ffcv/transforms/__init__.py +++ b/ffcv/transforms/__init__.py @@ -1,5 +1,5 @@ from .cutout import Cutout -from .flip import RandomHorizontalFlip +from .flip import RandomHorizontalFlip, RandomVerticalFlip from .ops import ToTensor, ToDevice, ToTorchImage, Convert, View from .common import Squeeze from .random_resized_crop import RandomResizedCrop @@ -15,7 +15,7 @@ 'ToTorchImage', 'NormalizeImage', 'Convert', 'Squeeze', 'View', 'RandomResizedCrop', 'RandomHorizontalFlip', 'RandomTranslate', - 'Cutout', 'ImageMixup', 'LabelMixup', 'MixupToOneHot', - 'Poison', 'ReplaceLabel', + 'RandomVerticalFlip', 'Cutout', 'ImageMixup', 'LabelMixup', + 'MixupToOneHot', 'Poison', 'ReplaceLabel', 'ModuleWrapper', 'RandomBrightness', 'RandomContrast', 'RandomSaturation'] diff --git a/ffcv/transforms/flip.py b/ffcv/transforms/flip.py index 63d4b1f9..d7fecd6c 100644 --- a/ffcv/transforms/flip.py +++ b/ffcv/transforms/flip.py @@ -44,3 +44,40 @@ def flip(images, dst): def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]: return (replace(previous_state, jit_mode=True), AllocationQuery(previous_state.shape, previous_state.dtype)) + + +class RandomVerticalFlip(Operation): + """Flip the image vertically with probability flip_prob. + Operates on raw arrays (not tensors). + + Parameters + ---------- + flip_prob : float + The probability with which to flip each image in the batch + vertically. + """ + + def __init__(self, flip_prob: float = 0.5): + super().__init__() + self.flip_prob = flip_prob + + def generate_code(self) -> Callable: + my_range = Compiler.get_iterator() + flip_prob = self.flip_prob + + def flip(images, dst): + should_flip = rand(images.shape[0]) < flip_prob + for i in my_range(images.shape[0]): + if should_flip[i]: + dst[i] = images[i, ::-1, ...] + else: + dst[i] = images[i] + + return dst + + flip.is_parallel = True + return flip + + def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]: + return (replace(previous_state, jit_mode=True), + AllocationQuery(previous_state.shape, previous_state.dtype)) diff --git a/tests/test_augmentations.py b/tests/test_augmentations.py index 01f502b6..07dc174c 100644 --- a/tests/test_augmentations.py +++ b/tests/test_augmentations.py @@ -81,14 +81,24 @@ def test_cutout(): ], comp, 'cutout') -def test_flip(): +def test_horizontal_flip(): for comp in [True, False]: run_test(100, [ SimpleRGBImageDecoder(), RandomHorizontalFlip(1.0), ToTensor(), ToTorchImage() - ], comp, 'flip') + ], comp, 'hflip') + + +def test_vertical_flip(): + for comp in [True, False]: + run_test(100, [ + SimpleRGBImageDecoder(), + RandomVerticalFlip(1.0), + ToTensor(), + ToTorchImage() + ], comp, 'vflip') def test_module_wrapper(): From c5e40a4915f37ca4530f60f792b696596f6a7243 Mon Sep 17 00:00:00 2001 From: Wouter Zwerink Date: Thu, 29 Jun 2023 10:00:57 +0000 Subject: [PATCH 15/19] Deepcopy mutable loader args --- ffcv/loader/loader.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ffcv/loader/loader.py b/ffcv/loader/loader.py index 55bd6b1d..ac365518 100644 --- a/ffcv/loader/loader.py +++ b/ffcv/loader/loader.py @@ -9,6 +9,7 @@ from typing import Any, Callable, Mapping, Sequence, Type, Union, Literal from collections import defaultdict from collections.abc import Collection +from copy import deepcopy from enum import Enum, unique, auto from ffcv.fields.base import Field @@ -121,8 +122,8 @@ def __init__(self, 'order': order, 'distributed': distributed, 'seed': seed, - 'indices': indices, - 'pipelines': pipelines, + 'indices': deepcopy(indices), + 'pipelines': deepcopy(pipelines), 'drop_last': drop_last, 'batches_ahead': batches_ahead, 'recompile': recompile From ba96e7d0b5cfdc1922a7a35773953b5afdfd6efa Mon Sep 17 00:00:00 2001 From: Wouter Zwerink Date: Thu, 29 Jun 2023 10:31:45 +0000 Subject: [PATCH 16/19] Add warning for field length --- ffcv/writer.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/ffcv/writer.py b/ffcv/writer.py index 1b70f74f..0ac2a2e2 100644 --- a/ffcv/writer.py +++ b/ffcv/writer.py @@ -1,20 +1,20 @@ +import ctypes +import warnings from functools import partial -from typing import Callable, List, Mapping +from multiprocessing import Process, Queue, Value, cpu_count, shared_memory from os import SEEK_END, path -import numpy as np from time import sleep -import ctypes -from multiprocessing import (shared_memory, cpu_count, Queue, Process, Value) +from typing import Callable, List, Mapping +import numpy as np from tqdm import tqdm from tqdm.contrib.concurrent import thread_map -from .utils import chunks, is_power_of_2 from .fields.base import Field from .memory_allocator import MemoryAllocator -from .types import (TYPE_ID_HANDLER, get_metadata_type, HeaderType, - FieldDescType, CURRENT_VERSION, ALLOC_TABLE_TYPE) - +from .types import (ALLOC_TABLE_TYPE, CURRENT_VERSION, TYPE_ID_HANDLER, + FieldDescType, HeaderType, get_metadata_type) +from .utils import chunks, is_power_of_2 MIN_PAGE_SIZE = 1 << 21 # 2MiB, which is the most common HugePage size MAX_PAGE_SIZE = 1 << 32 # Biggest page size that will not overflow uint32 @@ -151,7 +151,9 @@ def __init__(self, fname: str, fields: Mapping[str, Field], raise ValueError(f"page_size can't be lower than{MIN_PAGE_SIZE}") if page_size >= MAX_PAGE_SIZE: raise ValueError(f"page_size can't be bigger(or =) than{MAX_PAGE_SIZE}") - + for field_name in fields.keys(): + if len(field_name) > 16: + warnings.warn(f"Field name {field_name} will be cropped to {field_name[:16]}") self.page_size = page_size def prepare(self): From 53b9d3e973b4b079570a08c438fb3e12897654f8 Mon Sep 17 00:00:00 2001 From: wouterzwerink Date: Wed, 5 Jul 2023 11:24:05 +0200 Subject: [PATCH 17/19] Fix missing docs --- docs/benchmarks.rst | 1 + docs/ffcv_examples/custom_transforms.rst | 4 +-- ffcv/benchmarks/benchmark.py | 2 +- ffcv/fields/base.py | 10 ++++---- ffcv/fields/rgb_image.py | 2 +- ffcv/memory_managers/base.py | 2 +- ffcv/pipeline/graph.py | 31 +++++++++++++++--------- ffcv/pipeline/operation.py | 4 +-- ffcv/pipeline/state.py | 2 +- ffcv/transforms/mixup.py | 2 +- ffcv/transforms/random_resized_crop.py | 2 +- ffcv/traversal_order/base.py | 2 +- 12 files changed, 36 insertions(+), 28 deletions(-) diff --git a/docs/benchmarks.rst b/docs/benchmarks.rst index d03d4529..ef304a7c 100644 --- a/docs/benchmarks.rst +++ b/docs/benchmarks.rst @@ -89,6 +89,7 @@ We compare our results against existing data loading platforms: - `Pytorch DataLoader `_: This is the default option that comes with the Pytorch library and uses individual JPEG files as the source. - `Webdataset `_: This loader requires pre-processed files aggregated in multiple big `.tar` archives. - `DALI `_: Data loading pipeline developed by Nvidia. In this experiment we used the default file format which is the same as that of the Pytorch DataLoader. + The specific instantiation of DALI that we apply is the PyTorch ImageNet example DALI code found in the `NVIDIA DeepLearningExamples repository `_. We use the DGX-1 configuration and remove all the model optimization, benchmarking only the dataloader. diff --git a/docs/ffcv_examples/custom_transforms.rst b/docs/ffcv_examples/custom_transforms.rst index 7c4b4195..54acaa49 100644 --- a/docs/ffcv_examples/custom_transforms.rst +++ b/docs/ffcv_examples/custom_transforms.rst @@ -31,11 +31,11 @@ Doing so requires providing implementation for two functions: # Return the code to run this operation @abstractmethod def generate_code(self) -> Callable: - raise NotImplementedError + raise NotImplementedError() @abstractmethod def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]: - raise NotImplementedError + raise NotImplementedError() Advancing state and pre-allocating memory ------------------------------------------ diff --git a/ffcv/benchmarks/benchmark.py b/ffcv/benchmarks/benchmark.py index 859f731d..a5489cbe 100644 --- a/ffcv/benchmarks/benchmark.py +++ b/ffcv/benchmarks/benchmark.py @@ -8,4 +8,4 @@ def __init__(self, **kwargs): @abstractmethod def run(self): - raise NotImplemented() \ No newline at end of file + raise NotImplementedError() \ No newline at end of file diff --git a/ffcv/fields/base.py b/ffcv/fields/base.py index 329275d1..9f8fe171 100644 --- a/ffcv/fields/base.py +++ b/ffcv/fields/base.py @@ -25,21 +25,21 @@ class Field(ABC): @property @abstractmethod def metadata_type(self) -> np.dtype: - raise NotImplemented + raise NotImplementedError() @staticmethod @abstractmethod def from_binary(binary: ARG_TYPE) -> Field: - raise NotImplementedError + raise NotImplementedError() @abstractmethod def to_binary(self) -> ARG_TYPE: - raise NotImplementedError + raise NotImplementedError() @abstractmethod def encode(field, metadata_destination, malloc): - raise NotImplementedError + raise NotImplementedError() @abstractmethod def get_decoder_class(self) -> Type[Operation]: - raise NotImplementedError + raise NotImplementedError() diff --git a/ffcv/fields/rgb_image.py b/ffcv/fields/rgb_image.py index 829ec8a3..87dd7707 100644 --- a/ffcv/fields/rgb_image.py +++ b/ffcv/fields/rgb_image.py @@ -238,7 +238,7 @@ def decode(batch_indices, my_storage, metadata, storage_state): @property @abstractmethod def get_crop_generator(): - raise NotImplementedError + raise NotImplementedError() class RandomResizedCropRGBImageDecoder(ResizedCropRGBImageDecoder): diff --git a/ffcv/memory_managers/base.py b/ffcv/memory_managers/base.py index 525833a0..2c450369 100644 --- a/ffcv/memory_managers/base.py +++ b/ffcv/memory_managers/base.py @@ -72,7 +72,7 @@ def schedule_epoch(self, batches: Sequence[Sequence[int]]) -> MemoryContext: @abstractmethod def compile_reader(self, address, size) -> Callable: - raise NotImplemented() + raise NotImplementedError() @property @abstractmethod diff --git a/ffcv/pipeline/graph.py b/ffcv/pipeline/graph.py index 05da7cee..9ff26fb0 100644 --- a/ffcv/pipeline/graph.py +++ b/ffcv/pipeline/graph.py @@ -1,6 +1,6 @@ -from distutils.log import warn import warnings import ast +import sys try: # Useful for debugging @@ -23,11 +23,18 @@ import torch as ch import numpy as np -# This is the starting state of the pipeline -INITIAL_STATE = State(jit_mode=True, - device=ch.device('cpu'), - dtype=np.dtype('u1'), - shape=None) +if "sphinx" in sys.modules: + # Sphinx fails on jit+gpu assert due to improper initialization of device + INITIAL_STATE = State(jit_mode=False, + device=ch.device('cpu'), + dtype=np.dtype('u1'), + shape=None) +else: + # This is the starting state of the pipeline + INITIAL_STATE = State(jit_mode=True, + device=ch.device('cpu'), + dtype=np.dtype('u1'), + shape=None) class Node(ABC): @@ -40,34 +47,34 @@ def __init__(self): @property @abstractmethod def is_jitted(self): - raise NotImplemented() + raise NotImplementedError() @property @abstractmethod def parent(self): - raise NotImplemented() + raise NotImplementedError() @property @abstractmethod def arg_id(self): - raise NotImplemented() + raise NotImplementedError() @property @abstractmethod def result_id(self): - raise NotImplemented() + raise NotImplementedError() @property @abstractmethod def result_id(self): - raise NotImplemented() + raise NotImplementedError() def get_shared_code_ast(self, done_ops): return ast.Pass() @abstractmethod def generate_code(self): - raise NotImplemented() + raise NotImplementedError() def recompile(self): self._code = self.generate_code() diff --git a/ffcv/pipeline/operation.py b/ffcv/pipeline/operation.py index 8ad947e8..b46257fc 100644 --- a/ffcv/pipeline/operation.py +++ b/ffcv/pipeline/operation.py @@ -28,7 +28,7 @@ def accept_globals(self, metadata, memory_read): # Return the code to run this operation @abstractmethod def generate_code(self) -> Callable: - raise NotImplementedError + raise NotImplementedError() def declare_shared_memory(self, previous_state: State) -> Optional[AllocationQuery]: return None @@ -38,4 +38,4 @@ def generate_code_for_shared_state(self) -> Optional[Callable]: @abstractmethod def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]: - raise NotImplementedError + raise NotImplementedError() diff --git a/ffcv/pipeline/state.py b/ffcv/pipeline/state.py index a2e31dcc..0b553c5c 100644 --- a/ffcv/pipeline/state.py +++ b/ffcv/pipeline/state.py @@ -14,7 +14,7 @@ class State: # Assess the validity of a pipeline stage def __post_init__(self): - if self.jit_mode and self.device != ch.device('cpu'): + if self.jit_mode and self.device.type != 'cpu': raise AssertionError("Can't be in JIT mode and on the GPU") if self.jit_mode and isinstance(self.dtype, ch.dtype): raise AssertionError("Can't allocate a torch tensor in JIT mode") \ No newline at end of file diff --git a/ffcv/transforms/mixup.py b/ffcv/transforms/mixup.py index 53239b6f..724994d5 100644 --- a/ffcv/transforms/mixup.py +++ b/ffcv/transforms/mixup.py @@ -58,7 +58,7 @@ def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Option class LabelMixup(Operation): """Mixup for labels. Should be initialized in exactly the same way as - :cla:`ffcv.transforms.ImageMixup`. + :class:`ffcv.transforms.ImageMixup`. """ def __init__(self, alpha: float, same_lambda: bool): super().__init__() diff --git a/ffcv/transforms/random_resized_crop.py b/ffcv/transforms/random_resized_crop.py index 5a7405c5..7f311b42 100644 --- a/ffcv/transforms/random_resized_crop.py +++ b/ffcv/transforms/random_resized_crop.py @@ -14,7 +14,7 @@ class RandomResizedCrop(Operation): """Crop a random portion of image with random aspect ratio and resize it to a given size. Chances are you do not want to use this augmentation and instead want to include RRC as part of the decoder, by using the - :cla:`~ffcv.fields.rgb_image.ResizedCropRGBImageDecoder` class. + :class:`~ffcv.fields.rgb_image.ResizedCropRGBImageDecoder` class. Parameters ---------- diff --git a/ffcv/traversal_order/base.py b/ffcv/traversal_order/base.py index 74f1a70b..fcbb7f5e 100644 --- a/ffcv/traversal_order/base.py +++ b/ffcv/traversal_order/base.py @@ -17,4 +17,4 @@ def __init__(self, loader: 'Loader'): @abstractmethod def sample_order(self, epoch:int) -> Sequence[int]: - raise NotImplemented() + raise NotImplementedError() From 9299848950e405cb2c8cc5e891b8cf8cf4bda667 Mon Sep 17 00:00:00 2001 From: Andrew Ilyas Date: Mon, 6 May 2024 10:29:33 -0400 Subject: [PATCH 18/19] Update __init__.py --- ffcv/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ffcv/__init__.py b/ffcv/__init__.py index 541757b3..e2079883 100644 --- a/ffcv/__init__.py +++ b/ffcv/__init__.py @@ -1,5 +1,5 @@ from .loader import Loader from .writer import DatasetWriter -__version__ = '1.0.2' +__version__ = '1.1.0' __all__ = ['Loader'] From 39b2dd8a68716abb0205c3852ae5ec96d3c1d2fd Mon Sep 17 00:00:00 2001 From: Andrew Ilyas Date: Mon, 6 May 2024 10:31:45 -0400 Subject: [PATCH 19/19] Update setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 433b78d2..5028e8fd 100644 --- a/setup.py +++ b/setup.py @@ -102,7 +102,7 @@ def pkgconfig(package, kw): **extension_kwargs) setup(name='ffcv', - version='1.0.1', + version='1.1.0', description=' FFCV: Fast Forward Computer Vision ', author='MadryLab', author_email='ffcv@mit.edu',