Skip to content

Commit

Permalink
57 times write speed up and some other fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Psy-Fer committed Aug 1, 2023
1 parent 9e6921f commit 97d2ff3
Showing 1 changed file with 71 additions and 35 deletions.
106 changes: 71 additions & 35 deletions python/pyslow5.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ import time
import logging
import copy
from itertools import chain
# from cpython.mem cimport PyMem_Malloc, PyMem_Realloc, PyMem_Free
# I should replace these with the cpython methods above
from libc.stdlib cimport malloc, free
from libc.string cimport strdup
from libc.string cimport strdup, memcpy
cimport pyslow5
# Import the Python-level symbols of numpy
import numpy as np
Expand All @@ -22,6 +24,7 @@ np.import_array()
# m attribute sets the read/write state and file extension of p sets the type
#


cdef class Open:
cdef pyslow5.slow5_file_t *s5
cdef pyslow5.slow5_rec_t *rec
Expand Down Expand Up @@ -124,12 +127,15 @@ cdef class Open:
cdef double *time_since_mux_change_val_array
cdef pyslow5.uint64_t *num_minknow_events_val_array

cdef np.int16_t[:] temp_array

cdef pyslow5.float total_time_slow5_get_next
cdef pyslow5.float total_time_yield_reads
cdef pyslow5.float total_single_write_time
cdef pyslow5.float total_multi_write_signal_time
cdef pyslow5.float total_multi_write_time


def __cinit__(self, pathname, mode, rec_press="zlib", sig_press="svb_zd", DEBUG=0):
# Set to default NULL type
self.s5 = NULL
Expand Down Expand Up @@ -201,32 +207,32 @@ cdef class Open:
self.num_minknow_events = strdup("num_minknow_events")
self.end_reason_labels = NULL
self.end_reason_labels_len = 0
channel_number_val = NULL
median_before_val = 0.0
read_number_val = 0
start_mux_val = 0
start_time_val = 0
end_reason_val = 0
tracked_scaling_shift_val = 0.0
tracked_scaling_scale_val = 0.0
predicted_scaling_shift_val = 0.0
predicted_scaling_scale_val = 0.0
num_reads_since_mux_change_val = 0
time_since_mux_change_val = 0.0
num_minknow_events_val = 0
channel_number_val_array = NULL
median_before_val_array = NULL
read_number_val_array = NULL
start_mux_val_array = NULL
start_time_val_array = NULL
end_reason_val_array = NULL
tracked_scaling_shift_val_array = NULL
tracked_scaling_scale_val_array = NULL
predicted_scaling_shift_val_array = NULL
predicted_scaling_scale_val_array = NULL
num_reads_since_mux_change_val_array = NULL
time_since_mux_change_val_array = NULL
num_minknow_events_val_array = NULL
self.channel_number_val = NULL
self.median_before_val = 0.0
self.read_number_val = 0
self.start_mux_val = 0
self.start_time_val = 0
self.end_reason_val = 0
self.tracked_scaling_shift_val = 0.0
self.tracked_scaling_scale_val = 0.0
self.predicted_scaling_shift_val = 0.0
self.predicted_scaling_scale_val = 0.0
self.num_reads_since_mux_change_val = 0
self.time_since_mux_change_val = 0.0
self.num_minknow_events_val = 0
self.channel_number_val_array = NULL
self.median_before_val_array = NULL
self.read_number_val_array = NULL
self.start_mux_val_array = NULL
self.start_time_val_array = NULL
self.end_reason_val_array = NULL
self.tracked_scaling_shift_val_array = NULL
self.tracked_scaling_scale_val_array = NULL
self.predicted_scaling_shift_val_array = NULL
self.predicted_scaling_scale_val_array = NULL
self.num_reads_since_mux_change_val_array = NULL
self.time_since_mux_change_val_array = NULL
self.num_minknow_events_val_array = NULL


self.total_time_slow5_get_next = 0.0
Expand Down Expand Up @@ -1949,10 +1955,23 @@ cdef class Open:

self.logger.debug("write_record: self.write processing raw_signal")
start_write_copy_signal = time.time()
# grabs buffer of numby array so the for loop operats in C not python = super fast
memview = memoryview(checked_record["signal"])
for i in range(checked_record["len_raw_signal"]):
self.write.raw_signal[i] = memview[i]


if checked_record["signal"].data.contiguous:
self.temp_array = checked_record["signal"]
num_elements = checked_record["signal"].size
memcpy(self.write.raw_signal, &self.temp_array[0], num_elements * sizeof(int16_t))
else:
self.logger.warning("write_record: numpy array of signal is not contiguous, please check your numpy array with np.info(array)")
self.logger.warning("write_record: falling back to old memory view element by element contruction. This is ~57 times slower...")
memview = memoryview(checked_record["signal"])
for i in range(checked_record["len_raw_signal"]):
self.write.raw_signal[i] = memview[i]

# memview = memoryview(checked_record["signal"])
# for i in range(checked_record["len_raw_signal"]):
# self.write.raw_signal[i] = memview[i]

# for i in range(checked_record["len_raw_signal"]):
# self.write.raw_signal[i] = checked_record["signal"][i]
end_write_copy_signal = (time.time() - start_write_copy_signal)
Expand Down Expand Up @@ -2207,12 +2226,29 @@ cdef class Open:

self.logger.debug("write_record_batch: self.write processing raw_signal")
start_write_copy_signal = time.time()
# grabs buffer of numby array so the for loop operats in C not python = super fast
memview = memoryview(checked_records[batch[idx]]["signal"])
for i in range(checked_records[batch[idx]]["len_raw_signal"]):
self.twrite[idx].raw_signal[i] = memview[i]

# check that data in numpy array is contiguous in memory, otherwise error out for now.
# need to know how often this actually happens.
# I think it shouldn't ever happen given we allocate the numpy arary all at once but...
# users could do some weird stuff, so better to check.


if checked_records[batch[idx]]["signal"].data.contiguous:
self.temp_array = checked_records[batch[idx]]["signal"]
num_elements = checked_records[batch[idx]]["signal"].size
memcpy(self.twrite[idx].raw_signal, &self.temp_array[0], num_elements * sizeof(int16_t))
else:
self.logger.warning("write_record_batch: numpy array of signal is not contiguous, please check your numpy array with np.info(array)")
self.logger.warning("write_record_batch: falling back to old memory view element by element contruction. This is ~57 times slower...")
memview = memoryview(checked_records[batch[idx]]["signal"])
for i in range(checked_records[batch[idx]]["len_raw_signal"]):
self.twrite[idx].raw_signal[i] = memview[i]


# for i in range(checked_records[batch[idx]]["len_raw_signal"]):
# self.twrite[idx].raw_signal[i] = checked_records[batch[idx]]["signal"][i]


end_write_copy_signal = (time.time() - start_write_copy_signal)
self.total_multi_write_signal_time = self.total_multi_write_signal_time + end_write_copy_signal

Expand Down

0 comments on commit 97d2ff3

Please sign in to comment.