Skip to content

Commit

Permalink
Merge pull request #90 from lifewatch/fix/plot_spectrum_mean
Browse files Browse the repository at this point in the history
Fix/plot spectrum mean
  • Loading branch information
cparcerisas authored Oct 6, 2023
2 parents 67f67b0 + 12b1ab0 commit 5d4f6b5
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 12 deletions.
81 changes: 69 additions & 12 deletions pypam/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,16 +160,69 @@ def plot_spectrum_per_chunk(ds, data_var, log=True, save_path=None, show=True):
plt.show()


def plot_spectrum_mean(ds, data_var, log=True, save_path=None, ax=None, show=True):
def plot_multiple_spectrum_mean(ds_dict, data_var, percentiles='default', frequency_coord='frequency', time_coord='id',
log=True, save_path=None, ax=None, show=True, **kwargs):
"""
Same than plot_spectrum_mean but instead of one ds you can pass a dictionary of label: ds so they are all plot
on one figure.
Parameters
----------
ds_dict : xarray DataSet
Dataset to plot
data_var : string
Name of the data variable to use
percentiles: Tuple or 'default'
list or tuple with (min_percentile, max_percentile). If set to 'default' it will be [10, 90]
time_coord: str
Name of the coordinate representing time
frequency_coord: str
Name of the coordinate representing frequency
log : boolean
If set to True, y-axis in logarithmic scale
save_path : string or Path
Where to save the output graph. If None, it is not saved
ax : matplotlib.axes class or None
ax to plot on
show : bool
set to True to show the plot
Returns
-------
matplotlib.axes
"""
if ax is None:
fig, ax = plt.subplots()
for label, ds in ds_dict.items():
kwargs.update({'label': label})
plot_spectrum_mean(ds, data_var, percentiles=percentiles, frequency_coord=frequency_coord,
time_coord=time_coord, log=log, save_path=None, ax=ax, show=False, **kwargs)

if save_path is not None:
plt.savefig(save_path)
if show:
plt.show()

return ax


def plot_spectrum_mean(ds, data_var, percentiles='default', frequency_coord='frequency', time_coord='id',
log=True, save_path=None, ax=None, show=True, **kwargs):
"""
Plot the mean spectrum
Parameters
----------
ds : xarray DataSet
Output of evolution
Dataset to plot
data_var : string
Name of the data variable to use
percentiles: Tuple or 'default'
list or tuple with (min_percentile, max_percentile). If set to 'default' it will be [10, 90]
time_coord: str
Name of the coordinate representing time
frequency_coord: str
Name of the coordinate representing frequency
log : boolean
If set to True, y-axis in logarithmic scale
save_path : string or Path
Expand All @@ -186,20 +239,24 @@ def plot_spectrum_mean(ds, data_var, log=True, save_path=None, ax=None, show=Tru
"""
if ax is None:
fig, ax = plt.subplots()

sns.lineplot(x=ds[data_var].dims[1], y='value', ax=ax, data=ds[data_var].to_pandas().melt(), errorbar='sd')

if ('percentiles' in ds) and (len(ds['percentiles']) > 0):
# Add the percentiles values
ds['value_percentiles'].mean(dim='id').plot.line(hue='percentiles', ax=ax)
if percentiles == 'default':
percentiles = [10, 90]

pxx = ds[data_var].to_numpy().T
p = np.nanpercentile(a=pxx, q=np.array(percentiles), axis=1)
ax.plot(ds[frequency_coord].values, ds[data_var].mean(dim=time_coord).values, **kwargs)
if 'color' in kwargs.keys():
ax.fill_between(x=ds[frequency_coord].values, y1=p[0], y2=p[1], alpha=0.2, color=kwargs['color'])
else:
ax.fill_between(x=ds[frequency_coord].values, y1=p[0], y2=p[1], alpha=0.2)

ax.set_facecolor('white')
plt.title(data_var.replace('_', ' ').capitalize())
plt.xlabel('Frequency [Hz]')
plt.ylabel(r'%s [$%s$]' % (ds[data_var].standard_name, ds[data_var].units))
ax.set_title(data_var.replace('_', ' ').capitalize())
ax.set_xlabel('Frequency [Hz]')
ax.set_ylabel(r'%s [$%s$]' % (ds[data_var].standard_name, ds[data_var].units))

if log:
plt.xscale('log')
ax.set_xscale('log')
if save_path is not None:
plt.savefig(save_path)
if show:
Expand Down
7 changes: 7 additions & 0 deletions tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ def test_plot_spectrum_mean(self):
psd = self.asa.evolution_freq_dom('psd')
pypam.plots.plot_spectrum_mean(ds=psd, data_var='band_density', show=True)

@skip_unless_with_plots()
def test_plot_multiple_spectrum_mean(self):
psd = self.asa.hybrid_millidecade_bands(band=[10,2000])
ds_dict = {'asa': psd, 'test_day': self.ds}
pypam.plots.plot_multiple_spectrum_mean(ds_dict=ds_dict, data_var='millidecade_bands', show=True,
frequency_coord='frequency_bins')

@skip_unless_with_plots()
def test_plot_ltsa(self):
pypam.plots.plot_ltsa(ds=self.ds, data_var='millidecade_bands', show=True)
Expand Down

0 comments on commit 5d4f6b5

Please sign in to comment.