Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The negative influence of directly converting inputs to tensors and transferring back to CPU #5659

Open
1 task done
Arith2 opened this issue Oct 4, 2024 · 7 comments
Open
1 task done
Assignees
Labels
question Further information is requested

Comments

@Arith2
Copy link

Arith2 commented Oct 4, 2024

Describe the question.

Hi, recently I want to run DALI for some preprocessing pipelines in GPU and I find some problems which are very weird.

My pipeline is like this:

class SimCLR_DALIPipeline(Pipeline):

    def __init__(self, batch_size, num_threads, device_id, size, s=1, prefetch_queue_depth=1, py_num_workers=1, device='gpu'):
        super(SimCLR_DALIPipeline, self).__init__(batch_size, num_threads, device_id, seed=12, prefetch_queue_depth=prefetch_queue_depth, py_num_workers=py_num_workers)

        self.size = size
        self.s = s

        self.loader = load_binary_images(batch_size)
        self.input = fn.external_source(source=self.loader, device=device)

        self.resized_crop = fn.random_resized_crop(self.input, size=(size, size), device=device)

        flip_coin = fn.random.coin_flip(probability=0.5)
        self.horizontal_flip = fn.flip(self.resized_crop, horizontal=flip_coin, device=device)

        self.color_jitter = fn.color_twist(self.horizontal_flip, 
                                        brightness=0.8 * s,
                                        contrast=0.8 * s, 
                                        saturation=0.8 * s, 
                                        hue=0.2 * s, 
                                        device=device)

        self.random_grayscale = fn.color_space_conversion(self.color_jitter, image_type=types.RGB, output_type=types.GRAY, device=device)
        
        self.gaussian_blur = fn.gaussian_blur(self.random_grayscale, window_size=int(0.1 * size), device=device)
        self.to_tensor = fn.cast(self.gaussian_blur, dtype=types.FLOAT, device=device)

    def define_graph(self):
        images = self.input[0]
        return self.to_tensor

And this is how I access the data in this pipeline (preprocessed_images is not in use):

    pipeline = SimCLR_DALIPipeline(batch_size=args.batch_size, num_threads=args.num_thread, device_id=args.gpu_index, size=96, prefetch_queue_depth=args.prefetch_queue_depth, py_num_workers=args.py_num_workers, device=args.device)

    pipeline.build()

    dali_iterator = DALIGenericIteratorWithViews([pipeline], size=args.dataset_size, n_views=args.n_views)

    # Iterate over DALI preprocessed images
    start_time = time.time()
    for epoch in range(args.epochs):
        for batch in tqdm(dali_iterator, total=len(dali_iterator), desc=f"Epoch {epoch + 1}"):
            for views in batch:
                for elem in views:
                    # preprocessed_images = elem['data'].cpu().numpy()
                    preprocessed_images = elem['data']
        dali_iterator.reset()

    stats = pipeline.executor_statistics()
    print(stats)   

The problem is that:

  1. If I remove all operators (directly convert from self.input to self.to_tensor) and leave the final tensors in GPU, its performance is simlilar to containing these operators (this is expected).
  2. If I remove all operators and transfer the final tensors back to CPU, for different batch size, the time varies significantly and can be much larger than keeping these tensors in GPU.
  3. If I contain these operators in the pipeline and transfer the final tensors back to CPU, its performance is a bit better but still larger than keeping tensors in GPU.
  4. pipeline.executor_statistics() does not work and there is no output.

Here is the plot:
Image

In the command, I fix some other parameters, like num_threads, prefetch_queue_depth, py_num_workers.

I test in 3090 in local machine and in Nvidia V100 in google cloud and observe similar phenomenon.

Check for duplicates

  • I have searched the open bugs/issues and have found no duplicates for this bug report
@Arith2 Arith2 added the question Further information is requested label Oct 4, 2024
@mzient
Copy link
Contributor

mzient commented Oct 4, 2024

Hello @Arith2 ,
I don't know what DALIGenericIteratorWithViews is (it's not a part of DALI) and what are the objects being returned by it - but judging by the cpu() member I gather it's not a DALI batch any more, but rather something coming from your ML framework (Torch tensor?).
I'm not sure what kind of synchronization it implements - in the worst case, each .cpu() waits for the copy to finish. I cannot explain any more without an actual working example.

Meanwhile, if you use the latest release 1.42, you can try to do the D2H copy inside DALI with the new "dynamic" executor.
Just pass experimental_exec_dynamic=True to the pipeline and you can call .cpu() on the DataNodes.

        self.to_tensor = fn.cast(self.gaussian_blur, dtype=types.FLOAT, device=device).cpu()

