Skip to content

Commit

Permalink
add some plotting utilities
Browse files Browse the repository at this point in the history
  • Loading branch information
FrederikSchnack committed Jul 14, 2023
1 parent d830ae1 commit a27a062
Showing 1 changed file with 142 additions and 60 deletions.
202 changes: 142 additions & 60 deletions psydac/feec/multipatch/plotting_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,19 @@
from sympy import lambdify

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import cm, colors
from mpl_toolkits import mplot3d
from collections import OrderedDict

from psydac.linalg.utilities import array_to_psydac
from psydac.fem.basic import FemField
from psydac.fem.vector import ProductFemSpace, VectorFemSpace
from psydac.utilities.utils import refine_array_1d
from psydac.feec.pull_push import push_2d_h1, push_2d_hcurl, push_2d_hdiv, push_2d_l2

matplotlib.rcParams['font.size'] = 15

#==============================================================================
def is_vector_valued(u):
# small utility function, only tested for FemFields in multi-patch spaces of the 2D grad-curl sequence
Expand Down Expand Up @@ -57,18 +59,18 @@ def get_grid_vals(u, etas, mappings_list, space_kind='hcurl'):
uk_field_0 = u[k]

# computing the pushed-fwd values on the grid
if space_kind == 'h1':
if space_kind == 'h1' or space_kind == 'V0':
assert not vector_valued
# todo (MCP): add 2d_hcurl_vector
push_field = lambda eta1, eta2: push_2d_h1(uk_field_0, eta1, eta2)
elif space_kind == 'hcurl':
elif space_kind == 'hcurl' or space_kind == 'V1':
# todo (MCP): specify 2d_hcurl_scalar in push functions
push_field = lambda eta1, eta2: push_2d_hcurl(uk_field_0, uk_field_1, eta1, eta2, mappings_list[k])
elif space_kind == 'hdiv':
push_field = lambda eta1, eta2: push_2d_hdiv(uk_field_0, uk_field_1, eta1, eta2, mappings_list[k])
push_field = lambda eta1, eta2: push_2d_hcurl(uk_field_0, uk_field_1, eta1, eta2, mappings_list[k].get_callable_mapping())
elif space_kind == 'hdiv' or space_kind == 'V2':
push_field = lambda eta1, eta2: push_2d_hdiv(uk_field_0, uk_field_1, eta1, eta2, mappings_list[k].get_callable_mapping())
elif space_kind == 'l2':
assert not vector_valued
push_field = lambda eta1, eta2: push_2d_l2(uk_field_0, eta1, eta2, mappings_list[k])
push_field = lambda eta1, eta2: push_2d_l2(uk_field_0, eta1, eta2, mappings_list[k].get_callable_mapping())
else:
raise ValueError('unknown value for space_kind = {}'.format(space_kind))

Expand All @@ -81,9 +83,9 @@ def get_grid_vals(u, etas, mappings_list, space_kind='hcurl'):

# always return a list, even for scalar-valued functions ?
if not vector_valued:
return np.array(u_vals_components[0])
return u_vals_components[0]
else:
return [np.array(a) for a in u_vals_components]
return u_vals_components

#------------------------------------------------------------------------------
def get_grid_quad_weights(etas, patch_logvols, mappings_list): #_obj):
Expand All @@ -102,9 +104,11 @@ def get_grid_quad_weights(etas, patch_logvols, mappings_list): #_obj):
N1 = eta_1.shape[1]

log_weight = patch_logvols[k]/(N0*N1)
Fk = mappings_list[k].get_callable_mapping()
for i, x1i in enumerate(eta_1[:, 0]):
for j, x2j in enumerate(eta_2[0, :]):
quad_weights[k][i, j] = push_2d_l2(one_field, x1i, x2j, mapping=mappings_list[k]) * log_weight
det_Fk_ij = Fk.metric_det(x1i, x2j)**0.5
quad_weights[k][i, j] = det_Fk_ij * log_weight

return quad_weights

Expand Down Expand Up @@ -169,12 +173,8 @@ def get_patch_knots_gridlines(Vh, N, mappings, plotted_patch=-1):
F = [M.get_callable_mapping() for d,M in mappings.items()]

if plotted_patch in range(len(mappings)):
space = Vh.spaces[plotted_patch]
if isinstance(space, (VectorFemSpace, ProductFemSpace)):
space = space.spaces[0]

grid_x1 = space.breaks[0]
grid_x2 = space.breaks[1]
grid_x1 = Vh.spaces[plotted_patch].spaces[0].breaks[0]
grid_x2 = Vh.spaces[plotted_patch].spaces[0].breaks[1]

x1 = refine_array_1d(grid_x1, N)
x2 = refine_array_1d(grid_x2, N)
Expand All @@ -192,13 +192,16 @@ def get_patch_knots_gridlines(Vh, N, mappings, plotted_patch=-1):
return gridlines_x1, gridlines_x2

