From c479ca74cf9fdea67cc28e339e4f53ee9193c20f Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Sat, 14 Oct 2023 20:44:58 +0100 Subject: [PATCH 1/7] Initial refactor to class allowing lazy loading. --- pywavesurfer/ws.py | 380 +++++++++++++++++++++++++++------------------ 1 file changed, 225 insertions(+), 155 deletions(-) diff --git a/pywavesurfer/ws.py b/pywavesurfer/ws.py index 746ab6f..cd5016d 100644 --- a/pywavesurfer/ws.py +++ b/pywavesurfer/ws.py @@ -6,192 +6,262 @@ from packaging.version import parse as parse_version from packaging.specifiers import SpecifierSet +# TODO JZ +# header = self.data_file_as_dict["header"] # TODO: header["AcquisitionSampleRate"] +# why is this 2D +# num_samples = header["SweepDuration"] * header["AcquisitionSampleRate"] # TODO: +# what about these conditional set header["Acquisition"]["SampleRate"] +# investigate AreSweepsContinuous +# TODO: figre out timing on Neo +# TODO: ask SI about digital input channels (not probe) + # the latest version pywavesurfer was tested against _latest_version = 0.982 _over_version_1 = SpecifierSet(">=1.0") + # from pywavesurfer.ws import * will only import loadDataFile __all__ = ['loadDataFile'] def loadDataFile(filename, format_string='double'): - """ Loads Wavesurfer data file. + wavesurfer = PyWaveSurferData(filename, format_string) + data_file_as_dict = wavesurfer.load_all_data() + return data_file_as_dict - :param filename: File to load (has to be with .h5 extension) - :param format_string: optional: the return type of the data, defaults to double. Could be 'single', or 'raw'. - :return: dictionary with a structure array with one element per sweep in the data file. - """ - # Check that file exists - if not os.path.isfile(filename): - raise IOError("The file %s does not exist." % filename) - - # Check that file has proper extension - (_, ext) = os.path.splitext(filename) - if ext != ".h5": - raise RuntimeError("File must be a WaveSurfer-generated HDF5 (.h5) file.") - - # Extract dataset at each group level, recursively. - with h5py.File(filename, mode='r') as file: - data_file_as_dict = crawl_h5_group(file) # an h5py File is also a Group - - # Correct the samples rates for files that were generated by versions - # of WS which didn't coerce the sampling rate to an allowed rate. - header = data_file_as_dict["header"] - if "VersionString" in header: - version_numpy = header["VersionString"] # this is a scalar numpy array with a weird datatype - version = version_numpy.tobytes().decode("utf-8") - parsed_version = parse_version(version) - if parsed_version in _over_version_1: - if parsed_version > parse_version(str(_latest_version)): + +class PyWaveSurferData: + + def __init__(self, filename, format_string="double"): # TODO: move format string? + """ Loads Wavesurfer data file. + + :param filename: File to load (has to be with .h5 extension) + :param format_string: optional: the return type of the data, defaults to double. Could be 'single', or 'raw'. + :return: dictionary with a structure array with one element per sweep in the data file. + """ + self.format_string = format_string + + # Check that file exists + if not os.path.isfile(filename): + raise IOError("The file %s does not exist." % filename) + + # Check that file has proper extension + (_, ext) = os.path.splitext(filename) + if ext != ".h5": + raise RuntimeError("File must be a WaveSurfer-generated HDF5 (.h5) file.") + + self.file = h5py.File(filename, mode='r') + + + self.data_file_as_dict = self.get_metadata_dict() + + # Correct the samples rates for files that were generated by versions + # of WS which didn't coerce the sampling rate to an allowed rate. + header = self.data_file_as_dict["header"] + if "VersionString" in header: + version_numpy = header["VersionString"] # this is a scalar numpy array with a weird datatype + version = version_numpy.tobytes().decode("utf-8") + parsed_version = parse_version(version) + if parsed_version in _over_version_1: + if parsed_version > parse_version(str(_latest_version)): + warnings.warn('You are reading a WaveSurfer file version this module was not tested with: ' + 'file version %s, latest version tested: %s' + % (parsed_version.public, parse_version(str(_latest_version)).public), RuntimeWarning) + elif float(version) > _latest_version: warnings.warn('You are reading a WaveSurfer file version this module was not tested with: ' - 'file version %s, latest version tested: %s' - % (parsed_version.public, parse_version(str(_latest_version)).public), RuntimeWarning) - elif float(version) > _latest_version: - warnings.warn('You are reading a WaveSurfer file version this module was not tested with: ' - 'file version %s, latest version tested: %s' % (version, _latest_version), RuntimeWarning) - else: - # If no VersionsString field, the file is from an old old version - parsed_version = parse_version('0.0') - # version 0.912 has the problem, version 0.913 does not - if parsed_version not in _over_version_1 and parsed_version.release is not None: - version_string = str(parsed_version.release[1]) - ver_len = len(version_string) - if int(version_string[0]) < 9 or (ver_len >= 2 and int(version_string[1]) < 1) or \ - (ver_len >= 3 and int(version_string[1]) <= 1 and int(version_string[2]) <= 2): - # Fix the acquisition sample rate, if needed - nominal_acquisition_sample_rate = float(header["Acquisition"]["SampleRate"]) - nominal_n_timebase_ticks_per_sample = 100.0e6 / nominal_acquisition_sample_rate - if nominal_n_timebase_ticks_per_sample != round( - nominal_n_timebase_ticks_per_sample): # should use the python round, not numpy round - actual_acquisition_sample_rate = 100.0e6 / math.floor( - nominal_n_timebase_ticks_per_sample) # sic: the boards floor() for acq, but round() for stim - header["Acquisition"]["SampleRate"] = np.array(actual_acquisition_sample_rate) - data_file_as_dict["header"] = header - # Fix the stimulation sample rate, if needed - nominal_stimulation_sample_rate = float(header["Stimulation"]["SampleRate"]) - nominal_n_timebase_ticks_per_sample = 100.0e6 / nominal_stimulation_sample_rate - if nominal_n_timebase_ticks_per_sample != round(nominal_n_timebase_ticks_per_sample): - actual_stimulation_sample_rate = 100.0e6 / round( - nominal_n_timebase_ticks_per_sample) # sic: the boards floor() for acq, but round() for stim - header["Stimulation"]["SampleRate"] = np.array(actual_stimulation_sample_rate) - data_file_as_dict["header"] = header - - # If needed, use the analog scaling coefficients and scales to convert the - # analog scans from counts to experimental units. - if "NAIChannels" in header: - n_a_i_channels = header["NAIChannels"] - else: - acq = header["Acquisition"] - if "AnalogChannelScales" in acq: - all_analog_channel_scales = acq["AnalogChannelScales"] + 'file version %s, latest version tested: %s' % (version, _latest_version), RuntimeWarning) else: - # This is presumably a very old file, from before we supported digital inputs - all_analog_channel_scales = acq["ChannelScales"] - n_a_i_channels = all_analog_channel_scales.size # element count - if format_string.lower() != "raw" and n_a_i_channels > 0: - try: - if "AIChannelScales" in header: - # Newer files have this field, and lack header.Acquisition.AnalogChannelScales - all_analog_channel_scales = header["AIChannelScales"] - else: - # Fallback for older files - all_analog_channel_scales = header["Acquisition"]["AnalogChannelScales"] - except KeyError: - raise KeyError("Unable to read channel scale information from file.") - try: - if "IsAIChannelActive" in header: - # Newer files have this field, and lack header.Acquisition.AnalogChannelScales - is_active = header["IsAIChannelActive"].astype(bool) + # If no VersionsString field, the file is from an old old version + parsed_version = parse_version('0.0') + # version 0.912 has the problem, version 0.913 does not + if parsed_version not in _over_version_1 and parsed_version.release is not None: + version_string = str(parsed_version.release[1]) + ver_len = len(version_string) + if int(version_string[0]) < 9 or (ver_len >= 2 and int(version_string[1]) < 1) or \ + (ver_len >= 3 and int(version_string[1]) <= 1 and int(version_string[2]) <= 2): + # Fix the acquisition sample rate, if needed + nominal_acquisition_sample_rate = float(header["Acquisition"]["SampleRate"]) + nominal_n_timebase_ticks_per_sample = 100.0e6 / nominal_acquisition_sample_rate + if nominal_n_timebase_ticks_per_sample != round( + nominal_n_timebase_ticks_per_sample): # should use the python round, not numpy round + actual_acquisition_sample_rate = 100.0e6 / math.floor( + nominal_n_timebase_ticks_per_sample) # sic: the boards floor() for acq, but round() for stim + header["Acquisition"]["SampleRate"] = np.array(actual_acquisition_sample_rate) + self.data_file_as_dict["header"] = header + # Fix the stimulation sample rate, if needed + nominal_stimulation_sample_rate = float(header["Stimulation"]["SampleRate"]) + nominal_n_timebase_ticks_per_sample = 100.0e6 / nominal_stimulation_sample_rate + if nominal_n_timebase_ticks_per_sample != round(nominal_n_timebase_ticks_per_sample): + actual_stimulation_sample_rate = 100.0e6 / round( + nominal_n_timebase_ticks_per_sample) # sic: the boards floor() for acq, but round() for stim + header["Stimulation"]["SampleRate"] = np.array(actual_stimulation_sample_rate) + self.data_file_as_dict["header"] = header + + # Fill in channel header + if "NAIChannels" in header: + self.n_a_i_channels = header["NAIChannels"] + else: + acq = header["Acquisition"] + if "AnalogChannelScales" in acq: + all_analog_channel_scales = acq["AnalogChannelScales"] else: - # Fallback for older files - is_active = header["Acquisition"]["IsAnalogChannelActive"].astype(bool) - except KeyError: - raise KeyError("Unable to read active/inactive channel information from file.") - analog_channel_scales = all_analog_channel_scales[is_active] + # This is presumably a very old file, from before we supported digital inputs + all_analog_channel_scales = acq["ChannelScales"] + self.n_a_i_channels = all_analog_channel_scales.size # element count - # read the scaling coefficients - try: - if "AIScalingCoefficients" in header: - analog_scaling_coefficients = header["AIScalingCoefficients"] + if self.format_string.lower() != "raw" and self.n_a_i_channels > 0: + try: + if "AIChannelScales" in header: + # Newer files have this field, and lack header.Acquisition.AnalogChannelScales + all_analog_channel_scales = header["AIChannelScales"] + else: + # Fallback for older files + all_analog_channel_scales = header["Acquisition"]["AnalogChannelScales"] + except KeyError: + raise KeyError("Unable to read channel scale information from file.") + try: + if "IsAIChannelActive" in header: + # Newer files have this field, and lack header.Acquisition.AnalogChannelScales + is_active = header["IsAIChannelActive"].astype(bool) + else: + # Fallback for older files + is_active = header["Acquisition"]["IsAnalogChannelActive"].astype(bool) + except KeyError: + raise KeyError("Unable to read active/inactive channel information from file.") + self.analog_channel_scales = all_analog_channel_scales[is_active] # TODO + + # read the scaling coefficients + try: + if "AIScalingCoefficients" in header: + self.analog_scaling_coefficients = header["AIScalingCoefficients"] # TODO + else: + self.analog_scaling_coefficients = header["Acquisition"]["AnalogScalingCoefficients"] + except KeyError: + raise KeyError("Unable to read channel scaling coefficients from file.") + + def get_traces(self, segment_index, start_frame, end_frame, return_scaled=True): + field_names = [name for name in self.data_file_as_dict if name[0:5] in ["sweep", "trial"]] + + sweep_nums = [int(ele[6:]) for ele in field_names] + ordered_field_names = [field_names[num - 1] for num in sweep_nums] + + sweep_name = ordered_field_names[segment_index] + + if sweep_name[0:5] == "sweep": + analog_data_as_counts = self.file[sweep_name]["analogScans"][:, start_frame:end_frame] # +1? + else: + analog_data_as_counts = self.file[sweep_name][:, start_frame:end_frame] # +1? + + if not return_scaled: + return analog_data_as_counts + + if self.format_string.lower() != "raw" and self.n_a_i_channels > 0: # TODO: fix double conditional + does_user_want_single = (self.format_string.lower() == "single") + if does_user_want_single: + scaled_analog_data = scaled_single_analog_data_from_raw( + analog_data_as_counts, + self.analog_channel_scales, + self.analog_scaling_coefficients) else: - analog_scaling_coefficients = header["Acquisition"]["AnalogScalingCoefficients"] - except KeyError: - raise KeyError("Unable to read channel scaling coefficients from file.") + scaled_analog_data = scaled_double_analog_data_from_raw( + analog_data_as_counts, + self.analog_channel_scales, + self.analog_scaling_coefficients) + return scaled_analog_data + else: + return analog_data_as_counts + + def load_all_data(self): # investigate AreSweepsContinuous + """""" + idx = 0 + for field_name in self.data_file_as_dict: - does_user_want_single = (format_string.lower() == "single") - for field_name in data_file_as_dict: - # field_names = field_namess{i} if len(field_name) >= 5 and (field_name[0:5] == "sweep" or field_name[0:5] == "trial"): - # We check for "trial" for backward-compatibility with - # data files produced by older versions of WS. - analog_data_as_counts = data_file_as_dict[field_name]["analogScans"] - if does_user_want_single: - scaled_analog_data = scaled_single_analog_data_from_raw(analog_data_as_counts, - analog_channel_scales, - analog_scaling_coefficients) + + if field_name[0:5] == "sweep": + num_samples = self.file[field_name]["analogScans"].size else: - scaled_analog_data = scaled_double_analog_data_from_raw(analog_data_as_counts, - analog_channel_scales, - analog_scaling_coefficients) - data_file_as_dict[field_name]["analogScans"] = scaled_analog_data + num_samples = self.file[field_name].size - return data_file_as_dict + scaled_analog_data = self.get_traces(segment_index=idx, start_frame=0, end_frame=num_samples) + if field_name[0:5] == "sweep": + self.data_file_as_dict[field_name]["analogScans"] = scaled_analog_data + else: + self.data_file_as_dict[field_name] = scaled_analog_data + return self.data_file_as_dict -def crawl_h5_group(group): - result = dict() + def close_file(self): + if not self.file.closed(): + self.file.close() - item_names = list(group.keys()) + def __enter__(self): + return self - for item_name in item_names: - item = group[item_name] - if isinstance(item, h5py.Group): - field_name = field_name_from_hdf_name(item_name) - result[field_name] = crawl_h5_group(item) - elif isinstance(item, h5py.Dataset): - field_name = field_name_from_hdf_name(item_name) - result[field_name] = item[()] - else: - pass + def __exit__(self): + self.close_file() - return result + def get_metadata_dict(self): + return self.recursive_crawl_h5_group(self.file) -def field_name_from_hdf_name(hdf_name): - # Convert the name of an HDF dataset/group to something that is a legal - # Matlab struct field name. We do this even in Python, just to be consistent. - try: - # the group/dataset name seems to be a number. If it's an integer, we can deal, so check that. - hdf_name_as_double = float(hdf_name) - if hdf_name_as_double == round(hdf_name_as_double): - # If get here, group name is an integer, so we prepend with an "n" to get a valid field name - field_name = "n{:%s}".format(hdf_name) - else: - # Not an integer. Give up. - raise RuntimeError("Unable to convert group/dataset name {:%s} to a valid field name.".format(hdf_name)) - except ValueError: - # This is actually a good thing, b/c it means the groupName is not - # simply a number, which would be an illegal field name - field_name = hdf_name + def recursive_crawl_h5_group(self, group): + result = dict() + + item_names = list(group.keys()) + for item_name in item_names: + item = group[item_name] + if isinstance(item, h5py.Group): + field_name = self.field_name_from_hdf_name(item_name) + result[field_name] = self.recursive_crawl_h5_group(item) + elif isinstance(item, h5py.Dataset): + field_name = self.field_name_from_hdf_name(item_name) + if item_name[0:5] == "trial": + result[item_name] = {"analogScans": np.array([])} + elif item_name != "analogScans": + result[field_name] = item[()] + else: + pass - return field_name + return result + + + def field_name_from_hdf_name(self, hdf_name): + # Convert the name of an HDF dataset/group to something that is a legal + # Matlab struct field name. We do this even in Python, just to be consistent. + try: + # the group/dataset name seems to be a number. If it's an integer, we can deal, so check that. + hdf_name_as_double = float(hdf_name) + if hdf_name_as_double == round(hdf_name_as_double): + # If get here, group name is an integer, so we prepend with an "n" to get a valid field name + field_name = "n{:%s}".format(hdf_name) + else: + # Not an integer. Give up. + raise RuntimeError("Unable to convert group/dataset name {:%s} to a valid field name.".format(hdf_name)) + except ValueError: + # This is actually a good thing, b/c it means the groupName is not + # simply a number, which would be an illegal field name + field_name = hdf_name + + return field_name def scaled_double_analog_data_from_raw(data_as_ADC_counts, channel_scales, scaling_coefficients): - # Function to convert raw ADC data as int16s to doubles, taking to the - # per-channel scaling factors into account. - # - # data_as_ADC_counts: n_channels x nScans int16 array - # channel_scales: double vector of length n_channels, each element having - # (implicit) units of V/(native unit), where each - # channel has its own native unit. - # scaling_coefficients: n_channels x nCoefficients double array, - # contains scaling coefficients for converting - # ADC counts to volts at the ADC input. - # - # scaled_data: nScans x n_channels double array containing the scaled - # data, each channel with it's own native unit. + """ + Function to convert raw ADC data as int16s to doubles, taking to the + per-channel scaling factors into account. + data_as_ADC_counts: n_channels x nScans int16 array + channel_scales: double vector of length n_channels, each element having + (implicit) units of V/(native unit), where each + channel has its own native unit. + scaling_coefficients: n_channels x nCoefficients double array, + contains scaling coefficients for converting + ADC counts to volts at the ADC input. + + scaled_data: nScans x n_channels double array containing the scaled + data, each channel with it's own native unit. + """ inverse_channel_scales = 1.0 / channel_scales # if some channel scales are zero, this will lead to nans and/or infs n_channels = channel_scales.size scaled_data = np.empty(data_as_ADC_counts.shape) From 2d5723338db2594f18fdd45173b1a9883069c93b Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Sun, 15 Oct 2023 17:30:11 +0100 Subject: [PATCH 2/7] Add tests. --- tests/test_ws.py | 73 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/tests/test_ws.py b/tests/test_ws.py index f10f841..3ed5c82 100644 --- a/tests/test_ws.py +++ b/tests/test_ws.py @@ -361,3 +361,76 @@ def test_arbitrary_on_matrix_two_coeffs_single(): y = ws.scaled_single_analog_data_from_raw(x, channel_scale, adc_coefficients) assert y.dtype == 'float32' assert (y_theoretical == y).all() + + +# -------------------------------------------------------------------------------------- +# Testing Lazy Loading / +# -------------------------------------------------------------------------------------- + +@pytest.mark.parametrize("filename", ["30_kHz_sampling_rate_0p913_0001.h5", "30_kHz_sampling_rate_0p913_0001.h5", + "29997_Hz_sampling_rate_0p912_0001.h5", "29997_Hz_sampling_rate_0p913_0001.h5", + "ws_0p933_data_0001.h5", "ws_v1p0p2_data.h5", + "ws_0p74_data_0001.h5", "ws_0p97_data_0001.h5"] + ) +@pytest.mark.parametrize("indicies", [(0, 1), (0, 500), (250, 1250), (653, "max")]) +def test_lazy_loading_one_sweep(filename, indicies): + """ + Check that data indexed with `get_traces` is correct for all + 1-sweep (or trial) test file. Scaling is not tested here. + """ + this_file_path = os.path.realpath(__file__) + this_dir_name = os.path.dirname(this_file_path) + file_name = os.path.join(this_dir_name, filename) + + start_idx, end_idx = indicies + + data = ws.PyWaveSurferData(file_name, format_string="raw") + + if filename in ["ws_0p74_data_0001.h5"]: + true_data = np.array(data.file["trial_0001"]) + else: + true_data = np.array(data.file["sweep_0001"]["analogScans"]) + + if end_idx == "max": + end_idx = true_data.size + + test_data = data.get_traces(0, start_idx, end_idx, return_scaled=False) + assert np.array_equal(true_data[:, start_idx:end_idx], test_data) + + +@pytest.mark.parametrize("indices", [(0, 1), (0, 500), (250, 1250), (653, "max")]) +@pytest.mark.parametrize("format", ["raw", "single", "double"]) +def test_lazy_loading_three_sweeps(indices, format): + """ + Check that `test2.h5` data file is indexed corrected with + `get_traces()` across all three sweeps. Test scaling the data also. + Note the coefficients themselves are taken from `PyWaveSurferData` + directly and not tested as they are tested elsewhere. + """ + this_file_path = os.path.realpath(__file__) + this_dir_name = os.path.dirname(this_file_path) + file_name = os.path.join(this_dir_name, "test2.h5") + + start_idx, end_idx = indices + + return_scaled = False if format == "raw" else True + + for sweep_idx, sweep_id in enumerate(["sweep_0001", "sweep_0002", "sweep_0003"]): + + data = ws.PyWaveSurferData(file_name, format_string=format) + true_data = np.array(data.file[sweep_id]["analogScans"]) + + if format != "raw": + scaling_func = (ws.scaled_double_analog_data_from_raw + if format == "double" + else ws.scaled_single_analog_data_from_raw) + true_data = scaling_func(true_data, data.analog_channel_scales, data.analog_scaling_coefficients) + + if end_idx == "max": + end_idx = true_data.size + + test_data = data.get_traces(sweep_idx, start_idx, end_idx, return_scaled=return_scaled) + + assert np.array_equal(true_data[:, start_idx:end_idx], test_data) + + From 517ee58e93e1470a888a89a63414a07355a18e80 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Sun, 15 Oct 2023 20:59:52 +0100 Subject: [PATCH 3/7] Add some documentation and refactoring. --- pywavesurfer/ws.py | 276 ++++++++++++++++++++++++++++----------------- 1 file changed, 171 insertions(+), 105 deletions(-) diff --git a/pywavesurfer/ws.py b/pywavesurfer/ws.py index cd5016d..d37f8ce 100644 --- a/pywavesurfer/ws.py +++ b/pywavesurfer/ws.py @@ -6,15 +6,6 @@ from packaging.version import parse as parse_version from packaging.specifiers import SpecifierSet -# TODO JZ -# header = self.data_file_as_dict["header"] # TODO: header["AcquisitionSampleRate"] -# why is this 2D -# num_samples = header["SweepDuration"] * header["AcquisitionSampleRate"] # TODO: -# what about these conditional set header["Acquisition"]["SampleRate"] -# investigate AreSweepsContinuous -# TODO: figre out timing on Neo -# TODO: ask SI about digital input channels (not probe) - # the latest version pywavesurfer was tested against _latest_version = 0.982 _over_version_1 = SpecifierSet(">=1.0") @@ -25,19 +16,26 @@ def loadDataFile(filename, format_string='double'): - wavesurfer = PyWaveSurferData(filename, format_string) - data_file_as_dict = wavesurfer.load_all_data() + """ Return the PyWaveSurfer data as a dictionary. This + convenience function returns the entire loaded dataset. + Used for backwards compatability with versions prior + to introduction of lazy loading. + """ + with PyWaveSurferData(filename, format_string) as wavesurfer: + data_file_as_dict = wavesurfer.load_all_data() return data_file_as_dict class PyWaveSurferData: - def __init__(self, filename, format_string="double"): # TODO: move format string? + def __init__(self, filename, format_string="double"): """ Loads Wavesurfer data file. :param filename: File to load (has to be with .h5 extension) - :param format_string: optional: the return type of the data, defaults to double. Could be 'single', or 'raw'. - :return: dictionary with a structure array with one element per sweep in the data file. + :param format_string: optional: the return type of the data, defaults to double. + Could be 'single', or 'raw'. + :return: dictionary with a structure array with one element per sweep in the + data file. """ self.format_string = format_string @@ -52,12 +50,84 @@ def __init__(self, filename, format_string="double"): # TODO: move format strin self.file = h5py.File(filename, mode='r') - self.data_file_as_dict = self.get_metadata_dict() - # Correct the samples rates for files that were generated by versions - # of WS which didn't coerce the sampling rate to an allowed rate. - header = self.data_file_as_dict["header"] + self.analog_channel_scales, self.analog_scaling_coefficients, self.n_a_i_channels = self.get_scaling_coefficients() + + def close_file(self): + if not self.file.closed(): + self.file.close() + + def __enter__(self): + """ This and `__exit__` ensure the class can be + used in a `with` statement. + """ + return self + + def __exit__(self): + self.close_file() + + # ---------------------------------------------------------------------------------- + # Fill Metadata Dict + # ---------------------------------------------------------------------------------- + + def get_metadata_dict(self): + data_file_as_dict = self.recursive_crawl_h5_group(self.file) + data_file_as_dict = self.fix_sampling_rate_for_older_versions(data_file_as_dict) + + return data_file_as_dict + + def recursive_crawl_h5_group(self, group): + """ Recursively store the header information from the .h5 file + into a dictionary. + + The entry 'analogScans' hold the data from the 'sweep_xxxx' + keys, whereas in older version these were stored directly + in 'trial_xxxx' keys. For lazy loading, the raw data in 'sweep' + or 'trial' keys is is not loaded at this stage. + """ + result = dict() + + item_names = list(group.keys()) + for item_name in item_names: + item = group[item_name] + if isinstance(item, h5py.Group): + field_name = self.field_name_from_hdf_name(item_name) + result[field_name] = self.recursive_crawl_h5_group(item) + elif isinstance(item, h5py.Dataset): + field_name = self.field_name_from_hdf_name(item_name) + if item_name != "analogScans" and item_name[0:5] != "trial": + result[field_name] = item[()] + else: + pass + + return result + + def field_name_from_hdf_name(self, hdf_name): + """ Convert the name of an HDF dataset/group to something that is a legal + Matlab struct field name. We do this even in Python, just to be consistent. + """ + try: + # the group/dataset name seems to be a number. If it's an integer, we can deal, so check that. + hdf_name_as_double = float(hdf_name) + if hdf_name_as_double == round(hdf_name_as_double): + # If get here, group name is an integer, so we prepend with an "n" to get a valid field name + field_name = "n{:%s}".format(hdf_name) + else: + # Not an integer. Give up. + raise RuntimeError("Unable to convert group/dataset name {:%s} to a valid field name.".format(hdf_name)) + except ValueError: + # This is actually a good thing, b/c it means the groupName is not + # simply a number, which would be an illegal field name + field_name = hdf_name + + return field_name + + def fix_sampling_rate_for_older_versions(self, data_file_as_dict): + """ Correct the samples rates for files that were generated by versions + of WS which didn't coerce the sampling rate to an allowed rate. + """ + header = data_file_as_dict["header"] if "VersionString" in header: version_numpy = header["VersionString"] # this is a scalar numpy array with a weird datatype version = version_numpy.tobytes().decode("utf-8") @@ -73,6 +143,7 @@ def __init__(self, filename, format_string="double"): # TODO: move format strin else: # If no VersionsString field, the file is from an old old version parsed_version = parse_version('0.0') + # version 0.912 has the problem, version 0.913 does not if parsed_version not in _over_version_1 and parsed_version.release is not None: version_string = str(parsed_version.release[1]) @@ -87,7 +158,7 @@ def __init__(self, filename, format_string="double"): # TODO: move format strin actual_acquisition_sample_rate = 100.0e6 / math.floor( nominal_n_timebase_ticks_per_sample) # sic: the boards floor() for acq, but round() for stim header["Acquisition"]["SampleRate"] = np.array(actual_acquisition_sample_rate) - self.data_file_as_dict["header"] = header + data_file_as_dict["header"] = header # Fix the stimulation sample rate, if needed nominal_stimulation_sample_rate = float(header["Stimulation"]["SampleRate"]) nominal_n_timebase_ticks_per_sample = 100.0e6 / nominal_stimulation_sample_rate @@ -95,86 +166,131 @@ def __init__(self, filename, format_string="double"): # TODO: move format strin actual_stimulation_sample_rate = 100.0e6 / round( nominal_n_timebase_ticks_per_sample) # sic: the boards floor() for acq, but round() for stim header["Stimulation"]["SampleRate"] = np.array(actual_stimulation_sample_rate) - self.data_file_as_dict["header"] = header + data_file_as_dict["header"] = header + + return data_file_as_dict + + # ---------------------------------------------------------------------------------- + # Get gain and scaling coefficients + # ---------------------------------------------------------------------------------- + + def get_scaling_coefficients(self): + """ Get the correct scale and gain coefficients based on + the file version. + """ + header = self.data_file_as_dict["header"] - # Fill in channel header if "NAIChannels" in header: - self.n_a_i_channels = header["NAIChannels"] + n_a_i_channels = header["NAIChannels"] else: acq = header["Acquisition"] if "AnalogChannelScales" in acq: all_analog_channel_scales = acq["AnalogChannelScales"] else: - # This is presumably a very old file, from before we supported digital inputs + # This is presumably a very old file, from before we supported + # digital inputs all_analog_channel_scales = acq["ChannelScales"] - self.n_a_i_channels = all_analog_channel_scales.size # element count + n_a_i_channels = all_analog_channel_scales.size # element count - if self.format_string.lower() != "raw" and self.n_a_i_channels > 0: + if self.format_string.lower() != "raw" and n_a_i_channels > 0: try: if "AIChannelScales" in header: - # Newer files have this field, and lack header.Acquisition.AnalogChannelScales + # Newer files have this field, and lack + # header.Acquisition.AnalogChannelScales all_analog_channel_scales = header["AIChannelScales"] else: # Fallback for older files - all_analog_channel_scales = header["Acquisition"]["AnalogChannelScales"] + all_analog_channel_scales = header["Acquisition"][ + "AnalogChannelScales"] except KeyError: raise KeyError("Unable to read channel scale information from file.") try: if "IsAIChannelActive" in header: - # Newer files have this field, and lack header.Acquisition.AnalogChannelScales + # Newer files have this field, and lack + # header.Acquisition.AnalogChannelScales is_active = header["IsAIChannelActive"].astype(bool) else: # Fallback for older files - is_active = header["Acquisition"]["IsAnalogChannelActive"].astype(bool) + is_active = header["Acquisition"]["IsAnalogChannelActive"].astype( + bool) except KeyError: - raise KeyError("Unable to read active/inactive channel information from file.") - self.analog_channel_scales = all_analog_channel_scales[is_active] # TODO + raise KeyError( + "Unable to read active/inactive channel information from file.") + analog_channel_scales = all_analog_channel_scales[is_active] # TODO # read the scaling coefficients try: if "AIScalingCoefficients" in header: - self.analog_scaling_coefficients = header["AIScalingCoefficients"] # TODO + analog_scaling_coefficients = header[ + "AIScalingCoefficients"] # TODO else: - self.analog_scaling_coefficients = header["Acquisition"]["AnalogScalingCoefficients"] + analog_scaling_coefficients = header["Acquisition"][ + "AnalogScalingCoefficients"] except KeyError: raise KeyError("Unable to read channel scaling coefficients from file.") - def get_traces(self, segment_index, start_frame, end_frame, return_scaled=True): - field_names = [name for name in self.data_file_as_dict if name[0:5] in ["sweep", "trial"]] + else: + analog_channel_scales = analog_scaling_coefficients = None - sweep_nums = [int(ele[6:]) for ele in field_names] - ordered_field_names = [field_names[num - 1] for num in sweep_nums] + return analog_channel_scales, analog_scaling_coefficients, n_a_i_channels + # ---------------------------------------------------------------------------------- + # Data Getters + # ---------------------------------------------------------------------------------- + + def get_traces(self, segment_index, start_frame, end_frame, return_scaled=True): + """ + Get traces for a segment (i.e. a specific 'sweep' or 'trial' + number) indexied between `start_frame` and `end_frame`. + + If `return_scaled` is `True`, data will be scaled according to + the `format_string` argument passed during class construction. + """ + ordered_sweep_names = self.get_ordered_sweep_names() sweep_name = ordered_field_names[segment_index] + # Index out the data and scale if required. if sweep_name[0:5] == "sweep": - analog_data_as_counts = self.file[sweep_name]["analogScans"][:, start_frame:end_frame] # +1? + analog_data_as_counts = self.file[sweep_name]["analogScans"][:, start_frame:end_frame] else: - analog_data_as_counts = self.file[sweep_name][:, start_frame:end_frame] # +1? + analog_data_as_counts = self.file[sweep_name][:, start_frame:end_frame] - if not return_scaled: - return analog_data_as_counts + if return_scaled and self.format_string.lower() == "raw": + raise ValueError("`return_scaled` cannot be `True` is `format_string` is 'raw'.") - if self.format_string.lower() != "raw" and self.n_a_i_channels > 0: # TODO: fix double conditional - does_user_want_single = (self.format_string.lower() == "single") - if does_user_want_single: - scaled_analog_data = scaled_single_analog_data_from_raw( + if return_scaled and self.n_a_i_channels > 0: + if self.format_string.lower() == "single": + traces = scaled_single_analog_data_from_raw( analog_data_as_counts, self.analog_channel_scales, self.analog_scaling_coefficients) else: - scaled_analog_data = scaled_double_analog_data_from_raw( + traces = scaled_double_analog_data_from_raw( analog_data_as_counts, self.analog_channel_scales, self.analog_scaling_coefficients) - return scaled_analog_data else: - return analog_data_as_counts + traces = analog_data_as_counts + + return traces - def load_all_data(self): # investigate AreSweepsContinuous - """""" + def get_ordered_sweep_names(self): + """ Take the data field names (e.g. sweep_0001, sweep_0002), ensure they + are in the correct order and index according to `segment_index`. + """ + field_names = [name for name in self.file if name[0:5] in ["sweep", "trial"]] + sweep_nums = [int(ele[6:]) for ele in field_names] + ordered_field_names = [field_names[num - 1] for num in sweep_nums] + + return ordered_field_names + + def load_all_data(self): + """ + A convenience function to load into the `data_file_as_dict` + all data in the file. + """ idx = 0 - for field_name in self.data_file_as_dict: + for field_name in self.file: if len(field_name) >= 5 and (field_name[0:5] == "sweep" or field_name[0:5] == "trial"): @@ -184,6 +300,7 @@ def load_all_data(self): # investigate AreSweepsContinuous num_samples = self.file[field_name].size scaled_analog_data = self.get_traces(segment_index=idx, start_frame=0, end_frame=num_samples) + if field_name[0:5] == "sweep": self.data_file_as_dict[field_name]["analogScans"] = scaled_analog_data else: @@ -191,60 +308,9 @@ def load_all_data(self): # investigate AreSweepsContinuous return self.data_file_as_dict - def close_file(self): - if not self.file.closed(): - self.file.close() - - def __enter__(self): - return self - - def __exit__(self): - self.close_file() - - def get_metadata_dict(self): - - return self.recursive_crawl_h5_group(self.file) - - def recursive_crawl_h5_group(self, group): - result = dict() - - item_names = list(group.keys()) - for item_name in item_names: - item = group[item_name] - if isinstance(item, h5py.Group): - field_name = self.field_name_from_hdf_name(item_name) - result[field_name] = self.recursive_crawl_h5_group(item) - elif isinstance(item, h5py.Dataset): - field_name = self.field_name_from_hdf_name(item_name) - if item_name[0:5] == "trial": - result[item_name] = {"analogScans": np.array([])} - elif item_name != "analogScans": - result[field_name] = item[()] - else: - pass - - return result - - - def field_name_from_hdf_name(self, hdf_name): - # Convert the name of an HDF dataset/group to something that is a legal - # Matlab struct field name. We do this even in Python, just to be consistent. - try: - # the group/dataset name seems to be a number. If it's an integer, we can deal, so check that. - hdf_name_as_double = float(hdf_name) - if hdf_name_as_double == round(hdf_name_as_double): - # If get here, group name is an integer, so we prepend with an "n" to get a valid field name - field_name = "n{:%s}".format(hdf_name) - else: - # Not an integer. Give up. - raise RuntimeError("Unable to convert group/dataset name {:%s} to a valid field name.".format(hdf_name)) - except ValueError: - # This is actually a good thing, b/c it means the groupName is not - # simply a number, which would be an illegal field name - field_name = hdf_name - - return field_name - +# ---------------------------------------------------------------------------------- +# Scaling Functions +# ---------------------------------------------------------------------------------- def scaled_double_analog_data_from_raw(data_as_ADC_counts, channel_scales, scaling_coefficients): """ From 99e145050b48e51d59d75e0680d53bcba1e0b660 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Sun, 15 Oct 2023 21:11:53 +0100 Subject: [PATCH 4/7] Small fixes. --- pywavesurfer/ws.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/pywavesurfer/ws.py b/pywavesurfer/ws.py index d37f8ce..66c1585 100644 --- a/pywavesurfer/ws.py +++ b/pywavesurfer/ws.py @@ -54,19 +54,18 @@ def __init__(self, filename, format_string="double"): self.analog_channel_scales, self.analog_scaling_coefficients, self.n_a_i_channels = self.get_scaling_coefficients() - def close_file(self): - if not self.file.closed(): - self.file.close() - def __enter__(self): """ This and `__exit__` ensure the class can be used in a `with` statement. """ return self - def __exit__(self): + def __exit__(self, exception_type, exception_value, traceback): self.close_file() + def close_file(self): + self.file.close() + # ---------------------------------------------------------------------------------- # Fill Metadata Dict # ---------------------------------------------------------------------------------- @@ -247,7 +246,7 @@ def get_traces(self, segment_index, start_frame, end_frame, return_scaled=True): the `format_string` argument passed during class construction. """ ordered_sweep_names = self.get_ordered_sweep_names() - sweep_name = ordered_field_names[segment_index] + sweep_name = ordered_sweep_names[segment_index] # Index out the data and scale if required. if sweep_name[0:5] == "sweep": @@ -276,7 +275,8 @@ def get_traces(self, segment_index, start_frame, end_frame, return_scaled=True): def get_ordered_sweep_names(self): """ Take the data field names (e.g. sweep_0001, sweep_0002), ensure they - are in the correct order and index according to `segment_index`. + are in the correct order and index according to `segment_index`. Note + this function will treat 'sweep' or 'trial' as the same. """ field_names = [name for name in self.file if name[0:5] in ["sweep", "trial"]] sweep_nums = [int(ele[6:]) for ele in field_names] @@ -289,6 +289,7 @@ def load_all_data(self): A convenience function to load into the `data_file_as_dict` all data in the file. """ + return_scaled = False if self.format_string == "raw" else True idx = 0 for field_name in self.file: @@ -299,7 +300,7 @@ def load_all_data(self): else: num_samples = self.file[field_name].size - scaled_analog_data = self.get_traces(segment_index=idx, start_frame=0, end_frame=num_samples) + scaled_analog_data = self.get_traces(segment_index=idx, start_frame=0, end_frame=num_samples, return_scaled=return_scaled) if field_name[0:5] == "sweep": self.data_file_as_dict[field_name]["analogScans"] = scaled_analog_data From e0288a9ff141306d2664ea611a288b92fe45c8a0 Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Mon, 16 Oct 2023 19:21:51 +0100 Subject: [PATCH 5/7] Add example of lazy loading to README.md --- README.rst | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/README.rst b/README.rst index cc93b7a..6d0fb00 100644 --- a/README.rst +++ b/README.rst @@ -23,7 +23,10 @@ Example usage data_as_dict = ws.loadDataFile(filename='path/to/file.h5', format_string='single' ) # to get the raw analog channels in int16: data_as_dict = ws.loadDataFile(filename='path/to/file.h5', format_string='raw' ) - + # lazy-loading of subsets of the data is also possible + with ws.PyWaveSurferData(path, format_string="double") as wavesurfer_file: + sweep_1_subset = wavesurfer_file.get_traces(segment_index=0, start_frame=0, end_frame=500) + sweep_2_subset = wavesurfer_file.get_traces(segment_index=1, start_frame=1, end_frame=500) Description of the content can be found in the documentation `here `_. From d17393d8115f52771b921b2222b052be1bbb4fdd Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Mon, 16 Oct 2023 19:22:25 +0100 Subject: [PATCH 6/7] Fix example in README.md --- README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.rst b/README.rst index 6d0fb00..2b83775 100644 --- a/README.rst +++ b/README.rst @@ -24,7 +24,7 @@ Example usage # to get the raw analog channels in int16: data_as_dict = ws.loadDataFile(filename='path/to/file.h5', format_string='raw' ) # lazy-loading of subsets of the data is also possible - with ws.PyWaveSurferData(path, format_string="double") as wavesurfer_file: + with ws.PyWaveSurferData('path/to/file.h5', format_string="double") as wavesurfer_file: sweep_1_subset = wavesurfer_file.get_traces(segment_index=0, start_frame=0, end_frame=500) sweep_2_subset = wavesurfer_file.get_traces(segment_index=1, start_frame=1, end_frame=500) From 7ad3e220dd1ad3732185f368cfebbcc6d8f7f3f6 Mon Sep 17 00:00:00 2001 From: Joe Ziminski <55797454+JoeZiminski@users.noreply.github.com> Date: Mon, 16 Oct 2023 19:22:52 +0100 Subject: [PATCH 7/7] Fix README.md again. --- README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.rst b/README.rst index 2b83775..cbf6f6b 100644 --- a/README.rst +++ b/README.rst @@ -26,7 +26,7 @@ Example usage # lazy-loading of subsets of the data is also possible with ws.PyWaveSurferData('path/to/file.h5', format_string="double") as wavesurfer_file: sweep_1_subset = wavesurfer_file.get_traces(segment_index=0, start_frame=0, end_frame=500) - sweep_2_subset = wavesurfer_file.get_traces(segment_index=1, start_frame=1, end_frame=500) + sweep_2_subset = wavesurfer_file.get_traces(segment_index=1, start_frame=0, end_frame=500) Description of the content can be found in the documentation `here `_.