From b4048424683d750694cace4915c08fd6d51b3f0a Mon Sep 17 00:00:00 2001 From: Emma Ai Date: Thu, 22 Aug 2024 04:39:51 +0000 Subject: [PATCH] more aggregation on dask ops --- docker/requirements.txt | 2 +- odc/stats/plugins/lc_treelite_cultivated.py | 88 ++++++++++++--------- tests/test_rf_models.py | 1 + 3 files changed, 51 insertions(+), 40 deletions(-) diff --git a/docker/requirements.txt b/docker/requirements.txt index 757e188..286a6ad 100644 --- a/docker/requirements.txt +++ b/docker/requirements.txt @@ -11,7 +11,7 @@ odc-dscache>=0.2.3 odc-stac @ git+https://github.com/opendatacube/odc-stac@69bdf64 # odc-stac is in PyPI -odc-stats[ows] @ git+https://github.com/opendatacube/odc-stats@a9bdd82 +odc-stats[ows] @ git+https://github.com/opendatacube/odc-stats@f1f2771 # For ML tflite-runtime diff --git a/odc/stats/plugins/lc_treelite_cultivated.py b/odc/stats/plugins/lc_treelite_cultivated.py index 76ceeb2..7327230 100644 --- a/odc/stats/plugins/lc_treelite_cultivated.py +++ b/odc/stats/plugins/lc_treelite_cultivated.py @@ -8,7 +8,6 @@ import dask.array as da import numexpr as ne -from odc.stats._algebra import expr_eval from ._registry import register from .lc_ml_treelite import StatsMLTree, mask_and_predict @@ -233,6 +232,46 @@ def cultivated_predict(input_block, bands_indices): return cc +def aggregate_results(input_block, cultivated_value, natural_value): + # if there are >= 2 images + # any is cultivated -> final class is cultivated + # any is valid -> final class is valid + # for each pixel + + output_block = ne.evaluate( + "where(m 1: + output_block = ne.evaluate( + "sum(m,axis=2)", + local_dict={"m": output_block}, + ).astype("float32") + + else: + output_block = output_block.squeeze(axis=-1) + + output_block = ne.evaluate( + "where((m/nodata)+0.5>=_l, nodata, m%nodata)", + local_dict={"m": output_block, "_l": m_size, "nodata": NODATA}, + ).astype("float32") + + output_block = ne.evaluate( + "where((m>0.5)&(m= 2 images - # any is cultivated -> final class is cultivated - # any is valid -> final class is valid - # for each pixel m_size = len(predict_output) if m_size > 1: - predict_output = da.stack(predict_output) + predict_output = da.stack(predict_output, axis=-1) else: - predict_output = predict_output[0] - - predict_output = expr_eval( - "where(m 1: - predict_output = predict_output.sum(axis=0) - - predict_output = expr_eval( - "where((m/nodata)+0.5>=_l, nodata, m%nodata)", - {"m": predict_output}, - name="mark_nodata", - dtype="float32", - **{"_l": m_size, "nodata": NODATA}, - ) - - predict_output = expr_eval( - "where((m>0.5)&(m