Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding covering grid support #120

Merged
merged 8 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 108 additions & 24 deletions docs/examples/ytnapari_scene_01_intro.ipynb

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions docs/examples/ytnapari_scene_04_timeseries.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@
"\n",
"One difference between `yt-napari` and `yt` proper is that when sampling a time series, you first specify a selection object **independently** from a dataset object to define the extents and field of selection. That selection is then applied across all specified timesteps.\n",
"\n",
"The currently available selection objects are a `Slice` or 3D gridded `Region`. The arguments follow the same convention as a usual `yt` dataset selection object (i.e., `ds.slice`, `ds.region`) for specifying the geometric bounds of the selection with the additional constraint that you must specify a single field and the resolution you want to sample at:"
"The currently available selection objects are a 2D `Slice` or 3D gridded region, either a `Region` of a `CoveringGrid`. The arguments follow the same convention as a usual `yt` dataset selection object (i.e., `ds.slice`, `ds.region`, `ds.covering_grid`) for specifying the geometric bounds of the selection with the additional constraint that you must specify a single field and the resolution you want to sample at:"
]
},
{
Expand Down Expand Up @@ -238,7 +238,7 @@
"id": "edd2babf-5aae-4d2f-8079-96a68b594b22",
"metadata": {},
"source": [
"Once you create a `Slice` or `Region`, you can pass that to `add_to_viewer` and it will be used to sample each timestep specified. \n",
"Once you create a `Slice`, `Region` or `CoveringGrid`, you can pass that to `add_to_viewer` and it will be used to sample each timestep specified. \n",
"\n",
"## Slices through a timeseries\n",
"\n",
Expand Down Expand Up @@ -1131,7 +1131,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
"version": "3.11.7"
}
},
"nbformat": 4,
Expand Down
23 changes: 23 additions & 0 deletions src/yt_napari/_data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,26 @@ class Region(_ytBaseModel):
rescale: bool = Field(False, description="rescale the final image between 0,1")


class CoveringGrid(_ytBaseModel):
fields: List[ytField] = Field(
None, description="list of fields to load for this selection"
)
left_edge: Left_Edge = Field(
None,
description="the left edge (min x, min y, min z)",
)
right_edge: Right_Edge = Field(
None,
description="the right edge (max x, max y, max z)",
)
level: int = Field(0, description="Grid level to sample at")
num_ghost_zones: int = Field(
0,
description="Number of ghost zones to include",
)
rescale: bool = Field(False, description="rescale the final image between 0,1")


