Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat(plot_centers): add plot_centers support to PlotMapView and PlotCrossSection #2318

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions autotest/test_plot_cross_section.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,52 @@ def test_plot_limits():
raise AssertionError("PlotMapView auto extent setting not working")

plt.close(fig)


def test_plot_centers():
from matplotlib.collections import PathCollection

nlay = 1
nrow = 10
ncol = 10

delc = np.ones((nrow,))
delr = np.ones((ncol,))
top = np.ones((nrow, ncol))
botm = np.zeros((nlay, nrow, ncol))
idomain = np.ones(botm.shape, dtype=int)

idomain[0, :, 0:3] = 0

grid = flopy.discretization.StructuredGrid(
delc=delc, delr=delr, top=top, botm=botm, idomain=idomain
)

line = {"line": [(0, 0), (10, 10)]}
active_xc_cells = 7

pxc = flopy.plot.PlotCrossSection(modelgrid=grid, line=line)
pc = pxc.plot_centers()

if not isinstance(pc, PathCollection):
raise AssertionError(
"plot_centers() not returning PathCollection object"
)

verts = pc._offsets
if not verts.shape[0] == active_xc_cells:
raise AssertionError(
"plot_centers() not properly masking inactive cells"
)

center_dict = pxc.projctr
edge_dict = pxc.projpts

for node, center in center_dict.items():
verts = np.array(edge_dict[node]).T
xmin = np.min(verts[0])
xmax = np.max(verts[0])
if xmax < center < xmin:
raise AssertionError(
"Cell center not properly drawn on cross-section"
)
41 changes: 41 additions & 0 deletions autotest/test_plot_map_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,3 +276,44 @@ def test_plot_limits():
raise AssertionError("PlotMapView auto extent setting not working")

plt.close(fig)


def test_plot_centers():
nlay = 1
nrow = 10
ncol = 10

delc = np.ones((nrow,))
delr = np.ones((ncol,))
top = np.ones((nrow, ncol))
botm = np.zeros((nlay, nrow, ncol))
idomain = np.ones(botm.shape, dtype=int)

idomain[0, :, 0:3] = 0
active_cells = np.count_nonzero(idomain)

grid = flopy.discretization.StructuredGrid(
delc=delc, delr=delr, top=top, botm=botm, idomain=idomain
)

xcenters = grid.xcellcenters.ravel()
ycenters = grid.ycellcenters.ravel()
xycenters = list(zip(xcenters, ycenters))

pmv = flopy.plot.PlotMapView(modelgrid=grid)
pc = pmv.plot_centers()
if not isinstance(pc, PathCollection):
raise AssertionError(
"plot_centers() not returning PathCollection object"
)

verts = pc._offsets
if not verts.shape[0] == active_cells:
raise AssertionError(
"plot_centers() not properly masking inactive cells"
)

for vert in verts:
vert = tuple(vert)
if vert not in xycenters:
raise AssertionError("center location not properly plotted")
138 changes: 133 additions & 5 deletions flopy/plot/crosssection.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,13 @@ class PlotCrossSection:
(xmin, xmax, ymin, ymax) will be used to specify axes limits. If None
then these will be calculated based on grid, coordinates, and rotation.
geographic_coords : bool
boolean flag to allow the user to plot cross section lines in
geographic coordinates. If False (default), cross section is plotted
as the distance along the cross section line.

boolean flag to allow the user to plot cross-section lines in
geographic coordinates. If False (default), cross-section is plotted
as the distance along the cross-section line.
min_segment_length : float
minimum width of a grid cell polygon to be plotted. Cells with a
cross-sectional width less than min_segment_length will be ignored
and not included in the plot. Default is 1e-02.
"""

def __init__(
Expand All @@ -53,6 +56,7 @@ def __init__(
line=None,
extent=None,
geographic_coords=False,
min_segment_length=1e-02,
):
self.ax = ax
self.geographic_coords = geographic_coords
Expand Down Expand Up @@ -180,6 +184,22 @@ def __init__(
self.pts, self.xvertices, self.yvertices
)

self.xypts = plotutil.UnstructuredPlotUtilities.filter_line_segments(
self.xypts, threshold=min_segment_length
)
# need to ensure that the ordering of verticies in xypts is correct
# based on the projection. In certain cases vertices need to be sorted
# for the specific "projection"
for node, points in self.xypts.items():
if self.direction == "y":
if points[0][-1] < points[1][-1]:
points = points[::-1]
else:
if points[0][0] > points[1][0]:
points = points[::-1]

self.xypts[node] = points

if len(self.xypts) < 2:
if len(list(self.xypts.values())[0]) < 2:
s = (
Expand Down Expand Up @@ -238,6 +258,7 @@ def __init__(
self.idomain = np.ones(botm.shape, dtype=int)

self.projpts = self.set_zpts(None)
self.projctr = None

# Create cross-section extent
if extent is None:
Expand Down Expand Up @@ -926,6 +947,111 @@ def plot_bc(

return patches

def plot_centers(
self, a=None, s=None, masked_values=None, inactive=False, **kwargs
):
"""
Method to plot cell centers on cross-section using matplotlib
scatter. This method accepts an optional data array(s) for
coloring and scaling the cell centers. Cell centers in inactive
nodes are not plotted by default