Additionally - py_num_workers affects only parallel external_source. You're not using one, so it has no effect.
The parameters which you might want to tune are num_threads (typically num_cpu_cores/num_gpus - 1 works well) and prefetch_queue_depth (try 1, 2 or 3).

@Arith2
Copy link
Author

Arith2 commented Oct 4, 2024

Hi @mzient , thanks very much for your quick response. I define load_binary_images() by myself to read a binary file and then use fn.external_source() to load it. I am more interested in how the pipeline length affect the overall performance and what is the side effect of transferring data back to CPU to simulate the case of transferring data to other accelerators.

This is the command I use to run the script.
time python run_DALI_preprocess_batch.py --dataset-size 500000 --prefetch-queue-depth 8 --num-thread 8 --py-num-workers 2 --device gpu

This is the full file of my script called run_DALI_preprocess_batch.py:

import argparse
import numpy as np
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.fn as fn
import nvidia.dali.types as types
from nvidia.dali.plugin.pytorch import DALIGenericIterator
import time
from tqdm import tqdm


# Argument parser
parser = argparse.ArgumentParser(description='PyTorch SimCLR with DALI')
parser.add_argument('-data', metavar='DIR', default='./datasets',
                    help='path to dataset')
parser.add_argument('-dataset-name', default='stl10',
                    help='dataset name', choices=['stl10', 'cifar10'])
