Skip to content

Commit

Permalink
add water seasonality in module a0
Browse files Browse the repository at this point in the history
  • Loading branch information
Emma Ai committed Sep 23, 2024
1 parent b960c1d commit 898306b
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 18 deletions.
61 changes: 43 additions & 18 deletions odc/stats/plugins/lc_fc_wo_a0.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class StatsVegCount(StatsPluginInterface):
VERSION = "0.0.1"
PRODUCT_FAMILY = "lccs"

BAD_BITS_MASK = dict(cloud=(1 << 6), cloud_shadow=(1 << 5))
BAD_BITS_MASK = {"cloud": (1 << 6), "cloud_shadow": (1 << 5)}

def __init__(
self,
Expand Down Expand Up @@ -80,10 +80,10 @@ def native_transform(self, xx):
# get valid wo pixels, both dry and wet
data = expr_eval(
"where(a|b, a, _nan)",
dict(a=wet.data, b=valid.data),
{"a": wet.data, "b": valid.data},
name="get_valid_pixels",
dtype="float32",
**dict(_nan=np.nan),
**{"_nan": np.nan},
)

# Pick out the fc pixels that have an unmixing error of less than the threshold
Expand Down Expand Up @@ -111,30 +111,49 @@ def _veg_or_not(self, xx: xr.Dataset):
# otherwise 0
data = expr_eval(
"where((a>b)|(c>b), 1, 0)",
dict(a=xx["pv"].data, c=xx["npv"].data, b=xx["bs"].data),
{"a": xx["pv"].data, "c": xx["npv"].data, "b": xx["bs"].data},
name="get_veg",
dtype="uint8",
)

# mark nans
data = expr_eval(
"where(a!=a, nodata, b)",
dict(a=xx["pv"].data, b=data),
{"a": xx["pv"].data, "b": data},
name="get_veg",
dtype="uint8",
**dict(nodata=int(NODATA)),
**{"nodata": int(NODATA)},
)

# mark water freq >= 0.5 as 0
data = expr_eval(
"where(a>0, 0, b)",
dict(a=xx["wet"].data, b=data),
{"a": xx["wet"].data, "b": data},
name="get_veg",
dtype="uint8",
)

return data

def _water_or_not(self, xx: xr.Dataset):
# mark water freq > 0.5 as 1
data = expr_eval(
"where(a>0.5, 1, 0)",
{"a": xx["wet"].data},
name="get_water",
dtype="uint8",
)

# mark nans
data = expr_eval(
"where(a!=a, nodata, b)",
{"a": xx["wet"].data, "b": data},
name="get_water",
dtype="uint8",
**{"nodata": int(NODATA)},
)
return data

def _max_consecutive_months(self, data, nodata):
nan_mask = da.ones(data.shape[1:], chunks=data.chunks[1:], dtype="bool")
tmp = da.zeros(data.shape[1:], chunks=data.chunks[1:], dtype="uint8")
Expand All @@ -144,44 +163,44 @@ def _max_consecutive_months(self, data, nodata):
# +1 if not nodata
tmp = expr_eval(
"where(a==nodata, b, a+b)",
dict(a=t, b=tmp),
{"a": t, "b": tmp},
name="compute_consecutive_month",
dtype="uint8",
**dict(nodata=nodata),
**{"nodata": nodata},
)

# save the max
max_count = expr_eval(
"where(a>b, a, b)",
dict(a=max_count, b=tmp),
{"a": max_count, "b": tmp},
name="compute_consecutive_month",
dtype="uint8",
)

# reset if not veg
tmp = expr_eval(
"where((a<=0), 0, b)",
dict(a=t, b=tmp),
{"a": t, "b": tmp},
name="compute_consecutive_month",
dtype="uint8",
)

# mark nodata
nan_mask = expr_eval(
"where(a==nodata, b, False)",
dict(a=t, b=nan_mask),
{"a": t, "b": nan_mask},
name="mark_nodata",
dtype="bool",
**dict(nodata=nodata),
**{"nodata": nodata},
)

# mark nodata
max_count = expr_eval(
"where(a, nodata, b)",
dict(a=nan_mask, b=max_count),
{"a": nan_mask, "b": max_count},
name="mark_nodata",
dtype="uint8",
**dict(nodata=int(nodata)),
**{"nodata": int(nodata)},
)
return max_count

Expand All @@ -190,14 +209,20 @@ def reduce(self, xx: xr.Dataset) -> xr.Dataset:
xx = xx.groupby("time.month").map(median_ds, dim="spec")

data = self._veg_or_not(xx)
max_count = self._max_consecutive_months(data, NODATA)
max_count_veg = self._max_consecutive_months(data, NODATA)

data = self._water_or_not(xx)
max_count_water = self._max_consecutive_months(data, NODATA)

attrs = xx.attrs.copy()
attrs["nodata"] = int(NODATA)
data_vars = {
"veg_frequency": xr.DataArray(
max_count, dims=xx["wet"].dims[1:], attrs=attrs
)
max_count_veg, dims=xx["wet"].dims[1:], attrs=attrs
),
"water_frequency": xr.DataArray(
max_count_water, dims=xx["wet"].dims[1:], attrs=attrs
),
}
coords = dict((dim, xx.coords[dim]) for dim in xx["wet"].dims[1:])
return xr.Dataset(data_vars=data_vars, coords=coords, attrs=xx.attrs)
Expand Down
31 changes: 31 additions & 0 deletions tests/test_landcover_plugin_a0.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,23 @@ def test_veg_or_not(fc_wo_dataset):
i += 1


def test_water_or_not(fc_wo_dataset):
stats_veg = StatsVegCount()
xx = stats_veg.native_transform(fc_wo_dataset)
xx = xx.groupby("solar_day").map(partial(StatsVegCount.fuser, None))
yy = stats_veg._water_or_not(xx).compute()
valid_index = (
np.array([0, 0, 0, 0, 0, 1, 1, 2, 2, 2, 2, 2, 2, 2]),
np.array([1, 1, 3, 5, 6, 2, 6, 0, 0, 2, 2, 3, 5, 6]),
np.array([0, 3, 2, 1, 3, 5, 6, 0, 2, 1, 4, 2, 5, 6]),
)
expected_value = np.array([0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0])
i = 0
for idx in zip(*valid_index):
assert yy[idx] == expected_value[i]
i += 1


def test_reduce(fc_wo_dataset):
stats_veg = StatsVegCount()
xx = stats_veg.native_transform(fc_wo_dataset)
Expand All @@ -400,6 +417,20 @@ def test_reduce(fc_wo_dataset):

assert (xx.veg_frequency.data == expected_value).all()

expected_value = np.array(
[
[0, 255, 1, 255, 255, 255, 255],
[0, 255, 255, 0, 255, 255, 255],
[255, 1, 255, 255, 0, 0, 255],
[255, 255, 0, 255, 255, 255, 255],
[255, 255, 255, 255, 255, 255, 255],
[255, 0, 255, 255, 255, 0, 255],
[255, 255, 255, 0, 255, 255, 1],
]
)

assert (xx.water_frequency.data == expected_value).all()


def test_consecutive_month(consecutive_count):
stats_veg = StatsVegCount()
Expand Down

0 comments on commit 898306b

Please sign in to comment.