Parameters
----------
a : None, np.ndarray
optional numpy nd.array of size modelgrid.nnodes
s : None, float, numpy array
optional point size parameter
masked_values : None, iteratable
optional list, tuple, or np array of array (a) values to mask
inactive : bool
boolean flag to include inactive cell centers in the plot.
Default is False
**kwargs :
matplotlib ax.scatter() keyword arguments

Returns
-------
matplotlib ax.scatter() object
"""
ax = kwargs.pop("ax", self.ax)

projpts = self.projpts
nodes = list(projpts.keys())
xcs = self.mg.xcellcenters.ravel()
ycs = self.mg.ycellcenters.ravel()
projctr = {}

if not self.geographic_coords:
xcs, ycs = geometry.transform(
xcs,
ycs,
self.mg.xoffset,
self.mg.yoffset,
self.mg.angrot_radians,
inverse=True,
)

for node, points in self.xypts.items():
projpt = projpts[node]
d0 = np.min(np.array(projpt).T[0])

xc_dist = geometry.project_point_onto_xc_line(
points[:2], [xcs[node], ycs[node]], d0=d0, calc_dist=True
)
projctr[node] = xc_dist

else:
projctr = {}
for node in nodes:
if self.direction == "x":
projctr[node] = xcs[node]
else:
projctr[node] = ycs[node]

# pop off any centers that are outside the "visual field"
# for a given cross-section.
removed = {}
for node, points in projpts.items():
center = projctr[node]
points = np.array(points[:2]).T
if np.min(points[0]) > center or np.max(points[0]) < center:
removed[node] = (np.min(points[0]), center, np.max(points[0]))
projctr.pop(node)

# filter out inactive cells
if not inactive:
idomain = self.mg.idomain.ravel()
for node, points in projpts.items():
if idomain[node] == 0:
if node in projctr:
projctr.pop(node)

self.projctr = projctr
nodes = list(projctr.keys())
xcenters = list(projctr.values())
zcenters = [np.mean(np.array(projpts[node]).T[1]) for node in nodes]

if a is not None:
if not isinstance(a, np.ndarray):
a = np.array(a)
a = a.ravel().astype(float)

if masked_values is not None:
self._masked_values.extend(list(masked_values))

for mval in self._masked_values:
a[a == mval] = np.nan

a = a[nodes]

if s is not None:
if not isinstance(s, (int, float)):
s = s[nodes]
print(len(xcenters))
scat = ax.scatter(xcenters, zcenters, c=a, s=s, **kwargs)
return scat

def plot_vector(
self,
vx,
Expand Down Expand Up @@ -1350,6 +1476,7 @@ def plot_endpoint(
self.xvertices,
self.yvertices,
self.direction,
self._ncpl,
method=method,
starting=istart,
)
Expand All @@ -1362,15 +1489,16 @@ def plot_endpoint(
self.xypts,
self.direction,
self.mg,
self._ncpl,
self.geographic_coords,
starting=istart,
)

arr = []
c = []
for node, epl in sorted(epdict.items()):
c.append(cd[node])
for xy in epl:
c.append(cd[node])
arr.append(xy)

arr = np.array(arr)
Expand Down
61 changes: 61 additions & 0 deletions flopy/plot/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,67 @@ def plot_shapes(self, obj, **kwargs):
ax = self._set_axes_limits(ax)
return patch_collection

def plot_centers(
self, a=None, s=None, masked_values=None, inactive=False, **kwargs
):
"""
Method to plot cell centers on cross-section using matplotlib
scatter. This method accepts an optional data array(s) for
coloring and scaling the cell centers. Cell centers in inactive
nodes are not plotted by default

Parameters
----------
a : None, np.ndarray
optional numpy nd.array of size modelgrid.nnodes
s : None, float, numpy array
optional point size parameter
masked_values : None, iteratable
optional list, tuple, or np array of array (a) values to mask
inactive : bool
boolean flag to include inactive cell centers in the plot.
Default is False
**kwargs :
matplotlib ax.scatter() keyword arguments

Returns
-------
matplotlib ax.scatter() object
"""
ax = kwargs.pop("ax", self.ax)

xcenters = self.mg.get_xcellcenters_for_layer(self.layer).ravel()
ycenters = self.mg.get_ycellcenters_for_layer(self.layer).ravel()
idomain = self.mg.get_plottable_layer_array(
self.mg.idomain, self.layer
).ravel()

active_ixs = list(range(len(xcenters)))
if not inactive:
active_ixs = np.where(idomain != 0)[0]

xcenters = xcenters[active_ixs]
ycenters = ycenters[active_ixs]

if a is not None:
a = self.mg.get_plottable_layer_array(a).ravel()

if masked_values is not None:
self._masked_values.extend(list(masked_values))

for mval in self._masked_values:
a[a == mval] = np.nan

a = a[active_ixs]

if s is not None:
if not isinstance(s, (int, float)):
s = self.mg.get_plottable_layer_array(s).ravel()
s = s[active_ixs]

scat = ax.scatter(xcenters, ycenters, c=a, s=s, **kwargs)
return scat

def plot_vector(
self,
vx,
Expand Down
Loading
Loading