Skip to content

Commit

Permalink
more aggregation on dask ops
Browse files Browse the repository at this point in the history
  • Loading branch information
Emma Ai committed Aug 22, 2024
1 parent f1f2771 commit b404842
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 40 deletions.
2 changes: 1 addition & 1 deletion docker/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ odc-dscache>=0.2.3
odc-stac @ git+https://github.com/opendatacube/odc-stac@69bdf64

# odc-stac is in PyPI
odc-stats[ows] @ git+https://github.com/opendatacube/odc-stats@a9bdd82
odc-stats[ows] @ git+https://github.com/opendatacube/odc-stats@f1f2771

# For ML
tflite-runtime
Expand Down
88 changes: 49 additions & 39 deletions odc/stats/plugins/lc_treelite_cultivated.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import dask.array as da
import numexpr as ne

from odc.stats._algebra import expr_eval
from ._registry import register
from .lc_ml_treelite import StatsMLTree, mask_and_predict

Expand Down Expand Up @@ -233,6 +232,46 @@ def cultivated_predict(input_block, bands_indices):
return cc


def aggregate_results(input_block, cultivated_value, natural_value):
# if there are >= 2 images
# any is cultivated -> final class is cultivated
# any is valid -> final class is valid
# for each pixel

output_block = ne.evaluate(
"where(m<nodata, 1-m, m)",
local_dict={"m": input_block, "nodata": NODATA},
).astype("float32")

m_size = input_block.shape[-1]

if m_size > 1:
output_block = ne.evaluate(
"sum(m,axis=2)",
local_dict={"m": output_block},
).astype("float32")

else:
output_block = output_block.squeeze(axis=-1)

output_block = ne.evaluate(
"where((m/nodata)+0.5>=_l, nodata, m%nodata)",
local_dict={"m": output_block, "_l": m_size, "nodata": NODATA},
).astype("float32")

output_block = ne.evaluate(
"where((m>0.5)&(m<nodata), _u, m)",
local_dict={"m": output_block, "_u": cultivated_value, "nodata": NODATA},
).astype("float32")

output_block = ne.evaluate(
"where(m<0.5, _nu, m)",
local_dict={"m": output_block, "_nu": natural_value},
).astype("uint8")

return output_block


class StatsCultivatedClass(StatsMLTree):
NAME = "ga_ls_cultivated"
SHORT_NAME = NAME
Expand All @@ -258,49 +297,20 @@ def predict(self, input_array):
return cc

def aggregate_results_from_group(self, predict_output):
# if there are >= 2 images
# any is cultivated -> final class is cultivated
# any is valid -> final class is valid
# for each pixel
m_size = len(predict_output)
if m_size > 1:
predict_output = da.stack(predict_output)
predict_output = da.stack(predict_output, axis=-1)
else:
predict_output = predict_output[0]

predict_output = expr_eval(
"where(m<nodata, 1-m, m)",
{"m": predict_output},
name="invert_output",
dtype="float32",
**{"nodata": NODATA},
)
predict_output = predict_output[0][..., np.newaxis]

if m_size > 1:
predict_output = predict_output.sum(axis=0)

predict_output = expr_eval(
"where((m/nodata)+0.5>=_l, nodata, m%nodata)",
{"m": predict_output},
name="mark_nodata",
dtype="float32",
**{"_l": m_size, "nodata": NODATA},
)

predict_output = expr_eval(
"where((m>0.5)&(m<nodata), _u, m)",
{"m": predict_output},
name="output_classes_cultivated",
dtype="float32",
**{"_u": self.output_classes["cultivated"], "nodata": NODATA},
)

predict_output = expr_eval(
"where(m<0.5, _nu, m)",
{"m": predict_output},
name="output_classes_natural",
predict_output = da.map_blocks(
aggregate_results,
predict_output,
self.output_classes["cultivated"],
self.output_classes["natural"],
drop_axis=-1,
dtype="uint8",
**{"_nu": self.output_classes["natural"]},
name="aggregate_results",
)

return predict_output.rechunk(-1, -1)
Expand Down
1 change: 1 addition & 0 deletions tests/test_rf_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,7 @@ def test_cultivated_aggregate_results(
input_bands=cultivated_input_bands,
)
res = cultivated.aggregate_results_from_group([cultivated_results[0]])
print(res.compute())
assert (res.compute() == np.array([[112, 255], [111, 112]], dtype="uint8")).all()
res = cultivated.aggregate_results_from_group(cultivated_results)
assert (res.compute() == np.array([[111, 112], [111, 112]], dtype="uint8")).all()
Expand Down

0 comments on commit b404842

Please sign in to comment.