parser.add_argument('--epochs', default=1, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('-b', '--batch-size', default=1, type=int,
                    metavar='N',
                    help='mini-batch size (default: 256), this is the total '
                         'batch size of all GPUs on the current node when '
                         'using Data Parallel or Distributed Data Parallel')
parser.add_argument('--lr', '--learning-rate', default=0.0003, type=float,
                    metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
                    metavar='W', help='weight decay (default: 1e-4)',
                    dest='weight_decay')
parser.add_argument('--seed', default=None, type=int,
                    help='seed for initializing training. ')
parser.add_argument('--disable-cuda', action='store_true',
                    help='Disable CUDA')
parser.add_argument('--fp16-precision', action='store_true',
                    help='Whether or not to use 16-bit precision GPU training.')
parser.add_argument('--out_dim', default=128, type=int,
                    help='feature dimension (default: 128)')
parser.add_argument('--log-every-n-steps', default=100, type=int,
                    help='Log every n steps')
parser.add_argument('--temperature', default=0.07, type=float,
                    help='softmax temperature (default: 0.07)')
parser.add_argument('--n-views', default=2, type=int, metavar='N',
                    help='Number of views for contrastive learning training.')
parser.add_argument('--gpu-index', default=0, type=int, help='Gpu index.')
parser.add_argument('--num-thread', default=2, type=int,
                    help='number of threads in CPU for coordination')
parser.add_argument('--prefetch-queue-depth', default=1, type=int, help='Prefetch Queue Depth.')
parser.add_argument('--py-num-workers', default=1, type=int, help='Number of workers to load data from external source.')
parser.add_argument('--dataset-size', default=10, type=int, help='Number of images in use')
parser.add_argument('--device', default="gpu", type=str, help='Device for preprocessing.')


# Define the dimensions of your images in the binary file (e.g., for STL10)
IMAGE_HEIGHT = 96
IMAGE_WIDTH = 96
NUM_CHANNELS = 3


def load_binary_images(batch_size):
    file_path = './datasets/stl10_binary/output.bin'
    print("File path: ", file_path)
    image_size = IMAGE_HEIGHT * IMAGE_WIDTH * NUM_CHANNELS

    with open(file_path, 'rb') as f:
        while True:
            batch_data = []
            for _ in range(batch_size):
                data = f.read(image_size)
                if not data:
                    # Reset the file pointer to the beginning
                    f.seek(0)
                    data = f.read(image_size)
                    if not data:  # If the file is still empty, stop iteration
                        return
                image = np.frombuffer(data, dtype=np.uint8).reshape(NUM_CHANNELS, IMAGE_HEIGHT, IMAGE_WIDTH)
                image = np.transpose(image, (1, 2, 0))
                batch_data.append(image)
            yield np.array(batch_data)


# Modified pipeline to use `fn.external_source` for loading binary data
class SimCLR_DALIPipeline(Pipeline):
    def __init__(self, batch_size, num_threads, device_id, size, s=1, prefetch_queue_depth=1, py_num_workers=1, device='gpu'):
        super(SimCLR_DALIPipeline, self).__init__(batch_size, num_threads, device_id, seed=12, prefetch_queue_depth=prefetch_queue_depth, py_num_workers=py_num_workers)
        self.size = size
        self.s = s

        self.loader = load_binary_images(batch_size)
        self.input = fn.external_source(source=self.loader, device=device)

#        self.resized_crop = fn.random_resized_crop(self.input, size=(size, size), device=device)

#        flip_coin = fn.random.coin_flip(probability=0.5)
#        self.horizontal_flip = fn.flip(self.resized_crop, horizontal=flip_coin, device=device)

#        self.color_jitter = fn.color_twist(self.horizontal_flip, 
#                                        brightness=0.8 * s,
#                                        contrast=0.8 * s, 
#                                        saturation=0.8 * s, 
#                                        hue=0.2 * s, 
#                                        device=device)

#        self.random_grayscale = fn.color_space_conversion(self.color_jitter, image_type=types.RGB, output_type=types.GRAY, device=device)
        
#        self.gaussian_blur = fn.gaussian_blur(self.random_grayscale, window_size=int(0.1 * size), device=device)
#        self.gaussian_blur_3_channels = fn.cat(self.gaussian_blur, self.gaussian_blur, self.gaussian_blur, axis=0, device=device)
#        self.to_tensor = fn.cast(self.gaussian_blur, dtype=types.FLOAT, device=device)
        self.to_tensor = fn.cast(self.input, dtype=types.FLOAT, device=device)

    def define_graph(self):
        images = self.input[0]
        return self.to_tensor


class DALIGenericIteratorWithViews(DALIGenericIterator):
    def __init__(self, pipelines, size, n_views):
        super().__init__(pipelines, ['data'], size, last_batch_padded=True)
        self.n_views = n_views

    def __next__(self):
        data = super().__next__()  # Get the batch of preprocessed images
        views = [data] * self.n_views  # Create multiple views for contrastive learning
        return views


def main():
    args = parser.parse_args()


    print("num_thread: ", args.num_thread)
    print("prefetch_queue_depth: ", args.prefetch_queue_depth)
    print("py_num_workers: ", args.py_num_workers)
    print("dataset_size: ", args.dataset_size)
    print("n_views: ", args.n_views)
    print("device: ", args.device)

    #print("batch_size: ", args.batch_size)
    batch_size_array = np.array([64,128,256,512,1024,2048])
    for batch_size in batch_size_array:
        # Set up the DALI pipeline
        pipeline = SimCLR_DALIPipeline(batch_size=batch_size, num_threads=args.num_thread, device_id=args.gpu_index, size=96, prefetch_queue_depth=args.prefetch_queue_depth, py_num_workers=args.py_num_workers, device=args.device)
        pipeline.build()

        dali_iterator = DALIGenericIteratorWithViews([pipeline], size=args.dataset_size, n_views=args.n_views)

        start_time = time.time()
        for epoch in range(args.epochs):
            print(f"Epoch {epoch + 1}/{args.epochs}")

            for batch in tqdm(dali_iterator, total=len(dali_iterator), desc=f"Epoch {epoch + 1}"):
                for views in batch:
                    for elem in views:
                        # preprocessed_images = elem['data'].cpu()
                         preprocessed_images = elem['data']
                #pass
            dali_iterator.reset()


        end_time = time.time()
        print(f"Batch_size: {batch_size}, total preprocessing time: {end_time - start_time:.2f} seconds")


if __name__ == "__main__":
    main()

Now my nvidia.dali version is 1.41.0 with cudatoolkit 11.0.

@szkarpinski
Copy link
Collaborator

szkarpinski commented Oct 7, 2024

Hello, @Arith2 ! Thank you for bringing up this interesting issue. For now, I can only explain some part of the results that you see.

Let's list the factors that add up to the total time you are measuring:

  • Cost of the transformations: “GPU + transformation” and “GPU + no transformation” are very close on your plot, which makes me believe that this cost is negligible.

  • Constant overhead for each batch: DALI needs to do some work for each batch, independent of its size. This includes some initialization which needs to be done for each batch, launching the operators etc. This cost is amortized for bigger batch sizes. This is the reason why all variants are getting faster for batch sizes 8->16->32->64.

  • Cost of copying to the CPU: This cost is proportional to the size of an image. Your transformations make your images smaller, because you use random_resize_crop. As a result, transformed images are faster to copy than not transformed. This is the reason why “GPU + transformation + back to CPU” is faster than “GPU + no transformation + back to CPU”.

Still, this doesn't explain why the performance degrades for bigger batch sizes. With the model I explained above, I'd expect the results to look like this:

Image

Thank you for providing the reproduction code. I'll experiment with it a bit more and try to figure out what causes this sudden growth of runtime for batches larger that 128. I'll get back to you once I have some more answers. If you have any questions, or see something that doesn't fit the explanations above, please let me know.

@Arith2
Copy link
Author

Arith2 commented Oct 10, 2024

@mzient @szkarpinski Hi Michal and Szymon, I test my baseline many times and one interesting observation is that the execution time is the same for both online and offline training of ResNet50 (ImageNet 13GB in JPEG).

  1. Online training: read data from the disk (it resides in CPU's memory), do preprocessing with DALI and training in the same GPU board. It consumes about 100% compute resource.
  2. Offline training: preprocess data in advance (about 15GB in JPEG), store results back to the disk (it resides in CPU's memory), then start and measure the time of pure training in GPU. It consumes about 100% compute resource.
  3. Pure preprocessing: if I run preprocessing with DALI only, it consumes about 20% resources.

So these experiments imply that the resource for preprocessing with DALI is perfectly hidden inside the training (as long as the computation is not so easy like AlexNet).

I am interested for the unified programming environment of both preprocessing and training. It looks that DALI pays nothing and the online training is in the same level of efficiency as offline training, while the popularity of utilizing DALI as the default data loader is not as high as what I think. If you are also interested in this kind of questions, feel free to contact me via email [email protected].

@szkarpinski
Copy link
Collaborator

Hello @Arith2 ! First, let me answer the unanswered question from my previous response, in which I couldn't explain why the performance degrades for bigger batch sizes. I still can't, but it turns out it's not DALI, but PyTorch ;)

DALIGenericIterator returns the data as a torch.Tensor, so the .cpu() method that you're calling is a method of torch.Tensor and this is no longer handled by DALI. I checked that .cpu() itself is slower for bigger data:

import torch
import time

IMAGE_HEIGHT = 96
IMAGE_WIDTH = 96
NUM_CHANNELS = 3

DATASET_SIZE = 1024*16

for batch_size in [1,2,4,8,16,32,64,128,256,512,1024,2048]:
    start_time = time.time()
    mock_gpu_data = torch.zeros((batch_size, NUM_CHANNELS, IMAGE_HEIGHT, IMAGE_WIDTH)).cuda()
    num_iterations = DATASET_SIZE // batch_size
    for _ in range(num_iterations):
        mock_gpu_data.cpu()

    end_time = time.time()
    print("Batch size", batch_size, "took", end_time - start_time)

yields:

Batch size 1 took 0.641491174697876
Batch size 2 took 0.3684566020965576
Batch size 4 took 0.3080744743347168
Batch size 8 took 0.27879929542541504
Batch size 16 took 0.20729494094848633
Batch size 32 took 0.17817330360412598
Batch size 64 took 0.1596205234527588
Batch size 128 took 0.15517973899841309
Batch size 256 took 0.15795373916625977
Batch size 512 took 1.0051357746124268 <--- !
Batch size 1024 took 1.0199494361877441 <--- !
Batch size 2048 took 1.052891492843628 <--- !

It might be the case that PyTorch has some fast path for smaller copies. You can try to have a look at their source code or ask them on GitHub. Anyway, with your model running on a GPU, you shouldn't need this GPU->CPU copy.

@szkarpinski
Copy link
Collaborator

szkarpinski commented Oct 15, 2024

So these experiments imply that the resource for preprocessing with DALI is perfectly hidden inside the training (as long as the computation is not so easy like AlexNet).

I'm glad that your experiments confirmed that. As you are saying, for simple enough preprocessing pipelines DALI overhead should be minimal.

It looks that DALI pays nothing and the online training is in the same level of efficiency as offline training

Also, offline preprocessing might not always be possible: in particular, when you perform many random augmentations, in most cases you want them to be different each epoch. With offline-preprocessed data, until you duplicate it for each epoch, you can't achieve that. We believe that in such cases DALI is the optimal solution for data processing in terms of performance.

Thank you for your interest in DALI! Please feel free to reach out to us again if you see some other problems or things that need improvement!

@JanuszL
Copy link
Contributor

JanuszL commented Oct 15, 2024

Hi @Arith2,

I am interested for the unified programming environment of both preprocessing and training. It looks that DALI pays nothing and the online training is in the same level of efficiency as offline training, while the popularity of utilizing DALI as the default data loader is not as high as what I think. If you are also interested in this kind of questions, feel free to contact me via email [email protected].

Thank you for sharing your research interests. Feel free to reach us using [email protected] email. When writing, please specify the aspects of collaborations you have in mind and how we can contribute.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

4 participants