Skip to content

Commit

Permalink
Merge pull request #108 from TomasPetro/patch-2
Browse files Browse the repository at this point in the history
Update _widget_reader.py
  • Loading branch information
chrishavlin authored Nov 30, 2023
2 parents b4ec6fc + fd614a9 commit 9885ef5
Show file tree
Hide file tree
Showing 2 changed files with 203 additions and 48 deletions.
101 changes: 97 additions & 4 deletions src/yt_napari/_tests/test_widget_reader.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
import json

# note: the cache is disabled for all the tests in this file due to flakiness
# in github CI. It may be that loading from a true file, rather than the
# yt_ugrid_ds_fn fixture would fix that...
import os
from functools import partial
from unittest.mock import patch

import numpy as np

from yt_napari import _widget_reader as _wr
from yt_napari._data_model import InputModel
from yt_napari._ds_cache import dataset_cache

# import ReaderWidget, SelectionEntry, TimeSeriesReader
from yt_napari._special_loaders import _construct_ugrid_timeseries

# note: the cache is disabled for all the tests in this file due to flakiness
# in github CI. It may be that loading from a true file, rather than the
# yt_ugrid_ds_fn fixture would fix that...


def test_widget_reader_add_selections(make_napari_viewer, yt_ugrid_ds_fn):
viewer = make_napari_viewer()
Expand Down Expand Up @@ -45,6 +49,57 @@ def _rebuild_data(final_shape, data):
return np.random.random(final_shape) * data.mean()


def test_save_widget_reader(make_napari_viewer, yt_ugrid_ds_fn, tmp_path):
viewer = make_napari_viewer()
r = _wr.ReaderWidget(napari_viewer=viewer)
r.ds_container.filename.value = yt_ugrid_ds_fn
r.ds_container.store_in_cache.value = False
r.add_new_button.click()
sel = list(r.active_selections.values())[0]
assert isinstance(sel, _wr.SelectionEntry)

mgui_region = sel.selection_container_raw
mgui_region.fields.field_type.value = "enzo"
mgui_region.fields.field_name.value = "Density"
mgui_region.resolution.value = (400, 400, 400)

rebuild = partial(_rebuild_data, mgui_region.resolution.value)
r._post_load_function = rebuild

temp_file = tmp_path / "test.json"

with patch("PyQt5.QtWidgets.QFileDialog.exec_") as mock_exec, patch(
"PyQt5.QtWidgets.QFileDialog.selectedFiles"
) as mock_selectedFiles:
# Set the return values for the mocked functions
mock_exec.return_value = 1
mock_selectedFiles.return_value = [temp_file]

r.save_selection()

assert os.path.exists(temp_file)
with open(temp_file, "r") as json_file:
saved_data = json.load(json_file)

assert (
saved_data["datasets"][0]["selections"]["regions"][0]["fields"][0]["field_type"]
== "enzo"
)
assert (
saved_data["datasets"][0]["selections"]["regions"][0]["fields"][0]["field_name"]
== "Density"
)
assert saved_data["datasets"][0]["selections"]["regions"][0]["resolution"] == [
400,
400,
400,
]

# ensure that the saved json is a valid model
_ = InputModel.parse_obj(saved_data)
r.deleteLater()


def test_widget_reader(make_napari_viewer, yt_ugrid_ds_fn):
viewer = make_napari_viewer()
r = _wr.ReaderWidget(napari_viewer=viewer)
Expand Down Expand Up @@ -138,4 +193,42 @@ def test_timeseries_widget_reader(make_napari_viewer, tmp_path):
tsr.load_data()
assert len(viewer.layers) == 2

temp_file = tmp_path / "test.json"

# Use patch to replace the actual QFileDialog functions with mock functions
with patch("PyQt5.QtWidgets.QFileDialog.exec_") as mock_exec, patch(
"PyQt5.QtWidgets.QFileDialog.selectedFiles"
) as mock_selectedFiles:
# Set the return values for the mocked functions
mock_exec.return_value = 1 # Assuming QDialog::Accepted is 1
mock_selectedFiles.return_value = [temp_file]

