Skip to content

Commit

Permalink
feat: add ssl4eo-s12 support
Browse files Browse the repository at this point in the history
  • Loading branch information
kai-tub committed Jul 10, 2024
1 parent c3d0966 commit e4453a0
Show file tree
Hide file tree
Showing 63 changed files with 394 additions and 31 deletions.
10 changes: 10 additions & 0 deletions flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,17 @@
--prefix PATH : ${pkgs.lib.makeBinPath [pkgs.fd]}
'';
meta.mainProgram = "rico-hdl";
# The SSL4EO-S12 base folder is copied instead of the individual base directories
# as otherwise the directory would be prefixed with the hash of the directory
# and would result in an unpredictable LMDB key name, as the base directory name
# is used for the test.
checkPhase = ''
export PATH="$out/bin:$PATH"
export RICO_HDL_S1_PATH=${./integration_tests/tiffs/BigEarthNet/BigEarthNet-S1}
export RICO_HDL_S2_PATH=${./integration_tests/tiffs/BigEarthNet/BigEarthNet-S2}
export RICO_HDL_SSL4EO_S12_S1_PATH=${./integration_tests/tiffs/SSL4EO-S12}/s1
export RICO_HDL_SSL4EO_S12_S2_L1C_PATH=${./integration_tests/tiffs/SSL4EO-S12}/s2c
export RICO_HDL_SSL4EO_S12_S2_L2A_PATH=${./integration_tests/tiffs/SSL4EO-S12}/s2a
export RICO_HDL_HYSPECNET_PATH=${./integration_tests/tiffs/HySpecNet-11k}
export RICO_HDL_LMDB_REF_PATH=${./integration_tests/BigEarthNet_LMDB}
export RICO_HDL_UC_MERCED_PATH=${./integration_tests/tiffs/UCMerced_LandUse}
Expand Down Expand Up @@ -132,6 +139,9 @@
env.RICO_HDL_S2_PATH = "${config.env.DEVENV_ROOT}/integration_tests/tiffs/BigEarthNet/BigEarthNet-S2";
env.RICO_HDL_LMDB_REF_PATH = "${config.env.DEVENV_ROOT}/integration_tests/BigEarthNet_LMDB";
env.JUPYTER_PATH = "${pkgs.python3Packages.jupyterlab}/share/jupyter";
env.RICO_HDL_SSL4EO_S12_S1_PATH = "${config.env.DEVENV_ROOT}/integration_tests/tiffs/SSL4EO-S12/s1";
env.RICO_HDL_SSL4EO_S12_S2_L1C_PATH = "${config.env.DEVENV_ROOT}/integration_tests/tiffs/SSL4EO-S12/s2c";
env.RICO_HDL_SSL4EO_S12_S2_L2A_PATH = "${config.env.DEVENV_ROOT}/integration_tests/tiffs/SSL4EO-S12/s2a";
packages =
[
(mkPoetryEnv
Expand Down
203 changes: 192 additions & 11 deletions integration_tests/test_python_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,23 @@ def read_single_band_raster(path):


@pytest.fixture(scope="session")
def s1_root() -> Path:
def bigearthnet_s1_root() -> Path:
str_p = os.environ.get("RICO_HDL_S1_PATH") or "./tiffs/BigEarthNet/BigEarthNet-S1/"
p = Path(str_p)
assert p.exists()
assert p.is_dir()
return p


@pytest.fixture(scope="session")
def bigearthnet_s2_root() -> Path:
str_p = os.environ.get("RICO_HDL_S2_PATH") or "./tiffs/BigEarthNet/BigEarthNet-S2/"
p = Path(str_p)
assert p.exists()
assert p.is_dir()
return p


@pytest.fixture(scope="session")
def bigearthnet_lmdb_ref_path() -> Path:
str_p = os.environ.get("RICO_HDL_LMDB_REF_PATH") or "./BigEarthNet_LMDB/"
Expand All @@ -39,8 +48,30 @@ def bigearthnet_lmdb_ref_path() -> Path:


@pytest.fixture(scope="session")
def s2_root() -> Path:
str_p = os.environ.get("RICO_HDL_S2_PATH") or "./tiffs/BigEarthNet/BigEarthNet-S2/"
def ssl4eo_s12_s1_root() -> Path:
str_p = os.environ.get("RICO_HDL_SSL4EO_S12_S1_PATH") or "./tiffs/SSL4EO-S12/s1/"
p = Path(str_p)
assert p.exists()
assert p.is_dir()
return p


@pytest.fixture(scope="session")
def ssl4eo_s12_s2_l1c_root() -> Path:
str_p = (
os.environ.get("RICO_HDL_SSL4EO_S12_S2_L1C_PATH") or "./tiffs/SSL4EO-S12/s2c/"
)
p = Path(str_p)
assert p.exists()
assert p.is_dir()
return p


@pytest.fixture(scope="session")
def ssl4eo_s12_s2_l2a_root() -> Path:
str_p = (
os.environ.get("RICO_HDL_SSL4EO_S12_S2_L2A_PATH") or "./tiffs/SSL4EO-S12/s2a/"
)
p = Path(str_p)
assert p.exists()
assert p.is_dir()
Expand Down Expand Up @@ -76,15 +107,35 @@ def eurosat_ms_root() -> Path:

# https://docs.pytest.org/en/6.2.x/tmpdir.html#[email protected](scope="session")
@pytest.fixture
def encoded_bigearthnet_s1_s2_path(s1_root, s2_root, tmpdir_factory) -> Path:
def encoded_bigearthnet_s1_s2_path(
bigearthnet_s1_root, bigearthnet_s2_root, tmpdir_factory
) -> Path:
tmp_path = tmpdir_factory.mktemp("lmdb")
# This should make it easier to separately test different versions of the binary and the appimage as well
subprocess.run(
[
"rico-hdl",
"bigearthnet",
f"--bigearthnet-s1-dir={s1_root}",
f"--bigearthnet-s2-dir={s2_root}",
f"--bigearthnet-s1-dir={bigearthnet_s1_root}",
f"--bigearthnet-s2-dir={bigearthnet_s2_root}",
f"--target-dir={tmp_path}",
],
check=True,
)
return Path(tmp_path)


@pytest.fixture
def encoded_ssl4eo_s12_path(
ssl4eo_s12_s1_root, ssl4eo_s12_s2_l1c_root, ssl4eo_s12_s2_l2a_root, tmpdir_factory
) -> Path:
tmp_path = tmpdir_factory.mktemp("lmdb")
subprocess.run(
[
"rico-hdl",
"ssl4eo-s12",
f"--s1-dir={ssl4eo_s12_s1_root}",
f"--s2-l1c-dir={ssl4eo_s12_s2_l1c_root}",
f"--s2-l2a-dir={ssl4eo_s12_s2_l2a_root}",
f"--target-dir={tmp_path}",
],
check=True,
Expand Down Expand Up @@ -138,10 +189,19 @@ def encoded_eurosat_ms_path(eurosat_ms_root, tmpdir_factory) -> Path:


def test_reproducibility_and_data_consistency(
s1_root, s2_root, encoded_bigearthnet_s1_s2_path, bigearthnet_lmdb_ref_path
bigearthnet_s1_root,
bigearthnet_s2_root,
encoded_bigearthnet_s1_s2_path,
bigearthnet_lmdb_ref_path,
):
s1_data = {file: read_single_band_raster(file) for file in s1_root.glob("**/*.tif")}
s2_data = {file: read_single_band_raster(file) for file in s2_root.glob("**/*.tif")}
s1_data = {
file: read_single_band_raster(file)
for file in bigearthnet_s1_root.glob("**/*.tif")
}
s2_data = {
file: read_single_band_raster(file)
for file in bigearthnet_s2_root.glob("**/*.tif")
}
source_data = {**s1_data, **s2_data}
env = lmdb.open(str(encoded_bigearthnet_s1_s2_path), readonly=True)

Expand Down Expand Up @@ -175,7 +235,9 @@ def test_reproducibility_and_data_consistency(


def test_bigearthnet_integration(
s1_root, s2_root, encoded_bigearthnet_s1_s2_path, bigearthnet_lmdb_ref_path
bigearthnet_s1_root,
bigearthnet_s2_root,
encoded_bigearthnet_s1_s2_path,
):
env = lmdb.open(str(encoded_bigearthnet_s1_s2_path), readonly=True)

Expand Down Expand Up @@ -244,6 +306,125 @@ def test_bigearthnet_integration(
)


def test_ssl4eo_s12_integration(
ssl4eo_s12_s1_root,
ssl4eo_s12_s2_l1c_root,
ssl4eo_s12_s2_l2a_root,
encoded_ssl4eo_s12_path,
):
env = lmdb.open(str(encoded_ssl4eo_s12_path), readonly=True)

with env.begin(write=False) as txn:
cur = txn.cursor()
decoded_lmdb_data = {k.decode("utf-8"): load(v) for (k, v) in cur}

assert decoded_lmdb_data.keys() == set(
[
"s1_0000200_S1A_IW_GRDH_1SDV_20200607T010800_20200607T010825_032904_03CFBA_D457",
"s1_0000200_S1A_IW_GRDH_1SDV_20200903T131212_20200903T131237_034195_03F8F5_AC1C",
"s2a_0000200_20200604T054639_20200604T054831_T43RCP",
"s2a_0000200_20200813T054639_20200813T054952_T43RCP",
"s2c_0000200_20200604T054639_20200604T054831_T43RCP",
"s2c_0000200_20200823T054639_20200823T055618_T43RCP",
]
)

sample_s1_safetensors_dict = decoded_lmdb_data.get(
"s1_0000200_S1A_IW_GRDH_1SDV_20200607T010800_20200607T010825_032904_03CFBA_D457"
)
sample_s2_l1c_safetensors_dict = decoded_lmdb_data.get(
"s2c_0000200_20200604T054639_20200604T054831_T43RCP"
)
sample_s2_l2a_safetensors_dict = decoded_lmdb_data.get(
"s2a_0000200_20200604T054639_20200604T054831_T43RCP"
)
safetensors_s1_keys = sample_s1_safetensors_dict.keys()
safetensors_s2_l1c_keys = sample_s2_l1c_safetensors_dict.keys()
safetensors_s2_l2a_keys = sample_s2_l2a_safetensors_dict.keys()
assert (
set(
[
"B1",
"B2",
"B3",
"B4",
"B5",
"B6",
"B7",
"B8",
"B8A",
"B9",
"B10",
"B11",
"B12",
]
)
== safetensors_s2_l1c_keys
)
assert (
set(
[
"B1",
"B2",
"B3",
"B4",
"B5",
"B6",
"B7",
"B8",
"B8A",
"B9",
"B11",
"B12",
]
)
== safetensors_s2_l2a_keys
)
assert (
set(
[
"VV",
"VH",
]
)
== safetensors_s1_keys
)

# IMPORTANT!
# The SSL4EO-S12 authors didn't pay attention to the resulting size of the patches!
# Some have an extra row/column of pixels!
# This assertion does NOT hold over the entire dataset!
assert all(arr.shape == (264, 264) for arr in sample_s1_safetensors_dict.values())
assert all(arr.dtype == "float32" for arr in sample_s1_safetensors_dict.values())

assert all(arr.dtype == "uint16" for arr in sample_s2_l1c_safetensors_dict.values())
assert all(
sample_s2_l1c_safetensors_dict[key].shape == (264, 264)
for key in ["B2", "B3", "B4", "B8"]
)
assert all(
sample_s2_l1c_safetensors_dict[key].shape == (132, 132)
for key in ["B5", "B6", "B7", "B8A", "B11", "B12"]
)
assert all(
sample_s2_l1c_safetensors_dict[key].shape == (44, 44)
for key in ["B1", "B9", "B10"]
)

assert all(arr.dtype == "uint16" for arr in sample_s2_l2a_safetensors_dict.values())
assert all(
sample_s2_l2a_safetensors_dict[key].shape == (264, 264)
for key in ["B2", "B3", "B4", "B8"]
)
assert all(
sample_s2_l2a_safetensors_dict[key].shape == (132, 132)
for key in ["B5", "B6", "B7", "B8A", "B11", "B12"]
)
assert all(
sample_s2_l2a_safetensors_dict[key].shape == (44, 44) for key in ["B1", "B9"]
)


def test_hyspecnet_integration(hyspecnet_root, encoded_hyspecnet_path):
env = lmdb.open(str(encoded_hyspecnet_path), readonly=True)

Expand Down
Loading

0 comments on commit e4453a0

Please sign in to comment.