Skip to content

Commit

Permalink
avoid lambda for mp
Browse files Browse the repository at this point in the history
  • Loading branch information
kai-tub committed Jul 1, 2024
1 parent 4fb89e6 commit a766006
Showing 1 changed file with 32 additions and 16 deletions.
48 changes: 32 additions & 16 deletions rico_hdl/rico_hdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import structlog
from more_itertools import chunked
from tqdm import tqdm
from concurrent.futures import as_completed, ProcessPoolExecutor
from concurrent.futures import ProcessPoolExecutor
import multiprocessing as mp
import blosc2

Expand Down Expand Up @@ -142,6 +142,18 @@ def safetensor_generator_s1(patch_path: str) -> bytes:
return save(data, metadata=None)


def optional_compressed_safetensor_generator(
safetensor_generator, compress: bool = False
):
if not compress:
return safetensor_generator

def wrapper(x):
return zstd_compressor(safetensor_generator(x))

return wrapper


@app.command()
def hyspecnet_11k(
target_dir: TargetDir,
Expand Down Expand Up @@ -266,35 +278,34 @@ def bigearthnet(
# Otherwise an error in the latter CLI argument could produce an incomplete LMDB
env = open_lmdb(target_dir)

safetensor_gen_s1 = (
(lambda x: zstd_compressor(safetensor_generator_s1(x)))
if compress
else safetensor_generator_s1
)
safetensor_gen_s2 = (
(lambda x: zstd_compressor(safetensor_generator_s2(x)))
if compress
else safetensor_generator_s2
)

if bigearthnet_s1_dir is not None:
log.debug("Writing BigEarthNet-S1 data into LMDB")
lmdb_writer(
env, s1_patch_paths, bigearthnet_lmdb_key_extractor, safetensor_gen_s1
env,
s1_patch_paths,
bigearthnet_lmdb_key_extractor,
safetensor_generator_s1,
compress,
)

if bigearthnet_s2_dir is not None:
log.debug("Writing BigEarthNet-S2 data into LMDB")
lmdb_writer(
env, s2_patch_paths, bigearthnet_lmdb_key_extractor, safetensor_gen_s2
env,
s2_patch_paths,
bigearthnet_lmdb_key_extractor,
safetensor_generator_s2,
compress,
)


def bigearthnet_lmdb_key_extractor(path: str) -> bytes:
return str(Path(path).stem).encode()


def lmdb_writer(env, paths, lmdb_key_extractor_func, safetensor_generator):
def lmdb_writer(
env, paths, lmdb_key_extractor_func, safetensor_generator, compress: bool = False
):
# insertion order is important for reproducibility!
paths.sort()
log.debug("About to serialize data in chunks")
Expand All @@ -317,9 +328,14 @@ def lmdb_writer(env, paths, lmdb_key_extractor_func, safetensor_generator):
# To ensure deterministic output, write in order
# i.e., cannot use `as_completed(futures_to_path)` !
for future in futures_to_path:
data = (
future.result()
if not compress
else zstd_compressor(future.result())
)
if not txn.put(
lmdb_key_extractor_func(futures_to_path[future]),
future.result(),
data,
overwrite=False,
):
sys.exit(
Expand Down

0 comments on commit a766006

Please sign in to comment.