Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel building/running of pipelines #5618

Open
1 task done
kvablack opened this issue Sep 3, 2024 · 2 comments
Open
1 task done

Parallel building/running of pipelines #5618

kvablack opened this issue Sep 3, 2024 · 2 comments
Assignees
Labels
enhancement New feature or request JAX Issues related to DALI and JAX integration

Comments

@kvablack
Copy link

kvablack commented Sep 3, 2024

Is this a new feature, an improvement, or a change to existing functionality?

Improvement

How would you describe the priority of this feature request

Should have (e.g. Adoption is possible, but the performance shortcomings make the solution inferior).

Please provide a clear description of problem this feature solves

I'm using DALI with the JAX plugin. From what I can tell, every plugin builds multiple pipelines in sequence.

for pipe in pipelines:
    pipe.build()

Each pipeline contains a fairly beefy external source, so this process can take a very long time, especially with 8 GPUs.

Feature Description

I recently wrote my own code to bypass the JAX plugin, and I parallelized the building process with a thread pool:

executor = concurrent.futures.ThreadPoolExecutor(max_workers=len(pipelines))
concurrent.futures.wait([executor.submit(lambda p: p.build(), pipe) for pipe in pipelines])

This sped up initialization by a significant amount. Similarly, I parallelized running the pipelines:

run_futures = [executor.submit(lambda p: p.run(), pipe) for pipe in pipelines]
outputs = [future.result() for future in run_futures]

This also sped up my maximum throughput by a modest amount, although the above code is a simplification -- I also do some other things after each run, like copying buffers into JAX memory. So I'm not sure if this would apply to every plugin.

It would be nice if this could be upstreamed into DALI, especially the building part!

Describe your ideal solution

See above

Describe any alternatives you have considered

No response

Additional context

No response

Check for duplicates

  • I have searched the open bugs/issues and have found no duplicates for this bug report
@kvablack kvablack added the enhancement New feature or request label Sep 3, 2024
@awolant awolant assigned awolant and unassigned mdabek-nvidia Sep 4, 2024
@awolant
Copy link
Contributor

awolant commented Sep 4, 2024

Hello @kvablack

thank you for creating the issue. The idea looks very interesting. Could you share some of your performance results? Also, could you tell us more about the pipeline you are trying to build? Is it very complex?

@awolant awolant added the JAX Issues related to DALI and JAX integration label Sep 4, 2024
@kvablack
Copy link
Author

kvablack commented Sep 4, 2024

Sure thing. I just tested it with 8 GPUs:

  • Parallel: ~100s to build each pipeline, 108s total
  • Sequential: ~50s to build each pipeline, 432s (7 minutes) total

The pipeline is not very complex. It loads from a parallel ExternalSource and does image decoding. I suspect the overhead is from creating the workers (6 workers per GPU). I followed the recommendation in the documentation and put the heavy setup in the __setstate__ function of the ExternalSource; however, there is some amount of data that I need to send to each worker, which I just measured at 84Mb when pickled.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request JAX Issues related to DALI and JAX integration
Projects
None yet
Development

No branches or pull requests

3 participants