class Slice(_ytBaseModel):
fields: List[ytField] = Field(
None, description="list of fields to load for this selection"
Expand Down Expand Up @@ -93,6 +113,9 @@ class Slice(_ytBaseModel):
class SelectionObject(_ytBaseModel):
regions: List[Region] = Field(None, description="a list of regions to load")
slices: List[Slice] = Field(None, description="a list of slices to load")
covering_grids: List[CoveringGrid] = Field(
None, description="a list of covering grids to load"
)


class DataContainer(_ytBaseModel):
Expand Down
40 changes: 37 additions & 3 deletions src/yt_napari/_gui_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,16 @@ def get_widget_instance(self, pydantic_model, field: str):
func, args, kwargs = self.registry[pydantic_model][field]["magicgui"]
return func(*args, **kwargs)

def get_pydantic_attr(self, pydantic_model, field: str, widget_instance):
def get_pydantic_attr(
self, pydantic_model, field: str, widget_instance, required: bool = True
):
# given a widget instance, return an object that can be used to set a
# pydantic field
if self.is_registered(pydantic_model, field, required=True):
if self.is_registered(pydantic_model, field, required=required):
func, args, kwargs = self.registry[pydantic_model][field]["pydantic"]
return func(widget_instance, *args, **kwargs)
else:
raise RuntimeError("Could not retrieve pydantic attribute.")

def add_pydantic_to_container(
self,
Expand Down Expand Up @@ -214,6 +218,15 @@ def get_filename(file_widget: widgets.FileEdit):
return str(file_widget.value)


def get_int_box_widget(*args, **kwargs):
# could remove the need for this if the model uses pathlib.Path for typing
return widgets.IntText(*args, **kwargs)


def get_int_val(int_box: widgets.IntText):
return int(int_box.value)


def get_magicguidefault(field_name: str, field_def: pydantic.fields.Field):
# returns an instance of the default widget selected by magicgui
# returns an instance of the default widget selected by magicgui
Expand All @@ -229,6 +242,10 @@ def get_magicguidefault(field_name: str, field_def: pydantic.fields.Field):
options=opts_dict,
raise_on_unknown=False,
)

if new_widget_cls == widgets.TupleEdit:
ops["options"] = {"min": -1e12, "max": 1e12}

return new_widget_cls(**ops)


Expand All @@ -255,8 +272,10 @@ def _get_pydantic_model_field(
_models_to_embed_in_list = (
(_data_model.Slice, "fields"),
(_data_model.Region, "fields"),
(_data_model.CoveringGrid, "fields"),
(_data_model.DataContainer, "selections"),
(_data_model.SelectionObject, "regions"),
(_data_model.SelectionObject, "covering_grids"),
(_data_model.SelectionObject, "slices"),
)

Expand Down Expand Up @@ -297,6 +316,21 @@ def _register_yt_data_model(translator: MagicPydanticRegistry):
pydantic_attr_factory=handle_str_list_edit,
)

translator.register(
_data_model.CoveringGrid,
"level",
magicgui_factory=get_int_box_widget,
magicgui_kwargs={"name": "level"},
pydantic_attr_factory=get_int_val,
)
translator.register(
_data_model.CoveringGrid,
"num_ghost_zones",
magicgui_factory=get_int_box_widget,
magicgui_kwargs={"name": "num_ghost_zones"},
pydantic_attr_factory=get_int_val,
)


translator = MagicPydanticRegistry()
_register_yt_data_model(translator)
Expand All @@ -318,7 +352,7 @@ def get_yt_data_container(
return data_container


_valid_selections = ("Region", "Slice")
_valid_selections = ("Region", "Slice", "CoveringGrid")


def get_yt_selection_container(selection_type: str, return_native: bool = False):
Expand Down
50 changes: 40 additions & 10 deletions src/yt_napari/_model_ingestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from yt_napari import _special_loaders
from yt_napari._data_model import (
CoveringGrid,
DataContainer,
InputModel,
MetadataModel,
Expand All @@ -29,6 +30,30 @@ def _le_re_to_cen_wid(
return center, width


def _get_covering_grid(
ds, left_edge, right_edge, level, num_ghost_zones, test_dims=None
):
# returns a covering grid instance and the resolution of the covering grid
if test_dims is None:
test_dims = (4, 4, 4)
nghostzones = num_ghost_zones
temp_cg = ds.covering_grid(level, left_edge, test_dims, num_ghost_zones=nghostzones)
effective_dds = temp_cg.dds
dims = (right_edge - left_edge) / effective_dds
# get the actual covering grid
frb = ds.covering_grid(level, left_edge, dims, num_ghost_zones=nghostzones)
return frb, dims


def _get_region_frb(ds, LE, RE, res):
frb = ds.r[
LE[0] : RE[0] : complex(0, res[0]), # noqa: E203
LE[1] : RE[1] : complex(0, res[1]), # noqa: E203
LE[2] : RE[2] : complex(0, res[2]), # noqa: E203
]
return frb


class LayerDomain:
# container for domain info for a single layer
# left_edge, right_edge, resolution, n_d are all self explanatory.
Expand Down Expand Up @@ -434,7 +459,13 @@ def _load_3D_regions(
layer_list: list,
timeseries_container: Optional[TimeseriesContainer] = None,
) -> list:
for sel in selections.regions:

sels = []
for seltype in ("regions", "covering_grids"):
if getattr(selections, seltype) is not None:
sels += [sel for sel in getattr(selections, seltype)]

for sel in sels:
# get the left, right edge as a unitful array, initialize the layer
# domain tracking for this layer and update the global domain extent
if sel.left_edge is None:
Expand All @@ -446,16 +477,15 @@ def _load_3D_regions(
RE = ds.domain_right_edge
else:
RE = ds.arr(sel.right_edge.value, sel.right_edge.unit)
res = sel.resolution
layer_domain = LayerDomain(left_edge=LE, right_edge=RE, resolution=res)

# create the fixed resolution buffer
frb = ds.r[
LE[0] : RE[0] : complex(0, res[0]), # noqa: E203
LE[1] : RE[1] : complex(0, res[1]), # noqa: E203
LE[2] : RE[2] : complex(0, res[2]), # noqa: E203
]
if isinstance(sel, Region):
res = sel.resolution
frb = _get_region_frb(ds, LE, RE, res)
elif isinstance(sel, CoveringGrid):
frb, dims = _get_covering_grid(ds, LE, RE, sel.level, sel.num_ghost_zones)
res = dims

layer_domain = LayerDomain(left_edge=LE, right_edge=RE, resolution=res)
for field_container in sel.fields:
field = (field_container.field_type, field_container.field_name)

Expand Down Expand Up @@ -600,7 +630,7 @@ def _load_selections_from_ds(
layer_list: List[SpatialLayer],
timeseries_container: Optional[TimeseriesContainer] = None,
) -> List[SpatialLayer]:
if selections.regions is not None:
if selections.regions is not None or selections.covering_grids is not None:
layer_list = _load_3D_regions(
ds, selections, layer_list, timeseries_container=timeseries_container
)
Expand Down
45 changes: 45 additions & 0 deletions src/yt_napari/_tests/test_covering_grid_json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pytest

from yt_napari._data_model import InputModel
from yt_napari._model_ingestor import _choose_ref_layer, _process_validated_model
from yt_napari._schema_version import schema_name

jdicts = []
jdicts.append(
{
"$schema": schema_name,
"datasets": [
{
"filename": "_ytnapari_load_grid",
"selections": {
"covering_grids": [
{
"fields": [{"field_name": "density", "field_type": "gas"}],
"left_edge": {"value": (0.4, 0.4, 0.4)},
"right_edge": {"value": (0.5, 0.5, 0.5)},
"level": 0,
"rescale": 1,
}
]
},
}
],
}
)


@pytest.mark.parametrize("jdict", jdicts)
def test_covering_grid_validation(jdict):
_ = InputModel.model_validate(jdict)


@pytest.mark.parametrize("jdict", jdicts)
def test_slice_load(yt_ugrid_ds_fn, jdict):
im = InputModel.model_validate(jdict)
layer_lists, _ = _process_validated_model(im)
ref_layer = _choose_ref_layer(layer_lists)
_ = ref_layer.align_sanitize_layers(layer_lists)

im_data = layer_lists[0][0]
assert im_data.min() == 0
assert im_data.max() == 1
2 changes: 0 additions & 2 deletions src/yt_napari/_tests/test_ds_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def test_ds_cache(caplog):

dataset_cache.rm_ds(ds_name)
assert dataset_cache.exists(ds_name) is False
assert len(dataset_cache.available) == 0

ds_none = dataset_cache.get_ds("doesnotexist")
assert ds_none is None
Expand All @@ -35,7 +34,6 @@ def test_ds_cache(caplog):
dataset_cache.add_ds(ds, ds_name)
assert dataset_cache.exists(ds_name)
dataset_cache.rm_all()
assert len(dataset_cache.available) == 0
assert dataset_cache.most_recent is None


Expand Down
4 changes: 4 additions & 0 deletions src/yt_napari/_tests/test_gui_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ def get_value_from_nested(container_widget, extra_string):
pyvalue = reg.get_pydantic_attr(Model, "field_1", widget_instance)
assert pyvalue == "2_testxyz"

with pytest.raises(RuntimeError, match="Could not retrieve pydantic attribute."):
reg.get_pydantic_attr(
Model, "field_does_not_exist", widget_instance, required=False
)
widget_instance.close()


Expand Down
7 changes: 7 additions & 0 deletions src/yt_napari/_tests/test_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,13 @@ def test_region(yt_ds_0):
assert np.all(np.log10(data4) == data)


def test_covering_grid(yt_ds_0):
cg = ts.CoveringGrid(_field)
data = cg.sample_ds(yt_ds_0)
# sampled at level 0 for full domain, so should get out the base dimensions
assert data.shape == tuple(yt_ds_0.domain_dimensions)


def test_slice(yt_ds_0):
sample_res = (20, 20)
slc = ts.Slice(_field, "x", resolution=sample_res)
Expand Down
7 changes: 7 additions & 0 deletions src/yt_napari/_tests/test_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ def test_viewer(make_napari_viewer, yt_ds, caplog):
expected_layers += 1
assert len(viewer.layers) == expected_layers

LE = yt_ds.domain_left_edge
dds = yt_ds.domain_width / yt_ds.domain_dimensions
RE = yt_ds.arr(LE + dds * 10)
sc.add_covering_grid(viewer, yt_ds, ("gas", "density"), left_edge=LE, right_edge=RE)
expected_layers += 1
assert len(viewer.layers) == expected_layers

# build a new scene so it builds from prior
sc = Scene()
sc.add_region(viewer, yt_ds, ("gas", "density"))
Expand Down
27 changes: 26 additions & 1 deletion src/yt_napari/_tests/test_widget_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _rebuild_data(final_shape, data):
# the yt file thats being loaded from the pytest fixture is a saved
# dataset created from an in-memory uniform grid, and the re-loaded
# dataset will not have the full functionality of a ds. so here, we
# inject a correctly shaped random array here. If we start using full
# inject a correctly shaped random array. If we start using full
# test datasets from yt in testing, this should be changed.
return np.random.random(final_shape) * data.mean()

Expand Down Expand Up @@ -236,3 +236,28 @@ def test_timeseries_widget_reader(make_napari_viewer, tmp_path):
_ = InputModel.model_validate(saved_data)

tsr.deleteLater()


def test_covering_grid_selection(make_napari_viewer, yt_ugrid_ds_fn):
viewer = make_napari_viewer()
r = _wr.ReaderWidget(napari_viewer=viewer)
r.ds_container.filename.value = yt_ugrid_ds_fn
r.ds_container.store_in_cache.value = False
r.new_selection_type.setCurrentIndex(2)
r.add_new_button.click()
assert len(r.active_selections) == 1
sel = list(r.active_selections.values())[0]
assert isinstance(sel, _wr.SelectionEntry)
assert sel.selection_type == "CoveringGrid"

mgui_region = sel.selection_container_raw
mgui_region.fields.field_type.value = "gas"
mgui_region.fields.field_name.value = "density"
mgui_region.level.value = 0

mgui_region.left_edge.value.value = (-1.5,) * 3
mgui_region.right_edge.value.value = (1.5,) * 3
rebuild = partial(_rebuild_data, (64, 64, 64))
r._post_load_function = rebuild
r.load_data()
r.deleteLater()
Loading