Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speeding up predictions #4

Open
peterk87 opened this issue Aug 30, 2021 · 1 comment
Open

Speeding up predictions #4

peterk87 opened this issue Aug 30, 2021 · 1 comment

Comments

@peterk87
Copy link

Hello,

Thank you for developing BERTax! It looks like a really great tool for taxonomic classification of sequences that are typically difficult to classify with tools that rely on big databases.

I was interested to see if BERTax could be used for classification of metagenomic sequencing reads, but it seems like it would be quite a bit slower than kmer based methods (Centrifuge, Kraken2) even with GPU acceleration (16 CPU threads (Intel(R) Core(TM) i9-9980HK CPU @ 2.40GHz): 6 reads/s; Nvidia Quadro RTX 5000 (Driver Version: 470.63.01; CUDA Version: 11.4): 20 reads/s).

Are there any plans to optimize BERTax for performing predictions on larger inputs?

I tried to modify the BERTax code to be a little more efficient on large inputs (reads in FASTQ) in PR peterk87#1 but I'm not familiar with Keras or Tensorflow, so I'm not sure how one would go about optimizing that code. The call to model.predict seems to be taking the most time by far.

For example, for a read of length 6092 split into 5 chunks:

  • seq2tokens: 0.792363 ms
  • process_bert_tokens_batch: 1.096281 ms
  • model.predict: 67.773608 ms
  • writing output: 1.32 ms

Total elapsed time of 70.986515 ms. Timings were obtained with time.time_ns. Although there may be optimizations that could be possible for input processing and formatting output, most of the time (>95%) is spent running model.predict.

I noticed that in the bertax-visualize script, that the Keras model is converted into a PyTorch model:

https://github.com/f-kretschmer/bertax/blob/ae8cc568a2e66692e7663025906fda0016aa8b52/bertax/visualize.py#L29

I haven't tested whether using PyTorch and a converted model would help speed-up predictions. Maybe the Keras model could be converted to a Tensorflow model for less overhead per call to model.predict as per the following blogpost:

https://micwurm.medium.com/using-tensorflow-lite-to-speed-up-predictions-a3954886eb98

Unfortunately, I'm only familiar with NumPy and not familiar with Keras, Tensorflow or PyTorch. I have a bit of experience working with Cython and Numba for accelerating Python code, but using those may not be appropriate in this case.

Any speed-ups (or ideas for how to achieve speed-ups) would be extremely useful and appreciated and allow BERTax to be used on a wider range of datasets!

Thanks!
Peter

@f-kretschmer
Copy link
Collaborator

Hello Peter,

Many thanks for your tests and suggestions! I haven't looked into runtime optimization that much so far, so I think there are definitely some improvements that can be made. I didn't know about tensorflow lite, that seems like a promising starting point, although I'm not sure how well custom models (keras-bert) can be converted.
Thanks again, I'll look into it!
Fleming

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants