diff --git a/lidbox/data/steps.py b/lidbox/data/steps.py index 2b246aa..33072cd 100644 --- a/lidbox/data/steps.py +++ b/lidbox/data/steps.py @@ -5,6 +5,7 @@ """ import collections import io +import json import logging import os import shutil @@ -391,6 +392,28 @@ def cache(ds, directory=None, batch_size=1, cache_key=None): .unbatch()) +def validate_cache(dataframe, path, cache_key): + """ + Validate any existing cache. Validation is based on saving the keys and shape of the given + dataframe to JSON-format. + """ + cache_file = f"{path}/{cache_key}_meta.json" + if os.path.exists(cache_file): + with open(cache_file, 'r', encoding='utf-8') as infile: + existing_values = json.load(infile) + new_keys = dataframe.columns.to_list() + assert existing_values["keys"] == new_keys, \ + f"Cache validation failed, old keys {existing_values['keys']} vs. new {new_keys}" + assert existing_values["shape"] == list(dataframe.shape), \ + f"Cache validation failed, old shape {existing_values['shape']} vs. new {dataframe.shape}" + logger.info("Cache validation passed.") + else: + values = {"keys": dataframe.columns.to_list(), "shape": dataframe.shape} + logger.info(f"Previous cache does not exist. Saving dataframe keys and shape to {cache_file} for validation.") + with open(cache_file, 'w', encoding='utf-8') as outfile: + json.dump(values, outfile, indent=2) + + def compute_rms_vad(ds, strength, vad_frame_length_ms, min_non_speech_length_ms=0): """ Compute root mean square based voice activity detection.