# Call the save_selection method
tsr.save_selection()

assert os.path.exists(temp_file)
with open(temp_file, "r") as json_file:
saved_data = json.load(json_file)

assert (
saved_data["timeseries"][0]["selections"]["regions"][0]["fields"][0][
"field_type"
]
== "stream"
)
assert (
saved_data["timeseries"][0]["selections"]["regions"][0]["fields"][0][
"field_name"
]
== "density"
)
assert saved_data["timeseries"][0]["selections"]["regions"][0]["resolution"] == [
10,
10,
10,
]

# ensure that the saved json is a valid model
_ = InputModel.parse_obj(saved_data)

tsr.deleteLater()
150 changes: 106 additions & 44 deletions src/yt_napari/_widget_reader.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
import json
from collections import defaultdict
from typing import Callable, Optional

import napari
from magicgui import widgets
from napari.qt.threading import thread_worker
from qtpy import QtCore
from qtpy.QtWidgets import QComboBox, QHBoxLayout, QPushButton, QVBoxLayout, QWidget
from qtpy.QtWidgets import (
QComboBox,
QFileDialog,
QHBoxLayout,
QPushButton,
QVBoxLayout,
QWidget,
)

from yt_napari import _data_model, _gui_utilities, _model_ingestor
from yt_napari._ds_cache import dataset_cache
from yt_napari._schema_version import schema_name
from yt_napari.viewer import _check_for_reference_layer


Expand Down Expand Up @@ -68,7 +77,9 @@ def add_spatial_selection_widgets(self):
self.layout().addLayout(removal_group_layout)

def add_load_group_widgets(self):
pass
"""
add the widgets related to the Load button
"""

def add_a_selection(self):
selection_type = self.new_selection_type.currentText()
Expand Down Expand Up @@ -106,6 +117,26 @@ def add_load_group_widgets(self):
load_group.addWidget(cc.native)
self.layout().addLayout(load_group)

ss = widgets.PushButton(text="Save Selection")
ss.clicked.connect(self.save_selection)
load_group.addWidget(ss.native)

def save_selection(self):
py_kwargs = {}
py_kwargs = self._validate_data_model()

file_dialog = QFileDialog()
file_dialog.setFileMode(QFileDialog.AnyFile)
file_dialog.setAcceptMode(QFileDialog.AcceptSave)
file_dialog.setNameFilter("JSON Files (*.json);;All Files (*)")

if file_dialog.exec_():
file_path = file_dialog.selectedFiles()[0]
if file_path:
# Save the JSON data to the selected file
with open(file_path, "w") as json_file:
json.dump(py_kwargs, json_file, indent=4)

def clear_cache(self):
dataset_cache.rm_all()

Expand All @@ -114,7 +145,29 @@ def load_data(self):
# instantiate pydantic objects, which are then handed off to the
# same data ingestion function as the json loader.

# first, get the pydantic args for each selection type, embed in lists
py_kwargs = {}
py_kwargs = self._validate_data_model()
model = _data_model.InputModel.parse_obj(py_kwargs)

# process each layer
layer_list, _ = _model_ingestor._process_validated_model(model)

# align all layers after checking for or setting the reference layer
ref_layer = _check_for_reference_layer(self.viewer.layers)
if ref_layer is None:
ref_layer = _model_ingestor._choose_ref_layer(layer_list)
layer_list = ref_layer.align_sanitize_layers(layer_list)

for new_layer in layer_list:
im_arr, im_kwargs, _ = new_layer
if self._post_load_function is not None:
im_arr = self._post_load_function(im_arr)

# add the new layer
self.viewer.add_image(im_arr, **im_kwargs)

def _validate_data_model(self):
# this function save json data
selections_by_type = defaultdict(list)
for selection in self.active_selections.values():
py_kwargs = selection.get_current_pydantic_kwargs()
Expand All @@ -129,34 +182,17 @@ def load_data(self):
py_kwargs,
ignore_attrs="selections",
)