#------------------------------------------------------------------------------
def plot_field(fem_field=None, stencil_coeffs=None, numpy_coeffs=None, Vh=None, domain=None, space_kind=None, title=None, filename='dummy_plot.png', subtitles=None, hide_plot=True):
def plot_field(
fem_field=None, stencil_coeffs=None, numpy_coeffs=None, Vh=None, domain=None, surface_plot=False, cb_min=None, cb_max=None,
plot_type='amplitude', cmap='hsv', space_kind=None, title=None, filename='dummy_plot.png', subtitles=None, N_vis=20, vf_skip=2, hide_plot=True):
"""
plot a discrete field (given as a FemField or by its coeffs in numpy or stencil format) on the given domain
:param Vh: Fem space needed if v is given by its coeffs
:param space_kind: type of the push-forward defining the physical Fem Space
:param subtitles: in case one would like to have several subplots # todo: then v should be given as a list of fields...
:param N_vis: nb of visualization points per patch (per dimension)
"""
if not space_kind in ['h1', 'hcurl', 'l2']:
raise ValueError('invalid value for space_kind = {}'.format(space_kind))
Expand All @@ -212,30 +215,89 @@ def plot_field(fem_field=None, stencil_coeffs=None, numpy_coeffs=None, Vh=None,

mappings = OrderedDict([(P.logical_domain, P.mapping) for P in domain.interior])
mappings_list = list(mappings.values())
etas, xx, yy = get_plotting_grid(mappings, N=20)
etas, xx, yy = get_plotting_grid(mappings, N=N_vis)
grid_vals = lambda v: get_grid_vals(v, etas, mappings_list, space_kind=space_kind)

vh_vals = grid_vals(vh)
if is_vector_valued(vh):
# then vh_vals[d] contains the values of the d-component of vh (as a patch-indexed list)
vh_abs_vals = [np.sqrt(abs(v[0])**2 + abs(v[1])**2) for v in zip(vh_vals[0],vh_vals[1])]
if plot_type == 'vector_field' and not is_vector_valued(vh):
print("WARNING [plot_field]: vector_field plot is not possible with a scalar field, plotting the amplitude instead")
plot_type = 'amplitude'

if plot_type == 'vector_field':
if is_vector_valued(vh):
my_small_streamplot(
title=title,
vals_x=vh_vals[0],
vals_y=vh_vals[1],
skip=vf_skip,
xx=xx,
yy=yy,
amp_factor=2,
save_fig=filename,
hide_plot=hide_plot,
dpi = 200,
)

else:
# then vh_vals just contains the values of vh (as a patch-indexed list)
vh_abs_vals = np.abs(vh_vals)

my_small_plot(
title=title,
vals=[vh_abs_vals],
titles=subtitles,
xx=xx,
yy=yy,
surface_plot=False,
save_fig=filename,
save_vals = True,
hide_plot=hide_plot,
cmap='hsv',
dpi = 400,
)
# computing plot_vals_list: may have several elements for several plots
if plot_type=='amplitude':

if is_vector_valued(vh):
# then vh_vals[d] contains the values of the d-component of vh (as a patch-indexed list)
plot_vals = [np.sqrt(abs(v[0])**2 + abs(v[1])**2) for v in zip(vh_vals[0],vh_vals[1])]
else:
# then vh_vals just contains the values of vh (as a patch-indexed list)
plot_vals = np.abs(vh_vals)
plot_vals_list = [plot_vals]

elif plot_type=='components':
if is_vector_valued(vh):
# then vh_vals[d] contains the values of the d-component of vh (as a patch-indexed list)
plot_vals_list = vh_vals
if subtitles is None:
subtitles = ['x-component', 'y-component']
else:
# then vh_vals just contains the values of vh (as a patch-indexed list)
plot_vals_list = [vh_vals]
else:
raise ValueError(plot_type)

my_small_plot(
title=title,
vals=plot_vals_list,
titles=subtitles,
xx=xx,
yy=yy,
surface_plot=surface_plot,
cb_min=cb_min,
cb_max=cb_max,
save_fig=filename,
save_vals = False,
hide_plot=hide_plot,
cmap=cmap,
dpi = 300,
)

# if is_vector_valued(vh):
# # then vh_vals[d] contains the values of the d-component of vh (as a patch-indexed list)
# vh_abs_vals = [np.sqrt(abs(v[0])**2 + abs(v[1])**2) for v in zip(vh_vals[0],vh_vals[1])]
# else:
# # then vh_vals just contains the values of vh (as a patch-indexed list)
# vh_abs_vals = np.abs(vh_vals)

# my_small_plot(
# title=title,
# vals=[vh_abs_vals],
# titles=subtitles,
# xx=xx,
# yy=yy,
# surface_plot=False,
# save_fig=filename,
# save_vals=False,
# hide_plot=hide_plot,
# cmap='hsv',
# dpi = 400,
# )

#------------------------------------------------------------------------------
def my_small_plot(
Expand All @@ -245,6 +307,8 @@ def my_small_plot(
gridlines_x2=None,
surface_plot=False,
cmap='viridis',
cb_min=None,
cb_max=None,
save_fig=None,
save_vals = False,
hide_plot=False,
Expand All @@ -257,46 +321,49 @@ def my_small_plot(
assert xx and yy
n_plots = len(vals)
if n_plots > 1:
assert n_plots == len(titles)
if titles is None or n_plots != len(titles):
titles = n_plots*[title]
else:
if titles:
print('Warning [my_small_plot]: will discard argument titles for a single plot')
titles = [title]

n_patches = len(xx)
assert n_patches == len(yy)

if save_vals:
# saving as vals.npz
np.savez('vals', xx=xx, yy=yy, vals=vals)

fig = plt.figure(figsize=(2.6+4.8*n_plots, 4.8))
fig.suptitle(title, fontsize=14)

for i in range(n_plots):
vmin = np.min(vals[i])
vmax = np.max(vals[i])
if cb_min is None:
vmin = np.min(vals[i])
else:
vmin = cb_min
if cb_max is None:
vmax = np.max(vals[i])
else:
vmax = cb_max
cnorm = colors.Normalize(vmin=vmin, vmax=vmax)
assert n_patches == len(vals[i])

ax = fig.add_subplot(1, n_plots, i+1)
for k in range(n_patches):
ax.contourf(xx[k], yy[k], vals[i][k], 50, norm=cnorm, cmap=cmap) #, extend='both')
ax.contourf(xx[k], yy[k], vals[i][k], 50, norm=cnorm, cmap=cmap, zorder=-10) #, extend='both')
ax.set_rasterization_zorder(0)
cbar = fig.colorbar(cm.ScalarMappable(norm=cnorm, cmap=cmap), ax=ax, pad=0.05)

if gridlines_x1 is not None and gridlines_x2 is not None:
if isinstance(gridlines_x1[0], (list,tuple)):
for x1,x2 in zip(gridlines_x1,gridlines_x2):
if x1 is None or x2 is None:continue
kwargs = {'lw': 0.5}
ax.plot(*x1, color='k', **kwargs)
ax.plot(*x2, color='k', **kwargs)
else:
ax.plot(*gridlines_x1, color='k')
ax.plot(*gridlines_x2, color='k')

if gridlines_x1 is not None:
ax.plot(*gridlines_x1, color='k')
ax.plot(*gridlines_x2, color='k')
if show_xylabel:
ax.set_xlabel( r'$x$', rotation='horizontal' )
ax.set_ylabel( r'$y$', rotation='horizontal' )
if n_plots > 1:
ax.set_title ( titles[i] )
ax.set_aspect('equal')

if save_fig:
print('saving contour plot in file '+save_fig)
Expand All @@ -310,8 +377,14 @@ def my_small_plot(
fig.suptitle(title+' -- surface', fontsize=14)

for i in range(n_plots):
vmin = np.min(vals[i])
vmax = np.max(vals[i])
if cb_min is None:
vmin = np.min(vals[i])
else:
vmin = cb_min
if cb_max is None:
vmax = np.max(vals[i])
else:
vmax = cb_max
cnorm = colors.Normalize(vmin=vmin, vmax=vmax)
assert n_patches == len(vals[i])
ax = fig.add_subplot(1, n_plots, i+1, projection='3d')
Expand All @@ -331,7 +404,8 @@ def my_small_plot(
save_fig_surf = save_fig[:-4]+'_surf'+ext
print('saving surface plot in file '+save_fig_surf)
plt.savefig(save_fig_surf, bbox_inches='tight', dpi=dpi)
else:

if not hide_plot:
plt.show()

#------------------------------------------------------------------------------
Expand All @@ -341,6 +415,7 @@ def my_small_streamplot(
amp_factor=1,
save_fig=None,
hide_plot=False,
show_xylabel=True,
dpi='figure',
):
"""
Expand All @@ -349,7 +424,10 @@ def my_small_streamplot(
n_patches = len(xx)
assert n_patches == len(yy)

fig = plt.figure(figsize=(2.6+4.8, 4.8))
# fig = plt.figure(figsize=(2.6+4.8, 4.8))

fig, ax = plt.subplots(1,1, figsize=(2.6+4.8, 4.8))

fig.suptitle(title, fontsize=14)

delta = 0.25
Expand All @@ -359,14 +437,18 @@ def my_small_streamplot(
#print('max_val = {}'.format(max_val))
vf_amp = amp_factor/max_val
for k in range(n_patches):
plt.quiver(xx[k][::skip, ::skip], yy[k][::skip, ::skip], vals_x[k][::skip, ::skip], vals_y[k][::skip, ::skip],
ax.quiver(xx[k][::skip, ::skip], yy[k][::skip, ::skip], vals_x[k][::skip, ::skip], vals_y[k][::skip, ::skip],
scale=1/(vf_amp*0.05), width=0.002) # width=) units='width', pivot='mid',

if show_xylabel:
ax.set_xlabel( r'$x$', rotation='horizontal' )
ax.set_ylabel( r'$y$', rotation='horizontal' )

ax.set_aspect('equal')

if save_fig:
print('saving vector field (stream) plot in file '+save_fig)
plt.savefig(save_fig, bbox_inches='tight', dpi=dpi)

if not hide_plot:
plt.show()


0 comments on commit a27a062

Please sign in to comment.