diff --git a/pypam/plots.py b/pypam/plots.py index dae6e3b..5b001d5 100644 --- a/pypam/plots.py +++ b/pypam/plots.py @@ -160,16 +160,69 @@ def plot_spectrum_per_chunk(ds, data_var, log=True, save_path=None, show=True): plt.show() -def plot_spectrum_mean(ds, data_var, log=True, save_path=None, ax=None, show=True): +def plot_multiple_spectrum_mean(ds_dict, data_var, percentiles='default', frequency_coord='frequency', time_coord='id', + log=True, save_path=None, ax=None, show=True, **kwargs): + """ + Same than plot_spectrum_mean but instead of one ds you can pass a dictionary of label: ds so they are all plot + on one figure. + + Parameters + ---------- + ds_dict : xarray DataSet + Dataset to plot + data_var : string + Name of the data variable to use + percentiles: Tuple or 'default' + list or tuple with (min_percentile, max_percentile). If set to 'default' it will be [10, 90] + time_coord: str + Name of the coordinate representing time + frequency_coord: str + Name of the coordinate representing frequency + log : boolean + If set to True, y-axis in logarithmic scale + save_path : string or Path + Where to save the output graph. If None, it is not saved + ax : matplotlib.axes class or None + ax to plot on + show : bool + set to True to show the plot + + Returns + ------- + matplotlib.axes + """ + if ax is None: + fig, ax = plt.subplots() + for label, ds in ds_dict.items(): + kwargs.update({'label': label}) + plot_spectrum_mean(ds, data_var, percentiles=percentiles, frequency_coord=frequency_coord, + time_coord=time_coord, log=log, save_path=None, ax=ax, show=False, **kwargs) + + if save_path is not None: + plt.savefig(save_path) + if show: + plt.show() + + return ax + + +def plot_spectrum_mean(ds, data_var, percentiles='default', frequency_coord='frequency', time_coord='id', + log=True, save_path=None, ax=None, show=True, **kwargs): """ Plot the mean spectrum Parameters ---------- ds : xarray DataSet - Output of evolution + Dataset to plot data_var : string Name of the data variable to use + percentiles: Tuple or 'default' + list or tuple with (min_percentile, max_percentile). If set to 'default' it will be [10, 90] + time_coord: str + Name of the coordinate representing time + frequency_coord: str + Name of the coordinate representing frequency log : boolean If set to True, y-axis in logarithmic scale save_path : string or Path @@ -186,20 +239,24 @@ def plot_spectrum_mean(ds, data_var, log=True, save_path=None, ax=None, show=Tru """ if ax is None: fig, ax = plt.subplots() - - sns.lineplot(x=ds[data_var].dims[1], y='value', ax=ax, data=ds[data_var].to_pandas().melt(), errorbar='sd') - - if ('percentiles' in ds) and (len(ds['percentiles']) > 0): - # Add the percentiles values - ds['value_percentiles'].mean(dim='id').plot.line(hue='percentiles', ax=ax) + if percentiles == 'default': + percentiles = [10, 90] + + pxx = ds[data_var].to_numpy().T + p = np.nanpercentile(a=pxx, q=np.array(percentiles), axis=1) + ax.plot(ds[frequency_coord].values, ds[data_var].mean(dim=time_coord).values, **kwargs) + if 'color' in kwargs.keys(): + ax.fill_between(x=ds[frequency_coord].values, y1=p[0], y2=p[1], alpha=0.2, color=kwargs['color']) + else: + ax.fill_between(x=ds[frequency_coord].values, y1=p[0], y2=p[1], alpha=0.2) ax.set_facecolor('white') - plt.title(data_var.replace('_', ' ').capitalize()) - plt.xlabel('Frequency [Hz]') - plt.ylabel(r'%s [$%s$]' % (ds[data_var].standard_name, ds[data_var].units)) + ax.set_title(data_var.replace('_', ' ').capitalize()) + ax.set_xlabel('Frequency [Hz]') + ax.set_ylabel(r'%s [$%s$]' % (ds[data_var].standard_name, ds[data_var].units)) if log: - plt.xscale('log') + ax.set_xscale('log') if save_path is not None: plt.savefig(save_path) if show: diff --git a/tests/test_plots.py b/tests/test_plots.py index 47303e2..f8c80a9 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -61,6 +61,13 @@ def test_plot_spectrum_mean(self): psd = self.asa.evolution_freq_dom('psd') pypam.plots.plot_spectrum_mean(ds=psd, data_var='band_density', show=True) + @skip_unless_with_plots() + def test_plot_multiple_spectrum_mean(self): + psd = self.asa.hybrid_millidecade_bands(band=[10,2000]) + ds_dict = {'asa': psd, 'test_day': self.ds} + pypam.plots.plot_multiple_spectrum_mean(ds_dict=ds_dict, data_var='millidecade_bands', show=True, + frequency_coord='frequency_bins') + @skip_unless_with_plots() def test_plot_ltsa(self): pypam.plots.plot_ltsa(ds=self.ds, data_var='millidecade_bands', show=True)