# add selections in
py_kwargs["selections"] = selections_by_type

# now ready to instantiate the base model
py_kwargs = {
"$schema": schema_name,
"datasets": [
py_kwargs,
]
],
}
model = _data_model.InputModel.parse_obj(py_kwargs)

# process each layer
layer_list, _ = _model_ingestor._process_validated_model(model)

# align all layers after checking for or setting the reference layer
ref_layer = _check_for_reference_layer(self.viewer.layers)
if ref_layer is None:
ref_layer = _model_ingestor._choose_ref_layer(layer_list)
layer_list = ref_layer.align_sanitize_layers(layer_list)

for new_layer in layer_list:
im_arr, im_kwargs, _ = new_layer
if self._post_load_function is not None:
im_arr = self._post_load_function(im_arr)

# add the new layer
self.viewer.add_image(im_arr, **im_kwargs)
return py_kwargs


class SelectionEntry(QWidget):
Expand Down Expand Up @@ -223,7 +259,50 @@ def add_load_group_widgets(self):
load_group.addWidget(pb.native)
self.layout().addLayout(load_group)

ss = widgets.PushButton(text="Save Selection")
ss.clicked.connect(self.save_selection)
load_group.addWidget(ss.native)

def save_selection(self):
py_kwargs = {}
py_kwargs = self._validate_data_model()
# model = _data_model.InputModel.parse_obj(py_kwargs)

file_dialog = QFileDialog()
file_dialog.setFileMode(QFileDialog.AnyFile)
file_dialog.setAcceptMode(QFileDialog.AcceptSave)
file_dialog.setNameFilter("JSON Files (*.json);;All Files (*)")

if file_dialog.exec_():
file_path = file_dialog.selectedFiles()[0]
if file_path:
# Save the JSON data to the selected file
with open(file_path, "w") as json_file:
json.dump(py_kwargs, json_file, indent=4)

def load_data(self):
py_kwargs = {}
py_kwargs = self._validate_data_model()
model = _data_model.InputModel.parse_obj(py_kwargs)

if _use_threading: # pragma: no cover
worker = time_series_load(model)
worker.returned.connect(self.process_timeseries_layers)
worker.start()
else:
_, layer_list = _model_ingestor._process_validated_model(model)
self.process_timeseries_layers(layer_list)

def process_timeseries_layers(self, layer_list):
for new_layer in layer_list:
im_arr, im_kwargs, _ = new_layer
# probably can remove since the _special_loaders can be used
# if self._post_load_function is not None:
# im_arr = self._post_load_function(im_arr)
# add the new layer
self.viewer.add_image(im_arr, **im_kwargs)

def _validate_data_model(self):
# first, get the pydantic args for each selection type, embed in lists
selections_by_type = defaultdict(list)
for selection in self.active_selections.values():
Expand Down Expand Up @@ -254,32 +333,15 @@ def load_data(self):

# now ready to instantiate the base model
py_kwargs = {
"$schema": schema_name,
"timeseries": [
py_kwargs,
]
],
}

model = _data_model.InputModel.parse_obj(py_kwargs)

if _use_threading:
worker = time_series_load(model)
worker.returned.connect(self.process_timeseries_layers)
worker.start()
else:
_, layer_list = _model_ingestor._process_validated_model(model)
self.process_timeseries_layers(layer_list)

def process_timeseries_layers(self, layer_list):
for new_layer in layer_list:
im_arr, im_kwargs, _ = new_layer
# probably can remove since the _special_loaders can be used
# if self._post_load_function is not None:
# im_arr = self._post_load_function(im_arr)
# add the new layer
self.viewer.add_image(im_arr, **im_kwargs)
return py_kwargs


@thread_worker(progress=True)
def time_series_load(model):
def time_series_load(model): # pragma: no cover
_, layer_list = _model_ingestor._process_validated_model(model)
return layer_list

0 comments on commit 9885ef5

Please sign in to comment.